flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp Source File

flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp Source File#

Composable Kernel: flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp Source File
flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
14{
15 static constexpr auto I0 = number<0>{};
16 static constexpr auto I1 = number<1>{};
17 static constexpr auto I2 = number<2>{};
18
19 // 3d + padding
20 template <typename Problem>
22 {
23 using namespace ck_tile;
24
25 constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
26 constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
27 if constexpr(MPerXdl == 16 && NPerXdl == 16)
28 {
29 /*reduce transform layers,compare with old ck*/
30 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
31 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
32 constexpr index_t KPack = GetSmemPackA<Problem>();
33
34 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
38 number<1>{});
39
40 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
41 a_lds_block_desc_0,
43 make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
47
48 constexpr auto a_lds_block_desc = transform_tensor_descriptor(
49 a_lds_block_desc_permuted,
55
56 return a_lds_block_desc;
57 }
58 else
59 {
60 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
61 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
62 constexpr index_t kKPack = GetSmemPackA<Problem>();
63
64 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
66 make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
68 number<1>{});
69
70 constexpr auto a_lds_block_desc = transform_tensor_descriptor(
71 a_lds_block_desc_0,
73 make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
76
77 return a_lds_block_desc;
78 }
79/*xor*/
80#if 0
81 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
82 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
83 constexpr index_t kKPack = GetSmemPackA<Problem>();
85
86 constexpr auto DataTypeSize = sizeof(ADataType);
87 constexpr auto MLdsLayer =
88 (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
89
90 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
92 number<kMPerBlock / MLdsLayer>{},
96 number<1>{});
97
98 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
99 a_lds_block_desc_0,
101 number<kKPerBlock / kKPack * MLdsLayer>{})),
105
106 constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
107 a_lds_block_desc_permuted,
109 make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
114
115 constexpr auto a_lds_block_desc = transform_tensor_descriptor(
116 a_lds_block_desc_xk0_mnldslayer_mn_xk1,
123 return a_lds_block_desc;
124#endif
125 }
126
136 template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
138 {
139 constexpr index_t BlockSize = Problem::kBlockSize;
140 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
141 constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
142 constexpr index_t PackedSize =
144
145 // Assume DataType is even!
146 if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
147 elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
148 PackedSize == 2)
149 {
150 return (PackedSize * 32 / sizeof(DataType));
151 }
152 else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
153 elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
154 {
155 return (PackedSize * 16 / sizeof(DataType));
156 }
157 else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
158 elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
159 {
160 return (PackedSize * 8 / sizeof(DataType));
161 }
162 else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
163 XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
164 elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
165 {
166 return (PackedSize * 4 / sizeof(DataType));
167 }
168 else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
169 XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
170 elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
171 {
172 return (PackedSize * 2 / sizeof(DataType));
173 }
174 else
175 {
176 return PackedSize;
177 }
178 }
179
180 template <typename Problem>
181 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
182 {
185 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
186 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
187
188 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
189 {
191 }
192 else
193 {
195 }
196 }
197
198 template <typename Problem>
199 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
200 {
203 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
204 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
205
206 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
207 {
209 }
210 else
211 {
213 }
214 }
215
216 template <typename Problem>
218 {
219 constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
220 MakeALdsBlockDescriptor<Problem>().get_element_space_size();
221 return smem_size_a;
222 }
223
224 template <typename Problem>
226 {
227 constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
228
229 return smem_size_a;
230 }
231
232 template <typename Problem>
233 CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
234 {
235 return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
236 }
237
238 template <typename Problem>
239 CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
240 {
241 using TileShape = typename Problem::BlockGemmShape;
242 if constexpr(TileShape::WarpTile::at(I1) == 32)
243 {
244 return TileShape::WarpTile::at(I2) / 2;
245 }
246 else
247 {
248 static_assert(TileShape::WarpTile::at(I1) == 16);
249 return TileShape::WarpTile::at(I2) / 4;
250 }
251 }
252
253 template <typename Problem>
255 {
256 using TileShape = typename Problem::BlockGemmShape;
258
259 static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
260
261 constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
262 constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
263
264 constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
265
266 constexpr int KLane = get_warp_size() / MPerXdl;
267 constexpr int KPerThread = KPerXdl / KLane;
268
269 constexpr int MaxVecSize = 16 / sizeof(ADataType);
270 constexpr int KItemsPerLoad = min(MaxVecSize, KPerThread);
271 constexpr int KFragment = KPerThread / KItemsPerLoad;
272
280 sequence<0, 2>>{});
281 }
282
283 template <typename Problem>
285 {
288
289 constexpr index_t BlockSize = Problem::kBlockSize;
290
291 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
292 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
293
294 constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
295
296 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
297 {
298 constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
299 constexpr index_t M0 = MPerBlock / M1;
300 constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
301 static_assert(total_pixels % M1 == 0);
302 constexpr index_t K3 = total_pixels / M1;
303 constexpr index_t KPack = GetSmemPackA<Problem>();
304 static_assert(KPack % K3 == 0);
305 constexpr index_t K2 = KPack / K3;
306 if constexpr(get_warp_size() >= (K2 * M0))
307 {
308 constexpr index_t K1 = get_warp_size() / (K2 * M0);
309 constexpr index_t K0 = BlockSize / get_warp_size();
310 static_assert(KPerBlock == K0 * K1 * K2 * K3);
317 sequence<3, 1>>{});
318 }
319 else
320 {
321 constexpr index_t K1 = (K2 * M0) / get_warp_size();
322 constexpr index_t K2_m = K2 / K1;
323 constexpr index_t K0 = BlockSize / get_warp_size() / K1;
324 static_assert(KPerBlock == K0 * K1 * K2_m * K3);
331 sequence<3, 1>>{});
332 }
333 }
334 else
335 {
336 constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
337 constexpr index_t K0 = KPerBlock / K1;
338 // coalesce reading for each blocks
339 if constexpr(get_warp_size() % K0 == 0)
340 {
341 constexpr index_t M2 = get_warp_size() / K0;
342 constexpr index_t M1 = BlockSize / get_warp_size();
343 static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
344 static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
345 constexpr index_t M0 = MPerBlock / (M2 * M1);
346 static_assert(M0 * M1 * M2 == MPerBlock,
347 "Incorrect M0, M2, M1 configuration! "
348 "M0, M1, M2 must cover whole MPerBlock!");
349
356 sequence<0, 1>>{});
357 }
358 else
359 {
360 constexpr index_t KWave = K0 / get_warp_size();
361 constexpr index_t M0 = BlockSize / get_warp_size() / KWave;
362 constexpr index_t M1 = MPerBlock / M0;
363
371 sequence<1, 2>>{});
372 }
373 }
374 }
375
376 template <typename Problem>
378 {
380
381 constexpr index_t BlockSize = Problem::kBlockSize;
382
383 // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
384 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
385
386 constexpr index_t K1 = 16 / sizeof(ADataType);
387 constexpr index_t K0 = KPerBlock / K1;
388 constexpr index_t M2 = get_warp_size() / K0;
389 constexpr index_t M1 = BlockSize / get_warp_size();
390 static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
391 static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
392 // constexpr index_t M0 = MPerBlock / (M2 * M1);
393 // static_assert(M0 * M1 * M2 == MPerBlock,
394 // "Incorrect M0, M2, M1 configuration! "
395 // "M0, M1, M2 must cover whole MPerBlock!");
396
403 sequence<1>>{});
404 }
405
406 template <typename Problem>
408 {
409 using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
410
411 constexpr index_t BlockSize = Problem::kBlockSize;
412 constexpr index_t WaveSize = get_warp_size();
413 constexpr index_t WaveNum = BlockSize / WaveSize;
414
415 constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
416
417 constexpr index_t MaxVecSize = 16 / sizeof(typename Problem::BDataType);
418 constexpr index_t KItemsPerLoad = min(KBPerLoad, MaxVecSize);
419 constexpr index_t KFragment = KBPerLoad / KItemsPerLoad;
420 static_assert(KFragment * KItemsPerLoad == KBPerLoad);
421
422 constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim./
423 constexpr index_t KWavePerBlk = 1;
424 static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
425 static_assert(TileShape::BlockWarps::at(number<2>{}) == 1, "Requires K_Warp == 1");
426
427 constexpr index_t NBPerLoad = 1;
428 constexpr index_t NThdPerWave = 1;
429 constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
430 constexpr index_t NRepeat = 1;
431
432 constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
433
439 // direction
440 // wave in blk, // thd in wave
441 // <M, K> // <M, K>
442 tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
443 tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
444 // <repeat, vec_load>
447 }
448
449 template <typename Problem>
451 {
454 static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
455 constexpr index_t kBlockSize = Problem::kBlockSize;
456 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
457 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
458
459 constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
460 constexpr index_t M0 = kMPerBlock / M1;
461 constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
462 static_assert(total_pixels % M1 == 0);
463 constexpr index_t K3 = total_pixels / M1;
464 constexpr index_t kKPack = GetSmemPackA<Problem>();
465 static_assert(kKPack % K3 == 0);
466 constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
467 constexpr index_t warp_size = get_warp_size();
468 if constexpr(warp_size >= (K2 * M0))
469 {
470 constexpr index_t K1 = warp_size / (K2 * M0);
471 constexpr index_t K0 = kBlockSize / warp_size;
472
479 sequence<1, 3>>{});
480 }
481 else
482 {
483 constexpr index_t K1 = (K2 * M0) / get_warp_size();
484 constexpr index_t K2_m = K2 / K1;
485 constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
486 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
493 sequence<1, 3>>{});
494 }
495 }
496
497 template <typename Problem>
498 CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
499 {
500 // using AccDataType = float;
501 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
502 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
503 using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
504 typename Problem::BDataType,
505 typename Problem::CDataType,
506 WarpTile::at(I0),
507 WarpTile::at(I1),
508 WarpTile::at(I2),
509 Problem::TransposeC>;
510
511 using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
512 typename Problem::ADataType,
513 // BlockGemmASmemBSmemCRegV1CustomPolicy<typename
514 // Problem::ADataType,
515 typename Problem::BDataType,
516 typename Problem::CDataType,
517 BlockWarps,
518 WarpGemm>;
520 }
521};
522
523} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:16
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:14
static constexpr auto I2
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:17
static CK_TILE_HOST_DEVICE constexpr auto GetKBPerLoad()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:239
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledARegBlockDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:450
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:233
static CK_TILE_HOST_DEVICE constexpr auto GetGlobalVectorLoadSize()
Get the maximum global memory vector load size.
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:137
static CK_TILE_HOST_DEVICE constexpr auto MakeALDS_WarpTileDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:254
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeA()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:181
static CK_TILE_HOST_DEVICE constexpr auto MakeADramTileDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:284
static CK_TILE_HOST_DEVICE constexpr auto GetBlockFlatmm()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:498
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:225
static constexpr auto I0
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:15
static CK_TILE_HOST_DEVICE constexpr auto MakeADramDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:377
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeB()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:199
static CK_TILE_HOST_DEVICE constexpr auto MakeALdsBlockDescriptor()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:21
static CK_TILE_HOST_DEVICE constexpr auto MakeBFlatDramTileDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:407
static constexpr auto I1
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:16
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeA()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:217
Definition tile/core/numeric/numeric.hpp:81
static constexpr int PackedSize
Definition tile/core/numeric/numeric.hpp:82
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192