blockwise_gemm_pipeline_wmmaops_v1.hpp Source File

blockwise_gemm_pipeline_wmmaops_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_wmmaops_v1.hpp Source File
blockwise_gemm_pipeline_wmmaops_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 1
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 0
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeTypeA,
21 typename ComputeTypeB,
22 typename AccDataType,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
27 index_t MPerBlock,
28 index_t NPerBlock,
29 index_t KPerBlock,
30 index_t MPerWmma,
31 index_t NPerWmma,
32 index_t MRepeat,
33 index_t NRepeat,
34 index_t KPack,
35 bool TransposeC = false>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeTypeA,
44 typename ComputeTypeB,
45 typename AccDataType,
46 typename AWmmaTileDesc,
47 typename BWmmaTileDesc,
48 index_t ABlockTransferSrcScalarPerVector,
49 index_t BBlockTransferSrcScalarPerVector,
50 index_t MPerBlock,
51 index_t NPerBlock,
52 index_t KPerBlock,
53 index_t MPerWmma,
54 index_t NPerWmma,
55 index_t MRepeat,
56 index_t NRepeat,
57 index_t KPack,
58 bool TransposeC>
60 BlockSize,
61 ADataType,
62 BDataType,
63 ComputeTypeA,
64 ComputeTypeB,
65 AccDataType,
66 AWmmaTileDesc,
67 BWmmaTileDesc,
68 ABlockTransferSrcScalarPerVector,
69 BBlockTransferSrcScalarPerVector,
70 MPerBlock,
71 NPerBlock,
72 KPerBlock,
73 MPerWmma,
74 NPerWmma,
75 MRepeat,
76 NRepeat,
77 KPack,
78 TransposeC>
80 ADataType,
81 BDataType,
82 ComputeTypeA,
83 ComputeTypeB,
84 AccDataType,
85 AWmmaTileDesc,
86 BWmmaTileDesc,
87 ABlockTransferSrcScalarPerVector,
88 BBlockTransferSrcScalarPerVector,
89 MPerBlock,
90 NPerBlock,
91 KPerBlock,
92 MPerWmma,
93 NPerWmma,
94 MRepeat,
95 NRepeat,
96 KPack,
97 TransposeC>
98{
100 ADataType,
101 BDataType,
102 ComputeTypeA,
103 ComputeTypeB,
104 AccDataType,
105 AWmmaTileDesc,
106 BWmmaTileDesc,
107 ABlockTransferSrcScalarPerVector,
108 BBlockTransferSrcScalarPerVector,
109 MPerBlock,
110 NPerBlock,
111 KPerBlock,
112 MPerWmma,
113 NPerWmma,
114 MRepeat,
115 NRepeat,
116 KPack,
117 TransposeC>;
118 using Base::I0;
119 using Base::I1;
120 using Base::WaveSize;
121 using typename Base::HotLoopInstList;
122
123 using Base::A_K1;
124 using Base::A_KRow;
125 using Base::B_K1;
126 using Base::B_KRow;
127 using Base::KRepeat;
128 using Base::WmmaK;
129
130 using Base::wmma_gemm;
131
133 using Base::
134 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
136 using Base::
137 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
138
141
142 using typename Base::Empty;
143
144 static constexpr index_t PrefetchStages = 1;
145 static constexpr index_t PrefillStages = 1;
146 static constexpr index_t GlobalBufferNum = 1;
147
148 static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
149
151 {
152 ignore = num_loop;
153 return TailNumber::Full;
154 }
155
156 template <bool HasMainLoop,
157 TailNumber TailNum,
158 typename AGridDesc,
159 typename ABlockDesc,
160 typename ABlockTransfer,
161 typename AGridBuffer,
162 typename ABlockBuffer,
163 typename ABlockTransferStep,
164 typename BGridDesc,
165 typename BBlockDesc,
166 typename BBlockTransfer,
167 typename BGridBuffer,
168 typename BBlockBuffer,
169 typename BBlockTransferStep,
170 typename CThreadBuffer,
171 typename BScaleStruct>
172 __device__ void Run(const AGridDesc& a_grid_desc,
173 const ABlockDesc& a_block_desc,
174 ABlockTransfer& a_blockwise_copy,
175 const AGridBuffer& a_grid_buf,
176 ABlockBuffer& a_block_buf,
177 const ABlockTransferStep& a_block_copy_step,
178 const BGridDesc& b_grid_desc,
179 const BBlockDesc& b_block_desc,
180 BBlockTransfer& b_blockwise_copy,
181 const BGridBuffer& b_grid_buf,
182 BBlockBuffer& b_block_buf,
183 const BBlockTransferStep& b_block_copy_step,
184 CThreadBuffer& c_thread_buf,
185 // BScaleThreadCopy
186 BScaleStruct& b_scale_struct,
187 index_t num_loop,
188 index_t num_loop_per_scale) const
189 {
191 a_thread_desc_.GetElementSpaceSize());
193 b_thread_desc_.GetElementSpaceSize());
194
195 // Global prefetch 1
196 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
197 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
198
199 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
200 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
201
202 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
203
204 // Local prefill 1
205 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
206 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
207
208 // Initialize C
209 c_thread_buf.Clear();
210
211 auto blockwise_gemm_func = [&]() {
212 static_for<0, KRepeat, 1>{}([&](auto k0) {
213 static_for<0, MRepeat, 1>{}([&](auto m0) {
214 a_thread_copy_.Run(
217 a_block_buf,
219 make_tuple(I0, I0, I0, I0, I0, I0),
220 a_thread_buf);
221
222 if constexpr(m0 == I0)
223 {
224 if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
225 {
226 static_for<0, NRepeat, 1>{}([&](auto n0) {
227 b_thread_copy_.Run(
231 b_block_buf,
233 make_tuple(I0, n0, I0, I0, I0, I0),
234 b_thread_buf);
235 });
236 }
237 else
238 {
239 static_for<0, NRepeat, 1>{}([&](auto n0) {
240 b_thread_copy_.Run(
244 b_block_buf,
245 b_scale_struct.b_scale_thread_bufs(
246 I0)[Number<n0 * BScaleStruct::num_scale_k_block +
247 k0 / BScaleStruct::num_scale_krepeat>{}],
249 make_tuple(I0, n0, I0, I0, I0, I0),
250 b_thread_buf);
251 });
252 }
253 }
254
255 static_for<0, NRepeat, 1>{}([&](auto n0) {
256 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
257 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
258
259 static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
260 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
261 a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
263 });
264 static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
265 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
266 b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
267 Number<ik / B_K1>{}, n0, I0, I0, I0, Number<ik % B_K1>{}))>{}];
268 });
269
270 using wmma_input_type_a =
271 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
272 using wmma_input_type_b =
273 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
274
275 constexpr index_t c_offset =
276 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
277
278 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
279 b_thread_vec.template AsType<wmma_input_type_b>(),
280 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
281 });
282 });
283 });
284 };
285
286 // main body
287 if constexpr(HasMainLoop)
288 {
289 index_t i = 0;
290 do
291 {
292 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
293 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
294
295 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
296 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
297
299 blockwise_gemm_func();
300
302 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
303 if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
304 {
306 }
307 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
308 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
309
310 constexpr index_t num_ds_write_inst =
312
313 constexpr index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
316 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
317 });
318 static_for<0, KRepeat, 1>{}([&](auto) {
319 static_for<0, MRepeat, 1>{}([&](auto m0) {
320 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
321 if constexpr(m0 == I0)
322 {
323 static_for<0, NRepeat, 1>{}([&](auto) {
324 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
325 });
326 }
327 static_for<0, NRepeat, 1>{}([&](auto) {
328 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
329 });
330 });
331 });
333 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
334 });
335
336 i += 1;
337 } while(i < (num_loop - 1));
338 }
339
340 // tail
341 if constexpr(TailNum == TailNumber::Full)
342 {
344 blockwise_gemm_func();
345 }
346 }
347
348 protected:
349 // A[MRepeat, I1, I1, KPack]
352
353 // B[NRepeat, N1, N2, KPack]
356
359 ComputeTypeA,
361 decltype(a_thread_desc_),
362 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
364 5,
365 A_K1,
366 A_K1>;
367
370 ComputeTypeB,
372 decltype(b_thread_desc_),
373 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
375 5,
376 B_K1,
377 B_K1>;
378
381 using Base::c_thread_desc_;
382};
383
384template <index_t BlockSize,
385 typename ADataType,
386 typename BDataType,
387 typename ComputeTypeA,
388 typename ComputeTypeB,
389 typename AccDataType,
390 typename AWmmaTileDesc,
391 typename BWmmaTileDesc,
392 index_t ABlockTransferSrcScalarPerVector,
393 index_t BBlockTransferSrcScalarPerVector,
394 index_t MPerBlock,
395 index_t NPerBlock,
396 index_t KPerBlock,
397 index_t MPerWmma,
398 index_t NPerWmma,
399 index_t MRepeat,
400 index_t NRepeat,
401 index_t KPack,
402 bool TransposeC>
404 BlockSize,
405 ADataType,
406 BDataType,
407 ComputeTypeA,
408 ComputeTypeB,
409 AccDataType,
410 AWmmaTileDesc,
411 BWmmaTileDesc,
412 ABlockTransferSrcScalarPerVector,
413 BBlockTransferSrcScalarPerVector,
414 MPerBlock,
415 NPerBlock,
416 KPerBlock,
417 MPerWmma,
418 NPerWmma,
419 MRepeat,
420 NRepeat,
421 KPack,
422 TransposeC>
424 ADataType,
425 BDataType,
426 ComputeTypeA,
427 ComputeTypeB,
428 AccDataType,
429 AWmmaTileDesc,
430 BWmmaTileDesc,
431 ABlockTransferSrcScalarPerVector,
432 BBlockTransferSrcScalarPerVector,
433 MPerBlock,
434 NPerBlock,
435 KPerBlock,
436 MPerWmma,
437 NPerWmma,
438 MRepeat,
439 NRepeat,
440 KPack,
441 TransposeC>
442{
444 ADataType,
445 BDataType,
446 ComputeTypeA,
447 ComputeTypeB,
448 AccDataType,
449 AWmmaTileDesc,
450 BWmmaTileDesc,
451 ABlockTransferSrcScalarPerVector,
452 BBlockTransferSrcScalarPerVector,
453 MPerBlock,
454 NPerBlock,
455 KPerBlock,
456 MPerWmma,
457 NPerWmma,
458 MRepeat,
459 NRepeat,
460 KPack,
461 TransposeC>;
462 using Base::I0;
463 using Base::I1;
464
465 using Base::A_K1;
466 using Base::A_KRow;
467 using Base::B_K1;
468 using Base::B_KRow;
469 using Base::KRepeat;
470 using Base::WmmaK;
471
472 using Base::wmma_gemm;
473
475 using Base::
476 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
478 using Base::
479 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
480
483
484 using typename Base::Empty;
485
488
489 static constexpr index_t PrefetchStages = 1;
490 static constexpr index_t PrefillStages = 1;
491 static constexpr index_t GlobalBufferNum = 1;
492
493 static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
494
496 {
497 ignore = num_loop;
498 return TailNumber::Full;
499 }
500
501 template <bool HasMainLoop,
502 TailNumber TailNum,
503 typename AGridDesc,
504 typename ABlockDesc,
505 typename ABlockTransfer,
506 typename AGridBuffer,
507 typename ABlockBuffer,
508 typename ABlockTransferStep,
509 typename BGridDesc,
510 typename BBlockDesc,
511 typename BBlockTransfer,
512 typename BGridBuffer,
513 typename BBlockBuffer,
514 typename BBlockTransferStep,
515 typename CThreadBuffer,
516 typename BScaleStruct>
517 __device__ void Run(const AGridDesc& a_grid_desc,
518 const ABlockDesc& a_block_desc,
519 ABlockTransfer& a_blockwise_copy,
520 const AGridBuffer& a_grid_buf,
521 ABlockBuffer& a_block_buf,
522 const ABlockTransferStep& a_block_copy_step,
523 const BGridDesc& b_grid_desc,
524 const BBlockDesc& b_block_desc,
525 BBlockTransfer& b_blockwise_copy,
526 const BGridBuffer& b_grid_buf,
527 BBlockBuffer& b_block_buf,
528 const BBlockTransferStep& b_block_copy_step,
529 CThreadBuffer& c_thread_buf,
530 // BScaleThreadCopy
531 BScaleStruct& b_scale_struct,
532 index_t num_loop,
533 index_t num_loop_per_scale) const
534 {
536 a_thread_desc_.GetElementSpaceSize());
538 b_thread_desc_.GetElementSpaceSize());
539
540 // Global prefetch 1
541 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
542 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
543
544 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
545 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
546
547 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
548
549 // Local prefill 1
550 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
551 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
552
553 // Initialize C
554 c_thread_buf.Clear();
555
556 auto blockwise_gemm_func = [&]() {
557 static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
558 static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
559 static_for<0, MRepeat, 1>{}([&](auto m0) {
560 a_thread_copy_.Run(
562 make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
563 m0,
564 I0,
565 I0,
566 I0,
567 I0),
568 a_block_buf,
570 make_tuple(I0, m0, k0_inner, I0, I0, I0),
571 a_thread_buf);
572 });
573 if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
574 {
575 static_for<0, NRepeat, 1>{}([&](auto n0) {
576 b_thread_copy_.Run(
578 make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
579 n0,
580 I0,
581 I0,
582 I0,
583 I0),
584 b_block_buf,
586 make_tuple(I0, n0, k0_inner, I0, I0, I0),
587 b_thread_buf);
588 });
589 }
590 else
591 {
592 static_for<0, NRepeat, 1>{}([&](auto n0) {
593 b_thread_copy_.Run(
595 make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
596 n0,
597 I0,
598 I0,
599 I0,
600 I0),
601 b_block_buf,
602 b_scale_struct.b_scale_thread_bufs(I0)[Number<
603 n0 * BScaleStruct::num_scale_k_block +
604 (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
606 make_tuple(I0, n0, k0_inner, I0, I0, I0),
607 b_thread_buf);
608 });
609 }
610 });
611
612 __builtin_amdgcn_sched_barrier(0);
613 // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
614 // but except the first, as we can shorten non-MAC cluster a bit and there's no
615 // observable negative impact. The desired effect is waves in a workgroup
616 // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
617 // resource from other workgroups and reducing the chance of latency hiding by
618 // waiting for the rest of the workgroup at the eventual sync point.
619 if constexpr(k0_offset != 0 || KRepeat == 1)
620 {
621 __builtin_amdgcn_s_barrier();
622 __builtin_amdgcn_sched_barrier(0);
623 }
624 static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
625 static_for<0, MRepeat, 1>{}([&](auto m0) {
626 static_for<0, NRepeat, 1>{}([&](auto n0) {
627 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
628 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
629
630 static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
631 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
632 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
634 m0,
635 k0_inner,
636 I0,
637 I0,
638 Number<ik % A_K1>{}))>{}];
639 });
640 static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
641 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
642 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
644 n0,
645 k0_inner,
646 I0,
647 I0,
648 Number<ik % B_K1>{}))>{}];
649 });
650
651 using wmma_input_type_a =
652 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
653 using wmma_input_type_b =
654 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
655
656 constexpr index_t c_offset =
657 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
658
659 // The block_sync_lds() here performs double duty:
660 // A) safeguard against data hazard.
661 // B) reduce VMEM FIFO congestion by applying small delays to
662 // different wavefronts.
663 // It is performed near the end of MAC cluster to minimize lgkmcnt
664 // penalty
665 if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
666 n0 == NRepeat - 1)
667 {
668 __builtin_amdgcn_sched_barrier(0);
670 __builtin_amdgcn_sched_barrier(0);
671 }
672 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
673 b_thread_vec.template AsType<wmma_input_type_b>(),
674 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
675 if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
676 {
677 __builtin_amdgcn_sched_barrier(0);
678 __builtin_amdgcn_s_setprio(1);
679 __builtin_amdgcn_sched_barrier(0);
680 }
681 });
682 });
683 });
684 __builtin_amdgcn_sched_barrier(0);
685 __builtin_amdgcn_s_setprio(0);
686 __builtin_amdgcn_sched_barrier(0);
687 });
688 };
689
690 // main body
691 if constexpr(HasMainLoop)
692 {
693 index_t i = 0;
694 do
695 {
696 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
697 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
698
699 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
700 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
701
703 blockwise_gemm_func();
704
705 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
706 if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
707 {
709 }
710 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
711 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
712
713 i += 1;
714 } while(i < (num_loop - 1));
715 }
716
717 // tail
718 if constexpr(TailNum == TailNumber::Full)
719 {
721 blockwise_gemm_func();
722 }
723 }
724
725 protected:
726 static constexpr auto a_thread_desc_ =
730 I1,
731 I1,
732 Number<A_K1>{}),
734 Number<KPack / A_KRow>{},
735 Number<KPack / A_KRow * MRepeat>{},
736 I0,
737 I0,
738 I1));
739
740 static constexpr auto b_thread_desc_ =
744 I1,
745 I1,
746 Number<B_K1>{}),
748 Number<KPack / B_KRow>{},
749 Number<KPack / B_KRow * NRepeat>{},
750 I0,
751 I0,
752 I1));
753
756 ComputeTypeA,
758 decltype(a_thread_desc_),
759 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
761 5,
762 A_K1,
763 A_K1>;
764
767 ComputeTypeB,
769 decltype(b_thread_desc_),
770 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
772 5,
773 B_K1,
774 B_K1>;
775
778 using Base::c_thread_desc_;
779};
780
781} // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition ck.hpp:209
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmWmmaops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerWmma, NPerWmma, wmma_gemm.wmma_instr.k_per_wmma > HotLoopInstList
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:70
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition blockwise_gemm_pipeline_wmmaops_v1.hpp:517
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_wmmaops_v1.hpp:754
BlockwiseGemmWmmaops_pipeline_base< BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC > Base
Definition blockwise_gemm_pipeline_wmmaops_v1.hpp:443
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), Sequence< KPack/B_K1/B_KRow, 1, 1, 1, 1, B_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_wmmaops_v1.hpp:765
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition blockwise_gemm_pipeline_wmmaops_v1.hpp:172
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), Sequence< KPack/B_K1/B_KRow, 1, 1, 1, 1, B_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_wmmaops_v1.hpp:368
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_wmmaops_v1.hpp:357
BlockwiseGemmWmmaops_pipeline_base< BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC > Base
Definition blockwise_gemm_pipeline_wmmaops_v1.hpp:99
Definition blockwise_gemm_pipeline_wmmaops_v1.hpp:37
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition dtype_vector.hpp:10