transform_conv_bwd_data_to_gemm.hpp Source File

transform_conv_bwd_data_to_gemm.hpp Source File#

Composable Kernel: transform_conv_bwd_data_to_gemm.hpp Source File
transform_conv_bwd_data_to_gemm.hpp
Go to the documentation of this file.
1
2// SPDX-License-Identifier: MIT
3// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
4
5#pragma once
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <index_t NDimSpatial,
13 index_t VectorSizeA,
14 index_t VectorSizeB,
15 index_t VectorSizeC,
16 bool SplitN = false,
17 typename ADataType = float,
18 typename CDataType = float,
19 index_t NumGroupsToMerge = 1,
20 typename IndexType = index_t>
22{
23 private:
24 static constexpr auto I0 = number<0>{};
25 static constexpr auto I1 = number<1>{};
26 static constexpr auto I2 = number<2>{};
27 static constexpr auto I3 = number<3>{};
28 static constexpr auto I4 = number<4>{};
29 static constexpr auto I5 = number<5>{};
30
31 template <typename ConvDimsType>
32 static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
33 const ConvDimsType& strides,
34 index_t i)
35 {
36 long_index_t acc = 1;
37 for(; i < (NDimSpatial + 3); i++)
38 {
39 acc +=
40 static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
41 }
42
43 return acc;
44 }
45
46 template <typename ConvDimsType>
47 static IndexType GetSplitedNSize(const ConvDimsType& c_g_n_k_wos_lengths,
48 const ConvDimsType& a_g_n_c_wis_lengths)
49 {
50
51 // Calculate strides internally assuming contiguous memory layout
52 ConvDimsType c_g_n_k_wos_strides, a_g_n_c_wis_strides;
53 const index_t num_dims = c_g_n_k_wos_strides.size();
54
55 // Calculate strides for input tensor (innermost to outermost),
56 // Don't include outermost dimension G since it's gemm batch.
57 a_g_n_c_wis_strides[num_dims - 1] = 1;
58 for(index_t i = num_dims - 2; i >= 1; i--)
59 {
60 a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1];
61 }
62
63 // Calculate strides for output tensor,
64 // Don't include outermost dimension G since it's gemm batch.
65 c_g_n_k_wos_strides[num_dims - 1] = 1;
66 for(index_t i = num_dims - 2; i >= 1; i--)
67 {
68 c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1];
69 }
70
71 const long_index_t a_element_space_size =
72 calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
73 const long_index_t c_element_space_size =
74 calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
75 const long_index_t element_space_size = ck_tile::max(
76 a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
77
78 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
79
80 const IndexType N = c_g_n_k_wos_lengths[I1];
81
82 if(element_space_size > TwoGB)
83 {
84 // Minimum divisor of N to not exceed 2GB
85 const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB);
86
87 if(divisor <= static_cast<double>(N))
88 {
89 // Find least divisor of N larger than element_space_size / TwoGB
90 // Iterate up to sqrt(N). There are no divisors above this value.
91 for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
92 least_divisor++)
93 {
94 if(N % least_divisor == 0)
95 {
96 return N / least_divisor;
97 }
98 }
99 // Not found, process one Convolution N per block
100 return 1;
101 }
102 else
103 {
104 // Split Convolution's N dimension into N workgroups. However
105 // this still might not result in sufficiently small tensor,
106 // but at least later on we could divide the image as well.
107 return 1;
108 }
109 }
110 else
111 {
112 // Split N is not needed.
113 return N;
114 }
115 }
116
117 public:
118 // Public getter methods for Split-N support
119 CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
120 CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; }
121
123
124 template <typename TransformConvBwdDataToGemmBase>
126 TransformConvBwdDataToGemm(const TransformConvBwdDataToGemmBase& transform_conv_to_gemm_base)
127 : G_{static_cast<IndexType>(transform_conv_to_gemm_base.G_)},
128 N_{static_cast<IndexType>(transform_conv_to_gemm_base.N_)},
129 original_N_{static_cast<IndexType>(transform_conv_to_gemm_base.original_N_)},
130 Di_{static_cast<IndexType>(transform_conv_to_gemm_base.Di_)},
131 Hi_{static_cast<IndexType>(transform_conv_to_gemm_base.Hi_)},
132 Wi_{static_cast<IndexType>(transform_conv_to_gemm_base.Wi_)},
133 Do_{static_cast<IndexType>(transform_conv_to_gemm_base.Do_)},
134 Ho_{static_cast<IndexType>(transform_conv_to_gemm_base.Ho_)},
135 Wo_{static_cast<IndexType>(transform_conv_to_gemm_base.Wo_)},
136 Z_{static_cast<IndexType>(transform_conv_to_gemm_base.Z_)},
137 Y_{static_cast<IndexType>(transform_conv_to_gemm_base.Y_)},
138 X_{static_cast<IndexType>(transform_conv_to_gemm_base.X_)},
139 K_{static_cast<IndexType>(transform_conv_to_gemm_base.K_)},
140 C_{static_cast<IndexType>(transform_conv_to_gemm_base.C_)},
141 ConvStrideD_{static_cast<IndexType>(transform_conv_to_gemm_base.ConvStrideD_)},
142 ConvStrideH_{static_cast<IndexType>(transform_conv_to_gemm_base.ConvStrideH_)},
143 ConvStrideW_{static_cast<IndexType>(transform_conv_to_gemm_base.ConvStrideW_)},
144 ConvDilationD_{static_cast<IndexType>(transform_conv_to_gemm_base.ConvDilationD_)},
145 ConvDilationH_{static_cast<IndexType>(transform_conv_to_gemm_base.ConvDilationH_)},
146 ConvDilationW_{static_cast<IndexType>(transform_conv_to_gemm_base.ConvDilationW_)},
147 InLeftPadD_{static_cast<IndexType>(transform_conv_to_gemm_base.InLeftPadD_)},
148 InLeftPadH_{static_cast<IndexType>(transform_conv_to_gemm_base.InLeftPadH_)},
149 InLeftPadW_{static_cast<IndexType>(transform_conv_to_gemm_base.InLeftPadW_)},
150 InRightPadD_{static_cast<IndexType>(transform_conv_to_gemm_base.InRightPadD_)},
151 InRightPadH_{static_cast<IndexType>(transform_conv_to_gemm_base.InRightPadH_)},
152 InRightPadW_{static_cast<IndexType>(transform_conv_to_gemm_base.InRightPadW_)}
153 {
154 }
155
156 template <typename ConvDimsType,
157 typename ConvSpatialDimsType,
158 index_t NDim = NDimSpatial,
159 typename std::enable_if<NDim == 1, bool>::type = false>
160 CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
161 const ConvDimsType& b_g_k_c_xs_lengths,
162 const ConvDimsType& c_g_n_k_wos_lengths,
163 const ConvSpatialDimsType& conv_filter_strides,
164 const ConvSpatialDimsType& conv_filter_dilations,
165 const ConvSpatialDimsType& input_left_pads,
166 const ConvSpatialDimsType& input_right_pads,
167 const ConvSpatialDimsType& tildes)
168 : G_{a_g_n_c_wis_lengths[I0]},
169 N_{a_g_n_c_wis_lengths[I1]},
170 Di_{I1},
171 Hi_{I1},
172 Wi_{a_g_n_c_wis_lengths[I3]},
173 Do_{I1},
174 Ho_{I1},
175 Wo_{c_g_n_k_wos_lengths[I3]},
176 Z_{I1},
177 Y_{I1},
178 X_{b_g_k_c_xs_lengths[I3]},
179 K_{c_g_n_k_wos_lengths[I2]},
180 C_{b_g_k_c_xs_lengths[I2]},
181 ConvStrideD_{I1},
182 ConvStrideH_{I1},
183 ConvStrideW_{conv_filter_strides[I0]},
184 ConvDilationD_{I1},
185 ConvDilationH_{I1},
186 ConvDilationW_{conv_filter_dilations[I0]},
187 InLeftPadD_{I0},
188 InLeftPadH_{I0},
189 InLeftPadW_{input_left_pads[I0]},
190 InRightPadD_{I0},
191 InRightPadH_{I0},
192 InRightPadW_{input_right_pads[I0]},
193 IdxZTilde_{I1},
194 IdxYTilde_{I1},
195 IdxXTilde_{tildes[I0]}
196 {
197
198 // Store original N
199 original_N_ = a_g_n_c_wis_lengths[I1];
200
201 if constexpr(SplitN)
202 {
203 N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths);
204 }
205 else
206 {
207 N_ = a_g_n_c_wis_lengths[I1];
208 }
209
214 }
215
216 template <typename ConvDimsType,
217 typename ConvSpatialDimsType,
218 index_t NDim = NDimSpatial,
219 typename std::enable_if<NDim == 2, bool>::type = false>
220 CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
221 const ConvDimsType& b_g_k_c_xs_lengths,
222 const ConvDimsType& c_g_n_k_wos_lengths,
223 const ConvSpatialDimsType& conv_filter_strides,
224 const ConvSpatialDimsType& conv_filter_dilations,
225 const ConvSpatialDimsType& input_left_pads,
226 const ConvSpatialDimsType& input_right_pads,
227 const ConvSpatialDimsType& tildes)
228 : G_{a_g_n_c_wis_lengths[I0]},
229 N_{a_g_n_c_wis_lengths[I1]},
230 Di_{I1},
231 Hi_{a_g_n_c_wis_lengths[I3]},
232 Wi_{a_g_n_c_wis_lengths[I4]},
233 Do_{I1},
234 Ho_{c_g_n_k_wos_lengths[I3]},
235 Wo_{c_g_n_k_wos_lengths[I4]},
236 Z_{I1},
237 Y_{b_g_k_c_xs_lengths[I3]},
238 X_{b_g_k_c_xs_lengths[I4]},
239 K_{c_g_n_k_wos_lengths[I2]},
240 C_{b_g_k_c_xs_lengths[I2]},
241 ConvStrideD_{I1},
242 ConvStrideH_{conv_filter_strides[I0]},
243 ConvStrideW_{conv_filter_strides[I1]},
244 ConvDilationD_{I1},
245 ConvDilationH_{conv_filter_dilations[I0]},
246 ConvDilationW_{conv_filter_dilations[I1]},
247 InLeftPadD_{I0},
248 InLeftPadH_{input_left_pads[I0]},
249 InLeftPadW_{input_left_pads[I1]},
250 InRightPadD_{I0},
251 InRightPadH_{input_right_pads[I0]},
252 InRightPadW_{input_right_pads[I1]},
253 IdxZTilde_{I1},
254 IdxYTilde_{tildes[I0]},
255 IdxXTilde_{tildes[I1]}
256 {
257
258 // Store original N
259 original_N_ = a_g_n_c_wis_lengths[I1];
260
261 if constexpr(SplitN)
262 {
263 N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths);
264 }
265 else
266 {
267 N_ = a_g_n_c_wis_lengths[I1];
268 }
269
278 }
279
280 template <typename ConvDimsType,
281 typename ConvSpatialDimsType,
282 index_t NDim = NDimSpatial,
283 typename std::enable_if<NDim == 3, bool>::type = false>
284 CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
285 const ConvDimsType& b_g_k_c_xs_lengths,
286 const ConvDimsType& c_g_n_k_wos_lengths,
287 const ConvSpatialDimsType& conv_filter_strides,
288 const ConvSpatialDimsType& conv_filter_dilations,
289 const ConvSpatialDimsType& input_left_pads,
290 const ConvSpatialDimsType& input_right_pads,
291 [[maybe_unused]] const ConvSpatialDimsType& tildes)
292 : G_{a_g_n_c_wis_lengths[I0]},
293 N_{a_g_n_c_wis_lengths[I1]},
294 Di_{a_g_n_c_wis_lengths[I3]},
295 Hi_{a_g_n_c_wis_lengths[I4]},
296 Wi_{a_g_n_c_wis_lengths[I5]},
297 Do_{c_g_n_k_wos_lengths[I3]},
298 Ho_{c_g_n_k_wos_lengths[I4]},
299 Wo_{c_g_n_k_wos_lengths[I5]},
300 Z_{b_g_k_c_xs_lengths[I3]},
301 Y_{b_g_k_c_xs_lengths[I4]},
302 X_{b_g_k_c_xs_lengths[I5]},
303 K_{c_g_n_k_wos_lengths[I2]},
304 C_{b_g_k_c_xs_lengths[I2]},
305 ConvStrideD_{conv_filter_strides[I0]},
306 ConvStrideH_{conv_filter_strides[I1]},
307 ConvStrideW_{conv_filter_strides[I2]},
308 ConvDilationD_{conv_filter_dilations[I0]},
309 ConvDilationH_{conv_filter_dilations[I1]},
310 ConvDilationW_{conv_filter_dilations[I2]},
311 InLeftPadD_{input_left_pads[I0]},
312 InLeftPadH_{input_left_pads[I1]},
313 InLeftPadW_{input_left_pads[I2]},
314 InRightPadD_{input_right_pads[I0]},
315 InRightPadH_{input_right_pads[I1]},
316 InRightPadW_{input_right_pads[I2]},
317 IdxZTilde_{tildes[I0]},
318 IdxYTilde_{tildes[I1]},
319 IdxXTilde_{tildes[I2]}
320 {
321
322 // Store original N
323 original_N_ = a_g_n_c_wis_lengths[I1];
324
325 if constexpr(SplitN)
326 {
327 N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths);
328 }
329 else
330 {
331 N_ = a_g_n_c_wis_lengths[I1];
332 }
333
346 }
347
348#if 0 // TODO: Enable these functionalities
349 __host__ bool AreDescriptorsSmallerThan2GB() const
350 {
351 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
352
353 const long_index_t in_desc_space_size =
354 I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
355 (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
356 const long_index_t out_desc_space_size =
357 I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
358 (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
359
360 bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
361 bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
362
363 return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
364 }
365
366 __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
367 CDataType* c_grid_ptr_base) const
368 {
369 // Create copies
370 auto conv_to_gemm_transformer_left = *this;
371 auto conv_to_gemm_transformer_right = *this;
372 IndexType a_right_offset = 0;
373 IndexType c_right_offset = 0;
374 // Calculate real filter size
375 const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
376 const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
377 const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
378 // Calculate start position in input for right tensor
379 const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
380 const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
381 const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
382 // Calculate last position in input for left tensor
383 const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
384 const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
385 const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
386 // Allow to split if whole left padding will be in left tensor and right padding in right
387 // tensor
388 const bool is_possible_to_split_d = Do_ != 1 &&
389 di_right_transformer_start_idx > InLeftPadD_ &&
390 di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
391 const bool is_possible_to_split_h = Ho_ != 1 &&
392 hi_right_transformer_start_idx > InLeftPadH_ &&
393 hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
394 const bool is_possible_to_split_w = Wo_ != 1 &&
395 wi_right_transformer_start_idx > InLeftPadW_ &&
396 wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
397
398 if(is_possible_to_split_d)
399 {
400 // Apply new sizes
401 // Split output on half
402 conv_to_gemm_transformer_left.Do_ = Do_ / 2;
403 conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
404 // Assign left padding to left convolution
405 conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
406 conv_to_gemm_transformer_right.InLeftPadD_ = 0;
407 // Assign right padding to right convolution
408 conv_to_gemm_transformer_left.InRightPadD_ = 0;
409 conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
410 // Calculate new input size
411 conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
412 conv_to_gemm_transformer_right.Di_ =
413 math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
414 (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
415 ;
416 // Calcualte offsets
417 a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
418 c_right_offset = (Do_ / 2) * DoStride_;
419 }
420 else if(is_possible_to_split_h)
421 {
422 conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
423 conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
424
425 conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
426 conv_to_gemm_transformer_right.InLeftPadH_ = 0;
427
428 conv_to_gemm_transformer_left.InRightPadH_ = 0;
429 conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
430
431 conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
432 conv_to_gemm_transformer_right.Hi_ =
433 math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
434 (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
435 a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
436 c_right_offset = (Ho_ / 2) * HoStride_;
437 }
438 else if(is_possible_to_split_w)
439 {
440 conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
441 conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
442
443 conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
444 conv_to_gemm_transformer_right.InLeftPadW_ = 0;
445
446 conv_to_gemm_transformer_left.InRightPadW_ = 0;
447 conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
448
449 conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
450 conv_to_gemm_transformer_right.Wi_ =
451 math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
452 (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
453
454 a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
455 c_right_offset = (Wo_ / 2) * WoStride_;
456 }
457 // Return left transform, right transformer, right offset to Input and right offset to
458 // Output
459 return ck_tile::make_tuple(conv_to_gemm_transformer_left,
460 conv_to_gemm_transformer_right,
461 a_grid_ptr_base + a_right_offset,
462 c_grid_ptr_base + c_right_offset);
463 }
464#endif
465
466 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
468 {
469 // NWGK
470 const index_t NStride = Wo_ * G_ * K_;
471 const index_t WoStride = G_ * K_;
472 constexpr auto KStride = I1;
473
474 // TODO Add support for NumGroupsToMerge > 1
475
477 make_tuple(NStride, WoStride, KStride),
479 I1);
480 }
481
482 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
484 {
485 // GKXC
487 make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number<VectorSizeB>{}, I1);
488 }
489
490 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
492 {
493 // NWGC
494 const index_t NStride = Wi_ * G_ * C_;
495 const index_t WiStride = G_ * C_; // GC?
496 constexpr auto CStride = I1;
497
498 // TODO Add support for NumGroupsToMerge > 1
500 make_tuple(NStride, WiStride, CStride),
502 I1);
503 }
504
505 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
507 {
508 // NHWGK
509 const index_t NStride = Ho_ * Wo_ * G_ * K_;
510 const index_t HoStride = Wo_ * G_ * K_;
511 const index_t WoStride = G_ * K_;
512 constexpr auto KStride = I1;
513
514 // TODO Add support for NumGroupsToMerge > 1
515
517 make_tuple(NStride, HoStride, WoStride, KStride),
519 I1);
520 }
521
522 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
524 {
525 // NHWGC
526 const index_t NStride = Hi_ * Wi_ * G_ * C_;
527 const index_t HiStride = Wi_ * G_ * C_;
528 const index_t WiStride = G_ * C_;
529 constexpr auto CStride = I1;
530
531 // TODO Add support for NumGroupsToMerge > 1
533 make_tuple(NStride, HiStride, WiStride, CStride),
535 I1);
536 }
537
538 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
540 {
541 // GKYXC
543 make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1),
545 I1);
546 }
547
548 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
550 {
551 // NDHWGK
552 const index_t NStride = Do_ * Ho_ * Wo_ * G_ * K_;
553 const index_t DoStride = Ho_ * Wo_ * G_ * K_;
554 const index_t HoStride = Wo_ * G_ * K_;
555 const index_t WoStride = G_ * K_;
556 constexpr auto KStride = I1;
557
558 // TODO Add support for NumGroupsToMerge > 1
560 make_tuple(N_, Do_, Ho_, Wo_, K_),
561 make_tuple(NStride, DoStride, HoStride, WoStride, KStride),
563 I1);
564 }
565
566 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
568 {
569 const index_t NStride = Di_ * Hi_ * Wi_ * G_ * C_;
570 const index_t DiStride = Hi_ * Wi_ * G_ * C_;
571 const index_t HiStride = Wi_ * G_ * C_;
572 const index_t WiStride = G_ * C_;
573 constexpr auto CStride = I1;
574
575 // TODO Add support for NumGroupsToMerge > 1
577 make_tuple(N_, Di_, Hi_, Wi_, C_),
578 make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
580 I1);
581 }
582
583 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
585 {
586 // GKZYXC
588 make_tuple(K_, Z_, Y_, X_, C_),
589 make_tuple(C_ * X_ * Y_ * Z_, C_ * X_ * Y_, C_ * X_, C_, I1),
591 I1);
592 }
593 // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
594 // properties
595
596 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
597 CK_TILE_HOST auto
598 MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
599 {
600 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
601 const auto IWTildeSliceBegin = integer_divide_floor(
603
604 const auto IWTildeSliceEnd =
606
607 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
608
609 // GemmK is different for each GEMM
610 const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
611
612 const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
613 const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
614 const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
615
616 // A: output tensor comes in K_M
617 const auto out_n_wop_k_grid_desc =
618 transform_tensor_descriptor(out_grid_desc,
620 make_pad_transform(Wo_, I0, I0),
624
625 const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
626 out_n_wop_k_grid_desc,
633
634 const auto out_n_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor(
635 out_n_xdot_wtilde_k_grid_desc,
637 make_slice_transform(XDot_, I0, XDotSlice),
638 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
642
643 const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor(
644 out_n_xdotslice_wtildeslice_k_grid_desc,
646 make_merge_transform(make_tuple(N_, WTildeSlice))),
649
650 // B: weight tensor comes in K_N
651 const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
652 wei_grid_desc,
659
660 const auto wei_k_xdotslice_c_grid_desc = transform_tensor_descriptor(
661 wei_k_xdot_xtilde_c_grid_desc,
663 make_slice_transform(XDot_, I0, XDotSlice),
668
669 const auto wei_gemmn_gemmkraw_grid_desc =
670 transform_tensor_descriptor(wei_k_xdotslice_c_grid_desc,
675
676 // c: input
677 const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
678 in_grid_desc,
684
685 const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
686 in_n_wip_c_grid_desc,
693
694 const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
695 in_n_xtilde_wtilde_c_grid_desc,
698 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
702
703 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
704 in_n_wtildeslice_c_grid_desc,
709
710 return make_tuple(out_gemmm_gemmkraw_grid_desc,
711 wei_gemmn_gemmkraw_grid_desc,
712 in_gemmmraw_gemmnraw_grid_desc);
713 }
714
715 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
716 CK_TILE_HOST auto
717 MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
718 {
719 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
720 const auto IHTildeSliceBegin = integer_divide_floor(
722 const auto IWTildeSliceBegin = integer_divide_floor(
724
725 const auto IHTildeSliceEnd =
727 const auto IWTildeSliceEnd =
729
730 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
731 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
732
733 // GemmK is different for each GEMM
734 const auto YDotSlice = integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
735 const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
736
737 const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
738 const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
739 const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
740
741 // A: output tensor comes in K_M
742 const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
743 out_grid_desc,
745 make_pad_transform(Ho_, I0, I0),
746 make_pad_transform(Wo_, I0, I0),
750
751 const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
752 out_n_hop_wop_k_grid_desc,
761
762 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
764 out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
766 make_slice_transform(YDot_, I0, YDotSlice),
767 make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
768 make_slice_transform(XDot_, I0, XDotSlice),
769 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
772 sequence<1>{},
773 sequence<2>{},
774 sequence<3>{},
775 sequence<4>{},
776 sequence<5>{}),
778 sequence<1>{},
779 sequence<2>{},
780 sequence<3>{},
781 sequence<4>{},
782 sequence<5>{}));
783
784 const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor(
785 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
786 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
787 make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))),
790
791 // B: weight tensor comes in K_N
792 const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
793 wei_grid_desc,
802
803 const auto wei_k_ydotslice_xdotslice_c_grid_desc =
804 transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
806 make_slice_transform(YDot_, I0, YDotSlice),
807 make_slice_transform(XDot_, I0, XDotSlice),
812 sequence<1>{},
813 sequence<3>{},
814 sequence<2>{},
815 sequence<4>{},
816 sequence<5>{}),
818 sequence<1>{},
819 sequence<2>{},
820 sequence<>{},
821 sequence<>{},
822 sequence<3>{}));
823
824 const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor(
825 wei_k_ydotslice_xdotslice_c_grid_desc,
826 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
830
831 // c: input
832 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
833 in_grid_desc,
840
841 const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
842 in_n_hip_wip_c_grid_desc,
851
852 const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
853 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
856 make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
858 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
861 sequence<1>{},
862 sequence<2>{},
863 sequence<3>{},
864 sequence<4>{},
865 sequence<5>{}),
867 sequence<>{},
868 sequence<1>{},
869 sequence<>{},
870 sequence<2>{},
871 sequence<3>{}));
872
873 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
874 in_n_htildeslice_wtildeslice_c_grid_desc,
875 make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)),
879
880 return make_tuple(out_gemmm_gemmkraw_grid_desc,
881 wei_gemmn_gemmkraw_grid_desc,
882 in_gemmmraw_gemmnraw_grid_desc);
883 }
884
885 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
886 CK_TILE_HOST auto
887 MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
888 {
889 // only work on DTilde, HTilde and WTilde that contribute to non-padding area of input
890 // tensor
891 const auto IDTildeSliceBegin = integer_divide_floor(
893 const auto IHTildeSliceBegin = integer_divide_floor(
895 const auto IWTildeSliceBegin = integer_divide_floor(
897
898 const auto IDTildeSliceEnd =
900 const auto IHTildeSliceEnd =
902 const auto IWTildeSliceEnd =
904
905 const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
906 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
907 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
908
909 // GemmK is different for each GEMM
910 const auto ZDotSlice = integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
911 const auto YDotSlice = integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
912 const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
913
914 const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
915 const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
916 const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
917
918 // A: output tensor comes in K_M
919 const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
920 out_grid_desc,
922 make_pad_transform(Do_, I0, I0),
923 make_pad_transform(Ho_, I0, I0),
924 make_pad_transform(Wo_, I0, I0),
928
929 const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
930 out_n_hop_wop_k_grid_desc,
944 sequence<7>{}));
945
946 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
948 out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
950 make_slice_transform(ZDot_, I0, ZDotSlice),
951 make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
952 make_slice_transform(YDot_, I0, YDotSlice),
953 make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
954 make_slice_transform(XDot_, I0, XDotSlice),
955 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
958 sequence<1>{},
959 sequence<2>{},
960 sequence<3>{},
961 sequence<4>{},
962 sequence<5>{},
963 sequence<6>{},
964 sequence<7>{}),
966 sequence<1>{},
967 sequence<2>{},
968 sequence<3>{},
969 sequence<4>{},
970 sequence<5>{},
971 sequence<6>{},
972 sequence<7>{}));
973
974 const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor(
975 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
976 make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
977 make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))),
980
981 // B: weight tensor comes in K_N
982 const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
983 wei_grid_desc,
997 sequence<7>{}));
998
999 const auto wei_k_ydotslice_xdotslice_c_grid_desc =
1000 transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
1002 make_slice_transform(ZDot_, I0, ZDotSlice),
1003 make_slice_transform(YDot_, I0, YDotSlice),
1004 make_slice_transform(XDot_, I0, XDotSlice),
1010 sequence<1>{},
1011 sequence<3>{},
1012 sequence<5>{},
1013 sequence<2>{},
1014 sequence<4>{},
1015 sequence<6>{},
1016 sequence<7>{}),
1018 sequence<1>{},
1019 sequence<2>{},
1020 sequence<3>{},
1021 sequence<>{},
1022 sequence<>{},
1023 sequence<>{},
1024 sequence<4>{}));
1025
1026 const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor(
1027 wei_k_ydotslice_xdotslice_c_grid_desc,
1028 make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
1032
1033 // c: input
1034 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
1035 in_grid_desc,
1043
1044 const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
1045 in_n_hip_wip_c_grid_desc,
1059 sequence<7>{}));
1060
1061 const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
1062 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
1065 make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
1067 make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
1069 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
1072 sequence<1>{},
1073 sequence<2>{},
1074 sequence<3>{},
1075 sequence<4>{},
1076 sequence<5>{},
1077 sequence<6>{},
1078 sequence<7>{}),
1080 sequence<>{},
1081 sequence<1>{},
1082 sequence<>{},
1083 sequence<2>{},
1084 sequence<>{},
1085 sequence<3>{},
1086 sequence<4>{}));
1087
1088 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
1089 in_n_htildeslice_wtildeslice_c_grid_desc,
1090 make_tuple(make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)),
1094
1095 return make_tuple(out_gemmm_gemmkraw_grid_desc,
1096 wei_gemmn_gemmkraw_grid_desc,
1097 in_gemmmraw_gemmnraw_grid_desc);
1098 }
1099
1100 IndexType G_, N_, original_N_;
1101 IndexType Di_, Hi_, Wi_;
1102 IndexType Do_, Ho_, Wo_;
1103 IndexType Z_, Y_, X_;
1104 IndexType K_, C_;
1113 IndexType ZDot_, YDot_, XDot_;
1114};
1115
1116} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
ConvolutionSpecialization
Definition convolution_specialization.hpp:11
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
Definition tile/core/numeric/math.hpp:143
CK_TILE_HOST_DEVICE constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition coordinate_transform.hpp:1629
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1565
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
Definition tile/core/numeric/math.hpp:268
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition coordinate_transform.hpp:1647
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition coordinate_transform.hpp:1594
CK_TILE_HOST TransformConvBwdDataToGemm(const TransformConvBwdDataToGemmBase &transform_conv_to_gemm_base)
Definition transform_conv_bwd_data_to_gemm.hpp:126
CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType &a_g_n_c_wis_lengths, const ConvDimsType &b_g_k_c_xs_lengths, const ConvDimsType &c_g_n_k_wos_lengths, const ConvSpatialDimsType &conv_filter_strides, const ConvSpatialDimsType &conv_filter_dilations, const ConvSpatialDimsType &input_left_pads, const ConvSpatialDimsType &input_right_pads, const ConvSpatialDimsType &tildes)
Definition transform_conv_bwd_data_to_gemm.hpp:160
CK_TILE_HOST auto make_wei_grid_desc() const
Definition transform_conv_bwd_data_to_gemm.hpp:483
CK_TILE_HOST constexpr TransformConvBwdDataToGemm()
Definition transform_conv_bwd_data_to_gemm.hpp:122
CK_TILE_HOST constexpr IndexType GetN() const
Definition transform_conv_bwd_data_to_gemm.hpp:119
CK_TILE_HOST constexpr IndexType GetOriginalN() const
Definition transform_conv_bwd_data_to_gemm.hpp:120
CK_TILE_HOST auto make_in_grid_desc() const
Definition transform_conv_bwd_data_to_gemm.hpp:491
CK_TILE_HOST auto make_out_grid_desc() const
Definition transform_conv_bwd_data_to_gemm.hpp:467
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t GemmKBatch) const
Definition transform_conv_bwd_data_to_gemm.hpp:598
Definition tile/core/container/sequence.hpp:49