28template <
typename GridwiseGemm,
29 typename BatchedGemmArg,
30 bool HasMainKBlockLoop,
35#if CK_USE_LAUNCH_BOUNDS
40#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
41 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
43 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
45 const index_t g_idx = blockIdx.z % karg.Batch;
46 const index_t k_idx = blockIdx.z / karg.Batch;
48 const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
49 const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
50 const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
51 const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
53 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
58 karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
61 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
62 karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
63 karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
65 karg.p_c_grid + c_batch_offset,
77template <
typename GridwiseGemm,
78 typename BatchedGemmArg,
79 bool HasMainKBlockLoop,
84#if CK_USE_LAUNCH_BOUNDS
89#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
90 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
94 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
95 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
97 const index_t g_idx = blockIdx.z % karg.Batch;
98 const index_t k_idx = blockIdx.z / karg.Batch;
100 const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
101 const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
102 const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
103 const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
105 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
110 karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
113 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
114 karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
115 karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
117 karg.p_c_grid + c_batch_offset,
130namespace tensor_operation {
133template <
typename ALayout,
141 typename GemmAccDataType,
142 typename CShuffleDataType,
143 typename AElementwiseOperation,
144 typename BElementwiseOperation,
145 typename CElementwiseOperation,
157 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
158 typename ABlockTransferThreadClusterArrangeOrder,
159 typename ABlockTransferSrcAccessOrder,
160 index_t ABlockTransferSrcVectorDim,
161 index_t ABlockTransferSrcScalarPerVector,
162 index_t ABlockTransferDstScalarPerVector_AK1,
163 bool ABlockLdsExtraM,
164 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
165 typename BBlockTransferThreadClusterArrangeOrder,
166 typename BBlockTransferSrcAccessOrder,
167 index_t BBlockTransferSrcVectorDim,
168 index_t BBlockTransferSrcScalarPerVector,
169 index_t BBlockTransferDstScalarPerVector_BK1,
170 bool BBlockLdsExtraN,
171 index_t CShuffleMXdlPerWavePerShuffle,
172 index_t CShuffleNXdlPerWavePerShuffle,
173 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
174 typename CDEShuffleBlockTransferScalarPerVectors,
177 typename ComputeTypeA = ADataType,
178 typename ComputeTypeB = BDataType,
179 typename LDSTypeA = ComputeTypeA,
180 typename LDSTypeB = ComputeTypeB>
190 AElementwiseOperation,
191 BElementwiseOperation,
192 CElementwiseOperation>
203 template <index_t NXdlPerWave_>
215 AElementwiseOperation,
216 BElementwiseOperation,
217 CElementwiseOperation,
229 ABlockTransferThreadClusterLengths_AK0_M_AK1,
230 ABlockTransferThreadClusterArrangeOrder,
231 ABlockTransferSrcAccessOrder,
232 ABlockTransferSrcVectorDim,
233 ABlockTransferSrcScalarPerVector,
234 ABlockTransferDstScalarPerVector_AK1,
237 BBlockTransferThreadClusterLengths_BK0_N_BK1,
238 BBlockTransferThreadClusterArrangeOrder,
239 BBlockTransferSrcAccessOrder,
240 BBlockTransferSrcVectorDim,
241 BBlockTransferSrcScalarPerVector,
242 BBlockTransferDstScalarPerVector_BK1,
245 CShuffleMXdlPerWavePerShuffle,
246 CShuffleNXdlPerWavePerShuffle,
247 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
248 CDEShuffleBlockTransferScalarPerVectors,
264 std::array<ck::index_t, NumDTensor> BatchStrideDs,
266 : BatchStrideA_(BatchStrideA),
267 BatchStrideB_(BatchStrideB),
268 BatchStrideDs_(BatchStrideDs),
269 BatchStrideC_(BatchStrideC)
275 return static_cast<long_index_t>(BatchStrideA_) * g_idx;
280 return static_cast<long_index_t>(BatchStrideB_) * g_idx;
285 std::array<long_index_t, NumDTensor> ds_offset_;
288 ds_offset_[i] =
static_cast<long_index_t>(BatchStrideDs_[i]) * g_idx;
296 return static_cast<long_index_t>(BatchStrideC_) * g_idx;
302 std::array<ck::index_t, NumDTensor> BatchStrideDs_;
306 template <
typename Gr
idwiseGemm>
314 const BDataType* p_b_grid_,
315 std::array<const void*, NumDTensor> p_ds_grid_,
316 CDataType* p_e_grid_,
322 std::array<index_t, NumDTensor> StrideDs_,
326 const std::array<ck::index_t, NumDTensor>& BatchStrideDs_,
329 AElementwiseOperation a_element_op_,
330 BElementwiseOperation b_element_op_,
331 CElementwiseOperation c_element_op_,
350 BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_}
360 constexpr int dynamic_smem_size = 0;
361 int max_occupancy = 0;
363 constexpr index_t minimum_occupancy = []() {
370 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
451 template <
typename Gr
idwiseGemm>
455 if(stream_config.log_level_ > 0)
462 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
466 std::tie(gdx, gdy, gdz) =
471 index_t k_grain = arg.KBatch * KPerBlock;
472 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
476 const auto Run = [&](
const auto& kernel) {
477 if(stream_config.flush_cache)
480 std::array<std::size_t, NumDTensor> DsSize;
482 BatchGemmArgument arg_ =
reinterpret_cast<const BatchGemmArgument&
>(arg);
485 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
487 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
490 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) * arg.
Batch;
492 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) * arg.
Batch;
495 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
499 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() *
sizeof(DDataType);
503 stream_config.rotating_count,
507 rotating_mem.Print();
509 auto run_flush_cache = [&]() {
517 hipMemsetAsync(arg_.p_c_grid,
519 arg.
Batch * arg_.M * arg_.N *
sizeof(CDataType),
520 stream_config.stream_id_));
534 const auto clear_workspace = [&]() {
537 hipMemsetAsync(arg.p_c_grid,
539 arg.
Batch * arg.M * arg.N *
sizeof(CDataType),
540 stream_config.stream_id_));
543 BatchGemmArgument arg_ =
reinterpret_cast<const BatchGemmArgument&
>(arg);
554 constexpr index_t minimum_occupancy = []() {
561 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
569 if(has_main_k_block_loop)
625 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
640 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
656 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
672 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
688 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
703 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
745 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
760 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
776 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
792 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
808 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
823 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
985 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
1029 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
1043 std::array<const void*, NumDTensor> p_ds,
1051 std::array<index_t, NumDTensor> StrideDs,
1055 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
1057 AElementwiseOperation a_element_op,
1058 BElementwiseOperation b_element_op,
1059 CElementwiseOperation c_element_op,
1062 return Argument{
static_cast<const ADataType*
>(p_a),
1063 static_cast<const BDataType*
>(p_b),
1065 static_cast<CDataType*
>(p_e),
1087 std::unique_ptr<BaseArgument>
1090 const std::array<const void*, NumDTensor>& p_ds,
1098 const std::array<ck::index_t, NumDTensor>& StrideDs,
1102 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
1104 AElementwiseOperation a_element_op,
1105 BElementwiseOperation b_element_op,
1106 CElementwiseOperation c_element_op,
1109 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
1110 static_cast<const BDataType*
>(p_b),
1112 static_cast<CDataType*
>(p_e),
1134 return std::make_unique<Invoker>(
Invoker{});
1140 auto str = std::stringstream();
1142 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
1146 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
1154 str <<
"DeviceBatchedGemmXdlUniversal"
1157 << std::string(ALayout::name)[0]
1158 << std::string(BLayout::name)[0]
1159 << std::string(CLayout::name)[0]
1162 << BlockSize <<
", "
1164 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
1166 << MPerXDL<<
"x"<<NPerXDL <<
", "
1168 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
1170 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
1171 <<
"BlkGemmPipelineScheduler: "
1172 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
1173 <<
"BlkGemmPipelineVersion: "
1174 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
1175 <<
"BlkGemmPipelinePrefetchStages: "
1176 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:38
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__global__ void kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:87
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
int64_t long_index_t
Definition ck.hpp:300
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::Argument::Argument __host__ Argument()=default
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateKBlockLoopTailNum __host__ static __device__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1381
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeBGridDescriptor_BK0_N_BK1 __host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:435
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1374
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeDsGridDescriptor_M_N __host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:603
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeAGridDescriptor_AK0_M_AK1 __host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:353
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateGridSize static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:219
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1186
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:260
ComputePtrOffsetOfStridedBatch()=default
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:273
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:278
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, std::array< ck::index_t, NumDTensor > BatchStrideDs, index_t BatchStrideC)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:262
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:283
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:294
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:357
int max_occupancy_
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:445
ActiveWorkgroupsPerCU()
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:358
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:308
ArgumentBase(const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t BatchStrideA_, index_t BatchStrideB_, const std::array< ck::index_t, NumDTensor > &BatchStrideDs_, index_t BatchStrideE_, index_t Batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, index_t KBatch_)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:313
index_t Batch
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:309
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:310
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:450
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:452
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:982
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:193
std::string GetTypeString() const override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1138
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:196
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1036
GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:204
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1132
static constexpr index_t NumDTensor
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:198
GridwiseGemm64 GridwiseGemm
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:257
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:195
CDataType CDataType_
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:200
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:989
ArgumentBase< GridwiseGemm64 > Argument
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:354
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:255
CDEShuffleBlockTransferScalarPerVectors CDEShuffleBlockTransferScalarPerVectors_
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:199
static ck::index_t GetMaxOccupancy()
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1182
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, index_t KBatch=1) override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1088
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:995
static auto MakeInvoker()
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1084
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:256
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, index_t KBatch=1)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1041
Definition device_batched_gemm_multi_d.hpp:68
Definition flush_cache.hpp:174