gemm_pipeline_ag_bg_cr_comp_v4.hpp Source File

gemm_pipeline_ag_bg_cr_comp_v4.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_comp_v4.hpp Source File
gemm_pipeline_ag_bg_cr_comp_v4.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3#pragma once
4#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A Tile Window: global memory
12// B Tile Window: global memory
13// C Distributed tensor: register
14template <typename Problem>
16{
17 static constexpr index_t PrefetchStages = 2;
18 static constexpr index_t PrefillStages = 1;
19 static constexpr index_t GlobalBufferNum = 1;
20
21 static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
22
23 CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
24 {
25 return num_loop > PrefetchStages;
26 }
27
29 {
30 if(num_loop == 1)
31 {
32 return TailNumber::One;
33 }
34 if(num_loop % PrefetchStages == 1)
35 {
36 return TailNumber::Three;
37 }
38 else
39 {
40 return TailNumber::Two;
41 }
42 }
43
44 template <typename RunFunction>
45 CK_TILE_HOST_DEVICE static auto
46 TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
47 {
48 // Handle all the valid cases.
49 if(has_hot_loop)
50 {
51 if(tail_number == TailNumber::Three)
52 {
53 return run_func(bool_constant<true>{},
55 }
56 else if(tail_number == TailNumber::Two)
57 {
58 return run_func(bool_constant<true>{},
60 }
61 }
62 else
63 {
64 if(tail_number == TailNumber::Three)
65 {
66 return run_func(bool_constant<false>{},
68 }
69 else if(tail_number == TailNumber::Two)
70 {
71 return run_func(bool_constant<false>{},
73 }
74 else
75 {
76 return (run_func(bool_constant<false>{},
78 }
79 }
80 // If execution reaches here, it's an invalid tail_number because it wasn't handled above.
81#if defined(__HIP_DEVICE_COMPILE__)
82 __builtin_unreachable();
83#else
84 throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
85 "PrefetchStages are supported.");
86#endif
87 }
88};
89
103template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV4DefaultPolicy>
105{
108
113
117
120
123
126
127 static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
128
129 static constexpr index_t APackedSize =
131 static constexpr index_t BPackedSize =
133
135 using I0 = number<0>;
136 using I1 = number<1>;
137 using I2 = number<2>;
138
139 static constexpr index_t BlockSize = Problem::kBlockSize;
140
141 static constexpr index_t MPerBlock = BlockGemmShape::kM;
142 static constexpr index_t NPerBlock = BlockGemmShape::kN;
143 static constexpr index_t KPerBlock = BlockGemmShape::kK;
144
145 template <bool IsWave32Host = false>
146 static constexpr index_t GetVectorSizeA()
147 {
148 return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
149 }
150 template <bool IsWave32Host = false>
151 static constexpr index_t GetVectorSizeB()
152 {
153 return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
154 }
155 static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
156
157 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
158 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
159
160 static constexpr bool kPadM = Problem::kPadM;
161 static constexpr bool kPadN = Problem::kPadN;
162 static constexpr bool kPadK = Problem::kPadK;
163
164 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
165 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
166 static constexpr index_t Preshuffle = Problem::Preshuffle;
167
168 static constexpr bool HasHotLoop = Problem::HasHotLoop;
169 static constexpr auto TailNum = Problem::TailNum;
170 static constexpr auto Scheduler = Problem::Scheduler;
171
174
175 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
176 {
177 // clang-format off
178 return concat('_', "pipeline_AgBgCrCompV4",
181 concat('x', kPadM, kPadN, kPadK));
182 // clang-format on
183 }
184
186 {
187 return Policy::template GetSmemSize<Problem>();
188 }
189
190 CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
191 {
192 return Policy::template IsTransposeC<Problem>();
193 }
194
195 template <GemmPipelineScheduler Scheduler>
197 {
198 };
199
200 template <>
202 {
204
205 CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
206 {
207 constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
208 constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
209 constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
210
211 constexpr index_t WaveSize = get_warp_size();
212 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
213 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
214
215 constexpr index_t A_LDS_Read_Width = KPerXDL;
216 constexpr index_t B_LDS_Read_Width = KPerXDL;
217
218 constexpr index_t A_Buffer_Load_Inst_Num =
220 constexpr index_t B_Buffer_Load_Inst_Num =
222
223 constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
224 constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
225
226 constexpr index_t A_LDS_Read_Inst_Num =
227 WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
228 constexpr index_t B_LDS_Read_Inst_Num =
229 WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
230
231 constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
232 (BlockSize / WaveSize) /
233 (MPerXDL * NPerXDL * KPerXDL);
234
235 constexpr auto num_ds_read_inst_a =
236 A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
237 : A_LDS_Read_Inst_Num / 2;
238 constexpr auto num_ds_read_inst_b =
239 B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
240 : B_LDS_Read_Inst_Num / 2;
241
242 constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b;
243 constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;
244 constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
245 constexpr auto num_issue = num_buffer_load_inst;
246
248 ignore = i;
249 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
250 __builtin_amdgcn_sched_group_barrier(
251 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
252 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
253 __builtin_amdgcn_sched_group_barrier(
254 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
255 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
256 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
257 __builtin_amdgcn_sched_group_barrier(
258 0x008, C_MFMA_Inst_Num / num_issue - 3, 0); // MFMA : 5
259 });
260 __builtin_amdgcn_sched_barrier(0);
261 }
262
263 template <bool HasHotLoop,
265 typename AsDramBlockWindowTmp,
266 typename BsDramBlockWindowTmp,
267 typename AElementFunction,
268 typename BElementFunction,
269 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
271 bool>* = nullptr>
272 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
273 const AElementFunction& a_element_func,
274 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
275 const BElementFunction& b_element_func,
276 index_t num_loop,
277 void* __restrict__ p_smem_0,
278 void* __restrict__ p_smem_1) const
279 {
280 using ADramBlockWindowTmp =
281 remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
282 using BDramBlockWindowTmp =
283 remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
284
285 static_assert(
286 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
287 std::is_same_v<BDataType,
289 "Data Type conflict on A and B matrix input data type.");
290
291 constexpr bool is_a_col_major =
292 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
293 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
294
295 static_assert(is_a_col_major
296 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
297 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
298 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
299 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
300 "A block window has incorrect lengths for defined ALayout!");
301 static_assert(is_b_row_major
302 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
303 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
304 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
305 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
306 "B block window has incorrect lengths for defined BLayout!");
307
308 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
309 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
310
311 constexpr ADramTileWindowStep a_dram_tile_window_step =
312 is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
313 constexpr BDramTileWindowStep b_dram_tile_window_step =
314 is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
315
316 // global prefetch 0
317 // global read 0
318
320 auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
321 auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
322
323 constexpr auto a_lds_shape = []() {
324 if constexpr(is_a_load_tr_v())
326 else
328 }();
329 auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
330 auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
331
332 constexpr auto b_lds_shape = []() {
333 if constexpr(is_b_load_tr_v())
335 else
337 }();
338 auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
339 auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
340
341 // Block GEMM
342 auto block_gemm = BlockGemm();
343 auto c_block_tile = block_gemm.MakeCBlockTile();
344
345 // initialize C
346 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
347
348 // Generating a tuple with tile_windows for values A0, A1, ... AN
349 auto a_tile_windows = generate_tuple(
350 [&](auto idx) {
351 return make_tile_window(
352 a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
354 a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
355 Policy::template MakeADramTileDistribution<Problem>());
356 },
357 number<AsLayout::size()>{});
358
359 // Load tile — during value loading, an elementwise function is executed for each A0,
360 // A1, … AN. The values A0, A1, … AN are read by the same thread.
361 auto elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
362
363 // Move each A — the enhanced function move_tile_window is executed, which takes a tuple
364 // as input.
365 move_tile_window(a_tile_windows, a_dram_tile_window_step);
366
367 // Generating a tuple with tile_windows for values B0, B1, ... BN
368 auto b_tile_windows = generate_tuple(
369 [&](auto idx) {
370 return make_tile_window(
371 b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
373 b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
374 Policy::template MakeBDramTileDistribution<Problem>());
375 },
376 number<AsLayout::size()>{});
377
378 // Load tile — during value loading, an elementwise function is executed for each B0,
379 // B1, … BN. The values B0, B1, … BN are read by the same thread.
380 auto elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
381
382 // Move each B — the enhanced function move_tile_window is executed, which takes a tuple
383 // as input.
384 move_tile_window(b_tile_windows, b_dram_tile_window_step);
385
386 // LDS write 0
387 if constexpr(is_a_col_major && !is_a_load_tr_v())
388 {
390 Policy::template MakeShuffledARegTileDistribution<Problem>());
391 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
392 Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
393 }
394 else
395 {
396 Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
397 }
398 if constexpr(is_b_row_major && !is_b_load_tr_v())
399 {
401 Policy::template MakeShuffledBRegTileDistribution<Problem>());
402 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
403 Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
404 }
405 else
406 {
407 Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
408 }
409
410 // global read 1
411
412 elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
413 move_tile_window(a_tile_windows, a_dram_tile_window_step);
414
415 elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
416 move_tile_window(b_tile_windows, b_dram_tile_window_step);
418
419 constexpr auto ALdsTileDistr =
420 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
421 constexpr auto BLdsTileDistr =
422 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
423
424 using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
425 using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
426 ALdsTile a_block_tile0, a_block_tile1;
427 BLdsTile b_block_tile0, b_block_tile1;
428
429 constexpr auto a_lds_input_tile_distr = [&]() {
430 if constexpr(is_a_load_tr_v())
433 decltype(BlockGemm::MakeABlockDistributionEncode()),
434 typename Problem::ADataType>::TransposedDstrEncode{});
435 else
436 return ALdsTileDistr;
437 }();
438 constexpr auto b_lds_input_tile_distr = [&]() {
439 if constexpr(is_b_load_tr_v())
442 decltype(BlockGemm::MakeBBlockDistributionEncode()),
443 typename Problem::BDataType>::TransposedDstrEncode{});
444 else
445 return BLdsTileDistr;
446 }();
447 auto a_lds_ld_window0 =
448 make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
449 auto a_lds_ld_window1 =
450 make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
451 auto b_lds_ld_window0 =
452 make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
453 auto b_lds_ld_window1 =
454 make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
455
456 static_assert(!is_tile_window_linear_v<decltype(a_lds_ld_window0)> &&
457 !is_tile_window_linear_v<decltype(a_lds_ld_window1)> &&
458 !is_tile_window_linear_v<decltype(b_lds_ld_window0)> &&
459 !is_tile_window_linear_v<decltype(b_lds_ld_window1)>,
460 "LDS windows must not be linear");
461
462 Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
463 Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
464
465 if constexpr(is_a_col_major && !is_a_load_tr_v())
466 {
468 Policy::template MakeShuffledARegTileDistribution<Problem>());
469 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
470 Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
471 }
472 else
473 {
474 Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
475 }
476 if constexpr(is_b_row_major && !is_b_load_tr_v())
477 {
479 Policy::template MakeShuffledBRegTileDistribution<Problem>());
480 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
481 Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
482 }
483 else
484 {
485 Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
486 }
487
488 // Load tile — during value loading, an elementwise function is executed for each A0,
489 // A1, … AN. The values A0, A1, … AN are read by the same thread.
490 elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
491 move_tile_window(a_tile_windows, a_dram_tile_window_step);
492
493 elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
494 move_tile_window(b_tile_windows, b_dram_tile_window_step);
495
496 if constexpr(HasHotLoop)
497 {
498 // minus 2 because we have ping-pong double buffer.
499 index_t iCounter = amd_wave_read_first_lane(num_loop - 2);
500 do
501 {
502 // ping
503 {
505 Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
506 Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
507
508 if constexpr(is_a_col_major && !is_a_load_tr_v())
509 {
511 Policy::template MakeShuffledARegTileDistribution<Problem>());
512 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
513 Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
514 }
515 else
516 {
517 Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
518 }
519 if constexpr(is_b_row_major && !is_b_load_tr_v())
520 {
522 Policy::template MakeShuffledBRegTileDistribution<Problem>());
523 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
524 Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
525 }
526 else
527 {
528 Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
529 }
530
531 elementwise_As_res =
532 load_tile_with_elementwise(a_tile_windows, a_element_func);
533 move_tile_window(a_tile_windows, a_dram_tile_window_step);
534
535 elementwise_Bs_res =
536 load_tile_with_elementwise(b_tile_windows, b_element_func);
537 move_tile_window(b_tile_windows, b_dram_tile_window_step);
538 // gemm
539 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
541 }
542 // pong
543 {
545 Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
546 Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
547
548 if constexpr(is_a_col_major && !is_a_load_tr_v())
549 {
551 Policy::template MakeShuffledARegTileDistribution<Problem>());
552 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
553 Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
554 }
555 else
556 {
557 Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
558 }
559 if constexpr(is_b_row_major && !is_b_load_tr_v())
560 {
562 Policy::template MakeShuffledBRegTileDistribution<Problem>());
563 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
564 Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
565 }
566 else
567 {
568 Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
569 }
571
572 elementwise_As_res =
573 load_tile_with_elementwise(a_tile_windows, a_element_func);
574 move_tile_window(a_tile_windows, a_dram_tile_window_step);
575
576 elementwise_Bs_res =
577 load_tile_with_elementwise(b_tile_windows, b_element_func);
578 move_tile_window(b_tile_windows, b_dram_tile_window_step);
579
580 // gemm
581 block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
583 }
584 iCounter -= 2;
585 } while(iCounter > 1);
586 }
587
588 // tail 3
590 {
591 // 3
592 {
594 Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
595 Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
596 if constexpr(is_a_col_major && !is_a_load_tr_v())
597 {
599 Policy::template MakeShuffledARegTileDistribution<Problem>());
600 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
601 Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
602 }
603 else
604 {
605 Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
606 }
607 if constexpr(is_b_row_major && !is_b_load_tr_v())
608 {
610 Policy::template MakeShuffledBRegTileDistribution<Problem>());
611 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
612 Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
613 }
614 else
615 {
616 Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
617 }
618 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
619 }
620 // 2
621 {
623 Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
624 Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
625 block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
626 }
627 // 1
628 {
629 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
630 __builtin_amdgcn_sched_barrier(0);
631 }
632 }
633 else if(TailNum == TailNumber::Two)
634 {
635 // 2
636 {
638 Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
639 Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
640 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
641 static_for<0, 8, 1>{}([&](auto) {
642 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
643 __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
644 });
645 __builtin_amdgcn_sched_barrier(0);
646 }
647 // 1
648 {
649 block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
650 __builtin_amdgcn_sched_barrier(0);
651 }
652 }
653 else if(TailNum == TailNumber::One)
654 {
656 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
657 __builtin_amdgcn_sched_barrier(0);
658 }
659 return c_block_tile;
660 }
661 };
662
663 public:
664 template <typename AsDramBlockWindowTmp,
665 typename BsDramBlockWindowTmp,
666 typename AElementFunction,
667 typename BElementFunction,
668 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
670 bool>* = nullptr>
671 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
672 const AElementFunction& a_element_func,
673 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
674 const BElementFunction& b_element_func,
675 index_t num_loop,
676 void* p_smem_0,
677 void* p_smem_1) const
678 {
679 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
680 a_dram_block_window_tmp,
681 a_element_func,
682 b_dram_block_window_tmp,
683 b_element_func,
684 num_loop,
685 p_smem_0,
686 p_smem_1);
687 }
688
689 template <typename AsDramBlockWindowTmp,
690 typename BsDramBlockWindowTmp,
691 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
693 bool>* = nullptr>
694 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
695 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
696 const index_t num_loop,
697 void* __restrict__ p_smem_0,
698 void* __restrict__ p_smem_1) const
699 {
700 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
701 a_dram_block_window_tmp,
702 [](auto& e, const ADataType& a) { e = a; },
703 b_dram_block_window_tmp,
704 [](auto& e, const BDataType& b) { e = b; },
705 num_loop,
706 p_smem_0,
707 p_smem_1);
708 }
709
710 template <typename AsDramBlockWindowTmp,
711 typename BsDramBlockWindowTmp,
712 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
714 bool>* = nullptr>
715 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
716 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
717 index_t num_loop,
718 bool has_hot_loop,
719 TailNumber tail_number,
720 void* __restrict__ p_smem_0,
721 void* __restrict__ p_smem_1) const
722 {
723 const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
724 constexpr bool hot_loop = hot_loop_.value;
725 constexpr auto tail_num = tail_num_.value;
726 constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
727 return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
728 a_dram_block_window_tmp,
730 b_dram_block_window_tmp,
732 num_loop,
733 p_smem_0,
734 p_smem_1);
735 };
736 return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
737 }
738
739 template <typename ADramBlockWindowTmp,
740 typename BDramBlockWindowTmp,
741 typename AElementFunction,
742 typename BElementFunction,
743 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
745 bool>* = nullptr>
746 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
747 const AElementFunction& a_element_func,
748 const BDramBlockWindowTmp& b_dram_block_window_tmp,
749 const BElementFunction& b_element_func,
750 index_t num_loop,
751 void* p_smem_0,
752 void* p_smem_1) const
753 {
754 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
755 a_element_func,
756 ck_tile::make_tuple(b_dram_block_window_tmp),
757 b_element_func,
758 num_loop,
759 p_smem_0,
760 p_smem_1);
761 }
762
763 template <typename ADramBlockWindowTmp,
764 typename BDramBlockWindowTmp,
765 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
767 bool>* = nullptr>
768 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
769 const BDramBlockWindowTmp& b_dram_block_window_tmp,
770 const index_t num_loop,
771 void* __restrict__ p_smem_0,
772 void* __restrict__ p_smem_1) const
773 {
774 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
775 ck_tile::make_tuple(b_dram_block_window_tmp),
776 num_loop,
777 p_smem_0,
778 p_smem_1);
779 }
780
781 template <typename ADramBlockWindowTmp,
782 typename BDramBlockWindowTmp,
783 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
785 bool>* = nullptr>
786 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
787 const BDramBlockWindowTmp& b_dram_block_window_tmp,
788 index_t num_loop,
789 bool has_hot_loop,
790 TailNumber tail_number,
791 void* __restrict__ p_smem_0,
792 void* __restrict__ p_smem_1) const
793 {
794 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
795 ck_tile::make_tuple(b_dram_block_window_tmp),
796 num_loop,
797 has_hot_loop,
798 tail_number,
799 p_smem_0,
800 p_smem_1);
801 }
802};
803} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ One
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:27
@ Two
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:28
@ Three
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:29
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access >={}, bool_constant< oob_conditional_check >={})
Load tile with elementwise function.
Definition load_tile.hpp:41
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
ck_tile::element_wise::PassThrough PassThrough
Definition grouped_convolution_utils.hpp:47
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
constexpr bool is_tile_window_linear_v
Helper variable template to check if a type is a linear tile window.
Definition tile_window_linear.hpp:1119
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:16
static CK_TILE_HOST_DEVICE constexpr bool BlockHasHotloop(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:23
static CK_TILE_HOST_DEVICE constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:28
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:17
static constexpr bool UsePersistentKernel
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:21
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool has_hot_loop, TailNumber tail_number)
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:46
static constexpr index_t GlobalBufferNum
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:19
static constexpr index_t PrefillStages
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:18
static CK_TILE_DEVICE constexpr auto HotLoopScheduler()
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:205
PipelineImplBase Base
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:203
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *__restrict__ p_smem_0, void *__restrict__ p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:272
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:197
Compute optimized pipeline version 4.
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:105
static constexpr index_t Preshuffle
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:166
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:143
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, void *__restrict__ p_smem_0, void *__restrict__ p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:786
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:185
static constexpr index_t GetVectorSizeB()
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:151
number< 0 > I0
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:135
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:112
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, const index_t num_loop, void *__restrict__ p_smem_0, void *__restrict__ p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:768
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem_0, void *p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:671
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:134
number< 1 > I1
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:136
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:116
static constexpr index_t GetSmemPackB()
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:158
static constexpr auto is_b_load_tr_v
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:173
static constexpr auto TailNum
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:169
static constexpr index_t GetSmemPackA()
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:157
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:142
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem_0, void *p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:746
static CK_TILE_HOST_DEVICE constexpr auto IsTransposeC()
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:190
static constexpr bool HasHotLoop
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:168
static constexpr bool kPadK
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:162
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:124
static constexpr index_t GetVectorSizeA()
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:146
static constexpr auto is_a_load_tr_v
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:172
static constexpr index_t GetVectorSizeC()
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:155
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const index_t num_loop, void *__restrict__ p_smem_0, void *__restrict__ p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:694
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const BsDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, void *__restrict__ p_smem_0, void *__restrict__ p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:715
remove_cvref_t< std::tuple_element_t< 0, AsLayout > > ALayout
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:121
remove_cvref_t< typename Problem::BElementWise > BElementWise
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:119
static constexpr index_t BlockSize
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:139
static CK_TILE_HOST const std::string GetName()
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:175
static constexpr bool DoubleSmemBuffer
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:164
number< 2 > I2
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:137
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:111
remove_cvref_t< typename Problem::AElementWise > AElementWise
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:118
static constexpr index_t NumWaveGroups
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:165
static constexpr index_t APackedSize
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:129
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:115
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:125
static constexpr index_t BPackedSize
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:131
BaseGemmPipelineAgBgCrCompV4< Problem > Base
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:106
remove_cvref_t< std::tuple_element_t< 0, BsLayout > > BLayout
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:122
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:114
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:141
GemmPipelineAgBgCrImplBase< Problem, Policy > PipelineImplBase
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:107
static constexpr auto Scheduler
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:170
static constexpr bool kPadN
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:161
static constexpr bool kPadM
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:160
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:109
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition gemm_pipeline_ag_bg_cr_comp_v4.hpp:110
Definition gemm_pipeline_ag_bg_cr_base.hpp:13
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:22
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:26
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile &dst_block_tile, const SrcTileWindow &lds_tile_window, bool_constant< LoadTranspose >={}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:73
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:25
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile, const ElementFunction &element_func) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:57
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:20
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:27
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/utility/functional.hpp:43