BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ > Struct Template Reference

BlockFmhaFwdSplitKVCombinePipeline&lt; Problem_, Policy_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ > Struct Template Reference
ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ > Struct Template Reference

#include <block_fmha_fwd_splitkv_combine_pipeline.hpp>

Public Types

using Problem = remove_cvref_t<Problem_>
using Policy = remove_cvref_t<Policy_>
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>
using ODataType = remove_cvref_t<typename Problem::ODataType>

Public Member Functions

template<typename LSEaccDramBlockWindowTmp, typename OaccDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename LSEElementFunction, typename OaccElementFunction>
CK_TILE_HOST_DEVICE auto operator() (const LSEaccDramBlockWindowTmp &lse_acc_dram_block_window_tmp, const OaccDramBlockWindowTmp &o_acc_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const OaccElementFunction &o_acc_element_func, index_t num_splits, void *smem_ptr) const
template<typename LSEaccDramBlockWindow, typename OaccDramBlockWindow, typename LSEDramBlockWindow>
CK_TILE_HOST_DEVICE auto operator() (const LSEaccDramBlockWindow &lse_acc_dram_block_window, const OaccDramBlockWindow &o_acc_dram_block_window, LSEDramBlockWindow &lse_dram_block_window, index_t num_splits, void *smem_ptr) const

Static Public Member Functions

static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr index_t kNumWarps = Problem::kNumWarps
static constexpr index_t kBlockSize = Problem::kBlockSize
static constexpr index_t kHeadDimV = Problem::kHeadDimV
static constexpr index_t kM0 = Problem::kM0
static constexpr index_t kN1 = Problem::kN1
static constexpr bool kIsGroupMode = Problem::kIsGroupMode
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV
static constexpr bool kStoreLSE = Problem::kStoreLSE
static constexpr index_t kMaxSplits = Problem::kMaxSplits
static constexpr index_t kAlignmentLSE
static constexpr index_t kAlignmentLSEacc = kAlignmentLSE
static constexpr index_t kAlignmentOacc
static constexpr index_t kAlignmentO
static constexpr index_t kBlockPerCu
static constexpr const char * name = "unused"

Member Typedef Documentation

◆ LSEDataType

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType>

◆ OaccDataType

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::OaccDataType = remove_cvref_t<typename Problem::OaccDataType>

◆ ODataType

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType>

◆ Policy

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::Policy = remove_cvref_t<Policy_>

◆ Problem

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_>

Member Function Documentation

◆ GetSmemSize()

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ operator()() [1/2]

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
template<typename LSEaccDramBlockWindow, typename OaccDramBlockWindow, typename LSEDramBlockWindow>
CK_TILE_HOST_DEVICE auto ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::operator() ( const LSEaccDramBlockWindow & lse_acc_dram_block_window,
const OaccDramBlockWindow & o_acc_dram_block_window,
LSEDramBlockWindow & lse_dram_block_window,
index_t num_splits,
void * smem_ptr ) const
inline

◆ operator()() [2/2]

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
template<typename LSEaccDramBlockWindowTmp, typename OaccDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename LSEElementFunction, typename OaccElementFunction>
CK_TILE_HOST_DEVICE auto ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::operator() ( const LSEaccDramBlockWindowTmp & lse_acc_dram_block_window_tmp,
const OaccDramBlockWindowTmp & o_acc_dram_block_window_tmp,
LSEDramBlockWindowTmp & lse_dram_window_tmp,
const LSEElementFunction & lse_element_func,
const OaccElementFunction & o_acc_element_func,
index_t num_splits,
void * smem_ptr ) const
inline

Member Data Documentation

◆ kAlignmentLSE

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kAlignmentLSE
staticconstexpr
Initial value:
=
kPadSeqLenQ ? 1 : Policy::template GetAlignmentLSE<Problem>()
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:64

◆ kAlignmentLSEacc

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kAlignmentLSEacc = kAlignmentLSE
staticconstexpr

◆ kAlignmentO

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kAlignmentO
staticconstexpr
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_dot_do_o.hpp:24

◆ kAlignmentOacc

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kAlignmentOacc
staticconstexpr
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:65

◆ kBlockPerCu

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kBlockPerCu
staticconstexpr
Initial value:
= []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kHeadDimV <= 32)
{
constexpr std::array occupancy{3, 3, 3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
else if constexpr(kHeadDimV <= 128)
{
constexpr std::array occupancy{3, 3, 3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
else if constexpr(kHeadDimV <= 256)
{
constexpr std::array occupancy{2, 2, 2, 2, 2, 1};
}
}
}()
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:48
static constexpr index_t kHeadDimV
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:59
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:13

◆ kBlockSize

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kBlockSize = Problem::kBlockSize
staticconstexpr

◆ kHeadDimV

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kHeadDimV = Problem::kHeadDimV
staticconstexpr

◆ kIsGroupMode

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kIsGroupMode = Problem::kIsGroupMode
staticconstexpr

◆ kM0

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kM0 = Problem::kM0
staticconstexpr

◆ kMaxSplits

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kMaxSplits = Problem::kMaxSplits
staticconstexpr

◆ kN1

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kN1 = Problem::kN1
staticconstexpr

◆ kNumWarps

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
index_t ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kNumWarps = Problem::kNumWarps
staticconstexpr

◆ kPadHeadDimV

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kPadHeadDimV = Problem::kPadHeadDimV
staticconstexpr

◆ kPadSeqLenQ

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kPadSeqLenQ = Problem::kPadSeqLenQ
staticconstexpr

◆ kStoreLSE

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::kStoreLSE = Problem::kStoreLSE
staticconstexpr

◆ name

template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
const char* ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::name = "unused"
staticconstexpr

The documentation for this struct was generated from the following file: