pool_kernel.hpp Source File

pool_kernel.hpp Source File#

Composable Kernel: pool_kernel.hpp Source File
pool_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9#include <type_traits>
10
11namespace ck_tile {
12
14template <typename TensorShape, typename WindowShape>
16{
17
18 CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
19 void* output_ptr_,
20 void* output_index_ptr_,
21 TensorShape input_shape_,
22 TensorShape output_shape_,
23 TensorShape input_strides_,
24 TensorShape output_strides_,
25 WindowShape window_lengths_,
26 WindowShape window_strides_,
27 WindowShape window_dilations_,
28 WindowShape input_left_pads_,
29 WindowShape input_right_pads_)
30 : input_ptr(input_ptr_),
31 output_ptr(output_ptr_),
32 output_index_ptr(output_index_ptr_),
33 input_shape(input_shape_),
34 output_shape(output_shape_),
35 input_strides(input_strides_),
36 output_strides(output_strides_),
37 window_lengths(window_lengths_),
38 window_strides(window_strides_),
39 window_dilations(window_dilations_),
40 input_left_pads(input_left_pads_),
41 input_right_pads(input_right_pads_)
42 {
43 }
44
45 const void* input_ptr;
48
49 TensorShape input_shape;
50 TensorShape output_shape;
51 TensorShape input_strides;
52 TensorShape output_strides;
53 WindowShape window_lengths;
54 WindowShape window_strides;
55 WindowShape window_dilations;
56 WindowShape input_left_pads;
57 WindowShape input_right_pads;
58};
59
61template <typename TensorShape, typename WindowShape>
63{
64 const void* input_ptr;
67 TensorShape input_shape;
68 TensorShape output_shape;
69 TensorShape input_strides;
70 TensorShape output_strides;
71 WindowShape window_lengths;
72 WindowShape window_strides;
73 WindowShape window_dilations;
74 WindowShape input_left_pads;
75 WindowShape input_right_pads;
76};
77
78template <typename Problem_, typename Policy_ = PoolDefaultPolicy>
80{
83
88
89 static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
90
91 CK_TILE_HOST static constexpr auto BlockSize()
92 {
93 return is_wave32() ? kBlockSize / 2 : kBlockSize;
94 }
95
96 template <typename TensorShape, typename WindowShape>
98 {
99 using S = typename Problem::BlockShape;
100
101 // Compile-time validation for 2D pooling
102 static_assert(TensorShape::size() == 4, "2D pooling requires 4D input tensor (N,H,W,C)");
103 static_assert(WindowShape::size() == 2, "2D pooling requires 2D window shape (Y,X)");
104
105 // Extract dimension values
106 const index_t N = kargs.input_shape.at(number<0>{});
107 const index_t H = kargs.input_shape.at(number<1>{});
108 const index_t W = kargs.input_shape.at(number<2>{});
109 const index_t C = kargs.input_shape.at(number<3>{});
110
111 const index_t No = kargs.output_shape.at(number<0>{});
112 const index_t Ho = kargs.output_shape.at(number<1>{});
113 const index_t Wo = kargs.output_shape.at(number<2>{});
114 const index_t Co = kargs.output_shape.at(number<3>{});
115
116 const index_t Y = kargs.window_lengths.at(number<0>{});
117 const index_t X = kargs.window_lengths.at(number<1>{});
118
119 const index_t WindowStrideH = kargs.window_strides.at(number<0>{});
120 const index_t WindowStrideW = kargs.window_strides.at(number<1>{});
121
122 const index_t WindowDilationH = kargs.window_dilations.at(number<0>{});
123 const index_t WindowDilationW = kargs.window_dilations.at(number<1>{});
124
125 const index_t InLeftPadH = kargs.input_left_pads.at(number<0>{});
126 const index_t InLeftPadW = kargs.input_left_pads.at(number<1>{});
127
128 const index_t InRightPadH = kargs.input_right_pads.at(number<0>{});
129 const index_t InRightPadW = kargs.input_right_pads.at(number<1>{});
130
131 const index_t MRaw = N * Ho * Wo * C;
132 const index_t KRaw = Y * X;
133 const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
134 const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
135
136 auto reduce_op = typename Problem::ReduceOp{};
137
138 // Create input descriptor with all transformations
139 auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
140
141 // Apply spatial padding to input descriptor
142 const auto padded_in_desc = transform_tensor_descriptor(
143 in_desc,
145 make_pad_transform(H, InLeftPadH, InRightPadH),
146 make_pad_transform(W, InLeftPadW, InRightPadW),
150
151 // Create sliding windows by embedding pooling windows into descriptor
152 const auto embed_in_desc = transform_tensor_descriptor(
153 padded_in_desc,
156 make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
157 make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
161
162 // Reshape into 2D matrix: output positions (M) x pooling window elements (K)
163 const auto merged_embed_in_desc =
164 transform_tensor_descriptor(embed_in_desc,
169
170 const auto in_desc_padded = transform_tensor_descriptor(
171 merged_embed_in_desc,
175
176 // Create output descriptor with transformations
177 auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
178
179 const auto merged_out_desc = transform_tensor_descriptor(
180 out_desc,
184
185 const auto out_desc_padded =
186 transform_tensor_descriptor(merged_out_desc,
190
191 // Now create buffer views and tensor views with the fully transformed descriptors
192 const InDataType in_identity =
193 type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
194 const OutDataType out_identity =
195 type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
196
198 static_cast<const InDataType*>(kargs.input_ptr),
199 in_desc.get_element_space_size(),
200 in_identity);
201 const auto in_tensor_padded =
202 tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
203 in_desc_padded};
204
206 static_cast<OutDataType*>(kargs.output_ptr),
207 out_desc.get_element_space_size(),
208 out_identity);
209 const auto out_tensor_padded =
210 tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
211 out_desc_padded};
212
213 if constexpr(Problem::kOutputIndex)
214 {
215 auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
216 static_cast<IndexDataType*>(kargs.output_index_ptr),
217 out_desc.get_element_space_size(),
218 IndexDataType(-1));
219 const auto out_index_tensor_padded =
220 tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
221 out_index_buffer_view, out_desc_padded};
222
223 return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
224 }
225 else
226 {
227 // Return a dummy tensor for the third element when index output is not needed
228 return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
229 }
230 }
231
232 template <typename TensorShape, typename WindowShape>
234 {
235 using S = typename Problem::BlockShape;
236
237 // Compile-time validation for 3D pooling
238 static_assert(TensorShape::size() == 5, "3D pooling requires 5D input tensor (N,D,H,W,C)");
239 static_assert(WindowShape::size() == 3, "3D pooling requires 3D window shape (Z,Y,X)");
240
241 // Extract dimension values
242 const index_t N = kargs.input_shape.at(number<0>{});
243 const index_t D = kargs.input_shape.at(number<1>{});
244 const index_t H = kargs.input_shape.at(number<2>{});
245 const index_t W = kargs.input_shape.at(number<3>{});
246 const index_t C = kargs.input_shape.at(number<4>{});
247
248 const index_t No = kargs.output_shape.at(number<0>{});
249 const index_t Do = kargs.output_shape.at(number<1>{});
250 const index_t Ho = kargs.output_shape.at(number<2>{});
251 const index_t Wo = kargs.output_shape.at(number<3>{});
252 const index_t Co = kargs.output_shape.at(number<4>{});
253
254 const index_t Z = kargs.window_lengths.at(number<0>{});
255 const index_t Y = kargs.window_lengths.at(number<1>{});
256 const index_t X = kargs.window_lengths.at(number<2>{});
257
258 const index_t WindowStrideD = kargs.window_strides.at(number<0>{});
259 const index_t WindowStrideH = kargs.window_strides.at(number<1>{});
260 const index_t WindowStrideW = kargs.window_strides.at(number<2>{});
261
262 const index_t WindowDilationD = kargs.window_dilations.at(number<0>{});
263 const index_t WindowDilationH = kargs.window_dilations.at(number<1>{});
264 const index_t WindowDilationW = kargs.window_dilations.at(number<2>{});
265
266 const index_t InLeftPadD = kargs.input_left_pads.at(number<0>{});
267 const index_t InLeftPadH = kargs.input_left_pads.at(number<1>{});
268 const index_t InLeftPadW = kargs.input_left_pads.at(number<2>{});
269
270 const index_t InRightPadD = kargs.input_right_pads.at(number<0>{});
271 const index_t InRightPadH = kargs.input_right_pads.at(number<1>{});
272 const index_t InRightPadW = kargs.input_right_pads.at(number<2>{});
273
274 const index_t MRaw = N * Do * Ho * Wo * C;
275 const index_t KRaw = Z * Y * X;
276 const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
277 const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
278
279 auto reduce_op = typename Problem::ReduceOp{};
280
281 // Create input descriptor with all transformations
282 auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
283
284 // Apply spatial padding to input descriptor (all 3D dimensions)
285 const auto padded_in_desc = transform_tensor_descriptor(
286 in_desc,
288 make_pad_transform(D, InLeftPadD, InRightPadD),
289 make_pad_transform(H, InLeftPadH, InRightPadH),
290 make_pad_transform(W, InLeftPadW, InRightPadW),
294
295 // Create 3D sliding windows by embedding pooling windows into descriptor
296 const auto embed_in_desc = transform_tensor_descriptor(
297 padded_in_desc,
300 make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
301 make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
302 make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
309 sequence<7>{}));
310
311 // Reshape into 2D matrix: output positions (M) x pooling window elements (K)
312 const auto merged_embed_in_desc = transform_tensor_descriptor(
313 embed_in_desc,
314 make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)),
318
319 const auto in_desc_padded = transform_tensor_descriptor(
320 merged_embed_in_desc,
324
325 // Create output descriptor with transformations
326 auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
327
328 const auto merged_out_desc = transform_tensor_descriptor(
329 out_desc,
330 make_tuple(make_merge_transform(make_tuple(No, Do, Ho, Wo, Co))),
333
334 const auto out_desc_padded =
335 transform_tensor_descriptor(merged_out_desc,
339
340 // Now create buffer views and tensor views with the fully transformed descriptors
341 const InDataType in_identity =
342 type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
343 const OutDataType out_identity =
344 type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
345
347 static_cast<const InDataType*>(kargs.input_ptr),
348 in_desc.get_element_space_size(),
349 in_identity);
350 const auto in_tensor_padded =
351 tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
352 in_desc_padded};
353
355 static_cast<OutDataType*>(kargs.output_ptr),
356 out_desc.get_element_space_size(),
357 out_identity);
358 const auto out_tensor_padded =
359 tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
360 out_desc_padded};
361
362 if constexpr(Problem::kOutputIndex)
363 {
364 auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
365 static_cast<IndexDataType*>(kargs.output_index_ptr),
366 out_desc.get_element_space_size(),
367 IndexDataType(-1));
368 const auto out_index_tensor_padded =
369 tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
370 out_index_buffer_view, out_desc_padded};
371
372 return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
373 }
374 else
375 {
376 // Return a dummy tensor for the third element when index output is not needed
377 return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
378 }
379 }
380
381 public:
382 template <typename TensorShape, typename WindowShape>
384 {
385 using S = typename Problem::BlockShape;
386
387 // Compile-time validation for supported window dimensions
388 static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
389 "Only 2D and 3D pooling operations are supported");
390
391 const auto iM = get_block_id() * S::Block_M;
392
393 // Get tensors based on dimensionality
394 auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
395 if constexpr(WindowShape::size() == 2)
396 return MakeTensorView2D(kargs);
397 else if constexpr(WindowShape::size() == 3)
398 return MakeTensorView3D(kargs);
399 else
400 static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
401 "Unsupported WindowShape rank: only 2D or 3D pooling is supported");
402 }();
403
404 auto reduce_op = typename Problem::ReduceOp{};
405
406 auto x_window = make_tile_window(in_tensor_padded,
408 {iM, 0},
409 Policy::template MakeXBlockTileDistribution<Problem>());
410 auto y_window = make_tile_window(out_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
411
412 __shared__ char smem[Policy::template GetSmemSize<Problem>()];
413
414 const auto reduce_len =
415 in_tensor_padded.get_tensor_descriptor().get_lengths().at(number<1>{});
416 index_t num_k_tiles =
417 __builtin_amdgcn_readfirstlane(integer_divide_ceil(reduce_len, S::Block_N));
418
419 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
420 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
421 auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
422
423 using XTensorTile = decltype(load_tile(x_window));
424 auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
425 set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
426
427 if constexpr(Problem::kOutputIndex)
428 {
429 auto y_index_window =
430 make_tile_window(out_index_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
431
432 auto y_index_tile =
433 block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
434 set_tile(y_index_tile, IndexDataType(0));
435
436 // Main reduction loop - with index tracking
437 for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile)
438 {
439 const auto x_tile = load_tile(x_window);
440 auto index_calculator = [&](const auto& x_indices) {
441 // Get global coordinates in the 2D matrix space (M, N)
442 const auto global_M = x_indices.at(number<0>{}) + iM;
443 const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{});
444 return in_tensor_padded.get_tensor_descriptor().calculate_offset(
445 make_tuple(global_M, global_N));
446 };
447
448 block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
449 move_tile_window(x_window, {0, S::Block_N});
450 }
451
452 block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
453 if constexpr(Problem::kNeedCrossWarpSync)
454 {
455 __shared__ char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
456
457 block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
458 }
459
460 store_tile(y_window, cast_tile<OutDataType>(y_tile));
461 store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
462 }
463 else
464 {
465 // Main reduction loop - without index tracking
466 for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
467 {
468 const auto x_tile = load_tile(x_window);
469 block_reduce2d(x_tile, y_tile, reduce_op);
470 move_tile_window(x_window, {0, S::Block_N});
471 }
472
473 block_reduce2d_sync(y_tile, reduce_op);
474 block_reduce2d_cross_warp(y_tile, smem, reduce_op);
475
476 store_tile(y_window, cast_tile<OutDataType>(y_tile));
477 }
478 }
479
490 template <typename TensorShape, typename WindowShape>
492 {
493 constexpr index_t InputRank = TensorShape::size();
494 constexpr index_t OutputRank = TensorShape::size(); // Same as input rank
495 constexpr index_t WindowRank = WindowShape::size();
496
497 // Validate window dimensions (only 2D and 3D supported)
498 if constexpr(WindowRank != 2 && WindowRank != 3)
499 {
500 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
501 {
502 CK_TILE_ERROR("Only 2D and 3D pooling are supported!");
503 }
504 return false;
505 }
506
507 // Validate that input rank matches expected rank for window dimensions
508 if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5))
509 {
510 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
511 {
512 CK_TILE_ERROR("Input tensor rank doesn't match window dimensions!");
513 }
514 return false;
515 }
516
517 // Check that channel dimension (last dimension) is contiguous for both input and output
518 if(kargs.input_strides.at(number<InputRank - 1>{}) != 1)
519 {
520 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
521 {
522 CK_TILE_ERROR("Input tensor's channel dimension must have stride 1!");
523 }
524 return false;
525 }
526
527 if(kargs.output_strides.at(number<OutputRank - 1>{}) != 1)
528 {
529 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
530 {
531 CK_TILE_ERROR("Output tensor's channel dimension must have stride 1!");
532 }
533 return false;
534 }
535
536 return true;
537 }
538
541 template <typename TensorShape, typename WindowShape>
542 CK_TILE_HOST static constexpr index_t
544 {
545 using S = typename Problem::BlockShape;
546
547 // Calculate total output elements (M dimension)
548 index_t M = 1;
549 static_for<0, TensorShape::size(), 1>{}([&](auto i) { M *= kargs.output_shape.at(i); });
550
551 // Calculate grid size: ceil(M / Block_M)
552 return (M + S::Block_M - 1) / S::Block_M;
553 }
554
556 template <typename TensorShape, typename WindowShape>
557 CK_TILE_HOST static constexpr auto
559 {
561 host_args.output_ptr,
562 host_args.output_index_ptr,
563 host_args.input_shape,
564 host_args.output_shape,
565 host_args.input_strides,
566 host_args.output_strides,
567 host_args.window_lengths,
568 host_args.window_strides,
569 host_args.window_dilations,
570 host_args.input_left_pads,
571 host_args.input_right_pads};
572 }
573};
574
575} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1565
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE index_t get_block_id()
Definition arch.hpp:119
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1584
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T *__restrict__ p, BufferSizeType buffer_size)
Definition buffer_view.hpp:1262
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition coordinate_transform.hpp:1594
Host arguments for pooling operations.
Definition pool_kernel.hpp:16
TensorShape input_strides
Definition pool_kernel.hpp:51
void * output_ptr
Definition pool_kernel.hpp:46
WindowShape input_left_pads
Definition pool_kernel.hpp:56
const void * input_ptr
Definition pool_kernel.hpp:45
WindowShape window_lengths
Definition pool_kernel.hpp:53
WindowShape window_strides
Definition pool_kernel.hpp:54
TensorShape input_shape
Definition pool_kernel.hpp:49
TensorShape output_strides
Definition pool_kernel.hpp:52
CK_TILE_HOST PoolHostArgs(const void *input_ptr_, void *output_ptr_, void *output_index_ptr_, TensorShape input_shape_, TensorShape output_shape_, TensorShape input_strides_, TensorShape output_strides_, WindowShape window_lengths_, WindowShape window_strides_, WindowShape window_dilations_, WindowShape input_left_pads_, WindowShape input_right_pads_)
Definition pool_kernel.hpp:18
TensorShape output_shape
Definition pool_kernel.hpp:50
WindowShape input_right_pads
Definition pool_kernel.hpp:57
WindowShape window_dilations
Definition pool_kernel.hpp:55
void * output_index_ptr
Definition pool_kernel.hpp:47
Kernel arguments for pooling operations.
Definition pool_kernel.hpp:63
TensorShape output_shape
Definition pool_kernel.hpp:68
WindowShape input_right_pads
Definition pool_kernel.hpp:75
WindowShape window_lengths
Definition pool_kernel.hpp:71
WindowShape window_dilations
Definition pool_kernel.hpp:73
TensorShape input_strides
Definition pool_kernel.hpp:69
const void * input_ptr
Definition pool_kernel.hpp:64
WindowShape input_left_pads
Definition pool_kernel.hpp:74
TensorShape input_shape
Definition pool_kernel.hpp:67
WindowShape window_strides
Definition pool_kernel.hpp:72
void * output_ptr
Definition pool_kernel.hpp:65
TensorShape output_strides
Definition pool_kernel.hpp:70
void * output_index_ptr
Definition pool_kernel.hpp:66
Definition pool_kernel.hpp:80
static CK_TILE_HOST constexpr index_t CalculateGridSize(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition pool_kernel.hpp:543
ck_tile::remove_cvref_t< Policy_ > Policy
Definition pool_kernel.hpp:82
ck_tile::remove_cvref_t< typename Problem::OutDataType > OutDataType
Definition pool_kernel.hpp:86
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition pool_kernel.hpp:85
static constexpr index_t kBlockSize
Definition pool_kernel.hpp:89
static CK_TILE_HOST bool IsSupportedArgument(PoolKernelArgs< TensorShape, WindowShape > kargs)
Validates if the given arguments are supported by the pooling kernel.
Definition pool_kernel.hpp:491
static CK_TILE_HOST constexpr auto BlockSize()
Definition pool_kernel.hpp:91
static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition pool_kernel.hpp:97
static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition pool_kernel.hpp:233
ck_tile::remove_cvref_t< typename Problem::InDataType > InDataType
Definition pool_kernel.hpp:84
ck_tile::remove_cvref_t< typename Problem::IndexDataType > IndexDataType
Definition pool_kernel.hpp:87
CK_TILE_DEVICE void operator()(PoolKernelArgs< TensorShape, WindowShape > kargs) const
Definition pool_kernel.hpp:383
static CK_TILE_HOST constexpr auto MakeKernelArgs(PoolHostArgs< TensorShape, WindowShape > &host_args)
Create kernel arguments from host arguments.
Definition pool_kernel.hpp:558
ck_tile::remove_cvref_t< Problem_ > Problem
Definition pool_kernel.hpp:81
Definition null_tensor.hpp:9
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145