24template <
typename GridwiseGemm,
25 typename A0B0B1DataType,
29 typename A0ElementwiseOperation,
30 typename B0ElementwiseOperation,
31 typename CDE0ElementwiseOperation,
32 typename B1ElementwiseOperation,
33 typename CDE1ElementwiseOperation,
34 typename A0GridDesc_AK0_M_AK1,
35 typename B0GridDesc_BK0_N_BK1,
36 typename D0sGridDesc_M_N,
37 typename B1GridDesc_BK0_N_BK1,
38 typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
39 typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
40 typename Block2E1TileMap,
41 typename ComputeBasePtrOfStridedBatch,
42 bool HasMainKBlockLoop>
44#if CK_USE_LAUNCH_BOUNDS
48 const A0B0B1DataType* __restrict__ p_a0_grid,
49 const A0B0B1DataType* __restrict__ p_b0_grid,
50 D0sPointer p_d0s_grid,
51 const A0B0B1DataType* __restrict__ p_b1_grid,
52 D1sPointer p_d1s_grid,
53 E1DataType* __restrict__ p_e1_grid,
54 const A0ElementwiseOperation a0_element_op,
55 const B0ElementwiseOperation b0_element_op,
56 const CDE0ElementwiseOperation cde0_element_op,
57 const B1ElementwiseOperation b1_element_op,
58 const CDE1ElementwiseOperation cde1_element_op,
59 const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
60 const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
61 const D0sGridDesc_M_N d0s_griddesc_m_n,
62 const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
63 const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
64 d1s_grid_desc_mblock_mperblock_nblock_nperblock,
65 const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
66 e1_grid_desc_mblock_mperblock_nblock_nperblock,
67 const Block2E1TileMap block_2_e1tile_map,
69 const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
71#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
72 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
74 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
75 const index_t num_blocks_per_batch =
76 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
78 __builtin_amdgcn_readfirstlane(
get_block_1d_id() / num_blocks_per_batch);
80 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
81 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
82 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
83 static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
84 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
85 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
86 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
87 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
89 static_for<0, p_d0s_grid.Size(), 1>{}([&](
auto In) {
90 const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
91 static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
92 p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
95 static_for<0, p_d1s_grid.Size(), 1>{}([&](
auto In) {
96 const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
97 static_cast<long_index_t>(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
98 p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset;
101 GridwiseGemm::template Run<HasMainKBlockLoop>(
102 p_a0_grid + a_batch_offset,
103 p_b0_grid + b_batch_offset,
105 p_b1_grid + b1_batch_offset,
107 p_e1_grid + c_batch_offset,
114 a0_grid_desc_ak0_m_ak1,
115 b0_grid_desc_bk0_n_bk1,
117 b1_grid_desc_bk0_n_bk1,
118 d1s_grid_desc_mblock_mperblock_nblock_nperblock,
119 e1_grid_desc_mblock_mperblock_nblock_nperblock,
134 ignore = a0_grid_desc_ak0_m_ak1;
135 ignore = b0_grid_desc_bk0_n_bk1;
136 ignore = d0s_griddesc_m_n;
137 ignore = b1_grid_desc_bk0_n_bk1;
138 ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock;
139 ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock;
140 ignore = block_2_e1tile_map;
142 ignore = compute_base_ptr_of_batch;
149template <
typename A0Layout,
157 typename Acc0DataType,
158 typename D0sDataType,
160 typename Acc1DataType,
161 typename C1ShuffleDataType,
162 typename D1sDataType,
164 typename A0ElementwiseOperation,
165 typename B0ElementwiseOperation,
166 typename CDE0ElementwiseOperation,
167 typename B1ElementwiseOperation,
168 typename CDE1ElementwiseOperation,
174 index_t NumGemm0KPrefetchStage,
189 typename A0BlockTransferThreadClusterLengths_AK0_M_AK1,
190 typename A0BlockTransferThreadClusterArrangeOrder,
191 typename A0BlockTransferSrcAccessOrder,
192 index_t A0BlockTransferSrcVectorDim,
193 index_t A0BlockTransferSrcScalarPerVector,
194 index_t A0BlockTransferDstScalarPerVector_AK1,
195 bool A0BlockLdsExtraM,
196 typename B0BlockTransferThreadClusterLengths_BK0_N_BK1,
197 typename B0BlockTransferThreadClusterArrangeOrder,
198 typename B0BlockTransferSrcAccessOrder,
199 index_t B0BlockTransferSrcVectorDim,
200 index_t B0BlockTransferSrcScalarPerVector,
201 index_t B0BlockTransferDstScalarPerVector_BK1,
202 bool B0BlockLdsExtraN,
203 index_t CDE0BlockTransferSrcVectorDim,
204 index_t CDE0BlockTransferSrcScalaerPerVector,
205 typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
206 typename B1BlockTransferThreadClusterArrangeOrder,
207 typename B1BlockTransferSrcAccessOrder,
208 index_t B1BlockTransferSrcVectorDim,
209 index_t B1BlockTransferSrcScalarPerVector,
210 index_t B1BlockTransferDstScalarPerVector_BK1,
211 bool B1BlockLdsExtraN,
212 index_t C1ShuffleMXdlPerWavePerShuffle,
213 index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
214 typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
215 index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
230 A0ElementwiseOperation,
231 B0ElementwiseOperation,
232 CDE0ElementwiseOperation,
233 B1ElementwiseOperation,
234 CDE1ElementwiseOperation>
269 Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock};
273 Gemm0MPerBlock, Gemm1NPerBlock, Gemm1KPerBlock};
278 const auto a0_grid_desc_mraw_kraw = [&]() {
291 return gemm0_padder.PadADescriptor_M_K(a0_grid_desc_mraw_kraw);
297 const auto b0_grid_desc_nraw_kraw = [&]() {
310 return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc_nraw_kraw);
314 template <
typename DLay>
317 const auto d0_grid_desc_mraw_nraw = [&]() {
330 return gemm0_padder.PadCDescriptor_M_N(d0_grid_desc_mraw_nraw);
336 const auto b1_grid_desc_nraw_kraw = [&]() {
349 return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc_nraw_kraw);
353 template <
typename ELay>
356 const auto e1_grid_desc_mraw_nraw = [&]() {
369 return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw);
373 const std::array<index_t, NumD1Tensor>& NRaws,
374 const std::array<index_t, NumD1Tensor>& DsStride)
386 const std::array<index_t, NumD1Tensor>& NRaws,
387 const std::array<index_t, NumD1Tensor>& DsStride)
402 std::array<index_t, NumD0Tensor> BatchStrideD0s,
404 std::array<index_t, NumD1Tensor> BatchStrideD1s,
406 : BatchStrideA0_(BatchStrideA0),
407 BatchStrideB0_(BatchStrideB0),
408 BatchStrideD0s_(BatchStrideD0s),
409 BatchStrideB1_(BatchStrideB1),
410 BatchStrideD1s_(BatchStrideD1s),
411 BatchStrideE1_(BatchStrideE1)
417 return g_idx *
static_cast<long_index_t>(BatchStrideA0_);
422 return g_idx *
static_cast<long_index_t>(BatchStrideB0_);
429 return g_idx *
static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
434 return g_idx *
static_cast<long_index_t>(BatchStrideB1_);
439 return g_idx *
static_cast<long_index_t>(BatchStrideE1_);
445 return g_idx *
static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
451 std::array<index_t, NumD0Tensor> BatchStrideD0s_;
453 std::array<index_t, NumD1Tensor> BatchStrideD1s_;
465 template <index_t Gemm0MXdlPerWave_>
474 A0ElementwiseOperation,
475 B0ElementwiseOperation,
476 CDE0ElementwiseOperation,
477 B1ElementwiseOperation,
478 CDE1ElementwiseOperation,
486 NumGemm0KPrefetchStage,
501 A0BlockTransferThreadClusterLengths_AK0_M_AK1,
502 A0BlockTransferThreadClusterArrangeOrder,
503 A0BlockTransferSrcAccessOrder,
504 A0BlockTransferSrcVectorDim,
505 A0BlockTransferSrcScalarPerVector,
506 A0BlockTransferDstScalarPerVector_AK1,
509 B0BlockTransferThreadClusterLengths_BK0_N_BK1,
510 B0BlockTransferThreadClusterArrangeOrder,
511 B0BlockTransferSrcAccessOrder,
512 B0BlockTransferSrcVectorDim,
513 B0BlockTransferSrcScalarPerVector,
514 B0BlockTransferDstScalarPerVector_BK1,
517 CDE0BlockTransferSrcVectorDim,
518 CDE0BlockTransferSrcScalaerPerVector,
519 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
520 B1BlockTransferThreadClusterArrangeOrder,
521 B1BlockTransferSrcAccessOrder,
522 B1BlockTransferSrcVectorDim,
523 B1BlockTransferSrcScalarPerVector,
524 B1BlockTransferDstScalarPerVector_BK1,
527 C1ShuffleMXdlPerWavePerShuffle,
528 C1ShuffleGemm0NXdlPerWavePerShuffle,
529 CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
530 CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
549 const B0DataType* p_b0_grid,
550 std::array<const void*, NumD0Tensor> p_d0s_grid,
551 const B1DataType* p_b1_grid,
552 std::array<const void*, NumD1Tensor> p_d1s_grid,
553 E1DataType* p_e1_grid,
561 std::array<index_t, NumD0Tensor> StrideD0s,
563 std::array<index_t, NumD1Tensor> StrideD1s,
567 std::array<index_t, NumD0Tensor> BatchStrideD0s,
569 std::array<index_t, NumD1Tensor> BatchStrideD1s,
571 A0ElementwiseOperation a0_element_op,
572 B0ElementwiseOperation b0_element_op,
573 CDE0ElementwiseOperation cde0_element_op,
574 B1ElementwiseOperation b1_element_op,
575 CDE1ElementwiseOperation cde1_element_op)
628 p_d0s_grid_(i) =
static_cast<const D0DataType*
>(p_d0s_grid[i]);
640 p_d1s_grid_(i) =
static_cast<const D1DataType*
>(p_d1s_grid[i]);
690 template <
typename Gr
idwiseGemm>
699 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
702 auto e1_grid_desc_mblock_mperblock_nblock_nperblock =
703 GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
706 auto d1s_grid_desc_mblock_mperblock_nblock_nperblock =
707 GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
716 auto launch_kernel = [&](
auto has_main_k_block_loop_) {
720 typename GridwiseGemm::D0sGridPointer,
721 typename GridwiseGemm::D1sGridPointer,
723 A0ElementwiseOperation,
724 B0ElementwiseOperation,
725 CDE0ElementwiseOperation,
726 B1ElementwiseOperation,
727 CDE1ElementwiseOperation,
732 typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
733 typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
734 typename GridwiseGemm::DefaultBlock2E1TileMap,
735 ComputeBasePtrOfStridedBatch,
736 has_main_k_block_loop_>;
758 d1s_grid_desc_mblock_mperblock_nblock_nperblock,
759 e1_grid_desc_mblock_mperblock_nblock_nperblock,
767 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
773 return launch_kernel(integral_constant<bool, false>{});
800 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
811 template <
typename RefLayout,
typename DsLayout, const index_t NumDTensor>
814 static bool valid =
true;
881 const B0DataType* p_b0,
882 std::array<const void*, NumD0Tensor> p_d0s,
883 const B1DataType* p_b1,
884 std::array<const void*, NumD1Tensor> p_d1s,
893 std::array<index_t, NumD0Tensor> StrideD0s,
895 std::array<index_t, NumD1Tensor> StrideD1s,
899 std::array<index_t, NumD0Tensor> BatchStrideD0s,
901 std::array<index_t, NumD1Tensor> BatchStrideD1s,
903 A0ElementwiseOperation a0_element_op,
904 B0ElementwiseOperation b0_element_op,
905 CDE0ElementwiseOperation cde0_element_op,
906 B1ElementwiseOperation b1_element_op,
907 CDE1ElementwiseOperation cde1_element_op)
917 StrideE1, BatchStrideA0,
918 BatchStrideB0, BatchStrideD0s,
919 BatchStrideB1, BatchStrideD1s,
920 BatchStrideE1, a0_element_op,
921 b0_element_op, cde0_element_op,
922 b1_element_op, cde1_element_op};
928 std::unique_ptr<BaseArgument>
931 std::array<const void*, NumD0Tensor> p_d0s,
933 std::array<const void*, NumD1Tensor> p_d1s,
942 std::array<ck::index_t, NumD0Tensor> StrideD0s,
944 std::array<ck::index_t, NumD1Tensor> StrideD1s,
948 std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
950 std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
952 A0ElementwiseOperation a0_element_op,
953 B0ElementwiseOperation b0_element_op,
954 CDE0ElementwiseOperation cde0_element_op,
955 B1ElementwiseOperation b1_element_op,
956 CDE1ElementwiseOperation cde1_element_op)
override
958 return std::make_unique<Argument>(
static_cast<const A0DataType*
>(p_a0),
959 static_cast<const B0DataType*
>(p_b0),
961 static_cast<const B1DataType*
>(p_b1),
963 static_cast<E1DataType*
>(p_e1),
991 return std::make_unique<Invoker>(
Invoker{});
997 auto str = std::stringstream();
1000 str <<
"DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle"
1002 << BlockSize <<
", "
1003 << Gemm0MPerBlock <<
", "
1004 << Gemm0NPerBlock <<
", "
1005 << Gemm0KPerBlock <<
", "
1009 << Gemm0MPerXdl <<
", "
1010 << Gemm0NPerXdl <<
", "
1011 << Gemm0MXdlPerWave <<
", "
1012 << Gemm0NXdlPerWave <<
", "
1013 << Gemm1NXdlPerWave <<
"> ";
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_batched_gemm_gemm_xdl_cshuffle_v1(const A0B0B1DataType *__restrict__ p_a0_grid, const A0B0B1DataType *__restrict__ p_b0_grid, D0sPointer p_d0s_grid, const A0B0B1DataType *__restrict__ p_b1_grid, D1sPointer p_d1s_grid, E1DataType *__restrict__ p_e1_grid, const A0ElementwiseOperation a0_element_op, const B0ElementwiseOperation b0_element_op, const CDE0ElementwiseOperation cde0_element_op, const B1ElementwiseOperation b1_element_op, const CDE1ElementwiseOperation cde1_element_op, const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, const D0sGridDesc_M_N d0s_griddesc_m_n, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock d1s_grid_desc_mblock_mperblock_nblock_nperblock, const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e1_grid_desc_mblock_mperblock_nblock_nperblock, const Block2E1TileMap block_2_e1tile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:47
Definition convolution_backward_data_specialization.hpp:7
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
int64_t long_index_t
Definition ck.hpp:300
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:86
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::CheckValidity static __host__ constexpr bool CheckValidity(const A0GridDesc_M_K &a0_grid_desc_m_k, const B0GridDesc_N_K &b0_grid_desc_n_k, const B1GridDesc_N_K &b1_grid_desc_n_k, const E1GridDesc_M_N &e1_grid_desc_m_n, const Block2E1TileMap &block_2_e1tile_map)
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:286
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::D0sGridPointer decltype(MakeD0sGridPointer()) D0sGridPointer
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:553
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDefaultB1GridDescriptor_BK0_N_BK1 __host__ static __device__ constexpr auto MakeDefaultB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k)
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:443
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::DefaultBlock2E1TileMap remove_cvref_t< decltype(MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))> DefaultBlock2E1TileMap
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:520
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDefaultB0GridDescriptor_BK0_N_BK1 __host__ static __device__ constexpr auto MakeDefaultB0GridDescriptor_BK0_N_BK1(const B0GridDesc_N_K &b0_grid_desc_n_k)
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:386
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDefaultA0GridDescriptor_AK0_M_AK1 __host__ static __device__ constexpr auto MakeDefaultA0GridDescriptor_AK0_M_AK1(const A0GridDesc_M_K &a0_grid_desc_m_k)
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:369
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::D1sGridPointer decltype(MakeD1sGridPointer()) D1sGridPointer
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:554
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:547
CDE1ElementwiseOperation cde1_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:678
A0GridDesc_M_K a0_grid_desc_m_k_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:658
const B0DataType * p_b0_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:651
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:668
index_t batch_count_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:681
const B1DataType * p_b1_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:653
B0GridDesc_N_K b0_grid_desc_n_k_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:659
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:682
A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:666
GridwiseGemm64::D0sGridPointer p_d0s_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:652
Argument(const A0DataType *p_a0_grid, const B0DataType *p_b0_grid, std::array< const void *, NumD0Tensor > p_d0s_grid, const B1DataType *p_b1_grid, std::array< const void *, NumD1Tensor > p_d1s_grid, E1DataType *p_e1_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA0, index_t StrideB0, std::array< index_t, NumD0Tensor > StrideD0s, index_t StrideB1, std::array< index_t, NumD1Tensor > StrideD1s, index_t StrideE1, index_t BatchStrideA0, index_t BatchStrideB0, std::array< index_t, NumD0Tensor > BatchStrideD0s, index_t BatchStrideB1, std::array< index_t, NumD1Tensor > BatchStrideD1s, index_t BatchStrideE1, A0ElementwiseOperation a0_element_op, B0ElementwiseOperation b0_element_op, CDE0ElementwiseOperation cde0_element_op, B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:548
B1ElementwiseOperation b1_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:677
E1GridDesc_M_N e1_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:663
B1GridDesc_N_K b1_grid_desc_n_k_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:661
B0ElementwiseOperation b0_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:675
E1DataType * p_e1_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:655
D1sGridDesc_M_N d1s_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:662
D0sGridDesc_M_N d0s_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:660
A0ElementwiseOperation a0_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:674
CDE0ElementwiseOperation cde0_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:676
GridwiseGemm64::D1sGridPointer p_d1s_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:654
GridwiseGemm64::DefaultBlock2E1TileMap block_2_e1tile_map_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:671
B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:667
const A0DataType * p_a0_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:650
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:687
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:777
DeviceOp::Argument Argument
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:688
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:797
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:691
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:235
static constexpr index_t NumD1Tensor
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:254
remove_cvref_t< decltype(MakeD0sGridDescriptor_M_N({}, {}, {}))> D0sGridDesc_M_N
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:459
static auto MakeB1GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:334
static constexpr auto Gemm0MXdlPerWave32
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:245
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultB0GridDescriptor_BK0_N_BK1( B0GridDesc_N_K{}))> B0GridDesc_BK0_N_BK1
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:538
remove_cvref_t< decltype(MakeD1sGridDescriptor_M_N({}, {}, {}))> D1sGridDesc_M_N
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:461
static auto MakeB0GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:295
static constexpr auto I7
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:263
static bool CheckDLayout()
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:812
static constexpr auto I8
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:264
static constexpr auto gemm0_padder
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:267
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:804
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:875
GridwiseGemmBase< Gemm0MXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:533
static auto MakeA0GridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA0)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:276
static constexpr auto I4
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:260
static auto MakeD1sGridDescriptor_M_N(const std::array< index_t, NumD1Tensor > &MRaws, const std::array< index_t, NumD1Tensor > &NRaws, const std::array< index_t, NumD1Tensor > &DsStride)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:385
GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:466
GridwiseGemmBase< math::max(Gemm0MXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:532
static auto MakeD0GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD0)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:315
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle DeviceOp
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:236
static auto MakeInvoker()
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:925
static constexpr auto I0
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:256
static auto MakeArgument(const A0DataType *p_a0, const B0DataType *p_b0, std::array< const void *, NumD0Tensor > p_d0s, const B1DataType *p_b1, std::array< const void *, NumD1Tensor > p_d1s, E1DataType *p_e1, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA0, index_t StrideB0, std::array< index_t, NumD0Tensor > StrideD0s, index_t StrideB1, std::array< index_t, NumD1Tensor > StrideD1s, index_t StrideE1, index_t BatchStrideA0, index_t BatchStrideB0, std::array< index_t, NumD0Tensor > BatchStrideD0s, index_t BatchStrideB1, std::array< index_t, NumD1Tensor > BatchStrideD1s, index_t BatchStrideE1, A0ElementwiseOperation a0_element_op, B0ElementwiseOperation b0_element_op, CDE0ElementwiseOperation cde0_element_op, B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:880
static constexpr auto I2
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:258
decltype(MakeB1GridDescriptor_N_K(1, 1, 1)) B1GridDesc_N_K
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:460
static auto MakeD0sGridDescriptor_M_N(const std::array< index_t, NumD1Tensor > &MRaws, const std::array< index_t, NumD1Tensor > &NRaws, const std::array< index_t, NumD1Tensor > &DsStride)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:372
static constexpr auto gemm1_padder
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:271
std::string GetTypeString() const override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:995
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultA0GridDescriptor_AK0_M_AK1( A0GridDesc_M_K{}))> A0GridDesc_AK0_M_AK1
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:535
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:824
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a0, const void *p_b0, std::array< const void *, NumD0Tensor > p_d0s, const void *p_b1, std::array< const void *, NumD1Tensor > p_d1s, void *p_e1, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA0, index_t StrideB0, std::array< ck::index_t, NumD0Tensor > StrideD0s, index_t StrideB1, std::array< ck::index_t, NumD1Tensor > StrideD1s, index_t StrideE1, index_t BatchStrideA0, index_t BatchStrideB0, std::array< ck::index_t, NumD0Tensor > BatchStrideD0s, index_t BatchStrideB1, std::array< ck::index_t, NumD1Tensor > BatchStrideD1s, index_t BatchStrideE1, A0ElementwiseOperation a0_element_op, B0ElementwiseOperation b0_element_op, CDE0ElementwiseOperation cde0_element_op, B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op) override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:929
static constexpr auto I1
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:257
static constexpr auto Gemm0MXdlPerWave64
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:238
static constexpr auto I9
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:265
decltype(MakeA0GridDescriptor_M_K(1, 1, 1)) A0GridDesc_M_K
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:457
decltype(MakeE1GridDescriptor_M_N< E1Layout >(1, 1, 1)) E1GridDesc_M_N
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:462
decltype(MakeB0GridDescriptor_N_K(1, 1, 1)) B0GridDesc_N_K
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:458
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultB1GridDescriptor_BK0_N_BK1( B1GridDesc_N_K{}))> B1GridDesc_BK0_N_BK1
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:541
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:989
static constexpr auto I5
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:261
static constexpr auto I6
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:262
static constexpr auto I3
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:259
static constexpr index_t NumD0Tensor
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:253
static auto MakeE1GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE1)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:354
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:399
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:415
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number< I > d1_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:443
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:432
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:420
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, Number< I > d1_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:426
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, index_t BatchStrideB0, std::array< index_t, NumD0Tensor > BatchStrideD0s, index_t BatchStrideB1, std::array< index_t, NumD1Tensor > BatchStrideD1s, index_t BatchStrideE1)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:400
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:437
Definition device_batched_gemm_multiple_d_gemm_multiple_d.hpp:33
Definition matrix_padder.hpp:204
#define CK_ENV(name)
Definition utility/env.hpp:129