fmha_fwd_kernel.hpp Source File#
fmha_fwd_kernel.hpp
Go to the documentation of this file.
81 template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
82 template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
108 "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
109 "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
110 "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
111 "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
112 (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
113 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
114 (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
115 (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
#define _TS_
#define _SS_
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition block_masking.hpp:632
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
@ SYSTEM_NT1
Definition tile/core/arch/amd_buffer_addressing.hpp:1419
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, index_t window_left_size, index_t window_right_size, index_t y_total, index_t x_total, GenericAttentionMaskEnum mask_enum)
Definition block_position_encoding.hpp:148
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
@ MASK_FROM_TOP_LEFT
Definition block_masking.hpp:15
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_position_encoding.hpp:48
Definition block_attention_bias_enum.hpp:19
Definition block_dropout.hpp:53
Definition block_position_encoding.hpp:137
Definition fmha_fwd_kernel.hpp:330
ck_tile::index_t kv_head_idx
Definition fmha_fwd_kernel.hpp:333
ck_tile::index_t batch_idx
Definition fmha_fwd_kernel.hpp:331
ck_tile::index_t qo_head_idx
Definition fmha_fwd_kernel.hpp:332
Definition fmha_fwd_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition fmha_fwd_kernel.hpp:197
const void * alibi_slope_ptr
Definition fmha_fwd_kernel.hpp:196
Definition fmha_fwd_kernel.hpp:189
ck_tile::index_t batch_stride_bias
Definition fmha_fwd_kernel.hpp:190
Definition fmha_fwd_kernel.hpp:270
ck_tile::index_t batch_stride_randval
Definition fmha_fwd_kernel.hpp:271
Definition fmha_fwd_kernel.hpp:291
ck_tile::index_t batch_stride_o
Definition fmha_fwd_kernel.hpp:295
const int32_t * cu_seqlen_k_ptr
Definition fmha_fwd_kernel.hpp:300
ck_tile::index_t batch_stride_q
Definition fmha_fwd_kernel.hpp:292
ck_tile::index_t batch_stride_k
Definition fmha_fwd_kernel.hpp:293
const int32_t * cu_seqlen_q_ptr
Definition fmha_fwd_kernel.hpp:299
ck_tile::index_t batch_stride_v
Definition fmha_fwd_kernel.hpp:294
Definition fmha_fwd_kernel.hpp:182
const void * bias_ptr
Definition fmha_fwd_kernel.hpp:183
ck_tile::index_t stride_bias
Definition fmha_fwd_kernel.hpp:184
ck_tile::index_t nhead_stride_bias
Definition fmha_fwd_kernel.hpp:185
Definition fmha_fwd_kernel.hpp:235
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition fmha_fwd_kernel.hpp:248
float rp_undrop
Definition fmha_fwd_kernel.hpp:260
ck_tile::index_t stride_randval
Definition fmha_fwd_kernel.hpp:265
ck_tile::index_t nhead_stride_randval
Definition fmha_fwd_kernel.hpp:266
void * rand_val_ptr
Definition fmha_fwd_kernel.hpp:263
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition fmha_fwd_kernel.hpp:236
bool is_store_randval
Definition fmha_fwd_kernel.hpp:262
uint8_t p_undrop_in_uint8_t
Definition fmha_fwd_kernel.hpp:261
Definition fmha_fwd_kernel.hpp:131
ck_tile::index_t nhead_stride_k
Definition fmha_fwd_kernel.hpp:154
ck_tile::index_t seqlen_k
Definition fmha_fwd_kernel.hpp:138
ck_tile::index_t nhead_stride_o
Definition fmha_fwd_kernel.hpp:156
ck_tile::index_t nhead_ratio_qk
Definition fmha_fwd_kernel.hpp:145
ck_tile::index_t num_head_q
Definition fmha_fwd_kernel.hpp:142
ck_tile::index_t hdim_q
Definition fmha_fwd_kernel.hpp:139
const void * v_ptr
Definition fmha_fwd_kernel.hpp:134
const void * k_ptr
Definition fmha_fwd_kernel.hpp:133
ck_tile::index_t nhead_stride_q
Definition fmha_fwd_kernel.hpp:153
ck_tile::index_t stride_k
Definition fmha_fwd_kernel.hpp:149
ck_tile::index_t stride_o
Definition fmha_fwd_kernel.hpp:151
ck_tile::index_t stride_v
Definition fmha_fwd_kernel.hpp:150
ck_tile::index_t hdim_v
Definition fmha_fwd_kernel.hpp:140
ck_tile::index_t nhead_stride_v
Definition fmha_fwd_kernel.hpp:155
const void * q_ptr
Definition fmha_fwd_kernel.hpp:132
ck_tile::index_t seqlen_q
Definition fmha_fwd_kernel.hpp:137
ck_tile::index_t stride_q
Definition fmha_fwd_kernel.hpp:148
Definition fmha_fwd_kernel.hpp:214
ck_tile::index_t batch_stride_lse
Definition fmha_fwd_kernel.hpp:217
void * lse_ptr
Definition fmha_fwd_kernel.hpp:215
ck_tile::index_t nhead_stride_lse
Definition fmha_fwd_kernel.hpp:216
Definition fmha_fwd_kernel.hpp:221
bool is_drop_seed_offset_from_host
Definition fmha_fwd_kernel.hpp:231
ValueOrPointer< uint64_t > drop_seed
Definition fmha_fwd_kernel.hpp:229
ValueOrPointer< uint64_t > drop_offset
Definition fmha_fwd_kernel.hpp:230
Definition fmha_fwd_kernel.hpp:124
Definition fmha_fwd_kernel.hpp:208
float scale_o
Definition fmha_fwd_kernel.hpp:210
float scale_p
Definition fmha_fwd_kernel.hpp:209
Definition fmha_fwd_kernel.hpp:316
const int32_t * seqlen_q_ptr
Definition fmha_fwd_kernel.hpp:319
const int32_t * seqstart_q_ptr
Definition fmha_fwd_kernel.hpp:317
const int32_t * seqlen_k_ptr
Definition fmha_fwd_kernel.hpp:320
const int32_t * cu_seqlen_k_ptr
Definition fmha_fwd_kernel.hpp:324
const int32_t * cu_seqlen_q_ptr
Definition fmha_fwd_kernel.hpp:323
const int32_t * seqstart_k_ptr
Definition fmha_fwd_kernel.hpp:318
float logits_soft_cap
Definition fmha_fwd_kernel.hpp:177
FmhaFwdLogitsSoftCapKargs()=default
float logits_soft_cap_rcp
Definition fmha_fwd_kernel.hpp:178
void init_logits_soft_cap(float logits_soft_cap_)
Definition fmha_fwd_kernel.hpp:163
Definition fmha_fwd_kernel.hpp:201
ck_tile::GenericAttentionMaskEnum mask_type
Definition fmha_fwd_kernel.hpp:204
ck_tile::index_t window_size_right
Definition fmha_fwd_kernel.hpp:203
ck_tile::index_t window_size_left
Definition fmha_fwd_kernel.hpp:203
Definition fmha_fwd_kernel.hpp:275
ck_tile::index_t min_seqlen_q
Definition fmha_fwd_kernel.hpp:276
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:78
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:80
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:77
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:81
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:82
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:79
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:76
Definition fmha_fwd_kernel.hpp:75
Definition fmha_fwd_kernel.hpp:27
static constexpr bool kHasDropout
Definition fmha_fwd_kernel.hpp:56
static CK_TILE_HOST std::string GetName()
Definition fmha_fwd_kernel.hpp:85
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition fmha_fwd_kernel.hpp:40
static constexpr bool kIsAvailable
Definition fmha_fwd_kernel.hpp:70
static constexpr bool kDoFp8StaticQuant
Definition fmha_fwd_kernel.hpp:57
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:579
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition fmha_fwd_kernel.hpp:37
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition fmha_fwd_kernel.hpp:327
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_fwd_kernel.hpp:32
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:675
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:482
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition fmha_fwd_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition fmha_fwd_kernel.hpp:46
static constexpr ck_tile::index_t kBlockSize
Definition fmha_fwd_kernel.hpp:30
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition fmha_fwd_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition fmha_fwd_kernel.hpp:38
static constexpr ck_tile::index_t kBlockPerCuInput
Definition fmha_fwd_kernel.hpp:34
static constexpr bool kPadHeadDimV
Definition fmha_fwd_kernel.hpp:52
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:817
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &kargs)
Definition fmha_fwd_kernel.hpp:1017
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:906
static constexpr bool kSkipMinSeqlenQ
Definition fmha_fwd_kernel.hpp:58
static constexpr std::string_view kPipelineName
Definition fmha_fwd_kernel.hpp:72
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_fwd_kernel.hpp:1094
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition fmha_fwd_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition fmha_fwd_kernel.hpp:36
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_fwd_kernel.hpp:1082
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition fmha_fwd_kernel.hpp:1105
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition fmha_fwd_kernel.hpp:60
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition fmha_fwd_kernel.hpp:992
static constexpr bool kUseAsyncCopy
Definition fmha_fwd_kernel.hpp:64
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_fwd_kernel.hpp:28
static constexpr bool kPadHeadDimQ
Definition fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition fmha_fwd_kernel.hpp:44
static constexpr bool kPadSeqLenQ
Definition fmha_fwd_kernel.hpp:49
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition fmha_fwd_kernel.hpp:61
static constexpr bool kHasLogitsSoftCap
Definition fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition fmha_fwd_kernel.hpp:29
static constexpr bool kPadSeqLenK
Definition fmha_fwd_kernel.hpp:50
static constexpr bool kIsGroupMode
Definition fmha_fwd_kernel.hpp:48
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:338
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_fwd_kernel.hpp:1099
Definition variants.hpp:63
Definition block_dropout.hpp:39
Definition variants.hpp:51
Definition unary_element_function.hpp:12
Definition tile/core/utility/functional.hpp:86
Definition coordinate_transform.hpp:1392
Definition unary_element_function.hpp:56
Definition tile/core/numeric/math.hpp:28
Definition tile/core/container/sequence.hpp:49
Definition fmha_fwd_kernel.hpp:224
T val
Definition fmha_fwd_kernel.hpp:225
const T * ptr
Definition fmha_fwd_kernel.hpp:226