13template <
typename DsDataType,
16 typename CShuffleDataType,
23 index_t CShuffleMRepeatPerShuffle,
24 index_t CShuffleNRepeatPerShuffle,
25 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
26 typename CDEShuffleBlockTransferScalarPerVectors,
27 typename CDEElementwiseOperation,
29 typename BlockwiseGemmPipe,
42 CShuffleMRepeatPerShuffle,
43 CShuffleNRepeatPerShuffle,
44 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
45 CDEShuffleBlockTransferScalarPerVectors,
46 CDEElementwiseOperation,
61 CShuffleMRepeatPerShuffle,
62 CShuffleNRepeatPerShuffle,
63 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
64 CDEShuffleBlockTransferScalarPerVectors,
65 CDEElementwiseOperation,
78 template <
typename DoPads, index_t MPerTile, index_t NPerTile>
81 const auto grid_desc_m_n =
84 grid_desc_m_n,
make_tuple(MPerTile, NPerTile), DoPads{});
87 template <
typename DoPads, index_t MPerTile, index_t NPerTile>
92 const auto grid_desc_m_n =
95 grid_desc_m_n,
make_tuple(MPerTile, NPerTile), DoPads{});
98 template <
typename Gr
idDescriptor_M_N>
99 __host__ __device__
static constexpr auto
102 const auto M = grid_desc_m_n.GetLength(
I0);
103 const auto NBlock = grid_desc_m_n.GetLength(
I1);
104 const auto MBlock = M / MPerBlock;
113 return grid_desc_mblock_mperblock_nblock;
123 EDataType* p_welford_var_grid_,
124 int32_t* p_welford_count_grid_,
143 typename DsGridPointer,
144 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
145 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
146 __device__
void Run(CThreadBuf& c_thread_buf,
147 DsGridPointer p_ds_grid,
150 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
151 ds_grid_desc_mblock_mperblock_nblock_nperblock,
152 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
153 e_grid_desc_mblock_mperblock_nblock_nperblock,
154 CDEElementwiseOperation& cde_element_op,
163 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
168 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
170 auto mean_var_grid_desc_mblock_mperblock_nblock =
178 p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
180 auto count_grid_desc_mblock_mperblock_nblock =
186 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
190 static_cast<CShuffleDataType*
>(p_shared),
191 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
192 .GetElementSpaceSize());
196 tie(c_shuffle_block_buf),
198 {
return ds_grid_buf[i]; },
211 constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
213 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
217 tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
219 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
223 auto cde_shuffle_block_copy_lds_and_global =
226 e_grid_desc_mblock_mperblock_nblock_nperblock,
233 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
237 constexpr index_t PostShuffleThreadSliceSize_M =
238 (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
239 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(
I1);
241 constexpr index_t PostShuffleThreadSliceSize_N =
242 (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
243 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(
I3);
245 constexpr auto PostShuffleThreadSliceSize_M_N =
249 constexpr auto post_shuffle_thread_desc_m_n =
256 post_shuffle_thread_desc_m_n.GetElementSpaceSize());
258 using PostShuffleThreadClusterSize_M_N =
Sequence<
259 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(
I1),
260 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(
I3)>;
262 constexpr auto post_shuffle_thread_cluster_desc =
265 const auto post_shuffle_thread_cluster_idx =
266 post_shuffle_thread_cluster_desc.CalculateBottomIndex(
269 const auto post_shuffle_thread_data_idx_begin =
270 post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
275 constexpr auto thread_welford_dst_desc_m =
279 decltype(thread_welford_src_desc_m_k),
280 decltype(thread_welford_dst_desc_m)>;
284 PostShuffleThreadClusterSize_M_N,
288 constexpr int num_shuffleM =
289 MPerBlock / (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma);
291 constexpr int num_shuffleN =
292 NPerBlock / (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma);
295 thread_welford_dst_desc_m.GetElementSpaceSize()));
297 using welford_count_vgpr_type =
299 thread_welford_dst_desc_m.GetElementSpaceSize()));
306 int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
307 const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(
I2);
310 if(block_n_id % nblock == nblock - 1)
312 constexpr index_t NPerShuffleBlock =
313 CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma;
315 int NPerBlockTail =
NRaw - NPerBlock * (nblock - 1);
317 PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[
I1] + 1);
318 int shuffle_step = 0;
319 while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
322 thread_max_len += NPerShuffleBlock;
326 if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
328 else if(NPerBlockTail > thread_max_len)
329 delta = PostShuffleThreadSliceSize_N;
331 delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
333 max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
338 threadwise_welfords(i).max_count_ = max_count;
340 thread_welford_dst_desc_m.GetElementSpaceSize());
343 thread_welford_dst_desc_m.GetElementSpaceSize());
346 thread_welford_dst_desc_m.GetElementSpaceSize());
351 welford_count_thread_bufs(i)(j) = 0;
355 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
357 static_assert(num_access == sfc_cde_global.GetNumOfAccess(),
"wrong!");
360 int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
366 c_thread_copy_vgpr_to_lds.Run(
367 c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
368 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
370 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
371 c_shuffle_block_buf);
377 cde_shuffle_block_copy_lds_and_global.RunRead(c_ds_desc_refs, c_ds_buf_refs);
380 cde_shuffle_block_copy_lds_and_global.RunWriteAndStoreVgpr(
381 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
383 tie(post_shuffle_thread_desc_m_n),
386 if constexpr(access_id < num_access - 1)
388 constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
391 cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
392 c_ds_desc_refs, i +
I1, cde_global_step);
396 cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
397 tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
401 auto& threadwise_welford = threadwise_welfords(shuffleM_index);
402 auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
403 auto& var_thread_buf = var_thread_bufs(shuffleM_index);
405 threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
407 if constexpr(access_id < num_access - 1)
409 constexpr auto de_global_step = sfc_cde_global.GetForwardStep(access_id);
410 constexpr int shuffleMInc =
412 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
414 shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
420 auto& mean_thread_buf = mean_thread_bufs(i);
421 auto& var_thread_buf = var_thread_bufs(i);
422 auto& count_thread_buf = welford_count_thread_bufs(i);
426 count_thread_buf(j) = threadwise_welfords(i).cur_count_;
430 if(post_shuffle_thread_cluster_idx[
I1] == 0)
435 constexpr int shuffleMPerBlock =
436 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
441 shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[
I0],
447 decltype(thread_welford_desc_I_m_I),
448 decltype(mean_var_grid_desc_mblock_mperblock_nblock),
456 true>{mean_var_grid_desc_mblock_mperblock_nblock,
457 mean_var_count_thread_copy_index,
460 mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
463 mean_var_grid_desc_mblock_mperblock_nblock,
466 mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
469 mean_var_grid_desc_mblock_mperblock_nblock,
474 if(i == 0 && block_m_id == 0 && post_shuffle_thread_cluster_idx[
I0] == 0)
479 decltype(thread_welford_desc_I_m_I),
480 decltype(count_grid_desc_mblock_mperblock_nblock),
488 false>{count_grid_desc_mblock_mperblock_nblock,
489 mean_var_count_thread_copy_index,
492 count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
495 count_grid_desc_mblock_mperblock_nblock,
496 welford_count_grid_buf);
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
signed int int32_t
Definition stdint.h:123
Definition utility/array.hpp:14
Definition blockwise_welford.hpp:25
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition epilogue_cshuffle_v3_wmma_base.hpp:29
static constexpr auto I2
Definition epilogue_cshuffle_v3_wmma_base.hpp:32
static constexpr index_t NumDTensor
Definition epilogue_cshuffle_v3_wmma_base.hpp:38
static constexpr auto I0
Definition epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr auto I3
Definition epilogue_cshuffle_v3_wmma_base.hpp:33
static __device__ auto GetLDSToVmemEpilogueDescriptor(CDsDescRefs &c_ds_desc_refs, EGridDesc &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition epilogue_cshuffle_v3_wmma_base.hpp:204
SpaceFillingCurve< Sequence< MRepeat, 1, 1, NRepeat, 1, 1, BlockwiseGemmPipe::MAccVgprs >, Sequence< 0, 1, 2, 3, 4, 5, 6 >, Sequence< CShuffleMRepeatPerShuffle, 1, 1, CShuffleNRepeatPerShuffle, 1, 1, BlockwiseGemmPipe::MAccVgprs > > SpaceFillingCurveVgpr
Definition epilogue_cshuffle_v3_wmma_base.hpp:42
static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:118
SpaceFillingCurve< Sequence< 1, MPerBlock, 1, NPerBlock >, Sequence< 0, 2, 1, 3 >, Sequence< 1, CShuffleMRepeatPerShuffle *BlockwiseGemmPipe::MWaves *MPerWmma, 1, CShuffleNRepeatPerShuffle *BlockwiseGemmPipe::NWaves *NPerWmma > > SpaceFillingCurveVmem
Definition epilogue_cshuffle_v3_wmma_base.hpp:53
static __device__ constexpr auto GetCShuffleLDSDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:78
__host__ static __device__ auto MakeCountDescriptor_M_N(index_t M, index_t N)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:88
static constexpr auto I2
Definition epilogue_cshuffle_v3_wmma_base.hpp:32
EDataType * p_welford_var_grid
Definition epilogue_cshuffle_v3_welford_wmma.hpp:503
GemmCountGridDesc_M_N gemm_count_grid_desc_m_nblock
Definition epilogue_cshuffle_v3_welford_wmma.hpp:507
EpilogueCShuffleBase< DsDataType, EDataType, AccDataType, CShuffleDataType, MPerBlock, NPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, CDEElementwiseOperation, ThisThreadBlock, BlockwiseGemmPipe > Base
Definition epilogue_cshuffle_v3_welford_wmma.hpp:50
EDataType * p_welford_mean_grid
Definition epilogue_cshuffle_v3_welford_wmma.hpp:502
index_t NRaw
Definition epilogue_cshuffle_v3_welford_wmma.hpp:505
static constexpr auto I0
Definition epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr auto I3
Definition epilogue_cshuffle_v3_wmma_base.hpp:33
decltype(MakeMeanVarDescriptor_M_N< Sequence< true, false >, MPerBlock, 1 >(1, 1)) GemmMeanVarGridDesc_M_N
Definition epilogue_cshuffle_v3_welford_wmma.hpp:116
__device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:146
int32_t * p_welford_count_grid
Definition epilogue_cshuffle_v3_welford_wmma.hpp:504
__host__ static __device__ auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:79
__device__ EpilogueWelfordCShuffle(EDataType *p_welford_mean_grid_, EDataType *p_welford_var_grid_, int32_t *p_welford_count_grid_, index_t MRaw_, index_t NRaw_)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:122
static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition epilogue_cshuffle_v3_wmma_base.hpp:63
decltype(MakeCountDescriptor_M_N< Sequence< true, false >, MPerBlock, 1 >(1, 1)) GemmCountGridDesc_M_N
Definition epilogue_cshuffle_v3_welford_wmma.hpp:119
static constexpr auto I1
Definition epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:118
GemmMeanVarGridDesc_M_N gemm_mean_var_grid_desc_m_nblock
Definition epilogue_cshuffle_v3_welford_wmma.hpp:506
__host__ static __device__ constexpr auto MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N &grid_desc_m_n)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:100
static __device__ constexpr auto GetCShuffleLDSDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:78
Definition utility/sequence.hpp:43
Definition thread_group.hpp:12
Definition threadwise_tensor_slice_transfer.hpp:39
Definition threadwise_welford.hpp:18
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340