BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ > Struct Template Reference#
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ > Struct Template Reference
#include <block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp>
Public Types | |
| using | Problem = remove_cvref_t<Problem_> |
| using | Policy = remove_cvref_t<Policy_> |
| using | QDataType = remove_cvref_t<typename Problem::QDataType> |
| using | KDataType = remove_cvref_t<typename Problem::KDataType> |
| using | VDataType = remove_cvref_t<typename Problem::VDataType> |
| using | SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
| using | SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
| using | BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
| using | RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
| using | LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
| using | PDataType = remove_cvref_t<typename Problem::PDataType> |
| using | OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
| using | ODataType = remove_cvref_t<typename Problem::ODataType> |
| using | FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
| using | AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant> |
| using | BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
| using | VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
| using | DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout> |
Public Member Functions | |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename BiasElementFunction, typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices> | |
| CK_TILE_HOST_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &, const AttentionVariantParams &, const BlockIndices &, void *smem_ptr, DropoutType &dropout) const |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices> | |
| CK_TILE_HOST_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
Static Public Attributes | |
| static constexpr bool | kQLoadOnce = true |
| static constexpr index_t | kBlockSize = Problem::kBlockSize |
| static constexpr index_t | kM0 = BlockFmhaShape::kM0 |
| static constexpr index_t | kN0 = BlockFmhaShape::kN0 |
| static constexpr index_t | kK0 = BlockFmhaShape::kK0 |
| static constexpr index_t | kN1 = BlockFmhaShape::kN1 |
| static constexpr index_t | kK1 = BlockFmhaShape::kK1 |
| static constexpr index_t | kQKHeaddim = BlockFmhaShape::kQKHeaddim |
| static constexpr index_t | kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim |
| static constexpr bool | kIsGroupMode = Problem::kIsGroupMode |
| static constexpr bool | kPadSeqLenQ = Problem::kPadSeqLenQ |
| static constexpr bool | kPadSeqLenK = Problem::kPadSeqLenK |
| static constexpr bool | kPadHeadDimQ = Problem::kPadHeadDimQ |
| static constexpr bool | kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::kPadHeadDimV |
| static constexpr auto | BiasEnum = Problem::BiasEnum |
| static constexpr bool | kStoreLSE = Problem::kStoreLSE |
| static constexpr bool | kHasDropout = Problem::kHasDropout |
| static constexpr bool | kHasLogitsSoftCap = Problem::kHasLogitsSoftCap |
| static constexpr index_t | kAlignmentQ |
| static constexpr index_t | kAlignmentK |
| static constexpr index_t | kAlignmentV |
| static constexpr index_t | kAlignmentO |
| static constexpr index_t | kAlignmentBias |
| static constexpr index_t | kBlockPerCu |
| static constexpr const char * | name = "qr_async" |
Member Typedef Documentation
◆ AttentionVariant
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant> |
◆ BiasDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
◆ BlockFmhaShape
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
◆ DropoutType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout> |
◆ FmhaMask
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
◆ KDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::KDataType = remove_cvref_t<typename Problem::KDataType> |
◆ LSEDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
◆ OaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
◆ ODataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType> |
◆ PDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::PDataType = remove_cvref_t<typename Problem::PDataType> |
◆ Policy
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::Policy = remove_cvref_t<Policy_> |
◆ Problem
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_> |
◆ QDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::QDataType = remove_cvref_t<typename Problem::QDataType> |
◆ RandValOutputDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
◆ SaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
◆ SMPLComputeDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
◆ VDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::VDataType = remove_cvref_t<typename Problem::VDataType> |
◆ VLayout
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch< Problem_, Policy_ >::VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
Member Function Documentation
◆ GetSmemSize()
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
inlinestaticconstexpr |
◆ operator()() [1/2]
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
|
inline |
◆ operator()() [2/2]
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename BiasElementFunction, typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
|
inline |
NOTICE: bias might be materialized mask including -inf values, need consideration
Member Data Documentation
◆ BiasEnum
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentBias
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>()
static constexpr bool kPadSeqLenK
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64
◆ kAlignmentK
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>()
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:52
◆ kAlignmentO
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_dot_do_o.hpp:24
◆ kAlignmentQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>()
◆ kAlignmentV
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}()
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_appendkv_pipeline.hpp:16
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_appendkv_pipeline.hpp:15
◆ kBlockPerCu
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
{
return 2;
}
{
return 2;
}
{
return 1;
else
return 2;
}
{
return 1;
}
else
{
return 1;
};
}
}()
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
static constexpr index_t kQKHeaddim
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:46
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:45
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:55
◆ kBlockSize
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kHasDropout
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kHasLogitsSoftCap
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kIsGroupMode
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kK0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kK1
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kM0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kN0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kN1
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimV
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenK
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kQKHeaddim
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kQLoadOnce
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kStoreLSE
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ kSubQKHeaddim
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
◆ name
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
staticconstexpr |
The documentation for this struct was generated from the following file: