device_batched_gemm_e_permute_xdl.hpp Source File

device_batched_gemm_e_permute_xdl.hpp Source File#

Composable Kernel: device_batched_gemm_e_permute_xdl.hpp Source File
device_batched_gemm_e_permute_xdl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3#pragma once
4
5#include <iostream>
6#include <sstream>
7
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23/*
24 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
25 *
26 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
27 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
28 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
29 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
30#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
31 * limitations.
32 *
33 * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
34 * returns the 2D index of the tile that it computes. \see
35 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
36 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
37 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
38 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
39 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
40\link
41 * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
42 * pointer offset into \p ComputePtrOffsetOfStridedBatch.
43 *
44 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
45 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
46 * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
47 *
48 */
49template <typename GridwiseGemm,
50 typename ABDataType,
51 typename EDataType,
52 typename AGridDesc_AK0_M_AK1,
53 typename BGridDesc_BK0_N_BK1,
54 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
55 typename AElementwiseOperation,
56 typename BElementwiseOperation,
57 typename CDEElementwiseOperation,
58 typename ComputePtrOffsetOfBatch,
59 typename Block2ETileMap,
60 bool HasMainKBlockLoop>
61__global__ void
62#if CK_USE_LAUNCH_BOUNDS
64#endif
65 kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
66 const ABDataType* __restrict__ p_b_grid,
67 EDataType* __restrict__ p_e_grid,
68 const index_t batch_count,
69 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
70 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
71 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
72 e_grid_desc_mblock_mperblock_nblock_nperblock,
73 const AElementwiseOperation a_element_op,
74 const BElementwiseOperation b_element_op,
75 const CDEElementwiseOperation cde_element_op,
76 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
77 const Block2ETileMap block_2_etile_map)
78{
79#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
80 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
81 {
82 const index_t num_blocks_per_batch =
83 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
84 const index_t g_idx =
85 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
86
87 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
88 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
89 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
90 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
91 const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
92 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
93
94 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
95
96 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
97 p_a_grid + a_batch_offset,
98 p_b_grid + b_batch_offset,
100 p_e_grid + e_batch_offset,
101 p_shared,
102 a_element_op,
103 b_element_op,
104 cde_element_op,
105 a_grid_desc_ak0_m_ak1,
106 b_grid_desc_bk0_n_bk1,
107 ck::Tuple<>{},
108 e_grid_desc_mblock_mperblock_nblock_nperblock,
109 block_2_etile_map);
110 }
111#else
112 ignore = p_a_grid;
113 ignore = p_b_grid;
114 ignore = p_e_grid;
115 ignore = batch_count;
116 ignore = a_grid_desc_ak0_m_ak1;
117 ignore = b_grid_desc_bk0_n_bk1;
118 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
119 ignore = a_element_op;
120 ignore = b_element_op;
121 ignore = cde_element_op;
122 ignore = compute_ptr_offset_of_batch;
123 ignore = block_2_etile_map;
124#endif
125}
126
127template <typename ALayout,
128 typename BLayout,
129 typename ELayout,
130 typename ADataType,
131 typename BDataType,
132 typename AccDataType,
133 typename CShuffleDataType,
134 typename EDataType,
135 typename AElementwiseOperation,
136 typename BElementwiseOperation,
137 typename CDEElementwiseOperation,
138 GemmSpecialization GemmSpec,
139 index_t NumPrefetch,
140 index_t BlockSize,
141 index_t MPerBlock,
142 index_t NPerBlock,
143 index_t KPerBlock,
144 index_t AK1,
145 index_t BK1,
146 index_t MPerXDL,
147 index_t NPerXDL,
148 index_t MXdlPerWave,
149 index_t NXdlPerWave,
150 typename ABlockTransferThreadClusterLengths_K0_M_K1,
151 typename ABlockTransferThreadClusterArrangeOrder,
152 typename ABlockTransferSrcAccessOrder,
153 index_t ABlockTransferSrcVectorDim,
154 index_t ABlockTransferSrcScalarPerVector,
155 index_t ABlockTransferDstScalarPerVector_K1,
156 index_t ABlockLdsExtraM,
157 typename BBlockTransferThreadClusterLengths_K0_N_K1,
158 typename BBlockTransferThreadClusterArrangeOrder,
159 typename BBlockTransferSrcAccessOrder,
160 index_t BBlockTransferSrcVectorDim,
161 index_t BBlockTransferSrcScalarPerVector,
162 index_t BBlockTransferDstScalarPerVector_K1,
163 index_t BBlockLdsExtraN,
164 index_t CShuffleMXdlPerWavePerShuffle,
165 index_t CShuffleNXdlPerWavePerShuffle,
166 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
167 index_t CDEBlockTransferScalarPerVector_NPerBlock,
170 BLayout,
171 ELayout,
172 ADataType,
173 BDataType,
174 EDataType,
175 AElementwiseOperation,
176 BElementwiseOperation,
177 CDEElementwiseOperation>
178{
180
182 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
183 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
184
185 static constexpr auto I0 = Number<0>{};
186 static constexpr auto I1 = Number<1>{};
187 static constexpr auto I2 = Number<2>{};
188
189 static constexpr auto matrix_padder =
190 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
191
192 static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
193 {
194 const auto a_grid_desc_mraw_kraw = [&]() {
196 {
197 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
198 make_tuple(StrideA, I1));
199 }
201 {
202 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
203 make_tuple(I1, StrideA));
204 }
205 }();
206
207 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
208 }
209
210 static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
211 {
212 const auto b_grid_desc_nraw_kraw = [&]() {
214 {
215 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
216 make_tuple(I1, StrideB));
217 }
219 {
220 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
221 make_tuple(StrideB, I1));
222 }
223 }();
224
225 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
226 }
227
228 static auto
229 MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
230 {
231 const auto e_grid_desc_mraw_nraw =
232 make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N));
233
234 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
235 }
236
238 index_t G1,
239 index_t MRaw,
240 index_t NRaw,
241 index_t stride_G0,
242 index_t stride_G1,
243 index_t stride_M,
244 index_t stride_N)
245 {
246 const auto e_grid_desc_g0_g1_mraw_nraw = [&]() {
248 make_tuple(G0, G1, MRaw, NRaw),
249 make_tuple(stride_G0, stride_G1, stride_M, stride_N));
250 }();
251
252 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
253 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
254
255 const auto MPad = M - MRaw;
256 const auto NPad = N - NRaw;
257
258 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
260 {
261 // pad M and N
263 e_grid_desc_g0_g1_mraw_nraw,
266 make_right_pad_transform(MRaw, MPad),
267 make_right_pad_transform(NRaw, NPad)),
270 }
271 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
273 {
274 // pad M, but not N
276 e_grid_desc_g0_g1_mraw_nraw,
279 make_right_pad_transform(MRaw, MPad),
283 }
284 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
286 {
287 // pad N, but not M
289 e_grid_desc_g0_g1_mraw_nraw,
293 make_right_pad_transform(NRaw, NPad)),
296 }
297 else
298 {
299 // not pad M or N
300 return e_grid_desc_g0_g1_mraw_nraw;
301 }
302 }
303
304 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
305 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
306 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1));
307 using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
308
310 {
312 index_t Batchstride_B,
313 EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
314 : Batchstride_A_(Batchstride_A),
315 Batchstride_B_(Batchstride_B),
316 e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
317 {
318 }
319
320 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
321 {
322 return g_idx * static_cast<long_index_t>(Batchstride_A_);
323 }
324
325 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
326 {
327 return g_idx * static_cast<long_index_t>(Batchstride_B_);
328 }
329
330 __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
331 {
332 const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1);
333 index_t b0 = g_idx / G1;
334 index_t b1 = g_idx - b0 * G1; // g_idx % G1
335 return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0));
336 }
337
338 private:
339 index_t Batchstride_A_;
340 index_t Batchstride_B_;
341 EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
342 };
343
344 using ComputeDataType = ADataType;
345
346 // GridwiseGemm
347 template <index_t NXdlPerWave_>
349 ADataType,
350 BDataType,
352 AccDataType,
353 CShuffleDataType,
354 ck::Tuple<>, // DsDataType,
355 EDataType, // EDataType,
356 AElementwiseOperation,
357 BElementwiseOperation,
358 CDEElementwiseOperation,
361 Tuple<>,
363 NumPrefetch,
364 BlockSize,
365 MPerBlock,
366 NPerBlock,
367 KPerBlock,
368 AK1,
369 BK1,
370 MPerXDL,
371 NPerXDL,
372 MXdlPerWave,
373 NXdlPerWave_,
374 ABlockTransferThreadClusterLengths_K0_M_K1,
375 ABlockTransferThreadClusterArrangeOrder,
376 ABlockTransferSrcAccessOrder,
377 ABlockTransferSrcVectorDim,
378 ABlockTransferSrcScalarPerVector,
379 ABlockTransferDstScalarPerVector_K1,
380 false, // AThreadTransferSrcResetCoordinateAfterRun,
381 ABlockLdsExtraM,
382 BBlockTransferThreadClusterLengths_K0_N_K1,
383 BBlockTransferThreadClusterArrangeOrder,
384 BBlockTransferSrcAccessOrder,
385 BBlockTransferSrcVectorDim,
386 BBlockTransferSrcScalarPerVector,
387 BBlockTransferDstScalarPerVector_K1,
388 false, // BThreadTransferSrcResetCoordinateAfterRun,
389 BBlockLdsExtraN,
390 CShuffleMXdlPerWavePerShuffle,
391 CShuffleNXdlPerWavePerShuffle,
392 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
393 CDEBlockTransferScalarPerVector_NPerBlock,
394 LoopSched>;
397
399 remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
400 AGridDesc_M_K{}))>;
402 remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
403 BGridDesc_N_K{}))>;
404
406 decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
407 EGridDesc_M_N{}));
408 using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
409
410 // Argument
411 struct Argument : public BaseArgument
412 {
413 Argument(const ADataType* p_a_grid,
414 const BDataType* p_b_grid,
415 EDataType* p_e_grid,
416 index_t M,
417 index_t N,
418 index_t K,
419 index_t stride_A,
420 index_t stride_B,
421 index_t batch_stride_A,
422 index_t batch_stride_B,
423 BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
424 index_t BatchCount,
425 AElementwiseOperation a_element_op,
426 BElementwiseOperation b_element_op,
427 CDEElementwiseOperation cde_element_op)
428 : p_a_grid_{p_a_grid},
429 p_b_grid_{p_b_grid},
430 p_e_grid_{p_e_grid},
431 BatchCount_(BatchCount),
435 DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_,
436 batched_gemm_e_permute_desc.N_,
437 batched_gemm_e_permute_desc.stride_M_,
438 batched_gemm_e_permute_desc.stride_N_)},
440 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
442 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
444 DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_,
445 batched_gemm_e_permute_desc.G1_,
446 batched_gemm_e_permute_desc.M_,
447 batched_gemm_e_permute_desc.N_,
448 batched_gemm_e_permute_desc.stride_G0_,
449 batched_gemm_e_permute_desc.stride_G1_,
450 batched_gemm_e_permute_desc.stride_M_,
451 batched_gemm_e_permute_desc.stride_N_)},
452 compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_},
453 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
454 a_element_op_{a_element_op},
455 b_element_op_{b_element_op},
456 cde_element_op_{cde_element_op}
457 {
458 }
459
460 void Print() const
461 {
462 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
463 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
464 std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl;
465 }
466
467 // private:
468 // pointers
469 const ADataType* p_a_grid_;
470 const BDataType* p_b_grid_;
471 EDataType* p_e_grid_;
472
473 // batch count
475
476 // tensor descriptors for problem definiton
480
481 // tensor descriptors for block/thread-wise copy
486
487 // for calculating Batch offset
489
490 // block-to-e-tile map
492
493 // element-wise op
494 AElementwiseOperation a_element_op_;
495 BElementwiseOperation b_element_op_;
496 CDEElementwiseOperation cde_element_op_;
497 };
498
499 // Invoker
500 struct Invoker : public BaseInvoker
501 {
503
504 template <typename GridwiseGemm>
505 float RunImp(const typename GridwiseGemm::Argument& arg,
506 const StreamConfig& stream_config = StreamConfig{})
507 {
508 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
509 arg.b_grid_desc_n_k_,
510 ck::Tuple<>{},
511 arg.e_grid_desc_m_n_,
512 arg.block_2_etile_map_))
513 {
514 throw std::runtime_error(
515 "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
516 "setting");
517 }
518 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
519 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
520 arg.e_grid_desc_m_n_);
521 const index_t grid_size =
522 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_;
523
524 const auto K =
525 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
526
527 auto launch_kernel = [&](auto has_main_k_block_loop_) {
528 const auto kernel = kernel_batched_gemm_e_permute_xdl<
529 GridwiseGemm,
530 ADataType, // TODO: distiguish A/B datatype
531 EDataType,
534 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
535 AElementwiseOperation,
536 BElementwiseOperation,
537 CDEElementwiseOperation,
538 ComputePtrOffsetOfStridedBatch,
540 has_main_k_block_loop_>;
541
542 return launch_and_time_kernel(stream_config,
543 kernel,
544 dim3(grid_size),
545 dim3(BlockSize),
546 0,
547 arg.p_a_grid_,
548 arg.p_b_grid_,
549 arg.p_e_grid_,
550 arg.BatchCount_,
551 arg.a_grid_desc_ak0_m_ak1_,
552 arg.b_grid_desc_bk0_n_bk1_,
553 e_grid_desc_mblock_mperblock_nblock_nperblock,
554 arg.a_element_op_,
555 arg.b_element_op_,
556 arg.cde_element_op_,
557 arg.compute_ptr_offset_of_batch_,
558 arg.block_2_etile_map_);
559 };
560
561 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
562 {
563 return launch_kernel(integral_constant<bool, true>{});
564 }
565 else
566 {
567 return launch_kernel(integral_constant<bool, false>{});
568 }
569 }
570
572
573 // polymorphic
574 float Run(const BaseArgument* p_arg,
575 const StreamConfig& stream_config = StreamConfig{}) override
576 {
577 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
578 }
579 };
580
581 static constexpr bool IsValidCompilationParameter()
582 {
583 // TODO: properly implement this check
584 return true;
585 }
586
587 static bool IsSupportedArgument(const Argument& arg)
588 {
590 {
591 return false;
592 }
593
594 return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
596 ck::Tuple<>{},
599 }
600
601 // polymorphic
602 bool IsSupportedArgument(const BaseArgument* p_arg) override
603 {
604 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
605 }
606
607 static auto MakeArgument(const ADataType* p_a,
608 const BDataType* p_b,
609 EDataType* p_e,
610 index_t M,
611 index_t N,
612 index_t K,
613 index_t stride_A,
614 index_t stride_B,
615 index_t batch_stride_A,
616 index_t batch_stride_B,
617 BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
618 index_t BatchCount,
619 AElementwiseOperation a_element_op,
620 BElementwiseOperation b_element_op,
621 CDEElementwiseOperation cde_element_op)
622 {
623 return Argument{p_a,
624 p_b,
625 p_e,
626 M,
627 N,
628 K,
629 stride_A,
630 stride_B,
631 batch_stride_A,
632 batch_stride_B,
633 batched_gemm_e_permute_desc,
634 BatchCount,
635 a_element_op,
636 b_element_op,
637 cde_element_op};
638 }
639
640 static auto MakeInvoker() { return Invoker{}; }
641
642 // polymorphic
643 std::unique_ptr<BaseArgument>
644 MakeArgumentPointer(const void* p_a,
645 const void* p_b,
646 void* p_e,
647 index_t M,
648 index_t N,
649 index_t K,
650 index_t stride_A,
651 index_t stride_B,
652 index_t batch_stride_A,
653 index_t batch_stride_B,
654 BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
655 index_t BatchCount,
656 AElementwiseOperation a_element_op,
657 BElementwiseOperation b_element_op,
658 CDEElementwiseOperation cde_element_op) override
659 {
660 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
661 static_cast<const BDataType*>(p_b),
662 static_cast<EDataType*>(p_e),
663 M,
664 N,
665 K,
666 stride_A,
667 stride_B,
668 batch_stride_A,
669 batch_stride_B,
670 batched_gemm_e_permute_desc,
671 BatchCount,
672 a_element_op,
673 b_element_op,
674 cde_element_op);
675 }
676
677 // polymorphic
678 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
679 {
680 return std::make_unique<Invoker>(Invoker{});
681 }
682
683 // polymorphic
684 std::string GetTypeString() const override
685 {
686 auto str = std::stringstream();
687
688 // clang-format off
689 str << "DeviceBatchedGemmEPermuteXdl"
690 << "<"
691 << BlockSize << ", "
692 << MPerBlock << ", "
693 << NPerBlock << ", "
694 << KPerBlock
695 << ">";
696 // clang-format on
697
698 return str.str();
699 }
700};
701
702} // namespace device
703} // namespace tensor_operation
704} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_batched_gemm_e_permute_xdl(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
Definition device_batched_gemm_e_permute_xdl.hpp:65
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
constexpr bool is_same_v
Definition type.hpp:283
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition device_batched_gemm_e_permute.hpp:12
Definition device_batched_gemm_e_permute.hpp:27
Definition device_batched_gemm_e_permute_xdl.hpp:412
void Print() const
Definition device_batched_gemm_e_permute_xdl.hpp:460
EDataType * p_e_grid_
Definition device_batched_gemm_e_permute_xdl.hpp:471
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_e_permute_xdl.hpp:482
BGridDesc_N_K b_grid_desc_n_k_
Definition device_batched_gemm_e_permute_xdl.hpp:478
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_e_permute_xdl.hpp:483
CDEElementwiseOperation cde_element_op_
Definition device_batched_gemm_e_permute_xdl.hpp:496
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_gemm_e_permute_xdl.hpp:488
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_
Definition device_batched_gemm_e_permute_xdl.hpp:485
EGridDesc_M_N e_grid_desc_m_n_
Definition device_batched_gemm_e_permute_xdl.hpp:479
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, EDataType *p_e_grid, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_e_permute_xdl.hpp:413
const ADataType * p_a_grid_
Definition device_batched_gemm_e_permute_xdl.hpp:469
index_t BatchCount_
Definition device_batched_gemm_e_permute_xdl.hpp:474
const BDataType * p_b_grid_
Definition device_batched_gemm_e_permute_xdl.hpp:470
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_batched_gemm_e_permute_xdl.hpp:484
AGridDesc_M_K a_grid_desc_m_k_
Definition device_batched_gemm_e_permute_xdl.hpp:477
Block2ETileMap block_2_etile_map_
Definition device_batched_gemm_e_permute_xdl.hpp:491
BElementwiseOperation b_element_op_
Definition device_batched_gemm_e_permute_xdl.hpp:495
AElementwiseOperation a_element_op_
Definition device_batched_gemm_e_permute_xdl.hpp:494
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_e_permute_xdl.hpp:320
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_e_permute_xdl.hpp:330
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, index_t Batchstride_B, EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
Definition device_batched_gemm_e_permute_xdl.hpp:311
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_e_permute_xdl.hpp:325
Definition device_batched_gemm_e_permute_xdl.hpp:501
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_e_permute_xdl.hpp:505
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_e_permute_xdl.hpp:574
DeviceOp::Argument Argument
Definition device_batched_gemm_e_permute_xdl.hpp:502
Definition device_batched_gemm_e_permute_xdl.hpp:178
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, ck::Tuple<>, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, AGridDesc_M_K, BGridDesc_N_K, Tuple<>, EGridDesc_M_N, NumPrefetch, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_batched_gemm_e_permute_xdl.hpp:348
decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)) EGridDesc_G0_G1_M_N
Definition device_batched_gemm_e_permute_xdl.hpp:307
static auto MakeInvoker()
Definition device_batched_gemm_e_permute_xdl.hpp:640
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_batched_gemm_e_permute_xdl.hpp:644
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_e_permute_xdl.hpp:210
std::string GetTypeString() const override
Definition device_batched_gemm_e_permute_xdl.hpp:684
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
Definition device_batched_gemm_e_permute_xdl.hpp:229
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_e_permute_xdl.hpp:183
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_batched_gemm_e_permute_xdl.hpp:304
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_batched_gemm_e_permute_xdl.hpp:192
static constexpr auto I1
Definition device_batched_gemm_e_permute_xdl.hpp:186
static constexpr auto matrix_padder
Definition device_batched_gemm_e_permute_xdl.hpp:189
DeviceBatchedGemmEPermuteXdl DeviceOp
Definition device_batched_gemm_e_permute_xdl.hpp:179
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_batched_gemm_e_permute_xdl.hpp:398
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_batched_gemm_e_permute_xdl.hpp:305
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_e_permute_xdl.hpp:182
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, EDataType *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_e_permute_xdl.hpp:607
static constexpr auto I0
Definition device_batched_gemm_e_permute_xdl.hpp:185
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_e_permute_xdl.hpp:581
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_batched_gemm_e_permute_xdl.hpp:401
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_e_permute_xdl.hpp:602
decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1)) EGridDesc_M_N
Definition device_batched_gemm_e_permute_xdl.hpp:306
static constexpr auto I2
Definition device_batched_gemm_e_permute_xdl.hpp:187
typename GridwiseGemm::DefaultBlock2ETileMap Block2ETileMap
Definition device_batched_gemm_e_permute_xdl.hpp:408
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{})) EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_batched_gemm_e_permute_xdl.hpp:405
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_e_permute_xdl.hpp:678
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_e_permute_xdl.hpp:587
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_e_permute_xdl.hpp:395
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, index_t G1, index_t MRaw, index_t NRaw, index_t stride_G0, index_t stride_G1, index_t stride_M, index_t stride_N)
Definition device_batched_gemm_e_permute_xdl.hpp:237
ADataType ComputeDataType
Definition device_batched_gemm_e_permute_xdl.hpp:344
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_e_permute_xdl.hpp:396
Definition matrix_padder.hpp:180