BlockFmhaFwdV3Pipeline< Problem_, Policy_ > Struct Template Reference

BlockFmhaFwdV3Pipeline&lt; Problem_, Policy_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ > Struct Template Reference
ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ > Struct Template Reference

#include <block_fmha_fwd_v3_pipeline.hpp>

Public Types

using Problem = ck_tile::remove_cvref_t<Problem_>
using Policy = ck_tile::remove_cvref_t<Policy_>
using QDataType = ck_tile::remove_cvref_t<typename Problem::QDataType>
using KDataType = ck_tile::remove_cvref_t<typename Problem::KDataType>
using VDataType = ck_tile::remove_cvref_t<typename Problem::VDataType>
using SaccDataType = ck_tile::remove_cvref_t<typename Problem::SaccDataType>
using SMPLComputeDataType = ck_tile::remove_cvref_t<typename Problem::SMPLComputeDataType>
using LSEDataType = ck_tile::remove_cvref_t<typename Problem::LSEDataType>
using PDataType = ck_tile::remove_cvref_t<typename Problem::PDataType>
using OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>
using ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>
using FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask>
using BlockFmhaShape = ck_tile::remove_cvref_t<typename Problem::BlockFmhaShape>

Public Member Functions

template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction>
CK_TILE_DEVICE auto operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename LSEDramBlockWindowTmp>
CK_TILE_DEVICE auto operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const

Static Public Member Functions

static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()
template<ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
static CK_TILE_DEVICE constexpr auto MakeSimpleLdsDesc ()
template<ck_tile::index_t MPerBlock>
static CK_TILE_DEVICE constexpr auto MakeSimpleLdsDesc1D ()
template<typename DataType, typename Descriptor>
static CK_TILE_DEVICE constexpr auto make_lds_tile_window (void *base, const Descriptor &desc)
template<uint16_t Vmcnt, uint8_t Lgkmcnt, uint8_t Expcnt = 7>
static CK_TILE_DEVICE constexpr void s_waitcnt ()
template<uint16_t Vmcnt>
static CK_TILE_DEVICE constexpr void s_waitcnt_vmcnt ()
template<uint8_t Lgkmcnt>
static CK_TILE_DEVICE constexpr void s_waitcnt_lgkmcnt ()

Static Public Attributes

static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1
static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim
static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim
static constexpr bool kIsGroupMode = Problem::kIsGroupMode
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV
static constexpr bool kStoreLSE = Problem::kStoreLSE
static constexpr ck_tile::index_t kAlignmentQ
static constexpr ck_tile::index_t kAlignmentK
static constexpr ck_tile::index_t kAlignmentV
static constexpr ck_tile::index_t kAlignmentO
static constexpr ck_tile::index_t kBlockPerCu

Member Typedef Documentation

◆ BlockFmhaShape

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::BlockFmhaShape = ck_tile::remove_cvref_t<typename Problem::BlockFmhaShape>

◆ FmhaMask

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask>

◆ KDataType

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::KDataType = ck_tile::remove_cvref_t<typename Problem::KDataType>

◆ LSEDataType

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::LSEDataType = ck_tile::remove_cvref_t<typename Problem::LSEDataType>

◆ OaccDataType

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>

◆ ODataType

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>

◆ PDataType

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::PDataType = ck_tile::remove_cvref_t<typename Problem::PDataType>

◆ Policy

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::Policy = ck_tile::remove_cvref_t<Policy_>

◆ Problem

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::Problem = ck_tile::remove_cvref_t<Problem_>

◆ QDataType

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::QDataType = ck_tile::remove_cvref_t<typename Problem::QDataType>

◆ SaccDataType

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::SaccDataType = ck_tile::remove_cvref_t<typename Problem::SaccDataType>

◆ SMPLComputeDataType

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::SMPLComputeDataType = ck_tile::remove_cvref_t<typename Problem::SMPLComputeDataType>

◆ VDataType

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::VDataType = ck_tile::remove_cvref_t<typename Problem::VDataType>

Member Function Documentation

◆ GetSmemSize()

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ make_lds_tile_window()

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
template<typename DataType, typename Descriptor>
CK_TILE_DEVICE constexpr auto ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::make_lds_tile_window ( void * base,
const Descriptor & desc )
inlinestaticconstexpr

◆ MakeSimpleLdsDesc()

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
template<ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
CK_TILE_DEVICE constexpr auto ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::MakeSimpleLdsDesc ( )
inlinestaticconstexpr

◆ MakeSimpleLdsDesc1D()

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
template<ck_tile::index_t MPerBlock>
CK_TILE_DEVICE constexpr auto ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::MakeSimpleLdsDesc1D ( )
inlinestaticconstexpr

