device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp Source File

device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp Source File#

Composable Kernel: device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp Source File
device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21
22template <typename GridwiseGemm,
23 typename FloatAB,
24 typename FloatDsPointer,
25 typename FloatE,
26 typename FloatRsPointer,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CDEElementwiseOperation,
30 typename QsElementwiseOperation,
31 typename RsElementwiseOperation,
32 typename AGridDesc_AK0_M_AK1,
33 typename BGridDesc_BK0_N_BK1,
34 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
35 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
36 typename RsGridDescriptor_MBlock_MPerBlock,
37 typename Block2ETileMap,
38 bool HasMainKBlockLoop>
39__global__ void
40#if CK_USE_LAUNCH_BOUNDS
42#endif
44 const FloatAB* __restrict__ p_a_grid,
45 const FloatAB* __restrict__ p_b_grid,
46 FloatDsPointer p_ds_grid,
47 FloatE* __restrict__ p_e_grid,
48 FloatRsPointer p_rs_grid,
49 const AElementwiseOperation a_element_op,
50 const BElementwiseOperation b_element_op,
51 const CDEElementwiseOperation cde_element_op,
52 const QsElementwiseOperation qs_element_op,
53 const RsElementwiseOperation rs_element_op,
54 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
55 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
56 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
57 ds_grid_desc_mblock_mperblock_nblock_nperblock,
58 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
59 e_grid_desc_mblock_mperblock_nblock_nperblock,
60 const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
61 const Block2ETileMap block_2_etile_map)
62{
63#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
64 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
65 {
66 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
67
68 GridwiseGemm::template Run<HasMainKBlockLoop>(
69 p_a_grid,
70 p_b_grid,
71 p_ds_grid,
72 p_e_grid,
73 p_rs_grid,
74 p_shared,
75 a_element_op,
76 b_element_op,
77 cde_element_op,
78 qs_element_op,
79 rs_element_op,
80 a_grid_desc_ak0_m_ak1,
81 b_grid_desc_bk0_n_bk1,
82 ds_grid_desc_mblock_mperblock_nblock_nperblock,
83 e_grid_desc_mblock_mperblock_nblock_nperblock,
84 rs_grid_desc_mblock_mperblock,
85 block_2_etile_map);
86 }
87#else
88 ignore = p_a_grid;
89 ignore = p_b_grid;
90 ignore = p_ds_grid;
91 ignore = p_e_grid;
92 ignore = p_rs_grid;
93 ignore = a_element_op;
94 ignore = b_element_op;
95 ignore = cde_element_op;
96 ignore = qs_element_op;
97 ignore = rs_element_op;
98 ignore = a_grid_desc_ak0_m_ak1;
99 ignore = b_grid_desc_bk0_n_bk1;
100 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
101 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
102 ignore = rs_grid_desc_mblock_mperblock;
103 ignore = block_2_etile_map;
104#endif
105}
106
107} // namespace ck
108
109namespace ck {
110namespace tensor_operation {
111namespace device {
112
113// GEMM:
114// input : A[AK0, M, AK1]
115// input : B[AK0, N, AK1]
116// input : D0[M, N], D1[M, N], ...
117// output : E[M, N]
118// output : R0[M], R1[M], ...
119// C = a_op(A) * b_op(B)
120// E = cde_op(C, D0, D1, ...)
121// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
122// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
123// Assume:
124// D0, D1, ... and E have the same layout
125template <typename ALayout,
126 typename BLayout,
127 typename DELayout,
128 typename ADataType,
129 typename BDataType,
130 typename GemmAccDataType,
131 typename CShuffleDataType,
132 typename DsDataType,
133 typename EDataType,
134 typename ReduceAccDataType,
135 typename RsDataType,
136 typename AElementwiseOperation,
137 typename BElementwiseOperation,
138 typename CDEElementwiseOperation,
139 typename QsElementwiseOperation,
140 typename RsElementwiseOperation,
141 typename ThreadReduceOperations,
142 typename RsGlobalMemoryDataOperation,
143 GemmSpecialization GemmSpec,
144 index_t NumGemmKPrefetchStage,
145 index_t BlockSize,
146 index_t MPerBlock,
147 index_t NPerBlock,
148 index_t KPerBlock,
149 index_t AK1,
150 index_t BK1,
151 index_t MPerXDL,
152 index_t NPerXDL,
153 index_t MXdlPerWave,
154 index_t NXdlPerWave,
155 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
156 typename ABlockTransferThreadClusterArrangeOrder,
157 typename ABlockTransferSrcAccessOrder,
158 index_t ABlockTransferSrcVectorDim,
159 index_t ABlockTransferSrcScalarPerVector,
160 index_t ABlockTransferDstScalarPerVector_AK1,
161 bool ABlockLdsExtraM,
162 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
163 typename BBlockTransferThreadClusterArrangeOrder,
164 typename BBlockTransferSrcAccessOrder,
165 index_t BBlockTransferSrcVectorDim,
166 index_t BBlockTransferSrcScalarPerVector,
167 index_t BBlockTransferDstScalarPerVector_BK1,
168 bool BBlockLdsExtraN,
169 index_t CShuffleMXdlPerWavePerShuffle,
170 index_t CShuffleNXdlPerWavePerShuffle,
171 typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
172 index_t CDEReduceThreadTransferScalarPerVector_NPerBlock,
173 index_t RThreadTransferDstScalarPerVector_MPerBlock,
176 : public DeviceGemmMultipleDMultipleR<ALayout,
177 BLayout,
178 DELayout,
179 ADataType,
180 BDataType,
181 DsDataType,
182 EDataType,
183 RsDataType,
184 AElementwiseOperation,
185 BElementwiseOperation,
186 CDEElementwiseOperation,
187 QsElementwiseOperation,
188 RsElementwiseOperation>
189{
191
193 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
194 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
195
196 static constexpr index_t NumDTensor = DsDataType::Size();
197 static constexpr index_t NumRTensor = RsDataType::Size();
198
199 static constexpr auto I0 = Number<0>{};
200 static constexpr auto I1 = Number<1>{};
201 static constexpr auto I2 = Number<2>{};
202 static constexpr auto I3 = Number<3>{};
203
204 static constexpr auto matrix_padder =
205 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
206
207 static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
208 {
209 const auto a_grid_desc_mraw_kraw = [&]() {
211 {
212 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
213 make_tuple(StrideA, I1));
214 }
216 {
217 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
218 make_tuple(I1, StrideA));
219 }
220 }();
221
222 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
223 }
224
225 static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
226 {
227 const auto b_grid_desc_nraw_kraw = [&]() {
229 {
230 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
231 make_tuple(I1, StrideB));
232 }
234 {
235 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
236 make_tuple(StrideB, I1));
237 }
238 }();
239
240 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
241 }
242
243 static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
244 {
245 const auto e_grid_desc_mraw_nraw = [&]() {
247 {
248 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
249 make_tuple(StrideE, I1));
250 }
252 {
253 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
254 make_tuple(I1, StrideE));
255 }
256 }();
257
258 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
259 }
260
261 // assume D is packed tensor
263 {
264 const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
265
266 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
267 const auto MPad = M - MRaw;
268
269 if constexpr(GemmSpec == GemmSpecialization::MPadding ||
270 GemmSpec == GemmSpecialization::MNPadding ||
271 GemmSpec == GemmSpecialization::MKPadding ||
273 {
274 // pad M
275 return transform_tensor_descriptor(r_grid_desc_mraw,
279 }
280 else
281 {
282 // not pad M
283 return r_grid_desc_mraw;
284 }
285 }
286
287 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
288 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
289 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
290 using RGridDesc_M = decltype(MakeRGridDescriptor_M(1));
291
292 // GridwiseGemm
293 template <index_t NXdlPerWave_>
295 ADataType, // TODO: distinguish A/B datatype
296 GemmAccDataType,
297 CShuffleDataType,
298 DsDataType,
299 EDataType,
300 ReduceAccDataType,
301 RsDataType,
302 AElementwiseOperation,
303 BElementwiseOperation,
304 CDEElementwiseOperation,
305 QsElementwiseOperation,
306 RsElementwiseOperation,
307 ThreadReduceOperations,
309 RsGlobalMemoryDataOperation,
314 NumGemmKPrefetchStage,
315 BlockSize,
316 MPerBlock,
317 NPerBlock,
318 KPerBlock,
319 AK1,
320 BK1,
321 MPerXDL,
322 NPerXDL,
323 MXdlPerWave,
324 NXdlPerWave_,
325 ABlockTransferThreadClusterLengths_AK0_M_AK1,
326 ABlockTransferThreadClusterArrangeOrder,
327 ABlockTransferSrcAccessOrder,
328 ABlockTransferSrcVectorDim,
329 ABlockTransferSrcScalarPerVector,
330 ABlockTransferDstScalarPerVector_AK1,
331 false,
332 ABlockLdsExtraM,
333 BBlockTransferThreadClusterLengths_BK0_N_BK1,
334 BBlockTransferThreadClusterArrangeOrder,
335 BBlockTransferSrcAccessOrder,
336 BBlockTransferSrcVectorDim,
337 BBlockTransferSrcScalarPerVector,
338 BBlockTransferDstScalarPerVector_BK1,
339 false,
340 BBlockLdsExtraN,
341 CShuffleMXdlPerWavePerShuffle,
342 CShuffleNXdlPerWavePerShuffle,
343 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
344 CDEReduceThreadTransferScalarPerVector_NPerBlock,
345 RThreadTransferDstScalarPerVector_MPerBlock,
346 LoopSched>;
349
352 AGridDesc_M_K{}))>;
355 BGridDesc_N_K{}))>;
356
358
359 // Argument
360 struct Argument : public BaseArgument
361 {
362 Argument(const void* p_a_grid,
363 const void* p_b_grid,
364 std::array<const void*, NumDTensor> p_ds_grid,
365 void* p_e_grid,
366 std::array<void*, NumRTensor> p_rs_grid,
367 index_t MRaw,
368 index_t NRaw,
369 index_t KRaw,
370 index_t StrideA,
371 index_t StrideB,
372 std::array<index_t, NumDTensor> StrideDs,
373 index_t StrideE,
374 AElementwiseOperation a_element_op,
375 BElementwiseOperation b_element_op,
376 CDEElementwiseOperation cde_element_op,
377 QsElementwiseOperation qs_element_op,
378 RsElementwiseOperation rs_element_op)
379 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
380 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
381 p_ds_grid_{}, // FIXME
382 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
383 p_rs_grid_{}, // FIXME
384 MRaw_(MRaw),
385 NRaw_(NRaw),
391 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
393 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
394 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
395 a_element_op_{a_element_op},
396 b_element_op_{b_element_op},
397 cde_element_op_{cde_element_op},
398 qs_element_op_{qs_element_op},
399 rs_element_op_{rs_element_op}
400 {
401 static_for<0, NumDTensor, 1>{}([&](auto i) {
402 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
403 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
404 stride_ds_[i] = StrideDs[i];
405 });
406 static_for<0, NumRTensor, 1>{}([&](auto i) {
407 using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
408 p_rs_grid_(i) = static_cast<RDataType*>(p_rs_grid[i]);
409 });
410 }
411
412 // private:
413 // pointers
414 const ADataType* p_a_grid_;
415 const BDataType* p_b_grid_;
417 EDataType* p_e_grid_;
421 std::array<index_t, NumDTensor> stride_ds_;
422 // tensor descriptors
427 // tensor descriptors for block/thread-wise copy
430 // block-to-e-tile map
432
433 // element-wise op
434 AElementwiseOperation a_element_op_;
435 BElementwiseOperation b_element_op_;
436 CDEElementwiseOperation cde_element_op_;
437 QsElementwiseOperation qs_element_op_;
438 RsElementwiseOperation rs_element_op_;
439 };
440
441 // Invoker
442 struct Invoker : public BaseInvoker
443 {
445
446 template <typename GridwiseGemm>
447 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
448 {
449 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
452 arg.r_grid_desc_m_,
454 {
455 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
456 }
458 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
460 ds_grid_desc_mblock_mperblock_nblock_nperblock = {};
461
462 StaticallyIndexedArray<typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
464 rs_grid_desc_mblock_mperblock = {};
465
466 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
467 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
468 arg.e_grid_desc_m_n_);
469
470 static_for<0, NumDTensor, 1>{}([&](auto i) {
471 const auto d_grid_desc_m_n =
473 ds_grid_desc_mblock_mperblock_nblock_nperblock(i) =
474 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
475 d_grid_desc_m_n);
476 });
477
478 static_for<0, NumRTensor, 1>{}([&](auto i) {
479 rs_grid_desc_mblock_mperblock(i) =
481 });
482
483 const index_t grid_size =
484 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
485
486 const auto K =
487 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
488
489 auto launch_kernel = [&](auto has_main_k_block_loop) {
490 constexpr bool has_main_loop = has_main_k_block_loop.value;
491
493 GridwiseGemm,
494 ADataType, // TODO: distiguish A/B datatype
495 typename GridwiseGemm::DsGridPointer,
496 EDataType,
497 typename GridwiseGemm::RsGridPointer,
498 AElementwiseOperation,
499 BElementwiseOperation,
500 CDEElementwiseOperation,
501 QsElementwiseOperation,
502 RsElementwiseOperation,
506 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
507 NumDTensor>,
508 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
510 typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
511 NumRTensor>,
512 typename GridwiseGemm::DefaultBlock2ETileMap,
513 has_main_loop>;
514
515 return launch_and_time_kernel(stream_config,
516 kernel,
517 dim3(grid_size),
518 dim3(BlockSize),
519 0,
520 arg.p_a_grid_,
521 arg.p_b_grid_,
522 arg.p_ds_grid_,
523 arg.p_e_grid_,
524 arg.p_rs_grid_,
525 arg.a_element_op_,
526 arg.b_element_op_,
527 arg.cde_element_op_,
528 arg.qs_element_op_,
529 arg.rs_element_op_,
532 ds_grid_desc_mblock_mperblock_nblock_nperblock,
533 e_grid_desc_mblock_mperblock_nblock_nperblock,
534 rs_grid_desc_mblock_mperblock,
536 };
537
538 float ave_time = 0;
539
540 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
541 {
542 ave_time = launch_kernel(integral_constant<bool, true>{});
543 }
544 else
545 {
546 ave_time = launch_kernel(integral_constant<bool, false>{});
547 }
548
549 return ave_time;
550 }
551
553
554 // polymorphic
555 float Run(const BaseArgument* p_arg,
556 const StreamConfig& stream_config = StreamConfig{}) override
557 {
558 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
559 }
560 };
561
562 static bool IsSupportedArgument(const Argument& arg)
563 {
565 {
566 return false;
567 }
568 if(get_warp_size() == 64)
569 {
570 if constexpr(NXdlPerWave64 > 0)
571 {
575 arg.r_grid_desc_m_,
577 }
578 }
579 else
580 {
581 if constexpr(NXdlPerWave32 > 0)
582 {
586 arg.r_grid_desc_m_,
588 }
589 }
590 return false;
591 }
592
593 // polymorphic
594 bool IsSupportedArgument(const BaseArgument* p_arg) override
595 {
596 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
597 }
598
599 static auto MakeArgument(const void* p_a,
600 const void* p_b,
601 std::array<const void*, NumDTensor> p_ds,
602 void* p_e,
603 std::array<void*, NumRTensor> p_rs,
604 index_t MRaw,
605 index_t NRaw,
606 index_t KRaw,
607 index_t StrideA,
608 index_t StrideB,
609 std::array<index_t, NumDTensor> StrideDs,
610 index_t StrideE,
611 AElementwiseOperation a_element_op,
612 BElementwiseOperation b_element_op,
613 CDEElementwiseOperation cde_element_op,
614 QsElementwiseOperation qs_element_op,
615 RsElementwiseOperation rs_element_op)
616 {
617 return Argument{p_a,
618 p_b,
619 p_ds,
620 p_e,
621 p_rs,
622 MRaw,
623 NRaw,
624 KRaw,
625 StrideA,
626 StrideB,
627 StrideDs,
628 StrideE,
629 a_element_op,
630 b_element_op,
631 cde_element_op,
632 qs_element_op,
633 rs_element_op};
634 }
635
636 static auto MakeInvoker() { return Invoker{}; }
637
638 // polymorphic
639 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
640 const void* p_b,
641 std::array<const void*, NumDTensor> p_ds,
642 void* p_e,
643 std::array<void*, NumRTensor> p_rs,
644 index_t MRaw,
645 index_t NRaw,
646 index_t KRaw,
647 index_t StrideA,
648 index_t StrideB,
649 std::array<index_t, NumDTensor> StrideDs,
650 index_t StrideE,
651 AElementwiseOperation a_element_op,
652 BElementwiseOperation b_element_op,
653 CDEElementwiseOperation cde_element_op,
654 QsElementwiseOperation qs_element_op,
655 RsElementwiseOperation rs_element_op) override
656 {
657 return std::make_unique<Argument>(p_a,
658 p_b,
659 p_ds,
660 p_e,
661 p_rs,
662 MRaw,
663 NRaw,
664 KRaw,
665 StrideA,
666 StrideB,
667 StrideDs,
668 StrideE,
669 a_element_op,
670 b_element_op,
671 cde_element_op,
672 qs_element_op,
673 rs_element_op);
674 }
675
676 // polymorphic
677 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
678 {
679 return std::make_unique<Invoker>(Invoker{});
680 }
681
682 // polymorphic
683 std::string GetTypeString() const override
684 {
685 auto str = std::stringstream();
686
687 // clang-format off
688 str << "DeviceGemmMultipleDMultipleR_Xdl_CShuffle"
689 << "<"
690 << BlockSize << ", "
691 << MPerBlock << ", "
692 << NPerBlock << ", "
693 << KPerBlock << ", "
694 << AK1 << ", "
695 << BK1 << ", "
696 << getGemmSpecializationString(GemmSpec) << ", "
697 << MPerXDL << ", "
698 << NPerXDL << ", "
699 << MXdlPerWave << ", "
700 << NXdlPerWave << ", "
701 << ABlockTransferSrcScalarPerVector << ", "
702 << BBlockTransferSrcScalarPerVector << ", "
703 << CShuffleMXdlPerWavePerShuffle << ", "
704 << CShuffleNXdlPerWavePerShuffle
705 << ">";
706 // clang-format on
707
708 return str.str();
709 }
710};
711
712} // namespace device
713} // namespace tensor_operation
714} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__global__ void kernel_gemm_multiple_d_multiple_r_xdl_cshuffle(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatDsPointer p_ds_grid, FloatE *__restrict__ p_e_grid, FloatRsPointer p_rs_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const QsElementwiseOperation qs_element_op, const RsElementwiseOperation rs_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:43
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:74
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const RGridDesc_M &r_grid_desc_m, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:208
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:361
std::array< index_t, NumDTensor > stride_ds_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:421
RGridDesc_M r_grid_desc_m_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:426
Block2ETileMap block_2_etile_map_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:431
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:434
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:436
AGridDesc_M_K a_grid_desc_m_k_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:423
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:429
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, std::array< void *, NumRTensor > p_rs_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, QsElementwiseOperation qs_element_op, RsElementwiseOperation rs_element_op)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:362
index_t MRaw_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:419
QsElementwiseOperation qs_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:437
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:416
RsElementwiseOperation rs_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:438
BGridDesc_N_K b_grid_desc_n_k_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:424
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:428
index_t NRaw_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:420
EGridDesc_M_N e_grid_desc_m_n_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:425
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:415
GridwiseGemm64::RsGridPointer p_rs_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:418
EDataType * p_e_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:417
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:435
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:414
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:443
DeviceOp::Argument Argument
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:444
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:447
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:555
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:189
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:562
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, QsElementwiseOperation qs_element_op, RsElementwiseOperation rs_element_op)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:599
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:594
static constexpr auto matrix_padder
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:204
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:287
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:350
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:194
decltype(MakeEGridDescriptor_M_N(1, 1, 1)) EGridDesc_M_N
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:289
DeviceGemmMultipleDMultipleR_Xdl_CShuffle DeviceOp
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:190
static constexpr auto I1
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:200
static constexpr index_t NumRTensor
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:197
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:225
std::string GetTypeString() const override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:683
static constexpr auto I2
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:201
static auto MakeInvoker()
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:636
static auto MakeRGridDescriptor_M(index_t MRaw)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:262
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:207
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:353
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemmBase
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:294
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:243
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, QsElementwiseOperation qs_element_op, RsElementwiseOperation rs_element_op) override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:639
static constexpr auto I3
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:202
typename GridwiseGemm64::DefaultBlock2ETileMap Block2ETileMap
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:357
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:196
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:348
decltype(MakeRGridDescriptor_M(1)) RGridDesc_M
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:290
static constexpr auto I0
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:199
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:193
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:288
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:347
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:677
Definition device_gemm_multiple_d_multiple_r.hpp:41
Definition matrix_padder.hpp:180