25template <
typename GridwiseGemm,
26 bool HasMainKBlockLoop,
31#if CK_USE_LAUNCH_BOUNDS
37#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
38 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
40 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
42 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
58template <
typename GridwiseGemm,
59 bool HasMainKBlockLoop,
64#if CK_USE_LAUNCH_BOUNDS
70#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
71 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
75 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
76 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
78 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
95template <
typename ALayout,
100 typename AccDataType,
101 typename CShuffleDataType,
104 typename AElementwiseOperation,
105 typename BElementwiseOperation,
106 typename CElementwiseOperation,
118 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
119 typename ABlockTransferThreadClusterArrangeOrder,
120 typename ABlockTransferSrcAccessOrder,
121 index_t ABlockTransferSrcVectorDim,
122 index_t ABlockTransferSrcScalarPerVector,
123 index_t ABlockTransferDstScalarPerVector_AK1,
124 bool AThreadTransferSrcResetCoordinateAfterRun,
126 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
127 typename BBlockTransferThreadClusterArrangeOrder,
128 typename BBlockTransferSrcAccessOrder,
129 index_t BBlockTransferSrcVectorDim,
130 index_t BBlockTransferSrcScalarPerVector,
131 index_t BBlockTransferDstScalarPerVector_BK1,
132 bool BThreadTransferSrcResetCoordinateAfterRun,
134 index_t CShuffleMXdlPerWavePerShuffle,
135 index_t CShuffleNXdlPerWavePerShuffle,
136 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
137 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
140 typename ComputeTypeA = CDataType,
141 typename ComputeTypeB = ComputeTypeA>
172 return static_cast<const ADataType_*
>(
nullptr);
183 return static_cast<const BDataType_*
>(
nullptr);
194 return static_cast<const DDataType*
>(
nullptr);
246 auto K_t = K_Batch * KPerBlock;
247 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
252 auto K_t = K_Batch * KPerBlock;
253 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
258 auto K_t = K_Batch * KPerBlock;
259 return (K + K_t - 1) / K_t * KPerBlock;
265 auto K_t = K_Batch * KReadVec;
266 return (K + K_t - 1) / K_t * KReadVec;
279 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
297 const auto a_grid_desc_mraw_kraw = [&]() {
310 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
311 GemmSpec == GemmSpecialization::MNKPadding)
314 const auto a_grid_desc_m_k =
328 return a_grid_desc_ak0_m_ak1;
330 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
331 GemmSpec == GemmSpecialization::MNPadding)
335 a_grid_desc_mraw_kraw,
341 return a_grid_desc_ak0_m_ak1;
343 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
344 GemmSpec == GemmSpecialization::NKPadding)
348 a_grid_desc_mraw_kraw,
360 return a_grid_desc_ak0_m_ak1;
366 a_grid_desc_mraw_kraw,
372 return a_grid_desc_ak0_m_ak1;
376 __host__ __device__
static auto
381 const std::array<index_t, NumATensor>& StrideAs,
394 const auto b_grid_desc_nraw_kraw = [&]() {
407 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
408 GemmSpec == GemmSpecialization::MNKPadding)
411 const auto b_grid_desc_n_k =
425 return b_grid_desc_bk0_n_bk1;
427 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
428 GemmSpec == GemmSpecialization::MNPadding)
432 b_grid_desc_nraw_kraw,
438 return b_grid_desc_bk0_n_bk1;
440 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
441 GemmSpec == GemmSpecialization::MKPadding)
445 b_grid_desc_nraw_kraw,
457 return b_grid_desc_bk0_n_bk1;
463 b_grid_desc_nraw_kraw,
469 return b_grid_desc_bk0_n_bk1;
473 __host__ __device__
static auto
478 const std::array<index_t, NumBTensor>& StrideBs,
488 template <
typename ABlockDesc_AK0_M_AK1>
489 __host__ __device__
static constexpr auto
492 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
497 template <
typename BBlockDesc_BK0_N_BK1>
498 __host__ __device__
static constexpr auto
501 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
506 __host__ __device__
static auto
509 const auto c_grid_desc_mraw_nraw = [&]() {
522 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
523 GemmSpec == GemmSpecialization::MNKPadding)
532 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
533 GemmSpec == GemmSpecialization::MKPadding)
537 c_grid_desc_mraw_nraw,
542 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
543 GemmSpec == GemmSpecialization::NKPadding)
547 c_grid_desc_mraw_nraw,
555 return c_grid_desc_mraw_nraw;
572 std::array<index_t, NumATensor> StrideAs_,
573 std::array<index_t, NumBTensor> StrideBs_,
574 std::array<index_t, NumDTensor> StrideDs_,
598 std::cout <<
"problem {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
600 <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0
601 <<
", " <<
"MBlock: " <<
MBlock <<
", " <<
"NBlock: " <<
NBlock <<
"}"
628 __host__
Argument(std::array<const void*, NumATensor> p_as_grid_,
629 std::array<const void*, NumBTensor> p_bs_grid_,
630 std::array<const void*, NumDTensor> p_ds_grid_,
635 std::array<index_t, NumATensor> StrideAs_,
636 std::array<index_t, NumBTensor> StrideBs_,
637 std::array<index_t, NumDTensor> StrideDs_,
643 :
Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideC_, k_batch_},
647 p_c_grid{static_cast<CDataType*>(p_c_grid_)},
658 p_as_grid(i) =
static_cast<const ADataType_*
>(p_as_grid_[i]);
666 p_bs_grid(i) =
static_cast<const BDataType_*
>(p_bs_grid_[i]);
674 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
725 struct SplitKBatchOffsetMultiABD
727 __device__ SplitKBatchOffsetMultiABD(AsGridPointer& p_as_grid,
728 BsGridPointer& p_bs_grid,
732 using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
733 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout_>)
735 as_k_split_offset[i] = blockIdx.z * karg.KRead;
737 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout_>)
739 as_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideAs[i];
742 p_as_grid_(i) = p_as_grid[i] + as_k_split_offset[i];
745 static_for<0, NumBTensor, 1>{}([&](
auto i) {
747 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout_>)
749 bs_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideBs[i];
751 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout_>)
753 bs_k_split_offset[i] = blockIdx.z * karg.KRead;
756 p_bs_grid_(i) = p_bs_grid[i] + bs_k_split_offset[i];
759 if(blockIdx.z <
static_cast<uint32_t>(karg.KBatch - 1))
765 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
769 AsGridPointer p_as_grid_;
770 BsGridPointer p_bs_grid_;
771 std::array<index_t, NumATensor> as_k_split_offset;
772 std::array<index_t, NumBTensor> bs_k_split_offset;
778 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
779 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
780 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
782 if constexpr(ABlockLdsExtraM)
792 constexpr auto MLdsLayer = 32 * 4 / KPerBlock /
sizeof(
LDSTypeA) < 1
794 : 32 * 4 / KPerBlock /
sizeof(
LDSTypeA);
809 a_lds_block_desc_permuted,
817 a_lds_block_desc_ak0_mldslayer_m_ak1,
825 return a_lds_block_desc_ak0_m_ak1;
832 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
833 constexpr auto M1 = MPerBlock / M0;
835 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
836 constexpr auto K0PerThreadWrite =
AK0Number / KThreadWrite;
837 constexpr auto KThreadRead = WaveSize / MPerXdl;
838 constexpr auto K0PerThreadRead =
AK0Number / KThreadRead;
843 constexpr auto KThreadReadPerm =
844 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
845 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
859 Number<kfold * M0 / mpair>{},
878 a_lds_block_desc_permuted,
900 a_lds_block_desc_unmerged,
903 Number<KThreadWrite / kfold / KThreadReadPerm>{},
912 return a_lds_block_desc_ak0_m_ak1;
918 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
919 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
920 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
922 if constexpr(BBlockLdsExtraN)
931 constexpr auto NLdsLayer = 32 * 4 / KPerBlock /
sizeof(
LDSTypeB) < 1
933 : 32 * 4 / KPerBlock /
sizeof(
LDSTypeB);
949 b_lds_block_desc_permuted,
957 b_lds_block_desc_bk0_nldslayer_n_bk1,
965 return b_lds_block_desc_bk0_n_bk1;
969 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
970 constexpr auto N1 = NPerBlock / N0;
972 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
973 constexpr auto K0PerThreadWrite =
BK0Number / KThreadWrite;
974 constexpr auto KThreadRead = WaveSize / NPerXdl;
975 constexpr auto K0PerThreadRead =
BK0Number / KThreadRead;
980 constexpr auto KThreadReadPerm =
981 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
982 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
996 Number<kfold * N0 / npair>{},
1015 b_lds_block_desc_permuted,
1037 b_lds_block_desc_unmerged,
1040 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1049 return b_lds_block_desc_bk0_n_bk1;
1055 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1056 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1058 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1065 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1083 ABlockTransferSrcScalarPerVector,
1084 BBlockTransferSrcScalarPerVector,
1104 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1107 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1110 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1113 constexpr auto c_block_size =
1114 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1117 b_block_space_size_aligned *
sizeof(
LDSTypeB)),
1118 c_block_size *
sizeof(CShuffleDataType));
1126 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1127 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1128 "Invalid tuning param!");
1135 if(!(karg.M % MPerBlock == 0))
1139 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1140 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1152 if(!(karg.N % NPerBlock == 0))
1156 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1157 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1170 auto K_t = karg.KBatch * KPerBlock;
1171 if(!(karg.K % K_t == 0))
1175 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1176 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1177 <<
", in function: " << __func__ << std::endl;
1185 auto K_t = karg.KBatch * KReadVec;
1187 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1195 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1199 std::cout <<
"Arg K (" << karg.K
1200 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1201 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1202 << __LINE__ <<
", in function: " << __func__ << std::endl;
1209 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1213 std::cout <<
"Arg M (" << karg.M
1214 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1215 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1216 << __LINE__ <<
", in function: " << __func__ << std::endl;
1224 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1228 std::cout <<
"Arg N (" << karg.N
1229 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1230 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1231 << __LINE__ <<
", in function: " << __func__ << std::endl;
1238 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1242 std::cout <<
"Arg K (" << karg.K
1243 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1244 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1245 << __LINE__ <<
", in function: " << __func__ << std::endl;
1253 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1257 std::cout <<
"Arg N (" << karg.N
1258 <<
") value is not a multiple of "
1259 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1260 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1261 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1269 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1273 std::cout <<
"Arg M (" << karg.M
1274 <<
") value is not a multiple of "
1275 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1276 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1277 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1285 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1289 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1301 const index_t num_loop = K / KPerBlock;
1303 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1308 const index_t num_loop = K / KPerBlock;
1310 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1313 template <
typename CGr
idDesc>
1315 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1324 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1327 template <
typename DsGr
idDesc>
1329 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1334 ds_grid_desc_m_n[i], MBlock, NBlock);
1346 template <
bool HasMainKBlockLoop,
1352 CDataType* p_c_grid,
1354 const Problem& problem,
1355 const AElementwiseOperation& a_element_op,
1356 const BElementwiseOperation& b_element_op,
1357 const CElementwiseOperation& c_element_op)
1371 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
1373 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
1375 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1377 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1379 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1382 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1387 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs[j]);
1391 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1393 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1403 p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
1410 p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
1415 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1420 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1425 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1427 const auto block_work_idx =
1430 if(!block_2_ctile_map.ValidCTileIndex(
1432 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1433 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1438 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1439 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1442 const index_t m_block_data_idx_on_grid =
1443 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1445 const index_t n_block_data_idx_on_grid =
1446 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1459 auto a_blockwise_copy =
1461 AElementwiseOperation,
1465 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1466 ABlockTransferThreadClusterArrangeOrder,
1469 decltype(a_grid_desc_ak0_m_ak1),
1470 decltype(a_block_desc_ak0_m_ak1),
1471 ABlockTransferSrcAccessOrder,
1473 ABlockTransferSrcVectorDim,
1475 ABlockTransferSrcScalarPerVector,
1476 ABlockTransferDstScalarPerVector_AK1,
1479 AThreadTransferSrcResetCoordinateAfterRun,
1481 BlockwiseGemmPipe::GlobalBufferNum>(
1482 a_grid_desc_ak0_m_ak1,
1485 a_block_desc_ak0_m_ak1,
1489 const auto idx_as_block_begin =
1497 decltype(as_grid_desc_ak0_m_ak1),
1498 decltype(
tie(a_block_desc_ak0_m_ak1)),
1499 AElementwiseOperation,
1502 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1503 ABlockTransferThreadClusterArrangeOrder,
1504 ABlockTransferSrcAccessOrder,
1506 ABlockTransferSrcVectorDim,
1508 ABlockTransferSrcScalarPerVector,
1509 ABlockTransferDstScalarPerVector_AK1,
1512 BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1,
1514 tie(a_block_desc_ak0_m_ak1),
1521 auto b_blockwise_copy =
1523 BElementwiseOperation,
1527 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1528 BBlockTransferThreadClusterArrangeOrder,
1531 decltype(b_grid_desc_bk0_n_bk1),
1532 decltype(b_block_desc_bk0_n_bk1),
1533 BBlockTransferSrcAccessOrder,
1535 BBlockTransferSrcVectorDim,
1537 BBlockTransferSrcScalarPerVector,
1538 BBlockTransferDstScalarPerVector_BK1,
1541 BThreadTransferSrcResetCoordinateAfterRun,
1543 BlockwiseGemmPipe::GlobalBufferNum>(
1544 b_grid_desc_bk0_n_bk1,
1547 b_block_desc_bk0_n_bk1,
1551 const auto idx_bs_block_begin =
1559 decltype(bs_grid_desc_bk0_n_bk1),
1560 decltype(
tie(b_block_desc_bk0_n_bk1)),
1561 BElementwiseOperation,
1564 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1565 BBlockTransferThreadClusterArrangeOrder,
1566 BBlockTransferSrcAccessOrder,
1568 BBlockTransferSrcVectorDim,
1570 BBlockTransferSrcScalarPerVector,
1571 BBlockTransferDstScalarPerVector_BK1,
1574 BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1,
1576 tie(b_block_desc_bk0_n_bk1),
1584 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1588 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1593 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1599 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1601 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1603 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1604 (as_grid_desc_ak0_m_ak1[
I0].GetLength(
I0) * as_grid_desc_ak0_m_ak1[
I0].GetLength(
I2)) /
1608 a_block_desc_ak0_m_ak1,
1612 a_block_slice_copy_step,
1613 bs_grid_desc_bk0_n_bk1,
1614 b_block_desc_bk0_n_bk1,
1618 b_block_slice_copy_step,
1620 num_k_block_main_loop);
1624 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1625 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1628 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1629 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1632 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1633 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1637 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1638 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1640 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1641 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1642 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1643 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1644 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1645 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1646 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1647 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1649 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1653 static_cast<CShuffleDataType*
>(p_shared),
1654 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1657 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1677 const auto c_thread_mtx_on_block =
1678 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1680 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1681 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1683 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1689 const auto m_thread_data_on_block_idx =
1690 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1693 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1699 const auto n_thread_data_on_block_idx =
1700 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1704 auto c_thread_copy_vgpr_to_lds =
1707 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1708 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1710 Sequence<CShuffleMXdlPerWavePerShuffle,
1711 CShuffleNXdlPerWavePerShuffle,
1724 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1727 m_thread_data_on_block_idx[
I1],
1728 n_thread_data_on_block_idx[
I1],
1729 m_thread_data_on_block_idx[
I2],
1730 m_thread_data_on_block_idx[
I3],
1731 m_thread_data_on_block_idx[
I4],
1732 n_thread_data_on_block_idx[
I2]),
1739 CElementwiseOperation,
1740 CGlobalMemoryDataOperation,
1742 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1744 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1745 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1749 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1750 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1753 CShuffleBlockTransferScalarPerVector_NPerBlock,
1756 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1758 c_grid_desc_mblock_mperblock_nblock_nperblock,
1762 using EDataType = CDataType;
1766 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1768 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1773 tie(c_shuffle_block_buf),
1775 {
return ds_grid_buf[i]; },
1787 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1788 c_grid_desc_mblock_mperblock_nblock_nperblock;
1790 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
1791 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1792 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1793 const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock =
1794 CShuffleBlockTransferScalarPerVector_NPerBlock;
1800 decltype(c_ds_desc_refs),
1801 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1802 CElementwiseOperation,
1806 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1808 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1809 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1815 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
1816 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
1823 idx_c_ds_block_begin,
1824 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1831 constexpr auto sfc_c_vgpr =
1834 Sequence<CShuffleMXdlPerWavePerShuffle,
1835 CShuffleNXdlPerWavePerShuffle,
1843 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1846 constexpr auto sfc_c_global =
1850 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1852 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1855 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1859 constexpr auto sfc_cde_block =
1863 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1865 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1867 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1875 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1876 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1878 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1879 c_shuffle_block_buf);
1886 c_shuffle_block_copy_lds_to_global.Run(
1887 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1888 c_shuffle_block_buf,
1889 c_grid_desc_mblock_mperblock_nblock_nperblock,
1892 if constexpr(access_id < num_access - 1)
1894 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1897 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1898 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1902 cde_block_copy_lds_and_global.Run(
1905 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1908 if constexpr(access_id < num_access - 1)
1910 constexpr auto cde_lds_and_global_step =
1911 sfc_cde_block.GetForwardStep(access_id);
1915 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1916 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1920 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1921 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1923 cde_lds_and_global_step);
1931 template <
bool HasMainKBlockLoop,
1937 CDataType* p_c_grid,
1940 const Problem& problem,
1941 const AElementwiseOperation& a_element_op,
1942 const BElementwiseOperation& b_element_op,
1943 const CElementwiseOperation& c_element_op)
1950 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
1952 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
1954 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1956 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1958 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1961 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1963 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1965 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1974 p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
1981 p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
1986 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1991 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1996 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1998 const auto block_work_idx =
2001 if(!block_2_ctile_map.ValidCTileIndex(
2003 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
2004 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
2009 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
2010 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
2013 const index_t m_block_data_idx_on_grid =
2014 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
2016 const index_t n_block_data_idx_on_grid =
2017 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2030 auto a_blockwise_copy =
2032 AElementwiseOperation,
2036 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2037 ABlockTransferThreadClusterArrangeOrder,
2040 decltype(a_grid_desc_ak0_m_ak1),
2041 decltype(a_block_desc_ak0_m_ak1),
2042 ABlockTransferSrcAccessOrder,
2044 ABlockTransferSrcVectorDim,
2046 ABlockTransferSrcScalarPerVector,
2047 ABlockTransferDstScalarPerVector_AK1,
2050 AThreadTransferSrcResetCoordinateAfterRun,
2052 BlockwiseGemmPipe::GlobalBufferNum>(
2053 a_grid_desc_ak0_m_ak1,
2056 a_block_desc_ak0_m_ak1,
2060 const auto idx_as_block_begin =
2068 decltype(as_grid_desc_ak0_m_ak1),
2069 decltype(
tie(a_block_desc_ak0_m_ak1)),
2070 AElementwiseOperation,
2073 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2074 ABlockTransferThreadClusterArrangeOrder,
2075 ABlockTransferSrcAccessOrder,
2077 ABlockTransferSrcVectorDim,
2079 ABlockTransferSrcScalarPerVector,
2080 ABlockTransferDstScalarPerVector_AK1,
2083 BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1,
2085 tie(a_block_desc_ak0_m_ak1),
2093 auto b_blockwise_copy =
2095 BElementwiseOperation,
2099 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2100 BBlockTransferThreadClusterArrangeOrder,
2103 decltype(b_grid_desc_bk0_n_bk1),
2104 decltype(b_block_desc_bk0_n_bk1),
2105 BBlockTransferSrcAccessOrder,
2107 BBlockTransferSrcVectorDim,
2109 BBlockTransferSrcScalarPerVector,
2110 BBlockTransferDstScalarPerVector_BK1,
2113 BThreadTransferSrcResetCoordinateAfterRun,
2115 BlockwiseGemmPipe::GlobalBufferNum>(
2116 b_grid_desc_bk0_n_bk1,
2119 b_block_desc_bk0_n_bk1,
2123 const auto idx_bs_block_begin =
2131 decltype(bs_grid_desc_bk0_n_bk1),
2132 decltype(
tie(b_block_desc_bk0_n_bk1)),
2133 BElementwiseOperation,
2136 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2137 BBlockTransferThreadClusterArrangeOrder,
2138 BBlockTransferSrcAccessOrder,
2140 BBlockTransferSrcVectorDim,
2142 BBlockTransferSrcScalarPerVector,
2143 BBlockTransferDstScalarPerVector_BK1,
2146 BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1,
2148 tie(b_block_desc_bk0_n_bk1),
2155 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2158 static_cast<LDSTypeA*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2161 static_cast<LDSTypeB*
>(p_shared_0) +
2163 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2166 static_cast<LDSTypeA*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2169 static_cast<LDSTypeB*
>(p_shared_1) +
2171 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2173 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2174 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2180 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2182 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2184 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2185 (as_grid_desc_ak0_m_ak1[
I0].GetLength(
I0) * as_grid_desc_ak0_m_ak1[
I0].GetLength(
I2)) /
2189 a_block_desc_ak0_m_ak1,
2193 a_block_slice_copy_step,
2194 bs_grid_desc_bk0_n_bk1,
2195 b_block_desc_bk0_n_bk1,
2199 b_block_slice_copy_step,
2201 num_k_block_main_loop);
2205 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2206 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2209 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2210 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2213 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2214 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2218 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2219 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2221 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2222 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2223 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2224 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2225 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2226 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2227 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2228 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2230 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2234 static_cast<CShuffleDataType*
>(p_shared_0),
2235 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2238 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2258 const auto c_thread_mtx_on_block =
2259 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2261 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2262 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2264 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2270 const auto m_thread_data_on_block_idx =
2271 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2274 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2280 const auto n_thread_data_on_block_idx =
2281 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2285 auto c_thread_copy_vgpr_to_lds =
2288 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2289 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2291 Sequence<CShuffleMXdlPerWavePerShuffle,
2292 CShuffleNXdlPerWavePerShuffle,
2305 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2308 m_thread_data_on_block_idx[
I1],
2309 n_thread_data_on_block_idx[
I1],
2310 m_thread_data_on_block_idx[
I2],
2311 m_thread_data_on_block_idx[
I3],
2312 m_thread_data_on_block_idx[
I4],
2313 n_thread_data_on_block_idx[
I2]),
2320 CElementwiseOperation,
2321 CGlobalMemoryDataOperation,
2323 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2325 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2326 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2330 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2331 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2334 CShuffleBlockTransferScalarPerVector_NPerBlock,
2337 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2339 c_grid_desc_mblock_mperblock_nblock_nperblock,
2343 using EDataType = CDataType;
2347 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2349 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2354 tie(c_shuffle_block_buf),
2356 {
return ds_grid_buf[i]; },
2368 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2369 c_grid_desc_mblock_mperblock_nblock_nperblock;
2371 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
2372 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2373 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2374 const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock =
2375 CShuffleBlockTransferScalarPerVector_NPerBlock;
2381 decltype(c_ds_desc_refs),
2382 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2383 CElementwiseOperation,
2387 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2389 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2390 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2396 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
2397 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
2404 idx_c_ds_block_begin,
2405 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2412 constexpr auto sfc_c_vgpr =
2415 Sequence<CShuffleMXdlPerWavePerShuffle,
2416 CShuffleNXdlPerWavePerShuffle,
2425 constexpr auto sfc_c_global =
2429 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2431 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2433 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2435 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
2439 constexpr auto sfc_cde_block =
2443 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2445 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2453 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2454 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2456 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2457 c_shuffle_block_buf);
2464 c_shuffle_block_copy_lds_to_global.Run(
2465 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2466 c_shuffle_block_buf,
2467 c_grid_desc_mblock_mperblock_nblock_nperblock,
2470 if constexpr(access_id < num_access - 1)
2472 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2475 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2476 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2480 cde_block_copy_lds_and_global.Run(
2483 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2486 if constexpr(access_id < num_access - 1)
2488 constexpr auto cde_lds_and_global_step =
2489 sfc_cde_block.GetForwardStep(access_id);
2493 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2494 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2498 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2499 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2501 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
unsigned int uint32_t
Definition stdint.h:126
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:643
AsGridPointer p_as_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:678
BsGridPointer p_bs_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:679
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:760
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:642
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, void *p_c_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:628
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:644
DsGridPointer p_ds_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:680
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:641
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:700
std::array< index_t, NumATensor > StrideAs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:609
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:706
std::array< index_t, NumBTensor > StrideBs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:610
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:696
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:611
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:704
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:569
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:709
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:596
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:765
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:814
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:815
__device__ SplitKBatchOffset(Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:690
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:1349
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::NumATensor static constexpr index_t NumATensor
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:162
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:451
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKRead static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:262
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::is_scale_mfma static constexpr auto is_scale_mfma
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:273
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::LDSTypeA BlkGemmPipeSched LDSTypeA
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:153
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run_2Lds static __device__ void Run_2Lds(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:1934
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeGemmMmaTileDescriptor __host__ static __device__ constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:280
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::NumDTensor static constexpr index_t NumDTensor
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:164
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKPadded static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:239
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::NumBTensor static constexpr index_t NumBTensor
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:163
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateMPadded static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:229
static constexpr auto BK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:261
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::is_single_rate_mfma static constexpr bool is_single_rate_mfma
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:264
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetSharedMemoryNumberOfByte static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:1094
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::BlockwiseGemmPipe remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, BlkGemmPipeSched, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1112
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1 static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:776
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::ThisThreadBlock ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:283
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeAsGridPointer static constexpr auto MakeAsGridPointer()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:166
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateAK0Padded static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:244
static constexpr auto I2
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:250
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::KPack static constexpr index_t KPack
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:274
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::lcm_AK1_BK1 static constexpr auto lcm_AK1_BK1
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:263
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeDsGridPointer static constexpr auto MakeDsGridPointer()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:188
static constexpr auto I7
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:255
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKBlockLoopTailNum static __host__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:1306
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateHasMainKBlockLoop static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:1299
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:201
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeBsGridPointer static constexpr auto MakeBsGridPointer()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:177
static constexpr auto I5
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:253
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeBGridDescriptor_BK0_N_BK1 static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:391
static constexpr auto AK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:260
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:1053
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateGridSize static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:224
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateMBlock static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:269
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeAMmaTileDescriptor_M0_M1_M2_K __host__ static __device__ constexpr auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:490
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1413
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeAsGridDescriptor_AK0_M_AK1 __host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:377
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateNPadded static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:234
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1 static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:916
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateBK0Padded static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:250
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeDsGridDescriptor_M_N __host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:559
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1437
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeBMmaTileDescriptor_N0_N1_N2_K __host__ static __device__ constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:499
static constexpr auto I6
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr auto I1
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:249
static constexpr auto I0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:248
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::BsGridPointer decltype(MakeBsGridPointer()) BsGridPointer
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:200
static constexpr auto I3
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:251
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1202
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeBsGridDescriptor_BK0_N_BK1 __host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:474
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::Block2CTileMap BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1428
static constexpr auto I4
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:252
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::DsGridDesc_M_N remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:1339
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::LDSTypeB BlkGemmPipelineVer LDSTypeB
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:154
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::AsGridPointer decltype(MakeAsGridPointer()) AsGridPointer
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:199
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:369
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateNBlock static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:274
static constexpr auto BK0Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:259
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:1314
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:1328
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::AK0Number static constexpr auto AK0Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:258
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeAGridDescriptor_AK0_M_AK1 static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:294
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKPadded __host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:256
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeCGridDescriptor_M_N __host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:507
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition thread_group_tensor_slice_transfer_v7r2.hpp:47
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129