codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp Source File

codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp Source File#

Composable Kernel: codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp Source File
codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#ifndef CK_CODE_GEN_RTC
7#include <functional>
8#include <iostream>
9#include <iterator>
10#include <numeric>
11#include <sstream>
12#include <stdio.h>
13
16#endif
17
30
31namespace ck {
32namespace tensor_operation {
33namespace device {
34
35namespace {
36
37/*
38 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
39 *
40 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
41 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
42 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
43 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
44 * limitations.
45 *
46 * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
47 * returns the 2D index of the tile that it computes. \see
48 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
49 *
50 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
51 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
52 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
53 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
54 * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
55 * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
56 *
57 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
58 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
59 * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
60 *
61 */
62template <typename GridwiseGemm,
63 typename AsPointer, // tuples if multi AB, pointers if no
64 typename BsPointer,
65 typename DsPointer,
66 typename EDataType,
67 typename AElementwiseOperation,
68 typename BElementwiseOperation,
69 typename CDEElementwiseOperation,
70 typename AGridDesc_AK0_M_AK1,
71 typename BGridDesc_BK0_N_BK1,
72 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
73 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
74 typename Block2ETileMap,
75 typename ComputePtrOffsetOfBatch,
76 bool HasMainKBlockLoop,
77 bool isMultiA,
78 bool isMultiB>
79__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
80 AsPointer p_as_grid,
81 BsPointer p_bs_grid,
82 DsPointer p_ds_grid,
83 EDataType* __restrict__ p_e_grid,
84 const AElementwiseOperation a_element_op,
85 const BElementwiseOperation b_element_op,
86 const CDEElementwiseOperation cde_element_op,
87 const index_t batch_count,
88 const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
89 const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
90 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
91 ds_grid_desc_mblock_mperblock_nblock_nperblock,
92 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
93 e_grid_desc_mblock_mperblock_nblock_nperblock_,
94 const Block2ETileMap block_2_ctile_map,
95 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
96{
97#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
98 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
99 {
100 // offset base pointer for each work-group
101 const index_t num_blocks_per_batch =
102 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
103 const index_t g_idx =
104 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
105
106 const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
107 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
108 const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
109
110 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
111
112 DsPointer p_ds_grid_grp;
113
114 static constexpr index_t NumDTensor =
115 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
116
117 static_for<0, NumDTensor, 1>{}(
118 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
119
120 if constexpr(isMultiA || isMultiB)
121 {
122 AsPointer p_as_grid_grp;
123 BsPointer p_bs_grid_grp;
124
125 const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
126
127 static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
128 static_for<0, NumATensor, 1>{}(
129 [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
130
131 const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
132
133 static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
134 static_for<0, NumBTensor, 1>{}(
135 [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
136
137 GridwiseGemm::template Run<HasMainKBlockLoop>(
138 p_as_grid_grp,
139 p_bs_grid_grp,
140 p_ds_grid_grp,
141 p_e_grid + e_batch_offset,
142 p_shared,
143 a_element_op,
144 b_element_op,
145 cde_element_op,
146 a_grid_desc_k0_m_k1,
147 b_grid_desc_k0_n_k1,
148 ds_grid_desc_mblock_mperblock_nblock_nperblock,
149 e_grid_desc_mblock_mperblock_nblock_nperblock_,
150 block_2_ctile_map);
151 }
152 else
153 {
154 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
155 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
156 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
157 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
158
159 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
160 p_as_grid + a_batch_offset,
161 p_bs_grid + b_batch_offset,
162 p_ds_grid_grp,
163 p_e_grid + e_batch_offset,
164 p_shared,
165 a_element_op,
166 b_element_op,
167 cde_element_op,
168 a_grid_desc_k0_m_k1,
169 b_grid_desc_k0_n_k1,
170 ds_grid_desc_mblock_mperblock_nblock_nperblock,
171 e_grid_desc_mblock_mperblock_nblock_nperblock_,
172 block_2_ctile_map);
173 }
174 }
175#else
176 ignore = p_as_grid;
177 ignore = p_bs_grid;
178 ignore = p_ds_grid;
179 ignore = p_e_grid;
180 ignore = batch_count;
181 ignore = a_grid_desc_k0_m_k1;
182 ignore = b_grid_desc_k0_n_k1;
183 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
184 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
185 ignore = a_element_op;
186 ignore = b_element_op;
187 ignore = cde_element_op;
188 ignore = compute_ptr_offset_of_batch;
189 ignore = block_2_ctile_map;
190#endif
191}
192
193template <typename GridwiseGemm,
194 typename AsPointer, // tuples if multi AB, pointers if no
195 typename BsPointer,
196 typename DsPointer,
197 typename EDataType,
198 typename AElementwiseOperation,
199 typename BElementwiseOperation,
200 typename CDEElementwiseOperation,
201 typename AGridDesc_AK0_M_AK1,
202 typename BGridDesc_BK0_N_BK1,
203 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
204 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
205 typename Block2ETileMap,
206 typename ComputePtrOffsetOfBatch,
207 bool HasMainKBlockLoop,
208 bool isMultiA,
209 bool isMultiB>
210__global__ void
211#if CK_USE_LAUNCH_BOUNDS
213#endif
214 kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
215 AsPointer p_as_grid,
216 BsPointer p_bs_grid,
217 DsPointer p_ds_grid,
218 EDataType* __restrict__ p_e_grid,
219 const AElementwiseOperation a_element_op,
220 const BElementwiseOperation b_element_op,
221 const CDEElementwiseOperation cde_element_op,
222 const index_t batch_count,
223 const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
224 const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
225 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
226 ds_grid_desc_mblock_mperblock_nblock_nperblock,
227 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
228 e_grid_desc_mblock_mperblock_nblock_nperblock_,
229 const Block2ETileMap block_2_ctile_map,
230 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
231{
232
233 device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
234 GridwiseGemm,
235 AsPointer, // tuples if multi AB, pointers if no
236 BsPointer,
237 DsPointer,
238 EDataType,
239 AElementwiseOperation,
240 BElementwiseOperation,
241 CDEElementwiseOperation,
242 AGridDesc_AK0_M_AK1,
243 BGridDesc_BK0_N_BK1,
244 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
245 EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
246 Block2ETileMap,
247 ComputePtrOffsetOfBatch,
248 HasMainKBlockLoop,
249 isMultiA,
250 isMultiB>(p_as_grid,
251 p_bs_grid,
252 p_ds_grid,
253 *p_e_grid,
254 a_element_op,
255 b_element_op,
256 cde_element_op,
257 batch_count,
258 a_grid_desc_k0_m_k1,
259 b_grid_desc_k0_n_k1,
260 ds_grid_desc_mblock_mperblock_nblock_nperblock,
261 e_grid_desc_mblock_mperblock_nblock_nperblock_,
262 block_2_ctile_map,
263 compute_ptr_offset_of_batch);
264}
265
266} // namespace
267
268#ifdef CK_CODE_GEN_RTC
269template <typename T>
270using is_tuple = decltype(ck::declval<T&>().IsTuple());
271#else
272template <typename T>
273using is_tuple = decltype(std::declval<T&>().IsTuple());
274#endif
275
276//
277// @brief Device Convolution operation.
278//
279// Supports:
280// @li Forward convolution with up to 3 spatial dimentions
281// @li Input tensor in GNWC data format
282// @li Weight tensor in GKXC data format
283// @li Output tensor in GNWK data format
284//
285// 1D:
286// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
287// 2D:
288// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
289// 3D:
290// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
291//
292template <index_t NDimSpatial,
293 typename ALayout,
294 typename BLayout,
295 typename DsLayout,
296 typename ELayout,
297 typename ADataType,
298 typename BDataType,
299 typename AccDataType,
300 typename CShuffleDataType,
301 typename DsDataType,
302 typename EDataType,
303 typename AElementwiseOperation,
304 typename BElementwiseOperation,
305 typename CDEElementwiseOperation,
306 ConvolutionForwardSpecialization ConvForwardSpecialization,
307 GemmSpecialization GemmSpec,
308 index_t NumGemmKPrefetchStage,
309 index_t BlockSize,
310 index_t MPerBlock,
311 index_t NPerBlock,
312 index_t KPerBlock,
313 index_t AK1,
314 index_t BK1,
315 index_t MPerXDL,
316 index_t NPerXDL,
317 index_t MXdlPerWave,
318 index_t NXdlPerWave,
319 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
320 typename ABlockTransferThreadClusterArrangeOrder,
321 typename ABlockTransferSrcAccessOrder,
322 index_t ABlockTransferSrcVectorDim,
323 index_t ABlockTransferSrcScalarPerVector,
324 index_t ABlockTransferDstScalarPerVector_AK1,
325 index_t ABlockLdsExtraM,
326 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
327 typename BBlockTransferThreadClusterArrangeOrder,
328 typename BBlockTransferSrcAccessOrder,
329 index_t BBlockTransferSrcVectorDim,
330 index_t BBlockTransferSrcScalarPerVector,
331 index_t BBlockTransferDstScalarPerVector_BK1,
332 index_t BBlockLdsExtraN,
333 index_t CShuffleMXdlPerWavePerShuffle,
334 index_t CShuffleNXdlPerWavePerShuffle,
335 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
336 index_t CDEBlockTransferScalarPerVector_NPerBlock,
337 typename ComputeDataType =
338 decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
339 Number<0>,
340 ADataType>()), // ComputeType is InputType by default (first
341 // in tuple for MultiAB), unpack if tuple was
342 // passed
345 : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
346 ALayout,
347 BLayout,
348 DsLayout,
349 ELayout,
350 ADataType,
351 BDataType,
352 DsDataType,
353 EDataType,
354 AElementwiseOperation,
355 BElementwiseOperation,
356 CDEElementwiseOperation,
357 ComputeDataType>
358{
361 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
362 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
363
366
367 static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
368 static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
369 static constexpr index_t NumDTensor = DsDataType::Size();
370
371 static constexpr auto I0 = Number<0>{};
372 static constexpr auto I1 = Number<1>{};
373 static constexpr auto I2 = Number<2>{};
374 static constexpr auto I3 = Number<3>{};
375
377
378 static constexpr auto matrix_padder =
379 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
380
381 template <typename ALay>
382 __host__ __device__ static auto
383 MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
384 {
385 const auto in_gemmmraw_gemmkraw_desc =
386 conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
387
388 const auto in_gemmm_gemmk_desc =
389 matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
390
391 return in_gemmm_gemmk_desc;
392 }
393
394 template <typename BLay>
395 __host__ __device__ static auto
396 MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
397 {
398 const auto wei_gemmnraw_gemmkraw_desc =
399 conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
400
401 const auto wei_gemmn_gemmk_desc =
402 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
403
404 return wei_gemmn_gemmk_desc;
405 }
406
407 template <typename ELay>
408 __host__ __device__ static auto
409 MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
410 {
411 const auto out_gemmmraw_gemmnraw_desc =
412 conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
413
414 const auto out_gemmm_gemmn_desc =
415 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
416
417 return out_gemmm_gemmn_desc;
418 }
419
420 // Shape of Ds and E must be aligned. Strides can be different.
421 // Pass e_g_n_k_wos_lengths for logical broadcast.
422 static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
423 {
424 return generate_tuple(
425 [&](auto i) {
426 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
427
428 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
429 },
431 }
432
433 // desc for problem definition
443
444 // If we are using multiAB and one of the template datatype parameters is not a tuple, convert
445 // it to it
448
449#define GridwiseGemmMultiABDTemplateParameters \
450 GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
451 EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
452 InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
453 KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
454 ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
455 ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
456 ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
457 ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
458 BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
459 BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
460 BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
461 CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
462 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
463 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
464
465#define GridwiseGemmTemplateParameters \
466 GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
467 EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
468 NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
469 NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
470 ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
471 ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
472 ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
473 BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
474 BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
475 BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
476 BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
477 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
478 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
479 // Use appropriate gridwise gemm
480 template <index_t NXdlPerWave_>
487
488 // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
491 // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
492 // in initializer list what is required for single const pointer).
494 decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm64, ADataType > ())>;
496 decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm64, BDataType > ())>;
497
498 // desc for blockwise copy
500 remove_cvref_t<decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(
501 AGridDesc_M_K{}))>;
503 remove_cvref_t<decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(
504 BGridDesc_N_K{}))>;
506 decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
507 DsGridDesc_M_N{}))>;
509 decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
510 EGridDesc_M_N{}))>;
511
512 // block-to-e-tile map
514 remove_cvref_t<decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
515
516 // Argument
517 struct Argument
518 {
519 template <typename GridwiseGemm>
520 __host__ __device__ void init_ds_e_grid_desc()
521 {
522 if constexpr(isMultiA || isMultiB)
523 {
524 const auto as_grid_desc_ak0_m_ak1 =
525 generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
526 const auto bs_grid_desc_bk0_n_bk1 =
527 generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
528
529 if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
530 bs_grid_desc_bk0_n_bk1,
534 {
536 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
538
540 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
542 }
543 }
544 else
545 {
546 if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
551 {
553 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
555
557 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
559 }
560 }
561 }
562 __device__ __host__ Argument(
563 APointers p_as,
564 BPointers p_bs,
566 void* p_e,
567 const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
568 const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
569 const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
570 const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
571 const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
572 const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
573 const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
574 const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
575 const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
576 const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
577 const ck::Array<index_t, NDimSpatial>& input_left_pads,
578 const ck::Array<index_t, NDimSpatial>& input_right_pads,
579 const AElementwiseOperation& a_element_op,
580 const BElementwiseOperation& b_element_op,
581 const CDEElementwiseOperation& cde_element_op)
582 : p_as_grid_{},
583 p_bs_grid_{},
584 p_ds_grid_{},
585 p_e_grid_{static_cast<EDataType*>(p_e)},
586 num_group_{a_g_n_c_wis_lengths[0]},
587 conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
588 a_g_n_c_wis_strides,
589 b_g_k_c_xs_lengths,
590 b_g_k_c_xs_strides,
591 e_g_n_k_wos_lengths,
592 e_g_n_k_wos_strides,
593 conv_filter_strides,
594 conv_filter_dilations,
595 input_left_pads,
596 input_right_pads},
605 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
607 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
610 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
612 a_element_op_{a_element_op},
613 b_element_op_{b_element_op},
614 cde_element_op_{cde_element_op},
615 a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
616 a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
617 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
618 b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
619 ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
620 ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
621 e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
622 e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
623 conv_filter_strides_{conv_filter_strides},
624 conv_filter_dilations_{conv_filter_dilations},
625 input_left_pads_{input_left_pads},
626 input_right_pads_{input_right_pads}
627 {
628 // A/B/E Batch Stride
629 if constexpr(isMultiA || isMultiB)
630 {
631 static_for<0, NumATensor, 1>{}([&](auto i) {
632 // Init compute_ptr_offset_of_batch_ for multiple AB
633 compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
634
635 // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
636 // type is not tuple)
637 using DataType = remove_cvref_t<tuple_element_t<i.value, GemmADataType>>;
638 // It is possible that one of the AB is a pointer and one is a tuple.
639 // Then also use multiAB but we have to cast single pointer instead of tuple of
640 // pointer.
641 if constexpr(isMultiA)
642 {
643 // p_as is tuple
644 p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
645 }
646 else
647 {
648 // if MultiB and not MultiA then p_as is single pointer
649 p_as_grid_(i) = static_cast<const DataType*>(p_as);
650 }
651 });
652 static_for<0, NumBTensor, 1>{}([&](auto i) {
653 // Init compute_ptr_offset_of_batch_ for multiple AB
654 compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
655
656 using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
657 // It is possible that one of the AB is a pointer and one is a tuple.
658 // Then also use multiAB but we have to cast single pointer instead of tuple of
659 // pointer.
660 if constexpr(isMultiB)
661 {
662 // p_bs is tuple
663 p_bs_grid_(i) = static_cast<const DataType*>(p_bs[i.value]);
664 }
665 else
666 {
667 // if MultiA and not MultiB then p_bs is single pointer
668 p_bs_grid_(i) = static_cast<const DataType*>(p_bs);
669 }
670 });
671 }
672 else
673 {
674 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
675 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
676
677 // p_as and p_bs are pointers
678 p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
679 p_bs_grid_(I0) = static_cast<const BDataType*>(p_bs);
680 }
681
682 // populate pointer, batch stride, desc for Ds
683 static_for<0, NumDTensor, 1>{}([&](auto i) {
684 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
685 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
686
687 // D pointer
688 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
689
690 // D batch stride
691 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
692
693 ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
694 a_g_n_c_wis_strides,
695 b_g_k_c_xs_lengths,
696 b_g_k_c_xs_strides,
697 e_g_n_k_wos_lengths,
698 ds_g_n_k_wos_strides[i],
699 conv_filter_strides,
700 conv_filter_dilations,
701 input_left_pads,
702 input_right_pads};
703
704 // D desc
706 DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
707 });
708 compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
709
710 // populate desc for Ds/E
711 if(get_warp_size() == 64)
712 {
713 if constexpr(NXdlPerWave64 > 0)
714 {
716 }
717 }
718 else
719 {
720 if constexpr(NXdlPerWave32 > 0)
721 {
723 }
724 }
725 }
726
727 // private:
728 // pointers (tuple if multi AB, pointer if no)
731 typename GridwiseGemm64::DsGridPointer p_ds_grid_;
732 EDataType* p_e_grid_;
733
734 // tensor descriptors for problem definiton
736
738
743
744 // tensor descriptors for block/thread-wise copy
750
751 // block-to-e-tile map
753
754 // for computing batch offset
755 ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
757
758 // element-wise op
759 AElementwiseOperation a_element_op_;
760 BElementwiseOperation b_element_op_;
761 CDEElementwiseOperation cde_element_op_;
762
763 // for checking IsSupportedArgument()
776 };
777 template <typename GridwiseGemm>
778 static __device__ __host__ bool check_gemm_validity(const Argument& arg)
779 {
780 if constexpr(isMultiA || isMultiB)
781 {
782 // Genarate tuples with the same descriptors
783 const auto as_grid_desc_ak0_m_ak1 =
784 generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number<NumATensor>{});
785 const auto bs_grid_desc_bk0_n_bk1 =
786 generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number<NumBTensor>{});
787 return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
788 bs_grid_desc_bk0_n_bk1,
792 }
793 else
794 {
795 return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
800 }
801 }
802 static __device__ __host__ bool IsSupportedArgument(const Argument& arg)
803 {
804 namespace ctc = tensor_layout::convolution;
805
806 // check ConvolutionForwardSpecialization
807 if constexpr(ConvForwardSpecialization ==
809 {
810 // check if it's 1x1, stride=1 conv
811 for(index_t i = 0; i < NDimSpatial; ++i)
812 {
813 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
814 const index_t ConvStride = arg.conv_filter_strides_[i];
815 const index_t LeftPad = arg.input_left_pads_[i];
816 const index_t RightPad = arg.input_right_pads_[i];
817
818 if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
819 {
820 return false;
821 }
822 }
823 }
824 else if constexpr(ConvForwardSpecialization ==
826 {
827 // check if it's 1x1 conv
828 for(index_t i = 0; i < NDimSpatial; ++i)
829 {
830 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
831 const index_t LeftPad = arg.input_left_pads_[i];
832 const index_t RightPad = arg.input_right_pads_[i];
833
834 if(!(X == 1 && LeftPad == 0 && RightPad == 0))
835 {
836 return false;
837 }
838 }
839 }
840
841 // check vector access of A
842 // FIXME: layout
848 {
849 const index_t C = arg.a_g_n_c_wis_lengths_[2];
850
851 if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
852 {
853 return false;
854 }
855 }
856 else
857 {
858 return false;
859 }
860
861 // check vector access of B
862 // FIXME: layout
868
869 {
870 const index_t C = arg.b_g_k_c_xs_lengths_[2];
871
872 if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
873 {
874 return false;
875 }
876 }
877 else
878 {
879 return false;
880 }
881
882 // check vector access of Ds
883 bool valid = true;
884
885 static_for<0, NumDTensor, 1>{}([&](auto i) {
886 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
887 // FIXME: layout
893 {
894 const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
895
896 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
897 {
898 valid = false;
899 }
900
901 if constexpr(is_same_v<DLayout, ctc::G_K>)
902 {
903 // G and K must be the same
904 if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] ||
906 {
907 valid = false;
908 }
909 }
910 else
911 {
912 // E and D must have the same shape
913 for(index_t d = 0; d < NDimSpatial + 3; d++)
914 {
915 if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
916 {
917 valid = false;
918 }
919 }
920 }
921 }
922 else
923 {
924 valid = false;
925 }
926 });
927
928 if(!valid)
929 {
930 return false;
931 }
932
933 // check vector access of E
939 {
940 const index_t K = arg.e_g_n_k_wos_lengths_[2];
941
942 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
943 {
944 return false;
945 }
946 }
947 else
948 {
949 return false;
950 }
951
952 // check Gridwise GEMM
953 if(get_warp_size() == 64)
954 {
955 if constexpr(NXdlPerWave64 > 0)
956 {
958 }
959 }
960 else
961 {
962 if constexpr(NXdlPerWave32 > 0)
963 {
965 }
966 }
967 return false;
968 }
969
970 static __device__ __host__ auto MakeArgument(
971 APointers p_as,
972 BPointers p_bs,
974 void* p_e,
975 const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
976 const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
977 const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
978 const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
979 const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
980 const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
981 const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
982 const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
983 const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
984 const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
985 const ck::Array<index_t, NDimSpatial>& input_left_pads,
986 const ck::Array<index_t, NDimSpatial>& input_right_pads,
987 const AElementwiseOperation& a_element_op,
988 const BElementwiseOperation& b_element_op,
989 const CDEElementwiseOperation& cde_element_op)
990 {
991 return Argument{p_as,
992 p_bs,
993 p_ds,
994 p_e,
995 a_g_n_c_wis_lengths,
996 a_g_n_c_wis_strides,
997 b_g_k_c_xs_lengths,
998 b_g_k_c_xs_strides,
999 ds_g_n_k_wos_lengths,
1000 ds_g_n_k_wos_strides,
1001 e_g_n_k_wos_lengths,
1002 e_g_n_k_wos_strides,
1003 conv_filter_strides,
1004 conv_filter_dilations,
1005 input_left_pads,
1006 input_right_pads,
1007 a_element_op,
1008 b_element_op,
1009 cde_element_op};
1010 }
1011
1012 static __device__ __host__ auto MakeArgument(
1013 APointers p_as,
1014 BPointers p_bs,
1016 void* p_e,
1017 const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1018 const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1019 const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1020 const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1023 const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1024 const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1025 const ck::Array<long_index_t, NDimSpatial>& conv_filter_strides,
1026 const ck::Array<long_index_t, NDimSpatial>& conv_filter_dilations,
1027 const ck::Array<long_index_t, NDimSpatial>& input_left_pads,
1028 const ck::Array<long_index_t, NDimSpatial>& input_right_pads,
1029 const AElementwiseOperation& a_element_op,
1030 const BElementwiseOperation& b_element_op,
1031 const CDEElementwiseOperation& cde_element_op)
1032 {
1033 ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
1034 ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
1035 ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
1036 ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
1039 ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
1040 ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
1041 ck::Array<index_t, NDimSpatial> conv_filter_strides_i32;
1042 ck::Array<index_t, NDimSpatial> conv_filter_dilations_i32;
1043 ck::Array<index_t, NDimSpatial> input_left_pads_i32;
1044 ck::Array<index_t, NDimSpatial> input_right_pads_i32;
1045
1046 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
1047 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
1048 array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
1049 array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
1050 for(index_t d = 0; d < NumDTensor; d++)
1051 {
1052 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
1053 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
1054 }
1055 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
1056 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
1057 array_convert(conv_filter_strides_i32, conv_filter_strides);
1058 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
1059 array_convert(input_left_pads_i32, input_left_pads);
1060 array_convert(input_right_pads_i32, input_right_pads);
1061
1062 return Argument{p_as,
1063 p_bs,
1064 p_ds,
1065 p_e,
1066 a_g_n_c_wis_lengths_i32,
1067 a_g_n_c_wis_strides_i32,
1068 b_g_k_c_xs_lengths_i32,
1069 b_g_k_c_xs_strides_i32,
1070 ds_g_n_k_wos_lengths_i32,
1071 ds_g_n_k_wos_strides_i32,
1072 e_g_n_k_wos_lengths_i32,
1073 e_g_n_k_wos_strides_i32,
1074 conv_filter_strides_i32,
1075 conv_filter_dilations_i32,
1076 input_left_pads_i32,
1077 input_right_pads_i32,
1078 a_element_op,
1079 b_element_op,
1080 cde_element_op};
1081 }
1082};
1083
1084} // namespace device
1085} // namespace tensor_operation
1086} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition utility/array.hpp:14
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition functional2.hpp:33
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:518
ck::Array< index_t, NDimSpatial > conv_filter_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:772
ck::Array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:770
CDEElementwiseOperation cde_element_op_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:761
ck::Array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:767
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:731
ComputePtrOffsetOfStridedBatch< NumATensor, NumBTensor, NumDTensor > compute_ptr_offset_of_batch_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:756
ck::Array< index_t, NDimSpatial > input_left_pads_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:774
ck::Array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:771
AGridDesc_M_K a_grid_desc_m_k_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:739
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:746
__device__ __host__ Argument(APointers p_as, BPointers p_bs, const ck::Array< const void *, NumDTensor > &p_ds, void *p_e, const ck::Array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const ck::Array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const ck::Array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const ck::Array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const ck::Array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const ck::Array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const ck::Array< index_t, NDimSpatial > &conv_filter_strides, const ck::Array< index_t, NDimSpatial > &conv_filter_dilations, const ck::Array< index_t, NDimSpatial > &input_left_pads, const ck::Array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:562
EDataType * p_e_grid_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:732
ck::Array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:764
Block2ETileMap block_2_etile_map_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:752
ck::Array< index_t, NDimSpatial > input_right_pads_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:775
EGridDesc_M_N e_grid_desc_m_n_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:742
BGridPointer p_bs_grid_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:730
index_t num_group_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:735
AElementwiseOperation a_element_op_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:759
ck::Array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:766
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:749
AGridPointer p_as_grid_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:729
BGridDesc_N_K b_grid_desc_n_k_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:740
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:737
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:748
BElementwiseOperation b_element_op_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:760
ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:769
__host__ __device__ void init_ds_e_grid_desc()
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:520
ck::Array< index_t, NDimSpatial > conv_filter_dilations_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:773
ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:768
DsGridDesc_M_N ds_grid_desc_m_n_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:741
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:745
ck::Array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:765
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:358
static constexpr auto I1
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:372
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:441
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:434
ck::conditional_t< isMultiA, ck::Array< const void *, NumATensor > &, const void * > APointers
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:489
__host__ static __device__ auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:383
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:499
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:513
static constexpr auto NXdlPerWave32
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:362
static __device__ __host__ auto MakeArgument(APointers p_as, BPointers p_bs, const ck::Array< const void *, NumDTensor > &p_ds, void *p_e, const ck::Array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const ck::Array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const ck::Array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const ck::Array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const ck::Array< ck::Array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const ck::Array< ck::Array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const ck::Array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const ck::Array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const ck::Array< long_index_t, NDimSpatial > &conv_filter_strides, const ck::Array< long_index_t, NDimSpatial > &conv_filter_dilations, const ck::Array< long_index_t, NDimSpatial > &input_left_pads, const ck::Array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1012
remove_cvref_t< decltype(MakeAGridDescriptor_M_K< ALayout >(dummy_conv_to_gemm_transformer))> AGridDesc_M_K
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:435
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:361
static constexpr auto I0
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:371
static constexpr bool isMultiA
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:364
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > ConvToGemmFwdTransformer
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:376
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:508
remove_cvref_t< decltype(MakeBGridDescriptor_N_K< BLayout >(dummy_conv_to_gemm_transformer))> BGridDesc_N_K
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:437
__host__ static __device__ auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:396
ck::conditional_t< isMultiA||isMultiB, GridwiseGemmMultipleABD_xdl_cshuffle< GridwiseGemmMultiABDTemplateParameters >, GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmTemplateParameters > > GridwiseGemmBase
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:481
ck::conditional_t<!isMultiB &&isMultiA, Tuple< BDataType >, BDataType > GemmBDataType
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:447
remove_cvref_t< decltype(GetBGridPointer< isMultiA||isMultiB, GridwiseGemm64, BDataType >())> BGridPointer
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:495
static constexpr index_t NumATensor
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:367
ck::conditional_t<!isMultiA &&isMultiB, Tuple< ADataType >, ADataType > GemmADataType
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:446
static constexpr index_t NumDTensor
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:369
static __device__ __host__ auto MakeArgument(APointers p_as, BPointers p_bs, const ck::Array< const void *, NumDTensor > &p_ds, void *p_e, const ck::Array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const ck::Array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const ck::Array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const ck::Array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const ck::Array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const ck::Array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const ck::Array< index_t, NDimSpatial > &conv_filter_strides, const ck::Array< index_t, NDimSpatial > &conv_filter_dilations, const ck::Array< index_t, NDimSpatial > &input_left_pads, const ck::Array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:970
static constexpr auto I2
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:373
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:502
ck::conditional_t< isMultiB, ck::Array< const void *, NumBTensor > &, const void * > BPointers
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:490
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:505
static constexpr bool isMultiB
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:365
remove_cvref_t< decltype(GetAGridPointer< isMultiA||isMultiB, GridwiseGemm64, ADataType >())> AGridPointer
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:493
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:485
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:422
static __device__ __host__ bool IsSupportedArgument(const Argument &arg)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:802
CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:359
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:439
static constexpr auto matrix_padder
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:378
static constexpr auto I3
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:374
static __device__ __host__ bool check_gemm_validity(const Argument &arg)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:778
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:486
static constexpr index_t NumBTensor
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:368
__host__ static __device__ auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:409
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
Definition matrix_padder.hpp:180