device_avgpool3d_bwd_ndhwc_ndhwc.hpp Source File

device_avgpool3d_bwd_ndhwc_ndhwc.hpp Source File#

Composable Kernel: device_avgpool3d_bwd_ndhwc_ndhwc.hpp Source File
device_avgpool3d_bwd_ndhwc_ndhwc.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
17
18namespace ck {
19namespace tensor_operation {
20namespace device {
21
22// In and Din = [N, C, Di, Hi, Wi]
23// Out and Dout = [N, C, Do, Ho, Wo]
24// Out = AvgPoolFwd(In)
25// Din = AvgPoolBwd(Dout)
26// Pooling dimension = D, H, W
27template <typename DOutDataType,
28 typename DInDataType,
29 typename ComputeDataType,
30 ck::index_t BlockSize,
31 ck::index_t MThreadClusterSize,
32 ck::index_t KThreadClusterSize,
33 ck::index_t MThreadSliceSize,
34 ck::index_t KThreadSliceSize,
35 ck::index_t InSrcOutDstVectorSize>
37 DOutDataType,
38 DInDataType,
39 tensor_layout::convolution::NDHWC,
40 tensor_layout::convolution::NDHWC>
41{
42 static constexpr ck::index_t NDimSpatial = 3;
43
44 static constexpr auto I0 = Number<0>{};
45 static constexpr auto I1 = Number<1>{};
46
47 static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
48 static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
49
50 static auto
51 Make3DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
52 const std::vector<ck::index_t>& din_n_c_wos_length,
53 const std::vector<ck::index_t>& dout_n_c_wos_strides,
54 const std::vector<ck::index_t>& din_n_c_wos_strides,
55 const std::vector<ck::index_t>& window_lengths,
56 const std::vector<ck::index_t>& window_strides,
57 const std::vector<ck::index_t>& window_dilations,
58 const std::vector<ck::index_t>& input_left_pads,
59 const std::vector<ck::index_t>& input_right_pads,
60 const std::vector<ck::index_t>& tildes)
61 {
62 index_t i_ztilde = tildes[0];
63 index_t i_ytilde = tildes[1];
64 index_t i_xtilde = tildes[2];
65
66 const index_t N = dout_n_c_wos_lengths[0];
67 const index_t C = dout_n_c_wos_lengths[1];
68
69 const index_t Di = din_n_c_wos_length[2];
70 const index_t Hi = din_n_c_wos_length[3];
71 const index_t Wi = din_n_c_wos_length[4];
72
73 const index_t Do = dout_n_c_wos_lengths[2];
74 const index_t Ho = dout_n_c_wos_lengths[3];
75 const index_t Wo = dout_n_c_wos_lengths[4];
76
77 const index_t Z = window_lengths[0];
78 const index_t Y = window_lengths[1];
79 const index_t X = window_lengths[2];
80
81 const index_t InLeftPadD = input_left_pads[0];
82 const index_t InLeftPadH = input_left_pads[1];
83 const index_t InLeftPadW = input_left_pads[2];
84
85 const index_t InRightPadD = input_right_pads[0];
86 const index_t InRightPadH = input_right_pads[1];
87 const index_t InRightPadW = input_right_pads[2];
88
89 const index_t ConvStrideD = window_strides[0];
90 const index_t ConvStrideH = window_strides[1];
91 const index_t ConvStrideW = window_strides[2];
92
93 const index_t ConvDilationD = window_dilations[0];
94 const index_t ConvDilationH = window_dilations[1];
95 const index_t ConvDilationW = window_dilations[2];
96
97 const auto out_n_do_ho_wo_c_grid_desc =
99 make_tuple(dout_n_c_wos_strides[0],
100 dout_n_c_wos_strides[2],
101 dout_n_c_wos_strides[3],
102 dout_n_c_wos_strides[4],
103 dout_n_c_wos_strides[1]));
104
105 const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
106 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
107 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
108
109 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
110 const auto YTilde = ConvStrideH / GcdStrideDilationH;
111 const auto XTilde = ConvStrideW / GcdStrideDilationW;
112
113 const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
114 const auto YDot = math::integer_divide_ceil(Y, YTilde);
115 const auto XDot = math::integer_divide_ceil(X, XTilde);
116
117 const auto DTilde = Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
118 const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
119 const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
120
121 // only work on Tildes that contribute to non-padding area of input tensor
122 const auto IDTildeSliceBegin = math::integer_divide_floor(
123 math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
124 const auto IHTildeSliceBegin = math::integer_divide_floor(
125 math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
126 const auto IWTildeSliceBegin = math::integer_divide_floor(
127 math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
128
129 const auto IDTildeSliceEnd =
130 math::min(DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
131 const auto IHTildeSliceEnd =
132 math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
133 const auto IWTildeSliceEnd =
134 math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
135
136 const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
137 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
138 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
139
140 // ReduceK is different for each Reduce
141 const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
142 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
143 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
144
145 // Problem size of reduction kernel
146 const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C;
147 const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
148
149 const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice;
150 const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
151
152 // Out[ReduceM, ReduceK]
153 const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor(
154 out_n_do_ho_wo_c_grid_desc,
162
163 const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc =
165 out_n_dop_hop_wop_c_grid_desc,
168 make_embed_transform(make_tuple(ZDot, DTilde),
169 make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
170 make_embed_transform(make_tuple(YDot, HTilde),
171 make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
172 make_embed_transform(make_tuple(XDot, WTilde),
173 make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
181 Sequence<7>{}));
182
183 const auto
184 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
186 out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc,
188 make_slice_transform(ZDot, I0, ZDotSlice),
189 make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
190 make_slice_transform(YDot, I0, YDotSlice),
191 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
192 make_slice_transform(XDot, I0, XDotSlice),
193 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
196 Sequence<1>{},
197 Sequence<2>{},
198 Sequence<3>{},
199 Sequence<4>{},
200 Sequence<5>{},
201 Sequence<6>{},
202 Sequence<7>{}),
204 Sequence<1>{},
205 Sequence<2>{},
206 Sequence<3>{},
207 Sequence<4>{},
208 Sequence<5>{},
209 Sequence<6>{},
210 Sequence<7>{}));
211
212 const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
213 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
215 make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)),
216 make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))),
219
220 const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
221 out_grid_desc_reducemraw_reducekraw,
225
226 // In[ReduceM]
227 const auto in_n_di_hi_wi_c_grid_desc =
229 make_tuple(din_n_c_wos_strides[0],
230 din_n_c_wos_strides[2],
231 din_n_c_wos_strides[3],
232 din_n_c_wos_strides[4],
233 din_n_c_wos_strides[1]));
234
235 const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
236 in_n_di_hi_wi_c_grid_desc,
238 make_pad_transform(Di, InLeftPadD, InRightPadD),
239 make_pad_transform(Hi, InLeftPadH, InRightPadH),
240 make_pad_transform(Wi, InLeftPadW, InRightPadW),
244
245 const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
247 in_n_dip_hip_wip_c_grid_desc,
249 make_embed_transform(make_tuple(XTilde, DTilde),
250 make_tuple(ConvDilationD, ConvStrideD)),
251 make_embed_transform(make_tuple(YTilde, HTilde),
252 make_tuple(ConvDilationH, ConvStrideH)),
253 make_embed_transform(make_tuple(XTilde, WTilde),
254 make_tuple(ConvDilationW, ConvStrideW)),
262 Sequence<7>{}));
263
264 const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
266 in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
268 make_freeze_transform(i_ztilde),
269 make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
270 make_freeze_transform(i_ytilde),
271 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
272 make_freeze_transform(i_xtilde),
273 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
276 Sequence<1>{},
277 Sequence<2>{},
278 Sequence<3>{},
279 Sequence<4>{},
280 Sequence<5>{},
281 Sequence<6>{},
282 Sequence<7>{}),
284 Sequence<>{},
285 Sequence<1>{},
286 Sequence<>{},
287 Sequence<2>{},
288 Sequence<>{},
289 Sequence<3>{},
290 Sequence<4>{}));
291
292 const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
293 in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
295 make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))),
298
299 const auto in_grid_desc_reducem =
300 transform_tensor_descriptor(in_grid_desc_reducemraw,
304
305 return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
306 }
307
308 using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0},
309 {0, 0, 0, 0, 0},
310 {0, 0, 0, 0, 0},
311 {0, 0, 0, 0, 0},
312 {0, 0, 0},
313 {0, 0, 0},
314 {0, 0, 0},
315 {0, 0, 0},
316 {0, 0, 0},
317 {0, 0, 0}));
318
321
322 // FIXME
323 // for NDHWC, the dim C is the fastest dimension, and is not reduced.
324 // Hence, it is in M dimension for reduction kernel.
325 static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
326
329
331 DInDataType,
332 ComputeDataType,
333 int,
338 Div,
340 false, // propagate_nan
341 BlockSize,
342 MThreadSliceSize,
343 KThreadSliceSize,
345 InSrcOutDstVectorSize,
346 InSrcOutDstVectorSize>;
347
348 struct Argument : public BaseArgument
349 {
350 Argument(const DOutDataType* p_dout,
351 DInDataType* p_din,
352 std::vector<ck::index_t> dout_n_c_wos_lengths,
353 std::vector<ck::index_t> din_n_c_wos_length,
354 std::vector<ck::index_t> dout_n_c_wos_strides,
355 std::vector<ck::index_t> din_n_c_wos_strides,
356 std::vector<ck::index_t> window_lengths,
357 std::vector<ck::index_t> window_strides,
358 std::vector<ck::index_t> window_dilations,
359 std::vector<ck::index_t> input_left_pads,
360 std::vector<ck::index_t> input_right_pads)
361 : p_dout_grid_{p_dout},
362 p_din_grid_{p_din},
363 dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
364 din_n_c_wos_length_{din_n_c_wos_length},
365 dout_n_c_wos_strides_{dout_n_c_wos_strides},
366 din_n_c_wos_strides_{din_n_c_wos_strides},
367 num_reduce_{1},
368 div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]}
369 {
370 std::vector<ck::index_t> Tildes(NDimSpatial);
371 for(int i = 0; i < NDimSpatial; ++i)
372 {
373 int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
374 Tildes[i] = window_strides[i] / GcdStrideDilation;
375 num_reduce_ *= Tildes[i];
376 }
377
378 for(index_t i_ztilde = 0; i_ztilde < Tildes[0]; ++i_ztilde)
379 {
380 for(index_t i_ytilde = 0; i_ytilde < Tildes[1]; ++i_ytilde)
381 {
382 for(index_t i_xtilde = 0; i_xtilde < Tildes[2]; ++i_xtilde)
383 {
384 // check slice is valid
385 const auto ZDotSlice =
386 math::integer_divide_ceil(window_lengths[0] - i_ztilde, Tildes[0]);
387 const auto YDotSlice =
388 math::integer_divide_ceil(window_lengths[1] - i_ytilde, Tildes[1]);
389 const auto XDotSlice =
390 math::integer_divide_ceil(window_lengths[2] - i_xtilde, Tildes[2]);
391
392 if(ZDotSlice * YDotSlice * XDotSlice <= 0)
393 {
394 continue;
395 }
396
397 const auto dout_din_grid_desc =
398 Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
399 din_n_c_wos_length,
400 dout_n_c_wos_strides,
401 din_n_c_wos_strides,
402 window_lengths,
403 window_strides,
404 window_dilations,
405 input_left_pads,
406 input_right_pads,
407 {i_ztilde, i_ytilde, i_xtilde});
408
409 dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
410 din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
411 }
412 }
413 }
414 }
415
416 const DOutDataType* p_dout_grid_;
417 DInDataType* p_din_grid_;
418 std::vector<ck::index_t> dout_n_c_wos_lengths_;
419 std::vector<ck::index_t> din_n_c_wos_length_;
420 std::vector<ck::index_t> dout_n_c_wos_strides_;
421 std::vector<ck::index_t> din_n_c_wos_strides_;
422
424 std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
425 std::vector<DinGridDesc_M> din_grid_desc_m_container_;
426
428 };
429
430 struct Invoker : public BaseInvoker
431 {
432 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
433 {
434 float ave_time = 0;
435
436 for(index_t i = 0; i < arg.num_reduce_; i++)
437 {
438 const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
439 false,
440 false,
441 false, // don't have index input
442 DOutDataType,
443 DInDataType,
444 ComputeDataType,
445 int,
449 Div>;
450
451 ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
452 const index_t grid_size = (M / M_BlockTileSize);
453
454 ave_time += launch_and_time_kernel(stream_config,
455 kernel,
456 dim3(grid_size),
457 dim3(BlockSize),
458 0,
461 PassThrough{},
462 arg.div_element_op_,
463 float(1),
464 arg.p_dout_grid_,
465 nullptr,
466 float(0),
467 arg.p_din_grid_,
468 nullptr);
469 }
470
471 return ave_time;
472 }
473
474 float Run(const BaseArgument* p_arg,
475 const StreamConfig& stream_config = StreamConfig{}) override
476 {
477 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
478 }
479 };
480
481 static bool IsSupportedArgument(const Argument& arg)
482 {
483 constexpr index_t Rank = NDimSpatial + 2;
484 int doutFastestDim = -1;
485 int dinFastestDim = -1;
486
487 for(int i = 0; i < Rank; ++i)
488 {
489 if(arg.dout_n_c_wos_strides_[i] == 1)
490 doutFastestDim = i;
491 if(arg.din_n_c_wos_strides_[i] == 1)
492 dinFastestDim = i;
493 }
494
495 if(doutFastestDim == -1 || dinFastestDim == -1)
496 {
497 if constexpr(InSrcOutDstVectorSize != 1)
498 return false;
499 }
500 else
501 {
502 if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
503 return false;
504 if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
505 return false;
506 }
507
508 return true;
509 }
510
511 bool IsSupportedArgument(const BaseArgument* p_arg) override
512 {
513 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
514 }
515
516 std::unique_ptr<BaseArgument>
517 MakeArgumentPointer(const void* p_dout,
518 void* p_din,
519 std::vector<ck::index_t> dout_n_c_wos_lengths,
520 std::vector<ck::index_t> din_n_c_wos_length,
521 std::vector<ck::index_t> dout_n_c_wos_strides,
522 std::vector<ck::index_t> din_n_c_wos_strides,
523 std::vector<ck::index_t> window_lengths,
524 std::vector<ck::index_t> window_strides,
525 std::vector<ck::index_t> window_dilations,
526 std::vector<ck::index_t> input_left_pads,
527 std::vector<ck::index_t> input_right_pads) override
528 {
529 constexpr index_t Rank = NDimSpatial + 2;
530
531 if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
532 dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
533 throw std::runtime_error("dimension is incorrect");
534
535 if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
536 window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
537 input_right_pads.size() != NDimSpatial)
538 throw std::runtime_error("dimension is incorrect");
539
540 return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
541 static_cast<DInDataType*>(p_din),
542 dout_n_c_wos_lengths,
543 din_n_c_wos_length,
544 dout_n_c_wos_strides,
545 din_n_c_wos_strides,
546 window_lengths,
547 window_strides,
548 window_dilations,
549 input_left_pads,
550 input_right_pads);
551 }
552
553 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
554 {
555 return std::make_unique<Invoker>(Invoker{});
556 }
557
558 std::string GetTypeString() const override
559 {
560 auto str = std::stringstream();
561
562 // clang-format off
563 str << "DeviceAvgPool3dBwd<" << BlockSize << ",";
564 str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
565 str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
566 str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
567 // clang-format on
568
569 return str.str();
570 }
571};
572
573} // namespace device
574} // namespace tensor_operation
575} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_threadwise.hpp:28
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition ck/stream_config.hpp:10
Definition gridwise_2d_reduction_threadwise.hpp:84
Definition utility/sequence.hpp:43
Definition reduction_operator.hpp:37
Definition device_base.hpp:197
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:349
int num_reduce_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:423
Div div_element_op_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:427
std::vector< ck::index_t > dout_n_c_wos_strides_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:420
const DOutDataType * p_dout_grid_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:416
std::vector< ck::index_t > din_n_c_wos_length_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:419
DInDataType * p_din_grid_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:417
Argument(const DOutDataType *p_dout, DInDataType *p_din, std::vector< ck::index_t > dout_n_c_wos_lengths, std::vector< ck::index_t > din_n_c_wos_length, std::vector< ck::index_t > dout_n_c_wos_strides, std::vector< ck::index_t > din_n_c_wos_strides, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:350
std::vector< ck::index_t > dout_n_c_wos_lengths_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:418
std::vector< DinGridDesc_M > din_grid_desc_m_container_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:425
std::vector< ck::index_t > din_n_c_wos_strides_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:421
std::vector< DoutGridDesc_M_K > dout_grid_desc_m_k_container_
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:424
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:431
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:474
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:432
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:41
static constexpr auto I0
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:44
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_dout, void *p_din, std::vector< ck::index_t > dout_n_c_wos_lengths, std::vector< ck::index_t > din_n_c_wos_length, std::vector< ck::index_t > dout_n_c_wos_strides, std::vector< ck::index_t > din_n_c_wos_strides, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads) override
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:517
tensor_operation::element_wise::UnaryDivide Div
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:328
GridwiseReduction_mk_to_m_threadwise< DOutDataType, DInDataType, ComputeDataType, int, DoutGridDesc_M_K, DinGridDesc_M, reduce::Add, PassThrough, Div, InMemoryDataOperationEnum::Set, false, BlockSize, MThreadSliceSize, KThreadSliceSize, OutSrcInDstVectorDim, InSrcOutDstVectorSize, InSrcOutDstVectorSize > gridwise_reduce
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:330
static auto Make3DGridDescriptor_Out_M_K_In_M(const std::vector< ck::index_t > &dout_n_c_wos_lengths, const std::vector< ck::index_t > &din_n_c_wos_length, const std::vector< ck::index_t > &dout_n_c_wos_strides, const std::vector< ck::index_t > &din_n_c_wos_strides, const std::vector< ck::index_t > &window_lengths, const std::vector< ck::index_t > &window_strides, const std::vector< ck::index_t > &window_dilations, const std::vector< ck::index_t > &input_left_pads, const std::vector< ck::index_t > &input_right_pads, const std::vector< ck::index_t > &tildes)
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:51
decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0})) DoutDinGridDesc
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:308
remove_cvref_t< tuple_element_t< 0, DoutDinGridDesc > > DoutGridDesc_M_K
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:319
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:511
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:553
static constexpr ck::index_t M_BlockTileSize
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:47
tensor_operation::element_wise::PassThrough PassThrough
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:327
std::string GetTypeString() const override
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:558
static constexpr ck::index_t K_BlockTileSize
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:48
static bool IsSupportedArgument(const Argument &arg)
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:481
static constexpr index_t OutSrcInDstVectorDim
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:325
remove_cvref_t< tuple_element_t< 1, DoutDinGridDesc > > DinGridDesc_M
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:320
static constexpr auto I1
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:45
static constexpr ck::index_t NDimSpatial
Definition device_avgpool3d_bwd_ndhwc_ndhwc.hpp:42
Definition device_avgpool_bwd.hpp:20
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:701