transform_conv_fwd_to_gemm.hpp Source File

transform_conv_fwd_to_gemm.hpp Source File#

Composable Kernel: transform_conv_fwd_to_gemm.hpp Source File
tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
Go to the documentation of this file.
1
2// SPDX-License-Identifier: MIT
3// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
4
5#pragma once
6
12
13namespace ck {
14namespace tensor_operation {
15
16template <index_t NDimSpatial,
17 device::ConvolutionForwardSpecialization ConvForwardSpecialization,
18 bool SplitN = false,
19 typename ADataType = float,
20 typename CDataType = float,
21 index_t NumGroupsToMerge = 1,
22 typename IndexType = index_t,
23 bool CTranspose = false>
25{
26 private:
27 static constexpr auto I0 = Number<0>{};
28 static constexpr auto I1 = Number<1>{};
29 static constexpr auto I2 = Number<2>{};
30 static constexpr auto I3 = Number<3>{};
31 static constexpr auto I4 = Number<4>{};
32 static constexpr auto I5 = Number<5>{};
33
34 template <typename ConvDimsType>
35 static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
36 const ConvDimsType& strides,
37 index_t i)
38 {
39 long_index_t acc = 1;
40 for(; i < (NDimSpatial + 3); i++)
41 {
42 acc +=
43 static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
44 }
45
46 return acc;
47 }
48
49 template <typename ConvDimsType>
50 static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
51 const ConvDimsType& a_g_n_c_wis_strides,
52 const ConvDimsType& c_g_n_k_wos_lengths,
53 const ConvDimsType& c_g_n_k_wos_strides)
54 {
55 const long_index_t a_element_space_size =
56 calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
57 const long_index_t c_element_space_size =
58 calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
59 const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
60 c_element_space_size * sizeof(CDataType));
61 constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB threshold
62
63 const IndexType N = a_g_n_c_wis_lengths[I1];
64
65 if(element_space_size > TwoGB)
66 {
67 // Minimum divisor of N to not exceed 2GB
68 const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
69
70 if(divisor <= static_cast<double>(N))
71 {
72 // Find least divisor of N larger than element_space_size / TwoGB
73 // Iterate up to sqrt(N). There are no divisors above this value.
74 for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
75 least_divisor++)
76 {
77 if(N % least_divisor == 0)
78 {
79 return N / least_divisor;
80 }
81 }
82 // Not found, process one Convolution N per block
83 return 1;
84 }
85 else
86 {
87 // Split Convolution's N dimension into N workgroups. However
88 // this still might not result in sufficiently small tensor,
89 // but at least later on we could divide the image as well.
90 return 1;
91 }
92 }
93 else
94 {
95 // Split N is not needed.
96 return N;
97 }
98 }
99
100 public:
101 __host__ __device__ constexpr TransformConvFwdToGemm() {}
102
103 template <typename TransformConvFwdToGemmBase>
104 __host__ __device__
105 TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base)
106 : N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
107 Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
108 Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
109 Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
110 Do_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Do_)},
111 Ho_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Ho_)},
112 Wo_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wo_)},
113 Z_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Z_)},
114 Y_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Y_)},
115 X_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.X_)},
116 K_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.K_)},
117 C_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.C_)},
118 DiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DiStride_)},
119 HiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HiStride_)},
120 WiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WiStride_)},
121 DoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DoStride_)},
122 HoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HoStride_)},
123 WoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WoStride_)},
124 XStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.XStride_)},
125 CStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorA_)},
126 CStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorB_)},
127 KStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorB_)},
128 KStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorC_)},
129 NStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorA_)},
130 NStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorC_)},
131 GStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorA_)},
132 GStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorB_)},
133 GStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorC_)},
134 ConvStrideD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideD_)},
135 ConvStrideH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideH_)},
136 ConvStrideW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideW_)},
137 ConvDilationD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationD_)},
138 ConvDilationH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationH_)},
139 ConvDilationW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationW_)},
140 InLeftPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadD_)},
141 InLeftPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadH_)},
142 InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
143 InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
144 InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
145 InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
146 ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
147 {
148 }
149
150 template <typename ConvDimsType,
151 typename ConvSpatialDimsType,
152 index_t NDim = NDimSpatial,
154 __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
155 const ConvDimsType& a_g_n_c_wis_strides,
156 const ConvDimsType& b_g_k_c_xs_lengths,
157 const ConvDimsType& b_g_k_c_xs_strides,
158 const ConvDimsType& c_g_n_k_wos_lengths,
159 const ConvDimsType& c_g_n_k_wos_strides,
160 const ConvSpatialDimsType& conv_filter_strides,
161 const ConvSpatialDimsType& conv_filter_dilations,
162 const ConvSpatialDimsType& input_left_pads,
163 const ConvSpatialDimsType& input_right_pads)
164 : Di_{I1},
165 Hi_{I1},
166 Wi_{a_g_n_c_wis_lengths[I3]},
167 Do_{I1},
168 Ho_{I1},
169 Wo_{c_g_n_k_wos_lengths[I3]},
170 Z_{I1},
171 Y_{I1},
172 X_{b_g_k_c_xs_lengths[I3]},
173 K_{c_g_n_k_wos_lengths[I2]},
174 C_{b_g_k_c_xs_lengths[I2]},
175 DiStride_{I1},
176 HiStride_{I1},
177 WiStride_{a_g_n_c_wis_strides[I3]},
178 DoStride_{I1},
179 HoStride_{I1},
180 WoStride_{c_g_n_k_wos_strides[I3]},
181 XStride_{b_g_k_c_xs_strides[I3]},
182 CStrideTensorA_{a_g_n_c_wis_strides[I2]},
183 CStrideTensorB_{b_g_k_c_xs_strides[I2]},
184 KStrideTensorB_{b_g_k_c_xs_strides[I1]},
185 KStrideTensorC_{c_g_n_k_wos_strides[I2]},
186 NStrideTensorA_{a_g_n_c_wis_strides[I1]},
187 NStrideTensorC_{c_g_n_k_wos_strides[I1]},
188 GStrideTensorA_{a_g_n_c_wis_strides[I0]},
189 GStrideTensorB_{b_g_k_c_xs_strides[I0]},
190 GStrideTensorC_{c_g_n_k_wos_strides[I0]},
191 ConvStrideD_{I1},
192 ConvStrideH_{I1},
193 ConvStrideW_{conv_filter_strides[I0]},
194 ConvDilationD_{I1},
195 ConvDilationH_{I1},
196 ConvDilationW_{conv_filter_dilations[I0]},
197 InLeftPadD_{I0},
198 InLeftPadH_{I0},
199 InLeftPadW_{input_left_pads[I0]},
200 InRightPadD_{I0},
201 InRightPadH_{I0},
202 InRightPadW_{input_right_pads[I0]},
203 ZYX_{X_}
204 {
205#ifdef CK_CODE_GEN_RTC
208#else
213#endif
214 if constexpr(SplitN)
215 {
216 N_ = GetSplitedNSize(
217 a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
218 }
219 else
220 {
221 N_ = c_g_n_k_wos_lengths[I1];
222 }
223 }
224
225 template <typename ConvDimsType,
226 typename ConvSpatialDimsType,
227 index_t NDim = NDimSpatial,
229 __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
230 const ConvDimsType& a_g_n_c_wis_strides,
231 const ConvDimsType& b_g_k_c_xs_lengths,
232 const ConvDimsType& b_g_k_c_xs_strides,
233 const ConvDimsType& c_g_n_k_wos_lengths,
234 const ConvDimsType& c_g_n_k_wos_strides,
235 const ConvSpatialDimsType& conv_filter_strides,
236 const ConvSpatialDimsType& conv_filter_dilations,
237 const ConvSpatialDimsType& input_left_pads,
238 const ConvSpatialDimsType& input_right_pads)
239 : Di_{I1},
240 Hi_{a_g_n_c_wis_lengths[I3]},
241 Wi_{a_g_n_c_wis_lengths[I4]},
242 Do_{I1},
243 Ho_{c_g_n_k_wos_lengths[I3]},
244 Wo_{c_g_n_k_wos_lengths[I4]},
245 Z_{I1},
246 Y_{b_g_k_c_xs_lengths[I3]},
247 X_{b_g_k_c_xs_lengths[I4]},
248 K_{c_g_n_k_wos_lengths[I2]},
249 C_{b_g_k_c_xs_lengths[I2]},
250 DiStride_{I1},
251 HiStride_{a_g_n_c_wis_strides[I3]},
252 WiStride_{a_g_n_c_wis_strides[I4]},
253 DoStride_{I1},
254 HoStride_{c_g_n_k_wos_strides[I3]},
255 WoStride_{c_g_n_k_wos_strides[I4]},
256 XStride_{b_g_k_c_xs_strides[I4]},
257 CStrideTensorA_{a_g_n_c_wis_strides[I2]},
258 CStrideTensorB_{b_g_k_c_xs_strides[I2]},
259 KStrideTensorB_{b_g_k_c_xs_strides[I1]},
260 KStrideTensorC_{c_g_n_k_wos_strides[I2]},
261 NStrideTensorA_{a_g_n_c_wis_strides[I1]},
262 NStrideTensorC_{c_g_n_k_wos_strides[I1]},
263 GStrideTensorA_{a_g_n_c_wis_strides[I0]},
264 GStrideTensorB_{b_g_k_c_xs_strides[I0]},
265 GStrideTensorC_{c_g_n_k_wos_strides[I0]},
266 ConvStrideD_{I1},
267 ConvStrideH_{conv_filter_strides[I0]},
268 ConvStrideW_{conv_filter_strides[I1]},
269 ConvDilationD_{I1},
270 ConvDilationH_{conv_filter_dilations[I0]},
271 ConvDilationW_{conv_filter_dilations[I1]},
272 InLeftPadD_{I0},
273 InLeftPadH_{input_left_pads[I0]},
274 InLeftPadW_{input_left_pads[I1]},
275 InRightPadD_{I0},
276 InRightPadH_{input_right_pads[I0]},
277 InRightPadW_{input_right_pads[I1]},
278 ZYX_{Y_ * X_}
279 {
280#ifdef CK_CODE_GEN_RTC
283#else
288#endif
289 if constexpr(SplitN)
290 {
291 N_ = GetSplitedNSize(
292 a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
293 }
294 else
295 {
296 N_ = c_g_n_k_wos_lengths[I1];
297 }
298 }
299
300 template <typename ConvDimsType,
301 typename ConvSpatialDimsType,
302 index_t NDim = NDimSpatial,
304 __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
305 const ConvDimsType& a_g_n_c_wis_strides,
306 const ConvDimsType& b_g_k_c_xs_lengths,
307 const ConvDimsType& b_g_k_c_xs_strides,
308 const ConvDimsType& c_g_n_k_wos_lengths,
309 const ConvDimsType& c_g_n_k_wos_strides,
310 const ConvSpatialDimsType& conv_filter_strides,
311 const ConvSpatialDimsType& conv_filter_dilations,
312 const ConvSpatialDimsType& input_left_pads,
313 const ConvSpatialDimsType& input_right_pads)
314 : Di_{a_g_n_c_wis_lengths[I3]},
315 Hi_{a_g_n_c_wis_lengths[I4]},
316 Wi_{a_g_n_c_wis_lengths[I5]},
317 Do_{c_g_n_k_wos_lengths[I3]},
318 Ho_{c_g_n_k_wos_lengths[I4]},
319 Wo_{c_g_n_k_wos_lengths[I5]},
320 Z_{b_g_k_c_xs_lengths[I3]},
321 Y_{b_g_k_c_xs_lengths[I4]},
322 X_{b_g_k_c_xs_lengths[I5]},
323 K_{c_g_n_k_wos_lengths[I2]},
324 C_{b_g_k_c_xs_lengths[I2]},
325 DiStride_{a_g_n_c_wis_strides[I3]},
326 HiStride_{a_g_n_c_wis_strides[I4]},
327 WiStride_{a_g_n_c_wis_strides[I5]},
328 DoStride_{c_g_n_k_wos_strides[I3]},
329 HoStride_{c_g_n_k_wos_strides[I4]},
330 WoStride_{c_g_n_k_wos_strides[I5]},
331 XStride_{b_g_k_c_xs_strides[I5]},
332 CStrideTensorA_{a_g_n_c_wis_strides[I2]},
333 CStrideTensorB_{b_g_k_c_xs_strides[I2]},
334 KStrideTensorB_{b_g_k_c_xs_strides[I1]},
335 KStrideTensorC_{c_g_n_k_wos_strides[I2]},
336 NStrideTensorA_{a_g_n_c_wis_strides[I1]},
337 NStrideTensorC_{c_g_n_k_wos_strides[I1]},
338 GStrideTensorA_{a_g_n_c_wis_strides[I0]},
339 GStrideTensorB_{b_g_k_c_xs_strides[I0]},
340 GStrideTensorC_{c_g_n_k_wos_strides[I0]},
341 ConvStrideD_{conv_filter_strides[I0]},
342 ConvStrideH_{conv_filter_strides[I1]},
343 ConvStrideW_{conv_filter_strides[I2]},
344 ConvDilationD_{conv_filter_dilations[I0]},
345 ConvDilationH_{conv_filter_dilations[I1]},
346 ConvDilationW_{conv_filter_dilations[I2]},
347 InLeftPadD_{input_left_pads[I0]},
348 InLeftPadH_{input_left_pads[I1]},
349 InLeftPadW_{input_left_pads[I2]},
350 InRightPadD_{input_right_pads[I0]},
351 InRightPadH_{input_right_pads[I1]},
352 InRightPadW_{input_right_pads[I2]},
353 ZYX_{Z_ * Y_ * X_}
354 {
355#ifdef CK_CODE_GEN_RTC
358#else
363#endif
364 if constexpr(SplitN)
365 {
366 N_ = GetSplitedNSize(
367 a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
368 }
369 else
370 {
371 N_ = c_g_n_k_wos_lengths[I1];
372 }
373 }
374
375 __host__ bool AreDescriptorsSmallerThan2GB() const
376 {
377 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
378
379 const long_index_t in_desc_space_size =
380 I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
381 (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
382 const long_index_t out_desc_space_size =
383 I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
384 (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
385
386 bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
387 bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
388
389 return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
390 }
391
392 template <typename DsPointer>
393 __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
394 DsPointer& ds_grid_ptr_base,
395 CDataType* c_grid_ptr_base) const
396 {
397 // Create copies
398 auto conv_to_gemm_transformer_left = *this;
399 auto conv_to_gemm_transformer_right = *this;
400 IndexType a_right_offset = 0;
401 IndexType c_right_offset = 0;
402 // Calculate real filter size
403 const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
404 const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
405 const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
406 // Calculate start position in input for right tensor
407 const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
408 const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
409 const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
410 // Calculate last position in input for left tensor
411 const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
412 const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
413 const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
414 // Allow to split if whole left padding will be in left tensor and right padding in right
415 // tensor
416 const bool is_possible_to_split_d = Do_ != 1 &&
417 di_right_transformer_start_idx > InLeftPadD_ &&
418 di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
419 const bool is_possible_to_split_h = Ho_ != 1 &&
420 hi_right_transformer_start_idx > InLeftPadH_ &&
421 hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
422 const bool is_possible_to_split_w = Wo_ != 1 &&
423 wi_right_transformer_start_idx > InLeftPadW_ &&
424 wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
425
426 if(is_possible_to_split_d)
427 {
428 // Apply new sizes
429 // Split output on half
430 conv_to_gemm_transformer_left.Do_ = Do_ / 2;
431 conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
432 // Assign left padding to left convolution
433 conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
434 conv_to_gemm_transformer_right.InLeftPadD_ = 0;
435 // Assign right padding to right convolution
436 conv_to_gemm_transformer_left.InRightPadD_ = 0;
437 conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
438 // Calculate new input size
439 conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
440 conv_to_gemm_transformer_right.Di_ =
441 math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
442 (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
443 ;
444 // Calcualte offsets
445 a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
446 c_right_offset = (Do_ / 2) * DoStride_;
447 }
448 else if(is_possible_to_split_h)
449 {
450 conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
451 conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
452
453 conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
454 conv_to_gemm_transformer_right.InLeftPadH_ = 0;
455
456 conv_to_gemm_transformer_left.InRightPadH_ = 0;
457 conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
458
459 conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
460 conv_to_gemm_transformer_right.Hi_ =
461 math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
462 (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
463 a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
464 c_right_offset = (Ho_ / 2) * HoStride_;
465 }
466 else if(is_possible_to_split_w)
467 {
468 conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
469 conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
470
471 conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
472 conv_to_gemm_transformer_right.InLeftPadW_ = 0;
473
474 conv_to_gemm_transformer_left.InRightPadW_ = 0;
475 conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
476
477 conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
478 conv_to_gemm_transformer_right.Wi_ =
479 math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
480 (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
481
482 a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
483 c_right_offset = (Wo_ / 2) * WoStride_;
484 }
485
486 static constexpr index_t NumDTensor = DsPointer::Size();
487 const auto ds_grid_right_ptr = generate_tuple(
488 [&](auto i) { return ds_grid_ptr_base(i) + c_right_offset; }, Number<NumDTensor>{});
489
490 // Return left transform, right transformer, right offset to Input and right offset to
491 // Output
492 return ck::make_tuple(conv_to_gemm_transformer_left,
493 conv_to_gemm_transformer_right,
494 a_grid_ptr_base + a_right_offset,
495 ds_grid_right_ptr,
496 c_grid_ptr_base + c_right_offset);
497 }
498
499 // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
500 // properties
501 template <typename ALayout,
502 typename ck::enable_if<NDimSpatial == 1 &&
506 bool>::type = false>
507 __host__ __device__ auto MakeADescriptor_M_K() const
508 {
509 if constexpr(ConvForwardSpecialization ==
511 {
512 if constexpr(NumGroupsToMerge == 1)
513 {
514 const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
515 make_tuple(N_, Wo_, C_),
518 in_gemmm_gemmk_desc,
523 }
524 else
525 {
526 const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
527 make_tuple(N_, Wo_, NumGroupsToMerge, C_),
529
531 in_gemmm_groups_gemmk_desc,
532 make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
536 }
537 }
538 else if constexpr(ConvForwardSpecialization ==
540 {
541 if constexpr(NumGroupsToMerge == 1)
542 {
543
544 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
546
547 const auto in_n_wip_c_desc = transform_tensor_descriptor(
548 in_n_wi_c_desc,
553
554 const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
555 in_n_wip_c_desc,
561
563 in_n_x_wo_c_desc,
568 }
569 else
570 {
571 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
572 make_tuple(N_, Wi_, NumGroupsToMerge),
574
575 const auto in_n_wip_c_desc = transform_tensor_descriptor(
576 in_n_wi_c_desc,
579 make_pass_through_transform(NumGroupsToMerge)),
582
583 const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
584 in_n_wip_c_desc,
588 make_pass_through_transform(NumGroupsToMerge)),
591
593 in_n_x_wo_c_desc,
594 make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
598 }
599 }
600 else if constexpr(ConvForwardSpecialization ==
602 {
603 if constexpr(NumGroupsToMerge == 1)
604 {
605 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
606 make_tuple(N_, Wi_, C_),
608
609 const auto in_n_wo_c_desc = transform_tensor_descriptor(
610 in_n_wi_c_desc,
616
618 in_n_wo_c_desc,
623 }
624 else
625 {
626 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
627 make_tuple(N_, Wi_, NumGroupsToMerge, C_),
629
630 const auto in_n_wo_c_desc = transform_tensor_descriptor(
631 in_n_wi_c_desc,
634 make_pass_through_transform(NumGroupsToMerge),
638
640 in_n_wo_c_desc,
641 make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
645 }
646 }
647 else
648 {
649 if constexpr(NumGroupsToMerge == 1)
650 {
651 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
652 make_tuple(N_, Wi_, C_),
654
655 const auto in_n_wip_c_desc = transform_tensor_descriptor(
656 in_n_wi_c_desc,
662
663 const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
664 in_n_wip_c_desc,
671
673 in_n_x_wo_c_desc,
678 }
679 else
680 {
681 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
682 make_tuple(N_, Wi_, NumGroupsToMerge, C_),
684
685 const auto in_n_wip_c_desc = transform_tensor_descriptor(
686 in_n_wi_c_desc,
689 make_pass_through_transform(NumGroupsToMerge),
693
694 const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
695 in_n_wip_c_desc,
699 make_pass_through_transform(NumGroupsToMerge),
703
705 in_n_x_wo_c_desc,
706 make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
710 }
711 }
712 }
713
714 template <typename ALayout,
715 typename ck::enable_if<NDimSpatial == 2 &&
719 bool>::type = false>
720 __host__ __device__ auto MakeADescriptor_M_K() const
721
722 {
723 if constexpr(ConvForwardSpecialization ==
725 {
726 if constexpr(NumGroupsToMerge == 1)
727 {
728 const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
729 make_tuple(N_, Ho_, Wo_, C_),
731
733 in_gemmm_gemmk_desc,
738 }
739 else
740 {
741 const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
742 make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
745
747 in_gemmm_groups_gemmk_desc,
748 make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
752 }
753 }
754 else if constexpr(ConvForwardSpecialization ==
756 {
757 if constexpr(NumGroupsToMerge == 1)
758 {
759 const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
761
762 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
763 in_n_hi_wi_c_desc,
769
770 const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
771 in_n_hip_wip_c_desc,
779
781 in_n_y_ho_x_wo_c_desc,
786 }
787 else
788 {
789 const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
790 make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
792
793 const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
794 in_n_hi_wi_groups_c_desc,
798 make_pass_through_transform(NumGroupsToMerge)),
801
802 const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
803 in_n_hip_wip_groups_c_desc,
809 make_pass_through_transform(NumGroupsToMerge)),
812
814 in_n_y_ho_x_wo_groups_c_desc,
815 make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
819 }
820 }
821 else if constexpr(ConvForwardSpecialization ==
823 {
824 if constexpr(NumGroupsToMerge == 1)
825 {
826 const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
827 make_tuple(N_, Hi_, Wi_, C_),
829
830 const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
831 in_n_hi_wi_c_desc,
838
840 in_n_ho_wo_c_desc,
845 }
846 else
847 {
848 const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
849 make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
852
853 const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor(
854 in_n_hi_wi_groups_c_desc,
858 make_pass_through_transform(NumGroupsToMerge),
864
866 in_n_ho_wo_groups_c_desc,
867 make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
871 }
872 }
873 else
874 {
875 if constexpr(NumGroupsToMerge == 1)
876 {
877 const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
878 make_tuple(N_, Hi_, Wi_, C_),
880
881 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
882 in_n_hi_wi_c_desc,
889
890 const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
891 in_n_hip_wip_c_desc,
900
902 in_n_y_ho_x_wo_c_desc,
907 }
908 else
909 {
910
911 const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
912 make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
915
916 const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
917 in_n_hi_wi_groups_c_desc,
921 make_pass_through_transform(NumGroupsToMerge),
927
928 const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
929 in_n_hip_wip_groups_c_desc,
935 make_pass_through_transform(NumGroupsToMerge),
942 Sequence<5>{},
943 Sequence<6>{}));
944
946 in_n_y_ho_x_wo_groups_c_desc,
947 make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
951 }
952 }
953 }
954
955 template <typename ALayout,
956 typename ck::enable_if<
960 bool>::type = false>
961 __host__ __device__ auto MakeADescriptor_M_K() const
962
963 {
964 if constexpr(ConvForwardSpecialization ==
966 {
967 if constexpr(NumGroupsToMerge == 1)
968 {
969 const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
970 make_tuple(N_, Do_, Ho_, Wo_, C_),
972
974 in_gemmm_gemmk_desc,
979 }
980 else
981 {
982 const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
983 make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_),
985 DiStride_,
986 HiStride_,
987 WiStride_,
990
992 in_gemmm_groups_gemmk_desc,
994 make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
998 }
999 }
1000 else if constexpr(ConvForwardSpecialization ==
1002 {
1003 if constexpr(NumGroupsToMerge == 1)
1004 {
1005 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1006 make_tuple(N_, Di_, Hi_, Wi_),
1008
1009 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
1010 in_n_di_hi_wi_c_desc,
1017
1018 const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
1019 in_n_hip_wip_c_desc,
1028 make_tuple(
1030
1032 in_n_z_do_y_ho_x_wo_c_desc,
1033 make_tuple(
1038 }
1039 else
1040 {
1041 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1042 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
1044
1045 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
1046 in_n_di_hi_wi_c_desc,
1051 make_pass_through_transform(NumGroupsToMerge)),
1052 make_tuple(
1054 make_tuple(
1056
1057 const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
1058 in_n_hip_wip_c_desc,
1066 make_pass_through_transform(NumGroupsToMerge)),
1067 make_tuple(
1073 Sequence<7>{}));
1074
1076 in_n_z_do_y_ho_x_wo_c_desc,
1077 make_tuple(
1078 make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
1082 }
1083 }
1084 else if constexpr(ConvForwardSpecialization ==
1086 {
1087 if constexpr(NumGroupsToMerge == 1)
1088 {
1089 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1090 make_tuple(N_, Di_, Hi_, Wi_, C_),
1092
1093 const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
1094 in_n_di_hi_wi_c_desc,
1100 make_tuple(
1102 make_tuple(
1104
1106 in_n_do_ho_wo_c_desc,
1111 }
1112 else
1113 {
1114 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1115 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
1117 DiStride_,
1118 HiStride_,
1119 WiStride_,
1122
1123 const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
1124 in_n_di_hi_wi_c_desc,
1129 make_pass_through_transform(NumGroupsToMerge),
1132 Sequence<1>{},
1133 Sequence<2>{},
1134 Sequence<3>{},
1135 Sequence<4>{},
1136 Sequence<5>{}),
1138 Sequence<1>{},
1139 Sequence<2>{},
1140 Sequence<3>{},
1141 Sequence<4>{},
1142 Sequence<5>{}));
1143
1145 in_n_do_ho_wo_c_desc,
1146 make_tuple(
1147 make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
1151 }
1152 }
1153 else
1154 {
1155 if constexpr(NumGroupsToMerge == 1)
1156 {
1157 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1158 make_tuple(N_, Di_, Hi_, Wi_, C_),
1160
1161 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
1162 in_n_di_hi_wi_c_desc,
1168 make_tuple(
1170 make_tuple(
1172
1173 const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
1174 in_n_hip_wip_c_desc,
1183 make_tuple(
1189 Sequence<7>{}));
1190
1192 in_n_z_do_y_ho_x_wo_c_desc,
1197 }
1198 else
1199 {
1200 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1201 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
1203 DiStride_,
1204 HiStride_,
1205 WiStride_,
1208
1209 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
1210 in_n_di_hi_wi_c_desc,
1215 make_pass_through_transform(NumGroupsToMerge),
1218 Sequence<1>{},
1219 Sequence<2>{},
1220 Sequence<3>{},
1221 Sequence<4>{},
1222 Sequence<5>{}),
1224 Sequence<1>{},
1225 Sequence<2>{},
1226 Sequence<3>{},
1227 Sequence<4>{},
1228 Sequence<5>{}));
1229
1230 const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
1231 in_n_hip_wip_c_desc,
1239 make_pass_through_transform(NumGroupsToMerge),
1242 Sequence<1>{},
1243 Sequence<2>{},
1244 Sequence<3>{},
1245 Sequence<4>{},
1246 Sequence<5>{}),
1251 Sequence<7>{},
1252 Sequence<8>{}));
1253
1255 in_n_z_do_y_ho_x_wo_c_desc,
1256 make_tuple(
1257 make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
1261 }
1262 }
1263 }
1264
1265 template <typename ALayout,
1266 typename ck::enable_if<NDimSpatial == 1 &&
1268 bool>::type = false>
1269 __host__ __device__ auto MakeADescriptor_M_K() const
1270 {
1271 static_assert(NumGroupsToMerge == 1);
1272 static_assert(ConvForwardSpecialization ==
1274
1275 const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
1277
1279 in_gemmm_gemmk_desc,
1283 }
1284
1285 template <typename ALayout,
1286 typename ck::enable_if<NDimSpatial == 2 &&
1288 bool>::type = false>
1289 __host__ __device__ auto MakeADescriptor_M_K() const
1290 {
1291 static_assert(NumGroupsToMerge == 1);
1292 static_assert(ConvForwardSpecialization ==
1294
1295 const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
1297
1299 in_gemmm_gemmk_desc,
1304 }
1305
1306 template <typename ALayout,
1307 typename ck::enable_if<NDimSpatial == 3 &&
1309 bool>::type = false>
1310 __host__ __device__ auto MakeADescriptor_M_K() const
1311 {
1312 static_assert(NumGroupsToMerge == 1);
1313 static_assert(ConvForwardSpecialization ==
1315
1316 const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
1318
1320 in_gemmm_gemmk_desc,
1325 }
1326
1327 template <typename BLayout,
1331 bool>::type = false>
1332 __host__ __device__ auto MakeBDescriptor_N_K() const
1333 {
1334 static_assert(ConvForwardSpecialization ==
1336 ConvForwardSpecialization ==
1338 static_assert(NumGroupsToMerge == 1);
1340 }
1341
1342 template <typename BLayout,
1346 bool>::type = false>
1347 __host__ __device__ auto MakeBDescriptor_N_K() const
1348 {
1349 if constexpr(ConvForwardSpecialization ==
1351 {
1352 using FilterSizeNumType =
1353 ck::conditional_t<NDimSpatial == 1,
1354 Number<3>,
1356
1357 if constexpr(NumGroupsToMerge == 1)
1358 {
1359 return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{}));
1360 }
1361 else
1362 {
1363
1364 const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
1365 make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
1368 wei_gemmn_groups_gemmk_desc,
1369 make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
1370 make_pass_through_transform(FilterSizeNumType{})),
1373 }
1374 }
1375 else
1376 {
1377 if constexpr(NumGroupsToMerge == 1)
1378 {
1380 }
1381 else
1382 {
1383 const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
1384 make_tuple(K_, NumGroupsToMerge, ZYX_ * C_),
1387 wei_gemmn_groups_gemmk_desc,
1388 make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
1392 }
1393 }
1394 }
1395
1396 template <
1397 typename BLayout,
1404 bool>::type = false>
1405 __host__ __device__ auto MakeBDescriptor_N_K() const
1406 {
1407 const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
1409
1410 const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor(
1411 wei_k_yx_c_desc,
1415
1416 return wei_gemmn_gemmk_desc;
1417 }
1418
1419 template <
1420 typename CLayout,
1421 index_t NDimSp = NDimSpatial,
1422
1424 bool>::type = false>
1425 __host__ __device__ auto MakeCDescriptor_M_N() const
1426 {
1427 if constexpr(CTranspose)
1428 {
1431 }
1432 else
1433 {
1436 }
1437 }
1438
1439 template <
1440 typename CLayout,
1441 index_t NDimSp = NDimSpatial,
1442
1444 bool>::type = false>
1445 __host__ __device__ auto MakeCDescriptor_M_N() const
1446 {
1447 if constexpr(CTranspose)
1448 {
1451 }
1452 else
1453 {
1456 }
1457 }
1458
1459 template <
1460 typename CLayout,
1461 index_t NDimSp = NDimSpatial,
1462
1464 bool>::type = false>
1465 __host__ __device__ auto MakeCDescriptor_M_N() const
1466 {
1467 if constexpr(CTranspose)
1468 {
1471 }
1472 else
1473 {
1476 }
1477 }
1478
1479 template <typename CLayout,
1480 index_t NDimSp = NDimSpatial,
1481
1482 typename ck::enable_if<NDimSp == 1 &&
1486 bool>::type = false>
1487 __host__ __device__ auto MakeCDescriptor_M_N() const
1488 {
1489 static_assert(CTranspose == false);
1490 const IndexType NDoHoWo = N_ * Wo_;
1491 if constexpr(NumGroupsToMerge == 1)
1492 {
1495 }
1496 else
1497 {
1498 const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
1499 make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
1500 make_tuple(
1502 // Padd 1 to NumGroupsToMerge
1503 const auto padded_desc = transform_tensor_descriptor(
1504 nhwo_groups_k_1_desc,
1506 make_pass_through_transform(NumGroupsToMerge),
1508 make_pad_transform(1, 0, NumGroupsToMerge - 1)),
1511 // We need only matrices from diagonal. X_or returns 0 for the same
1512 // values. So if matrices is not on diagonal then it will be stored in padding.
1513 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
1514 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
1515 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
1516 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
1517 const auto unmerged_padded_desc = transform_tensor_descriptor(
1518 padded_desc,
1520 make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
1524 // Merge To M, N
1526 unmerged_padded_desc,
1527 make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
1528 make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
1531 }
1532 }
1533
1534 template <typename CLayout,
1535 index_t NDimSp = NDimSpatial,
1536
1537 typename ck::enable_if<NDimSp == 2 &&
1541 bool>::type = false>
1542 __host__ __device__ auto MakeCDescriptor_M_N() const
1543 {
1544 static_assert(CTranspose == false);
1545 const IndexType NDoHoWo = N_ * Ho_ * Wo_;
1546 if constexpr(NumGroupsToMerge == 1)
1547 {
1550 }
1551 else
1552 {
1553 const auto nhwo_groups_k_1_desc =
1554 make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
1556 HoStride_,
1557 WoStride_,
1561 // Padd 1 to NumGroupsToMerge
1562 const auto padded_desc = transform_tensor_descriptor(
1563 nhwo_groups_k_1_desc,
1565 make_pass_through_transform(NumGroupsToMerge),
1567 make_pad_transform(1, 0, NumGroupsToMerge - 1)),
1570 // We need only matrices from diagonal. X_or returns 0 for the same
1571 // values. So if matrices is not on diagonal then it will be stored in padding.
1572 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
1573 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
1574 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
1575 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
1576 const auto unmerged_padded_desc = transform_tensor_descriptor(
1577 padded_desc,
1579 make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
1583 // Merge To M, N
1585 unmerged_padded_desc,
1586 make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
1587 make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
1590 }
1591 }
1592
1593 template <typename CLayout,
1594 index_t NDimSp = NDimSpatial,
1595 typename ck::enable_if<
1599 bool>::type = false>
1600 __host__ __device__ auto MakeCDescriptor_M_N() const
1601 {
1602 static_assert(CTranspose == false);
1603 const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
1604 if constexpr(NumGroupsToMerge == 1)
1605 {
1608 }
1609 else
1610 {
1611 const auto nhwo_groups_k_1_desc =
1612 make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
1614 DoStride_,
1615 HoStride_,
1616 WoStride_,
1620 // Padd 1 to NumGroupsToMerge
1621 const auto padded_desc = transform_tensor_descriptor(
1622 nhwo_groups_k_1_desc,
1624 make_pass_through_transform(NumGroupsToMerge),
1626 make_pad_transform(1, 0, NumGroupsToMerge - 1)),
1629 // We need only matrices from diagonal. X_or returns 0 for the same
1630 // values. So if matrices is not on diagonal then it will be stored in padding.
1631 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
1632 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
1633 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
1634 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
1635 const auto unmerged_padded_desc = transform_tensor_descriptor(
1636 padded_desc,
1638 make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
1642 // Merge To M, N
1644 unmerged_padded_desc,
1645 make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
1646 make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
1649 }
1650 }
1651
1652 template <typename CLayout,
1653 index_t NDimSp = NDimSpatial,
1654
1655 typename ck::enable_if<NDimSp == 1 &&
1658 bool>::type = false>
1659 __host__ __device__ auto MakeCDescriptor_M_N() const
1660 {
1661 static_assert(NumGroupsToMerge == 1);
1662 auto n_k_wo_desc = make_naive_tensor_descriptor(
1664 if constexpr(CTranspose)
1665 {
1667 n_k_wo_desc,
1672 }
1673 else
1674 {
1675 return transform_tensor_descriptor(n_k_wo_desc,
1680 }
1681 }
1682
1683 template <typename CLayout,
1684 index_t NDimSp = NDimSpatial,
1685
1686 typename ck::enable_if<NDimSp == 2 &&
1689 bool>::type = false>
1690 __host__ __device__ auto MakeCDescriptor_M_N() const
1691 {
1692 static_assert(NumGroupsToMerge == 1);
1693 auto n_k_howo_desc = make_naive_tensor_descriptor(
1695 if constexpr(CTranspose)
1696 {
1698 n_k_howo_desc,
1703 }
1704 else
1705 {
1707 n_k_howo_desc,
1712 }
1713 }
1714
1715 template <typename CLayout,
1716 index_t NDimSp = NDimSpatial,
1717
1718 typename ck::enable_if<NDimSp == 3 &&
1721 bool>::type = false>
1722 __host__ __device__ auto MakeCDescriptor_M_N() const
1723 {
1724 static_assert(NumGroupsToMerge == 1);
1725 auto n_k_dohowo_desc = make_naive_tensor_descriptor(
1727
1728 if constexpr(CTranspose)
1729 {
1731 n_k_dohowo_desc,
1736 }
1737 else
1738 {
1740 n_k_dohowo_desc,
1745 }
1746 }
1747 IndexType N_;
1748 IndexType Di_, Hi_, Wi_;
1749 IndexType Do_, Ho_, Wo_;
1750 IndexType Z_, Y_, X_;
1751 IndexType K_, C_;
1754 IndexType XStride_;
1762 IndexType ZYX_;
1763};
1764
1765// wrapper class to call member functions on TransformConvToGemm struct at runtime
1766// TODO: figure out aq way to properly pass in layout as an argument
1768{
1770
1771 template <index_t NDimSpatial,
1772 device::ConvolutionForwardSpecialization ConvForwardSpecialization>
1773 auto
1775 {
1776 if(NDimSpatial == 2)
1777 {
1778 return conv_fwd_to_gemm
1779 .template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK, 2>();
1780 }
1781 else if(NDimSpatial == 3)
1782 {
1783 return conv_fwd_to_gemm
1784 .template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK, 3>();
1785 }
1786 else if(NDimSpatial == 1)
1787 {
1788 return conv_fwd_to_gemm
1789 .template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK, 1>();
1790 }
1791 }
1792};
1793
1794} // namespace tensor_operation
1795} // namespace ck
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter3x3
Definition convolution_forward_specialization.hpp:20
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:191
Definition utility/sequence.hpp:43
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
index_t Di_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1748
__host__ __device__ auto MakeBDescriptor_N_K() const
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1332
__host__ __device__ constexpr TransformConvFwdToGemm()
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:101
index_t InRightPadH_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1761
index_t KStrideTensorB_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1755
index_t Do_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1749
index_t InRightPadW_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1761
index_t CStrideTensorA_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1755
index_t NStrideTensorC_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1756
index_t ConvDilationH_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1759
index_t ConvStrideH_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1758
index_t ZYX_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1762
index_t GStrideTensorA_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1757
__host__ bool AreDescriptorsSmallerThan2GB() const
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:375
index_t Y_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1750
index_t InLeftPadW_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1760
index_t K_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1751
index_t GStrideTensorB_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1757
index_t ConvStrideW_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1758
index_t C_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1751
index_t N_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1747
index_t InLeftPadH_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1760
index_t DiStride_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1752
__host__ auto SplitConvProblem(const ADataType *a_grid_ptr_base, DsPointer &ds_grid_ptr_base, CDataType *c_grid_ptr_base) const
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:393
index_t WiStride_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1752
__host__ __device__ auto MakeADescriptor_M_K() const
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:507
index_t ConvStrideD_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1758
index_t Z_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1750
index_t DoStride_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1753
index_t HiStride_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1752
index_t KStrideTensorC_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1755
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType &a_g_n_c_wis_lengths, const ConvDimsType &a_g_n_c_wis_strides, const ConvDimsType &b_g_k_c_xs_lengths, const ConvDimsType &b_g_k_c_xs_strides, const ConvDimsType &c_g_n_k_wos_lengths, const ConvDimsType &c_g_n_k_wos_strides, const ConvSpatialDimsType &conv_filter_strides, const ConvSpatialDimsType &conv_filter_dilations, const ConvSpatialDimsType &input_left_pads, const ConvSpatialDimsType &input_right_pads)
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:154
index_t X_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1750
index_t NStrideTensorA_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1756
__host__ __device__ auto MakeCDescriptor_M_N() const
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1425
index_t Wi_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1748
index_t ConvDilationD_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1759
index_t Hi_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1748
index_t Ho_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1749
__host__ __device__ TransformConvFwdToGemm(const TransformConvFwdToGemmBase &transform_conv_fwd_to_gemm_base)
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:105
index_t HoStride_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1753
index_t Wo_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1749
index_t GStrideTensorC_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1757
index_t WoStride_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1753
index_t ConvDilationW_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1759
index_t InRightPadD_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1761
index_t CStrideTensorB_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1755
index_t XStride_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1754
index_t InLeftPadD_
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1760
TransformConv()
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1769
auto transform_func(TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > conv_fwd_to_gemm)
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1774