universal_gemm_kernel.hpp Source File

universal_gemm_kernel.hpp Source File#

Composable Kernel: universal_gemm_kernel.hpp Source File
universal_gemm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <string>
8
9#include "ck_tile/core.hpp"
16
17namespace ck_tile {
18
30template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
32{
33 CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
34 const std::array<const void*, NumBTensor>& bs_ptr_,
35 const std::array<const void*, NumDTensor>& ds_ptr_,
36 void* e_ptr_,
37 index_t k_batch_,
38 index_t M_,
39 index_t N_,
40 index_t K_,
41 const std::array<index_t, NumATensor>& stride_As_,
42 const std::array<index_t, NumBTensor>& stride_Bs_,
43 const std::array<index_t, NumDTensor>& stride_Ds_,
44 index_t stride_E_)
45 : as_ptr(as_ptr_),
46 bs_ptr(bs_ptr_),
47 ds_ptr(ds_ptr_),
48 e_ptr(e_ptr_),
49 M(M_),
50 N(N_),
51 K(K_),
52 stride_As(stride_As_),
53 stride_Bs(stride_Bs_),
54 stride_Ds(stride_Ds_),
55 stride_E(stride_E_),
56 k_batch(k_batch_)
57 {
58 }
59
60 const std::array<const void*, NumATensor> as_ptr;
61 const std::array<const void*, NumBTensor> bs_ptr;
62 const std::array<const void*, NumDTensor> ds_ptr;
63 union
64 {
65 void* e_ptr;
66 void* c_ptr;
67 };
71 const std::array<index_t, NumATensor> stride_As;
72 const std::array<index_t, NumBTensor> stride_Bs;
73 const std::array<index_t, NumDTensor> stride_Ds;
74 union
75 {
78 };
79
81};
82
84template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
86{
88 const std::array<const void*, NumATensor> as_ptr;
90 const std::array<const void*, NumBTensor> bs_ptr;
92 const std::array<const void*, NumDTensor> ds_ptr;
94 void* e_ptr;
103 std::array<index_t, NumATensor> stride_As;
106 std::array<index_t, NumBTensor> stride_Bs;
109 std::array<index_t, NumDTensor> stride_Ds;
114};
115
152template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
154{
158
159 static constexpr bool ADataTypeIsTuple =
161 static constexpr bool BDataTypeIsTuple =
163 static constexpr bool DDataTypeIsTuple =
165 static constexpr bool ALayoutIsTuple =
167 static constexpr bool BLayoutIsTuple =
169 static constexpr bool DLayoutIsTuple =
171
172 using AsLayout = std::conditional_t<ALayoutIsTuple,
175 using BsLayout = std::conditional_t<BLayoutIsTuple,
178
179 using DsLayout = std::conditional_t<DLayoutIsTuple,
182
183 using AsDataType = std::conditional_t<ADataTypeIsTuple,
186
187 using BsDataType = std::conditional_t<BDataTypeIsTuple,
190
192 std::conditional_t<DDataTypeIsTuple,
195
198
201
202 static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
203
204 // Get the persistent kernel if the pipeline has it available
206 {
207 template <typename T>
208 using has_persistent_type = decltype(T::UsePersistentKernel);
209
210 static constexpr bool value = []() {
212 return GemmPipeline::UsePersistentKernel;
213 else
214 return false;
215 }();
216 };
218
219 // Check if TilePartitioner has GetOutputOffset method with kargs and k_id
221 {
222 template <typename T, typename KernelArgs>
224 decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
225
226 static constexpr bool value = []() {
228 return true;
229 else
230 return false;
231 }();
232 };
235
236 static constexpr auto I0 = number<0>();
237 static constexpr auto I1 = number<1>();
238 static constexpr auto I2 = number<2>();
239 static constexpr auto I3 = number<3>{};
240
241 static constexpr index_t NumATensor = AsDataType::size();
242 static constexpr index_t NumBTensor = BsDataType::size();
243 static constexpr index_t NumDTensor = DsDataType::size();
244
247
248 static_assert(AsLayout::size() == AsDataType::size(),
249 "The size of AsLayout and AsDataType should be the same");
250
251 static_assert(BsLayout::size() == BsDataType::size(),
252 "The size of BsLayout and BsDataType should be the same");
253
254 static_assert(DsLayout::size() == DsDataType::size(),
255 "The size of DsLayout and DsDataType should be the same");
256
258 UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
259
260 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
261 {
262 // clang-format off
263 return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
264 // clang-format on
265 }
266
267 CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
268 {
269 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
270 }
271
278 CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
279 {
281 const auto kernel = kentry<1, Kernel, KernelArgs>;
282 int occupancy;
284 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
285
286 const int grid_size = get_available_compute_units(s) * occupancy;
287 return dim3(grid_size, 1, 1);
288 }
289
291 {
293 {
294 return dim3(kBlockSize / 2);
295 }
296 else
297 {
298 return dim3(kBlockSize);
299 }
300 }
301
302 CK_TILE_HOST static constexpr KernelArgs
304 {
305 return KernelArgs{hostArgs.as_ptr,
306 hostArgs.bs_ptr,
307 hostArgs.ds_ptr,
308 hostArgs.e_ptr,
309 hostArgs.M,
310 hostArgs.N,
311 hostArgs.K,
312 hostArgs.stride_As,
313 hostArgs.stride_Bs,
314 hostArgs.stride_Ds,
315 hostArgs.stride_E,
316 hostArgs.k_batch};
317 }
318
320 {
321 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
322 }
323
325 {
326 __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
327 {
328 constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
329 const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
330 const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
331
332 static_for<0, NumATensor, 1>{}([&](auto index) {
333 using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
334 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
335 {
336 as_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
337 }
338 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
339 {
340 as_k_split_offset[index] =
341 amd_wave_read_first_lane(k_id * KRead * kargs.stride_As[index]);
342 }
343 });
344
345 static_for<0, NumBTensor, 1>{}([&](auto index) {
346 using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
347 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
348 {
349 bs_k_split_offset[index] =
350 amd_wave_read_first_lane(k_id * KRead * kargs.stride_Bs[index]);
351 }
352 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
353 {
354 bs_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
355 }
356 });
357
358 if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
359 {
361 }
362 else
363 {
364 splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
365 }
366 }
367
368 std::array<index_t, NumATensor> as_k_split_offset;
369 std::array<index_t, NumBTensor> bs_k_split_offset;
371 };
372
374 {
375 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
377 {
378 if(kargs.k_batch != 1)
379 {
380 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
381 {
382 CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
383 }
384 return false;
385 }
386 }
387
388 const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
389 : GemmPipeline::template GetVectorSizeA<false>();
390 bool AsTesnorIsValid = {true};
391 static_for<0, NumATensor, 1>{}([&](auto index) {
392 using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
393 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
394 {
395 if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
396 GemmPipeline::kPadK == false)
397 {
398 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
399 {
401 "Can't support K that is not a multiple of k_batch * KPerBlock "
402 "without padding!");
403 }
404 AsTesnorIsValid = false;
405 }
406 if(kargs.K % vectorSizeA != 0)
407 {
408 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
409 {
410 CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
411 }
412 AsTesnorIsValid = false;
413 }
414 }
415 else
416 {
417 if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
418 {
419 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
420 {
422 "Can't support M that is not a multiple of MPerBlock without padding!");
423 }
424 AsTesnorIsValid = false;
425 }
426 if(kargs.M % vectorSizeA != 0)
427 {
428 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
429 {
430 CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
431 }
432 AsTesnorIsValid = false;
433 }
434 }
435 });
436
437 bool BsTesnorIsValid = {true};
438 const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
439 : GemmPipeline::template GetVectorSizeB<false>();
440 static_for<0, NumBTensor, 1>{}([&](auto index) {
441 using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
442 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
443 {
444 if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
445 {
446 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
447 {
449 "Can't support N that is not a multiple of NPerBlock without padding!");
450 }
451 BsTesnorIsValid = false;
452 }
453 if(kargs.N % vectorSizeB != 0)
454 {
455 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
456 {
457 CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
458 }
459 BsTesnorIsValid = false;
460 }
461 }
462 else
463 {
464 if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
465 GemmPipeline::kPadK == false)
466 {
467 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
468 {
470 "Can't support K that is not a multiple of k_batch * KPerBlock "
471 "without padding!");
472 }
473 BsTesnorIsValid = false;
474 }
475 if(kargs.K % vectorSizeB != 0)
476 {
477 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
478 {
479 CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
480 }
481 BsTesnorIsValid = false;
482 }
483 }
484 });
485
486 bool DTesnorIsValid = {true};
487 static_for<0, NumDTensor, 1>{}([&](auto index) {
488 using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
489 if(std::is_same_v<DiLayout, CLayout> == false)
490 {
491 DTesnorIsValid = false;
492 }
493 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
494 {
495 if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
496 {
497 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
498 {
499 CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
500 "NPerBlock without padding!");
501 }
502 DTesnorIsValid = false;
503 }
504 if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
505 {
506 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
507 {
508 CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
509 }
510 DTesnorIsValid = false;
511 }
512 }
513 else
514 {
515 if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
516 {
517 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
518 {
519 CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
520 "MPerBlock without padding!");
521 }
522 DTesnorIsValid = false;
523 }
524 if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
525 {
526 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
527 {
528 CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
529 }
530 DTesnorIsValid = false;
531 }
532 }
533 });
534
535 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
536 {
537 if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
538 {
539 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
540 {
542 "Can't support N that is not a multiple of NPerBlock without padding!");
543 }
544 return false;
545 }
546 if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
547 {
548 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
549 {
550 CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
551 }
552 return false;
553 }
554 }
555 else
556 {
557 if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
558 {
559 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
560 {
562 "Can't support M that is not a multiple of MPerBlock without padding!");
563 }
564 return false;
565 }
566 if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
567 {
568 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
569 {
570 CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
571 }
572 return false;
573 }
574 }
575 return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
576 }
577
578 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
579 CK_TILE_DEVICE static auto
580 MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
581 const std::array<const BDataType*, NumBTensor>& bs_ptr,
582 const std::array<const void*, NumDTensor>& ds_ptr,
583 EDataType* e_ptr,
584 const KernelArgs& kargs,
585 const index_t k_size)
586 {
587 static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
588
589 const auto& as_tensor_view = generate_tuple(
590 [&](auto i) {
591 using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
592 using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
593 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
594 {
596 static_cast<const AiDataType*>(as_ptr[i]),
597 make_tuple(kargs.M, k_size),
598 make_tuple(kargs.stride_As[i], 1),
599 number<GemmPipeline::GetVectorSizeA()>{},
600 number<1>{});
601 }
602 else
603 {
605 static_cast<const AiDataType*>(as_ptr[i]),
606 make_tuple(k_size, kargs.M),
607 make_tuple(kargs.stride_As[i], 1),
608 number<GemmPipeline::GetVectorSizeA()>{},
609 number<1>{});
610 }
611 },
613
614 const auto& bs_tensor_view = generate_tuple(
615 [&](auto i) {
616 using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
617 using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
618 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
619 {
620 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
621 {
622 constexpr index_t K1 = GemmPipeline::GetSmemPackB();
623 const index_t K0 = k_size / K1;
624 constexpr index_t VectorSizeB =
625 std::min(K1, GemmPipeline::GetVectorSizeB());
626 const auto b_k0_n_k1_desc =
628 make_tuple(kargs.N * K1, K1, I1),
630 number<1>{});
631 const auto b_n_k_desc = transform_tensor_descriptor(
632 b_k0_n_k1_desc,
638 static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
639 }
640 else
641 {
643 bs_ptr[i],
644 make_tuple(k_size, kargs.N),
645 make_tuple(kargs.stride_Bs[i], 1),
646 number<GemmPipeline::GetVectorSizeB()>{},
647 number<1>{});
648 }
649 }
650 else
651 {
652 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
653 {
654 constexpr index_t K1 = GemmPipeline::GetSmemPackB();
655 const index_t K0 = k_size / K1;
656 constexpr index_t VectorSizeB =
657 std::min(K1, GemmPipeline::GetVectorSizeB());
658 const auto b_k0_n_k1_desc =
660 make_tuple(kargs.N * K1, K1, I1),
662 number<1>{});
663 const auto b_n_k_desc = transform_tensor_descriptor(
664 b_k0_n_k1_desc,
670 static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
671 }
672 else
673 {
674 if constexpr(GemmPipeline::Preshuffle)
675 {
676 index_t kFlatK =
677 GemmPipeline::BlockGemmShape::flatKPerWarp *
678 (k_size /
679 TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
680 index_t kFlatN = kargs.N * kargs.K / kFlatK;
681
683 bs_ptr[i],
684 make_tuple(kFlatN, kFlatK),
685 make_tuple(kFlatK, 1),
686 number<GemmPipeline::GetVectorSizeB()>{},
687 number<1>{});
688 }
689 else
690 {
692 bs_ptr[i],
693 make_tuple(kargs.N, k_size),
694 make_tuple(kargs.stride_Bs[i], 1),
695 number<GemmPipeline::GetVectorSizeB()>{},
696 number<1>{});
697 }
698 }
699 }
700 },
702
703 const auto& ds_tensor_view = generate_tuple(
704 [&](auto i) {
705 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
706 using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
707 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
708 {
710 static_cast<const DDataType_*>(ds_ptr[i]),
711 make_tuple(kargs.M, kargs.N),
712 make_tuple(kargs.stride_Ds[i], 1),
713 number<EpiloguePipeline::GetVectorSizeD(i)>{},
714 number<1>{});
715 }
716 else
717 {
719 static_cast<const DDataType_*>(ds_ptr[i]),
720 make_tuple(kargs.N, kargs.M),
721 make_tuple(kargs.stride_Ds[i], 1),
722 number<EpiloguePipeline::GetVectorSizeD(i)>{},
723 number<1>{});
724 }
725 },
727
728 // TODO: enable vector write for C in ColMajor
729 const auto& e_tensor_view = [&]() {
730 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
731 {
733 e_ptr,
734 make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm.
735 make_tuple(kargs.stride_E, 1),
736 number<EpiloguePipeline::GetVectorSizeC()>{},
737 number<1>{});
738 }
739 else
740 {
742 e_ptr,
743 make_tuple(kargs.M, kargs.N),
744 make_tuple(1, kargs.stride_E),
745 number<1>{},
746 number<1>{});
747 }
748 }();
749
750 return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
751 }
752
753 template <typename TensorView>
754 CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
755 {
756 const auto& as_pad_view = generate_tuple(
757 [&](auto i) {
758 const auto& a_tensor_view = views.at(I0);
759 using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
760 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
761 {
762 return pad_tensor_view(a_tensor_view[i],
766 }
767 else
768 {
769 return pad_tensor_view(a_tensor_view[i],
773 }
774 },
776
777 const auto& b_flat_pad_view = views.at(I1);
778
779 const auto& bs_pad_view = generate_tuple(
780 [&](auto i) {
781 const auto& b_tensor_view = views.at(I1);
782 using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
783 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
784 {
785 return pad_tensor_view(b_tensor_view[i],
789 }
790 else
791 {
792 return pad_tensor_view(b_tensor_view[i],
796 }
797 },
799
800 const auto& ds_pad_view = generate_tuple(
801 [&](auto i) {
802 const auto& d_tensor_view = views.at(I2);
803 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
804 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
805 {
806 return pad_tensor_view(d_tensor_view[i],
810 }
811 else
812 {
813 return pad_tensor_view(d_tensor_view[i],
817 }
818 },
820
821 // TODO vector write in for C in ColMajor
822 const auto& e_pad_view = [&]() {
823 const auto& e_tensor_view = views.at(I3);
824 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
825 {
826 return pad_tensor_view(e_tensor_view,
830 }
831 else
832 {
833 return pad_tensor_view(e_tensor_view,
837 }
838 }();
839
840 if constexpr(GemmPipeline::Preshuffle)
841 {
842 // For flatmm, we need to use the flat B tensor view
843 return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
844 }
845 else
846 {
847 return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
848 }
849 }
850
851 template <typename PadView>
852 CK_TILE_DEVICE static auto
853 MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
854 {
855 const auto& as_pad_view = views.at(I0);
856 const auto& bs_pad_view = views.at(I1);
857 const auto& ds_pad_view = views.at(I2);
858 const auto& e_pad_view = views.at(I3);
859
860 const auto& as_block_window = generate_tuple(
861 [&](auto i) {
862 using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
863 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
864 {
865 return make_tile_window(as_pad_view[i],
868 {i_m, 0});
869 }
870 else
871 {
872 return make_tile_window(as_pad_view[i],
875 {0, i_m});
876 }
877 },
879
880 const auto& bs_block_window = generate_tuple(
881 [&](auto i) {
882 using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
883 if constexpr(GemmPipeline::Preshuffle)
884 {
885 return make_tile_window(
886 bs_pad_view[i],
889 {static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)),
890 0});
891 }
892 else
893 {
894 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
895 {
896 return make_tile_window(bs_pad_view[i],
899 {i_n, 0});
900 }
901 else
902 {
903 return make_tile_window(bs_pad_view[i],
906 {0, i_n});
907 }
908 }
909 },
911
912 const auto ds_block_window = generate_tuple(
913 [&](auto i) {
914 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
915 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
916 {
917 return make_tile_window(ds_pad_view[i],
920 {i_m, i_n});
921 }
922 else
923 {
924 return make_tile_window(ds_pad_view[i],
927 {i_n, i_m});
928 }
929 },
931
932 auto e_block_window = make_tile_window(
933 e_pad_view,
935 {i_m, i_n});
936
937 return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
938 }
939
954 template <bool UseDefaultScheduler = true>
955 CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
956 const std::array<const BDataType*, NumBTensor>& bs_ptr,
957 const std::array<const void*, NumDTensor>& ds_ptr,
958 EDataType* e_ptr,
959 void* smem_ptr_0,
960 const KernelArgs& kargs,
961 const SplitKBatchOffset& splitk_batch_offset,
962 const index_t block_idx_m,
963 const index_t block_idx_n)
964 {
965 // Create Gemm tensor views, pad views and tile windows
966 const auto& gemm_tensor_views_tuple =
968 as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
969
970 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
971 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
972
973 const index_t num_loop =
974 amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
975
976 // Run GEMM cooperatively by whole workgroup.
977 const auto& as_block_window = gemm_tile_windows.at(I0);
978 const auto& bs_block_window = gemm_tile_windows.at(I1);
979 const auto& ds_block_window = gemm_tile_windows.at(I2);
980
981 const auto& c_block_tile = GemmPipeline{}.template operator()(
982 as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
983
984 if(UseDefaultScheduler || (get_warp_id() == 0))
985 {
986 // Run Epilogue Pipeline
987 auto& c_block_window = gemm_tile_windows.at(I3);
988
989 EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
990 }
991 }
992
1010 CK_TILE_DEVICE static void RunGemm2LDS(const std::array<const ADataType*, NumATensor>& as_ptr,
1011 const std::array<const BDataType*, NumBTensor>& bs_ptr,
1012 const std::array<const void*, NumDTensor>& ds_ptr,
1013 EDataType* e_ptr,
1014 void* __restrict__ smem_ptr_0,
1015 void* __restrict__ smem_ptr_1,
1016 const KernelArgs& kargs,
1017 const SplitKBatchOffset& splitk_batch_offset,
1018 const index_t block_idx_m,
1019 const index_t block_idx_n)
1020 {
1021 // Create Gemm tensor views, pad views and tile windows
1022 const auto& gemm_tensor_views_tuple =
1024 as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
1025
1026 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
1027 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
1028
1029 const index_t num_loop =
1030 amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1031
1032 // Run GEMM cooperatively by whole workgroup.
1033 const auto& as_block_window = gemm_tile_windows.at(I0);
1034 const auto& bs_block_window = gemm_tile_windows.at(I1);
1035 const auto& ds_block_window = gemm_tile_windows.at(I2);
1036
1037 const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
1038 AElementWise{},
1039 bs_block_window,
1040 BElementWise{},
1041 num_loop,
1042 smem_ptr_0,
1043 smem_ptr_1);
1044
1045 // Run Epilogue Pipeline
1046 auto& c_block_window = gemm_tile_windows.at(I3);
1047
1048 EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
1049 }
1050
1051 // Non-persistent kernel entry point
1052 template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
1054 {
1055 const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1056 const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1057 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1058 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1059
1060 const SplitKBatchOffset splitk_batch_offset(kargs);
1061
1062 // options
1063 std::array<const ADataType*, NumATensor> as_ptr;
1064 static_for<0, NumATensor, 1>{}([&](auto i) {
1065 as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1066 splitk_batch_offset.as_k_split_offset[i];
1067 });
1068
1069 std::array<const BDataType*, NumBTensor> bs_ptr;
1070 static_for<0, NumBTensor, 1>{}([&](auto i) {
1071 bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1072 splitk_batch_offset.bs_k_split_offset[i];
1073 });
1074
1075 // Calculate output offset from tile partitioner and apply to output pointer
1076 EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1078 {
1079 const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
1080 e_ptr += output_offset;
1081 }
1082
1083 // allocate LDS
1084 __shared__ char smem_ptr_0[GetSmemSize()];
1085
1086 if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1087 {
1088 __shared__ char smem_ptr_1[GetSmemSize()];
1089 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1090 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1092 {
1093 RunGemm2LDS(as_ptr,
1094 bs_ptr,
1095 kargs.ds_ptr,
1096 e_ptr,
1097 smem_ptr_0,
1098 smem_ptr_1,
1099 kargs,
1100 splitk_batch_offset,
1101 i_m,
1102 i_n);
1103 }
1104 }
1105 else
1106 {
1107 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1108 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1110 {
1111 constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
1113 bs_ptr,
1114 kargs.ds_ptr,
1115 e_ptr,
1116 smem_ptr_0,
1117 kargs,
1118 splitk_batch_offset,
1119 i_m,
1120 i_n);
1121 }
1122 }
1123 }
1124
1125 // Persistent kernel entry point
1126 template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
1128 {
1129 const auto grid_size = amd_wave_read_first_lane(get_grid_size());
1130 const auto num_tiles =
1131 amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
1132 const auto num_work = amd_wave_read_first_lane(num_tiles * kargs.k_batch);
1133 auto block_id = amd_wave_read_first_lane(get_block_id());
1134
1135 while(block_id < num_work)
1136 {
1138 // Get the tile index for this block
1139 const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
1140 const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
1141 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1142 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1143
1144 // Get the SplitK offset for this block
1145 const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles);
1146 const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
1147
1148 std::array<const ADataType*, NumATensor> as_ptr;
1149 static_for<0, NumATensor, 1>{}([&](auto i) {
1150 as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1151 splitk_batch_offset.as_k_split_offset[i];
1152 });
1153
1154 std::array<const BDataType*, NumBTensor> bs_ptr;
1155 static_for<0, NumBTensor, 1>{}([&](auto i) {
1156 bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1157 splitk_batch_offset.bs_k_split_offset[i];
1158 });
1159
1160 // Calculate output offset from tile partitioner and apply to output pointer
1161 EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1163 {
1164 const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, k_batch);
1165 e_ptr += output_offset;
1166 }
1167
1168 // allocate LDS
1169 __shared__ char smem_ptr_0[GetSmemSize()];
1170 // Run the GEMM
1171 if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1172 {
1173 __shared__ char smem_ptr_1[GetSmemSize()];
1174 if constexpr(!(EpiloguePipeline::MemoryOperation ==
1176 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1178 {
1179 RunGemm2LDS(as_ptr,
1180 bs_ptr,
1181 kargs.ds_ptr,
1182 e_ptr,
1183 smem_ptr_0,
1184 smem_ptr_1,
1185 kargs,
1186 splitk_batch_offset,
1187 i_m,
1188 i_n);
1189 }
1190 }
1191 else
1192 {
1193 if constexpr(!(EpiloguePipeline::MemoryOperation ==
1195 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1197 {
1198 RunGemm(as_ptr,
1199 bs_ptr,
1200 kargs.ds_ptr,
1201 e_ptr,
1202 smem_ptr_0,
1203 kargs,
1204 splitk_batch_offset,
1205 i_m,
1206 i_n);
1207 }
1208 }
1209 // Advance to the next work item
1210 block_id += grid_size;
1211 if(block_id >= num_work)
1212 {
1213 break;
1214 }
1215 }
1216 }
1217};
1218} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
__global__ void kentry(Args... args)
Definition tile/host/kernel_launch.hpp:22
@ atomic_add
Definition arch.hpp:58
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition tile/host/hip_check_error.hpp:13
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE index_t get_block_id()
Definition arch.hpp:119
CK_TILE_DEVICE index_t get_grid_size()
Definition arch.hpp:89
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void s_waitcnt_barrier()
Definition arch.hpp:260
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
unsigned int uint32_t
Definition stdint.h:126
The Universal GEMM kernel host arguments.
Definition universal_gemm_kernel.hpp:32
void * c_ptr
Definition universal_gemm_kernel.hpp:66
const std::array< index_t, NumDTensor > stride_Ds
Definition universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition universal_gemm_kernel.hpp:33
index_t K
Definition universal_gemm_kernel.hpp:70
void * e_ptr
Definition universal_gemm_kernel.hpp:65
index_t M
Definition universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition universal_gemm_kernel.hpp:71
index_t N
Definition universal_gemm_kernel.hpp:69
index_t stride_E
Definition universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition universal_gemm_kernel.hpp:61
index_t stride_C
Definition universal_gemm_kernel.hpp:77
index_t k_batch
Definition universal_gemm_kernel.hpp:80
Definition universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition universal_gemm_kernel.hpp:368
index_t splitted_k
Definition universal_gemm_kernel.hpp:370
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition universal_gemm_kernel.hpp:326
std::array< index_t, NumBTensor > bs_k_split_offset
Definition universal_gemm_kernel.hpp:369
Definition universal_gemm_kernel.hpp:206
static constexpr bool value
Definition universal_gemm_kernel.hpp:210
decltype(T::UsePersistentKernel) has_persistent_type
Definition universal_gemm_kernel.hpp:208
decltype(T::GetOutputOffset(std::declval< KernelArgs >(), std::declval< index_t >())) has_get_output_offset_t
Definition universal_gemm_kernel.hpp:223
static constexpr bool value
Definition universal_gemm_kernel.hpp:226
The GEMM kernel device arguments.
Definition universal_gemm_kernel.hpp:86
std::array< index_t, NumBTensor > stride_Bs
Definition universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
Definition universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
Definition universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
Definition universal_gemm_kernel.hpp:88
const std::array< const void *, NumBTensor > bs_ptr
Definition universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
Definition universal_gemm_kernel.hpp:109
The Universal GEMM kernel template.
Definition universal_gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition universal_gemm_kernel.hpp:1053
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition universal_gemm_kernel.hpp:156
static CK_TILE_HOST const std::string GetName()
Definition universal_gemm_kernel.hpp:260
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > > > DsLayout
Definition universal_gemm_kernel.hpp:179
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition universal_gemm_kernel.hpp:155
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition universal_gemm_kernel.hpp:1127
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition universal_gemm_kernel.hpp:955
static constexpr bool BDataTypeIsTuple
Definition universal_gemm_kernel.hpp:161
static constexpr auto I2
Definition universal_gemm_kernel.hpp:238
static constexpr bool BLayoutIsTuple
Definition universal_gemm_kernel.hpp:167
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BsDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > > > BsDataType
Definition universal_gemm_kernel.hpp:187
remove_cvref_t< typename GemmPipeline::BElementWise > BElementWise
Definition universal_gemm_kernel.hpp:200
static constexpr index_t NumATensor
Definition universal_gemm_kernel.hpp:241
static constexpr bool ALayoutIsTuple
Definition universal_gemm_kernel.hpp:165
static CK_TILE_DEVICE void RunGemm2LDS(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition universal_gemm_kernel.hpp:1010
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition universal_gemm_kernel.hpp:853
static constexpr auto I3
Definition universal_gemm_kernel.hpp:239
static constexpr bool ADataTypeIsTuple
Definition universal_gemm_kernel.hpp:159
static constexpr bool has_tile_partitioner_output_offset
Definition universal_gemm_kernel.hpp:233
static CK_TILE_HOST constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition universal_gemm_kernel.hpp:267
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > > > DsDataType
Definition universal_gemm_kernel.hpp:191
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition universal_gemm_kernel.hpp:754
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BsLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > > > BsLayout
Definition universal_gemm_kernel.hpp:175
static constexpr index_t NumDTensor
Definition universal_gemm_kernel.hpp:243
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition universal_gemm_kernel.hpp:245
static constexpr bool DDataTypeIsTuple
Definition universal_gemm_kernel.hpp:163
static constexpr bool PersistentKernel
Definition universal_gemm_kernel.hpp:217
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition universal_gemm_kernel.hpp:319
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition universal_gemm_kernel.hpp:196
static constexpr auto I1
Definition universal_gemm_kernel.hpp:237
static CK_TILE_HOST auto BlockSize()
Definition universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition universal_gemm_kernel.hpp:278
static constexpr index_t NumBTensor
Definition universal_gemm_kernel.hpp:242
static CK_TILE_DEVICE auto MakeGemmTensorViews(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const index_t k_size)
Definition universal_gemm_kernel.hpp:580
static constexpr auto I0
Definition universal_gemm_kernel.hpp:236
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::AsLayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > > > AsLayout
Definition universal_gemm_kernel.hpp:172
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition universal_gemm_kernel.hpp:373
static constexpr bool DLayoutIsTuple
Definition universal_gemm_kernel.hpp:169
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition universal_gemm_kernel.hpp:157
static CK_TILE_HOST constexpr KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition universal_gemm_kernel.hpp:303
remove_cvref_t< typename GemmPipeline::AElementWise > AElementWise
Definition universal_gemm_kernel.hpp:199
static constexpr index_t kBlockSize
Definition universal_gemm_kernel.hpp:202
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::AsDataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > > > AsDataType
Definition universal_gemm_kernel.hpp:183
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition universal_gemm_kernel.hpp:257
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition universal_gemm_kernel.hpp:197
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition universal_gemm_kernel.hpp:246
Definition type_traits.hpp:115
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition ck_tile/host/stream_config.hpp:30
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145