◆ operator()() [1/2]

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename LSEDramBlockWindowTmp>
CK_TILE_DEVICE auto ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::operator() ( const QDramBlockWindowTmp & q_dram_block_window_tmp,
const KDramBlockWindowTmp & k_dram_block_window_tmp,
const VDramBlockWindowTmp & v_dram_block_window_tmp,
LSEDramBlockWindowTmp & lse_dram_block_window_tmp,
FmhaMask mask,
float scale_s,
void * smem_ptr ) const
inline

◆ operator()() [2/2]

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction>
CK_TILE_DEVICE auto ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::operator() ( const QDramBlockWindowTmp & q_dram_block_window_tmp,
const QElementFunction & q_element_func,
const KDramBlockWindowTmp & k_dram_block_window_tmp,
const KElementFunction & k_element_func,
const VDramBlockWindowTmp & v_dram_block_window_tmp,
const VElementFunction & v_element_func,
LSEDramBlockWindowTmp & lse_dram_window_tmp,
const LSEElementFunction & lse_element_func,
const SAccElementFunction & s_acc_element_func,
const PComputeElementFunction & p_compute_element_func,
const OAccElementFunction & o_acc_element_func,
FmhaMask mask,
float scale_s,
void * smem_ptr ) const
inline

FIXME: use the future-predicting method to move the window

FIXME: use the future-predicting method to move the window

TODO: remove the sp_delta and use sp_compute directly

TODO: move some fmha_alu1() code here if necessary

Note: The compiler keeps moving the following instructions elsewhere because 'l' is first consumed later. To anchor them here, we rewrite the final addition in inline assembly to create a dependency, forcing the dependent instructions to be emitted at this point.

Note: The compiler keeps sinking the conversion instructions because the result 'p' is only consumed later. To anchor them here, we rewrite the cast_tile() call as inline assembly, forcing the conversions to be emitted at this point.

Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly can interfere with the behavior of sched_group_barrier(), so ending the phase here avoids unintended reordering.

NOTICE: Use inline asm v_pk_mul_f32 to reduce latency. The fmha_alu_D_upd() call should be placed at the end of a phase.

TODO: find better way to map fmha_alu(0,96) call

◆ s_waitcnt()

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
template<uint16_t Vmcnt, uint8_t Lgkmcnt, uint8_t Expcnt = 7>
CK_TILE_DEVICE constexpr void ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::s_waitcnt ( )
inlinestaticconstexpr

◆ s_waitcnt_lgkmcnt()

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
template<uint8_t Lgkmcnt>
CK_TILE_DEVICE constexpr void ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::s_waitcnt_lgkmcnt ( )
inlinestaticconstexpr

◆ s_waitcnt_vmcnt()

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
template<uint16_t Vmcnt>
CK_TILE_DEVICE constexpr void ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::s_waitcnt_vmcnt ( )
inlinestaticconstexpr

Member Data Documentation

◆ kAlignmentK

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kAlignmentK
staticconstexpr
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>()
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:52

◆ kAlignmentO

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kAlignmentO
staticconstexpr
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_dot_do_o.hpp:24

◆ kAlignmentQ

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kAlignmentQ
staticconstexpr
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>()

◆ kAlignmentV

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kAlignmentV
staticconstexpr
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>()
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:53

◆ kBlockPerCu

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kBlockPerCu
staticconstexpr
Initial value:
= []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
return 2;
}
}()

◆ kBlockSize

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kBlockSize = Problem::kBlockSize
staticconstexpr

◆ kIsGroupMode

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kIsGroupMode = Problem::kIsGroupMode
staticconstexpr

◆ kK0

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kK0 = BlockFmhaShape::kK0
staticconstexpr

◆ kK1

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kK1 = BlockFmhaShape::kK1
staticconstexpr

◆ kM0

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kM0 = BlockFmhaShape::kM0
staticconstexpr

◆ kN0

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kN0 = BlockFmhaShape::kN0
staticconstexpr

◆ kN1

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kN1 = BlockFmhaShape::kN1
staticconstexpr

◆ kPadHeadDimQ

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kPadHeadDimQ = Problem::kPadHeadDimQ
staticconstexpr

◆ kPadHeadDimV

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kPadHeadDimV = Problem::kPadHeadDimV
staticconstexpr

◆ kPadSeqLenK

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kPadSeqLenK = Problem::kPadSeqLenK
staticconstexpr

◆ kPadSeqLenQ

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kPadSeqLenQ = Problem::kPadSeqLenQ
staticconstexpr

◆ kQKHeaddim

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kQKHeaddim = BlockFmhaShape::kQKHeaddim
staticconstexpr

◆ kStoreLSE

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
bool ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kStoreLSE = Problem::kStoreLSE
staticconstexpr

◆ kSubQKHeaddim

template<typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
ck_tile::index_t ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim
staticconstexpr

The documentation for this struct was generated from the following file: