device_grouped_contraction_multiple_d_xdl_cshuffle.hpp Source File

device_grouped_contraction_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: device_grouped_contraction_multiple_d_xdl_cshuffle.hpp Source File
device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22
23template <typename GridwiseGemm,
24 typename ContractionMultiDKernelArg,
25 typename AElementwiseOperation,
26 typename BElementwiseOperation,
27 typename CDEElementwiseOperation,
28 bool HasMainKBlockLoop>
29__global__ void
30#if CK_USE_LAUNCH_BOUNDS
32#endif
34 const void CK_CONSTANT_ADDRESS_SPACE* contraction_args,
35 const index_t group_count,
36 const AElementwiseOperation a_element_op,
37 const BElementwiseOperation b_element_op,
38 const CDEElementwiseOperation cde_element_op)
39{
40#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
41 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
42 {
43 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
44
45 const index_t block_id = get_block_1d_id();
46
47 const auto contraction_arg_ptr = reinterpret_cast<const ContractionMultiDKernelArg*>(
49
50 index_t left = 0;
51 index_t right = group_count;
52 index_t group_id = index_t((left + right) / 2);
53
54 while((!(block_id >= contraction_arg_ptr[group_id].block_start_ &&
55 block_id < contraction_arg_ptr[group_id].block_end_)) &&
56 left <= right)
57 {
58 if(block_id < contraction_arg_ptr[group_id].block_start_)
59 {
60 right = group_id;
61 }
62 else
63 {
64 left = group_id;
65 }
66 group_id = index_t((left + right) / 2);
67 }
68
69 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
70 contraction_arg_ptr[group_id].p_a_grid_,
71 contraction_arg_ptr[group_id].p_b_grid_,
72 contraction_arg_ptr[group_id].p_ds_grid_,
73 contraction_arg_ptr[group_id].p_e_grid_,
74 p_shared,
75 a_element_op,
76 b_element_op,
77 cde_element_op,
78 contraction_arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
79 contraction_arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
80 contraction_arg_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
81 contraction_arg_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
82 contraction_arg_ptr[group_id].block_2_etile_map_);
83 }
84#else
85 ignore = contraction_args;
86 ignore = group_count;
87 ignore = a_element_op;
88 ignore = b_element_op;
89 ignore = cde_element_op;
90#endif
91}
92
93} // namespace ck
94
95namespace ck {
96namespace tensor_operation {
97namespace device {
98
99// Tensor Contraction:
100// input : A
101// input : B
102// input : D0, D1, ...
103// output : E
104// C = a_op(A) * b_op(B)
105// E = cde_op(C, D0, D1, ...)
106// Assume:
107// A[M0, M1, M2, ..., K0, K1, K2, ...]
108// B[N0, N1, N2, ..., K0, K1, K2, ...]
109// D[M0, M1, M2, ..., N0, N1, N2, ...]
110// E[M0, M1, M2, ..., N0, N1, N2, ...]
111template <index_t NumDimM,
112 index_t NumDimN,
113 index_t NumDimK,
114 typename ADataType,
115 typename BDataType,
116 typename AccDataType,
117 typename CShuffleDataType,
118 typename DsDataType,
119 typename EDataType,
120 typename AElementwiseOperation,
121 typename BElementwiseOperation,
122 typename CDEElementwiseOperation,
123 GemmSpecialization GemmSpec,
127 index_t NumGemmKPrefetchStage,
128 index_t BlockSize,
129 index_t MPerBlock,
130 index_t NPerBlock,
131 index_t KPerBlock,
132 index_t AK1,
133 index_t BK1,
134 index_t MPerXDL,
135 index_t NPerXDL,
136 index_t MXdlPerWave,
137 index_t NXdlPerWave,
138 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
139 typename ABlockTransferThreadClusterArrangeOrder,
140 typename ABlockTransferSrcAccessOrder,
141 index_t ABlockTransferSrcVectorDim,
142 index_t ABlockTransferSrcScalarPerVector,
143 index_t ABlockTransferDstScalarPerVector_AK1,
144 bool ABlockLdsExtraM,
145 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
146 typename BBlockTransferThreadClusterArrangeOrder,
147 typename BBlockTransferSrcAccessOrder,
148 index_t BBlockTransferSrcVectorDim,
149 index_t BBlockTransferSrcScalarPerVector,
150 index_t BBlockTransferDstScalarPerVector_BK1,
151 bool BBlockLdsExtraN,
152 index_t CShuffleMXdlPerWavePerShuffle,
153 index_t CShuffleNXdlPerWavePerShuffle,
154 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
155 index_t CDEBlockTransferScalarPerVector_NPerBlock,
158 : public DeviceGroupedContractionMultipleD<NumDimM,
159 NumDimN,
160 NumDimK,
161 ADataType,
162 BDataType,
163 DsDataType,
164 EDataType,
165 AElementwiseOperation,
166 BElementwiseOperation,
167 CDEElementwiseOperation>
168{
170
172 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
173 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
174 static constexpr index_t NumDTensor = DsDataType::Size();
175
176 static constexpr auto I0 = Number<0>{};
177 static constexpr auto I1 = Number<1>{};
178 static constexpr auto I2 = Number<2>{};
179 static constexpr auto I3 = Number<3>{};
180
181 static constexpr auto matrix_padder =
182 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
183
184 // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
185 static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_ms_ks_lengths_vec,
186 const std::vector<index_t>& a_ms_ks_strides_vec)
187 {
188 assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK &&
189 a_ms_ks_strides_vec.size() == NumDimM + NumDimK);
190
191 const auto to_tuple = [&](auto& vec, auto num) {
192 return generate_tuple([&](auto i) { return vec[i]; }, num);
193 };
194
195 const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
196 const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
197
198 // dimension Ids for M0, M1, ...
199 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
200
201 // dimension Ids for K0, K1, ...
202 constexpr auto kDimIds =
204
205 // lengths for M0, M1, ...
206 const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
207
208 // lengths for K0, K1, ...
209 const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
210
211 if constexpr(ASpec == TensorSpecialization::Packed)
212 {
213 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
214 auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
215 const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
216 make_tuple(M, K),
217 make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
218 a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
219 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
220 }
221 else
222 {
223 // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
224 const auto a_grid_desc_ms_ks =
225 make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
226
227 // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
228 const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
229 a_grid_desc_ms_ks,
231 make_tuple(mDimIds, kDimIds),
233
234 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
235 }
236 }
237
238 // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
239 static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_ns_ks_lengths_vec,
240 const std::vector<index_t>& b_ns_ks_strides_vec)
241 {
242 assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK &&
243 b_ns_ks_strides_vec.size() == NumDimN + NumDimK);
244
245 const auto to_tuple = [&](auto& vec, auto num) {
246 return generate_tuple([&](auto i) { return vec[i]; }, num);
247 };
248
249 const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number<NumDimN + NumDimK>{});
250 const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number<NumDimN + NumDimK>{});
251
252 // dimension Ids for N0, N1, ...
253 constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
254
255 // dimension Ids for K0, K1, ...
256 constexpr auto kDimIds =
258
259 // lengths for K0, K1, ...
260 const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
261
262 // lengths for N0, N1, ...
263 const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
264
265 if constexpr(BSpec == TensorSpecialization::Packed)
266 {
267 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
268 auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
269 const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
270 make_tuple(N, K),
271 make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
272 b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
273 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
274 }
275 else
276 {
277 // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
278 const auto b_grid_desc_ns_ks =
279 make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
280
281 // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
282 const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
283 b_grid_desc_ns_ks,
285 make_tuple(nDimIds, kDimIds),
287
288 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
289 }
290 }
291
292 // assume E[M0, M1, M2, ..., N0, N1, N2...]
293 static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_vec,
294 const std::vector<index_t>& e_ms_ns_strides_vec)
295 {
296 assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN &&
297 e_ms_ns_strides_vec.size() == NumDimM + NumDimN);
298
299 const auto to_tuple = [&](auto& vec, auto num) {
300 return generate_tuple([&](auto i) { return vec[i]; }, num);
301 };
302
303 const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number<NumDimM + NumDimN>{});
304 const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number<NumDimM + NumDimN>{});
305
306 // dimension Ids for M0, M1, ...
307 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
308
309 // dimension Ids for N0, N1, ...
310 constexpr auto nDimIds =
312
313 // lengths for M0, M1, ...
314 const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
315
316 // lengths for K0, K1, ...
317 const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
318
319 if constexpr(DESpec == TensorSpecialization::Packed)
320 {
321 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
322 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
323 const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor(
324 make_tuple(M, N),
325 make_tuple(e_ms_ns_strides[Number<NumDimM - 1>{}],
326 e_ms_ns_strides[Number<NumDimM + NumDimN - 1>{}]));
327 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
328 }
329 else
330 {
331 // naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
332 const auto e_grid_desc_ms_ns =
333 make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
334
335 // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
336 const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
337 e_grid_desc_ms_ns,
339 make_tuple(mDimIds, nDimIds),
341
342 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
343 }
344 }
345
347 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths_vec,
348 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides_vec)
349 {
350 return generate_tuple(
351 [&](auto i) {
352 return DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths_vec[i],
353 ds_ms_ns_strides_vec[i]);
354 },
356 }
357
358 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
359 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
361 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
362
363 using ComputeDataType = ADataType;
364
365 // GridwiseGemm
366 template <index_t NXdlPerWave_>
368 ADataType, // TODO: distinguish A/B datatype
369 BDataType,
371 AccDataType,
372 CShuffleDataType,
373 DsDataType,
374 EDataType,
375 AElementwiseOperation,
376 BElementwiseOperation,
377 CDEElementwiseOperation,
378 NumGemmKPrefetchStage,
379 BlockSize,
380 MPerBlock,
381 NPerBlock,
382 KPerBlock,
383 AK1,
384 BK1,
385 MPerXDL,
386 NPerXDL,
387 MXdlPerWave,
388 NXdlPerWave_,
389 ABlockTransferThreadClusterLengths_AK0_M_AK1,
390 ABlockTransferThreadClusterArrangeOrder,
391 ABlockTransferSrcAccessOrder,
392 ABlockTransferSrcVectorDim,
393 ABlockTransferSrcScalarPerVector,
394 ABlockTransferDstScalarPerVector_AK1,
395 false,
396 ABlockLdsExtraM,
397 BBlockTransferThreadClusterLengths_BK0_N_BK1,
398 BBlockTransferThreadClusterArrangeOrder,
399 BBlockTransferSrcAccessOrder,
400 BBlockTransferSrcVectorDim,
401 BBlockTransferSrcScalarPerVector,
402 BBlockTransferDstScalarPerVector_BK1,
403 false,
404 BBlockLdsExtraN,
405 CShuffleMXdlPerWavePerShuffle,
406 CShuffleNXdlPerWavePerShuffle,
407 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
408 CDEBlockTransferScalarPerVector_NPerBlock,
409 LoopSched>;
412
413 // desc for blockwise copy
416 AGridDesc_M_K{}))>;
419 BGridDesc_N_K{}))>;
422 DsGridDesc_M_N{}))>;
425 EGridDesc_M_N{}))>;
426
428 {
429 // block-to-e-tile map
432
434 ck::index_t BlockStart)
435 {
437 block_start_ = BlockStart;
438 }
439
440 template <typename TopIdx>
441 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
442 {
443 return default_block_2_etile_map_.CalculateBottomIndex(
444 make_multi_index(idx_top[I0] - block_start_));
445 }
446
447 // it's actually E-Tile
448 template <typename CTileIdx, typename CTileDim>
449 __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
450 const CTileDim& c_tile_dim) const
451 {
452 return default_block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
453 }
454
455 __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
456 {
457 return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
458 }
459
462 };
463
484
486 {
487 // tensor descriptors for problem definiton
492
493 // Strides for the last M/N/K dimensions of A/B/Ds/E
494 // for sanity check of vector load/store
499 std::array<index_t, NumDTensor> ds_nz_stride_;
500 // index_t e_mz_stride_;
502 };
503
504 // Argument
505 struct Argument : public BaseArgument
506 {
507 Argument(std::vector<const void*> p_a_vec,
508 std::vector<const void*> p_b_vec,
509 std::vector<std::array<const void*, NumDTensor>> p_ds_vec,
510 std::vector<void*> p_e_vec,
511 std::vector<ContractionDesc<NumDTensor>> contraction_descs,
512 AElementwiseOperation a_element_op,
513 BElementwiseOperation b_element_op,
514 CDEElementwiseOperation cde_element_op)
515 : a_element_op_{a_element_op},
516 b_element_op_{b_element_op},
517 cde_element_op_{cde_element_op}
518 {
519 group_count_ = contraction_descs.size();
520
521 if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
522 group_count_ == p_e_vec.size()))
523 {
524 throw std::runtime_error("wrong! group_count_ != a/b/e_vec.size");
525 }
526
528
529 grid_size_ = 0;
530
531 for(std::size_t i = 0; i < group_count_; i++)
532 {
533 const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
534 const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
535 const auto p_e_grid = static_cast<EDataType*>(p_e_vec[i]);
536
537 const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(
538 contraction_descs[i].a_ms_ks_lengths, contraction_descs[i].a_ms_ks_strides);
539 const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(
540 contraction_descs[i].b_ns_ks_lengths, contraction_descs[i].b_ns_ks_strides);
541
542 DsGridDesc_M_N ds_grid_desc_m_n;
543 typename GridwiseGemm64::DsGridPointer p_ds_grid;
544
545 // populate pointer, batch stride, desc for Ds
546 static_for<0, NumDTensor, 1>{}([&](auto j) {
547 using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
548
549 // D pointer
550 p_ds_grid(j) = static_cast<const DDataType*>(p_ds_vec[i][j]);
551
552 // D desc
553 ds_grid_desc_m_n(j) =
554 DeviceOp::MakeEGridDescriptor_M_N(contraction_descs[i].ds_ms_ns_lengths[j],
555 contraction_descs[i].ds_ms_ns_strides[j]);
556 });
557
558 const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N(
559 contraction_descs[i].e_ms_ns_lengths, contraction_descs[i].e_ms_ns_strides);
560
561 const auto a_grid_desc_ak0_m_ak1 =
563 const auto b_grid_desc_bk0_n_bk1 =
565
566 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
568 ds_grid_desc_m_n);
569 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
571 e_grid_desc_m_n);
572
573 const index_t grid_size_grp =
575 .CalculateGridSize(e_grid_desc_m_n);
576
577 const index_t BlockStart = grid_size_;
578 const index_t BlockEnd = grid_size_ + grid_size_grp;
579
580 grid_size_ += grid_size_grp;
581
582 const auto block_2_etile_map =
583 GroupedContractionBlock2ETileMap(e_grid_desc_m_n, BlockStart);
584
585 // for sanity check of vector memory access
586 const index_t a_mz_stride = contraction_descs[i].a_ms_ks_strides[NumDimM - 1];
587 const index_t a_kz_stride =
588 contraction_descs[i].a_ms_ks_strides[NumDimM + NumDimK - 1];
589
590 const index_t b_nz_stride = contraction_descs[i].b_ns_ks_strides[NumDimN - 1];
591 const index_t b_kz_stride =
592 contraction_descs[i].b_ns_ks_strides[NumDimN + NumDimK - 1];
593
594 std::array<index_t, NumDTensor> ds_nz_stride;
595 for(index_t j = 0; j < NumDTensor; ++j)
596 {
597 ds_nz_stride[j] =
598 contraction_descs[i].ds_ms_ns_strides[j][NumDimM + NumDimN - 1];
599 }
600
601 const index_t e_nz_stride =
602 contraction_descs[i].e_ms_ns_strides[NumDimM + NumDimN - 1];
603
604 bool valid = false;
605 if(get_warp_size() == 64)
606 {
607 if constexpr(NXdlPerWave64 > 0)
608 {
609 valid = GridwiseGemm64::CheckValidity(a_grid_desc_m_k,
610 b_grid_desc_n_k,
611 ds_grid_desc_m_n,
612 e_grid_desc_m_n,
613 block_2_etile_map);
614 }
615 }
616 else
617 {
618 if constexpr(NXdlPerWave32 > 0)
619 {
620 valid = GridwiseGemm32::CheckValidity(a_grid_desc_m_k,
621 b_grid_desc_n_k,
622 ds_grid_desc_m_n,
623 e_grid_desc_m_n,
624 block_2_etile_map);
625 }
626 }
627 if(valid)
628 {
630 {p_a_grid,
631 p_b_grid,
632 p_ds_grid,
633 p_e_grid,
634 a_grid_desc_ak0_m_ak1,
635 b_grid_desc_bk0_n_bk1,
636 ds_grid_desc_mblock_mperblock_nblock_nperblock,
637 e_grid_desc_mblock_mperblock_nblock_nperblock,
638 block_2_etile_map,
639 BlockStart,
640 BlockEnd});
641
642 contraction_multi_d_device_args_.push_back({a_grid_desc_m_k,
643 b_grid_desc_n_k,
644 ds_grid_desc_m_n,
645 e_grid_desc_m_n,
646 a_mz_stride,
647 a_kz_stride,
648 b_nz_stride,
649 b_kz_stride,
650 ds_nz_stride,
651 e_nz_stride});
652 }
653 }
654 }
655
656 std::vector<ContractionMultiDKernelArg> contraction_multi_d_kernel_args_;
657 std::vector<ContractionMultiDDeviceArg> contraction_multi_d_device_args_;
658
659 std::size_t group_count_;
661
662 // element-wise op
663 AElementwiseOperation a_element_op_;
664 BElementwiseOperation b_element_op_;
665 CDEElementwiseOperation cde_element_op_;
666 };
667
668 // Invoker
669 struct Invoker : public BaseInvoker
670 {
672
673 template <typename GridwiseGemm>
674 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
675 {
676 bool has_main_k_block_loop = true;
677
678 for(std::size_t i = 0; i < arg.group_count_; i++)
679 {
680 const auto K =
681 arg.contraction_multi_d_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
682 arg.contraction_multi_d_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
683
684 if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
685 {
686 throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
687 }
688 }
689
690 hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
693 sizeof(ContractionMultiDKernelArg),
694 hipMemcpyHostToDevice,
695 stream_config.stream_id_));
696
697 float ave_time = 0;
698
699 auto launch_kernel = [&](auto has_main_k_block_loop_) {
700 const auto kernel =
702 ContractionMultiDKernelArg,
703 AElementwiseOperation,
704 BElementwiseOperation,
705 CDEElementwiseOperation,
706 has_main_k_block_loop_>;
707
709 stream_config,
710 kernel,
711 dim3(arg.grid_size_),
712 dim3(BlockSize),
713 0,
715 arg.group_count_,
716 arg.a_element_op_,
717 arg.b_element_op_,
718 arg.cde_element_op_);
719 };
720
721 if(has_main_k_block_loop)
722 {
723 ave_time = launch_kernel(integral_constant<bool, true>{});
724 }
725 else
726 {
727 ave_time = launch_kernel(integral_constant<bool, false>{});
728 }
729
730 return ave_time;
731 }
732
734
735 // polymorphic
736 float Run(const BaseArgument* p_arg,
737 const StreamConfig& stream_config = StreamConfig{}) override
738 {
739 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
740 }
741 };
742
743 static bool IsSupportedArgument(const Argument& arg)
744 {
746 {
747 return false;
748 }
749 for(std::size_t i = 0; i < arg.group_count_; i++)
750 {
751 const auto a_grid_desc_m_k_ = arg.contraction_multi_d_device_args_[i].a_grid_desc_m_k_;
752 const auto b_grid_desc_n_k_ = arg.contraction_multi_d_device_args_[i].b_grid_desc_n_k_;
753 const auto ds_grid_desc_m_n_ =
754 arg.contraction_multi_d_device_args_[i].ds_grid_desc_m_n_;
755 const auto e_grid_desc_m_n_ = arg.contraction_multi_d_device_args_[i].e_grid_desc_m_n_;
756 const auto a_grid_desc_ak0_m_ak1_ =
757 arg.contraction_multi_d_kernel_args_[i].a_grid_desc_ak0_m_ak1_;
758 const auto b_grid_desc_bk0_n_bk1_ =
759 arg.contraction_multi_d_kernel_args_[i].b_grid_desc_bk0_n_bk1_;
760 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
762 .ds_grid_desc_mblock_mperblock_nblock_nperblock_;
763 const auto e_grid_desc_mblock_mperblock_nblock_nperblock_ =
765 .e_grid_desc_mblock_mperblock_nblock_nperblock_;
766
767 const auto block_2_etile_map_ =
768 arg.contraction_multi_d_kernel_args_[i].block_2_etile_map_;
769
770 const auto a_mz_stride_ = arg.contraction_multi_d_device_args_[i].a_mz_stride_;
771 const auto a_kz_stride_ = arg.contraction_multi_d_device_args_[i].a_kz_stride_;
772 const auto b_nz_stride_ = arg.contraction_multi_d_device_args_[i].b_nz_stride_;
773 const auto b_kz_stride_ = arg.contraction_multi_d_device_args_[i].b_kz_stride_;
774 const auto ds_nz_stride_ = arg.contraction_multi_d_device_args_[i].ds_nz_stride_;
775 const auto e_nz_stride_ = arg.contraction_multi_d_device_args_[i].e_nz_stride_;
776
777 bool valid = false;
778 if(get_warp_size() == 64)
779 {
780 if constexpr(NXdlPerWave64 > 0)
781 {
782 valid = GridwiseGemm64::CheckValidity(a_grid_desc_m_k_,
783 b_grid_desc_n_k_,
784 ds_grid_desc_m_n_,
785 e_grid_desc_m_n_,
786 block_2_etile_map_);
787 }
788 }
789 else
790 {
791 if constexpr(NXdlPerWave32 > 0)
792 {
793 valid = GridwiseGemm32::CheckValidity(a_grid_desc_m_k_,
794 b_grid_desc_n_k_,
795 ds_grid_desc_m_n_,
796 e_grid_desc_m_n_,
797 block_2_etile_map_);
798 }
799 }
800 if(!valid)
801 {
802 return false;
803 }
804
805 // check vector access
806 static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
807 (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
808 "wrong!");
809
810 // vector memory access of A: could be on M or AK1 dimension
811 if constexpr(ABlockTransferSrcVectorDim == 1)
812 {
813 if(!(a_mz_stride_ == 1 &&
814 a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
815 {
816 return false;
817 }
818 }
819 else
820 {
821 if(!(a_kz_stride_ == 1 &&
822 a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
823 {
824 return false;
825 }
826 }
827
828 // vector memory access of B: could be on N or BK1 dimension
829 if constexpr(BBlockTransferSrcVectorDim == 1)
830 {
831 if(!(b_nz_stride_ == 1 &&
832 b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
833 {
834 return false;
835 }
836 }
837 else
838 {
839 if(!(b_kz_stride_ == 1 &&
840 b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
841 {
842 return false;
843 }
844 }
845
846 // vector memory access of Ds: always on NPerBlock dimension
847 bool valid_d_access = true;
848
849 static_for<0, NumDTensor, 1>{}([&](auto j) {
850 if(!(ds_nz_stride_[j] == 1 &&
851 ds_grid_desc_mblock_mperblock_nblock_nperblock_[j].GetLength(I3) %
852 CDEBlockTransferScalarPerVector_NPerBlock ==
853 0))
854 {
855 valid_d_access = false;
856 }
857 });
858
859 if(valid_d_access == false)
860 {
861 return false;
862 }
863
864 // vector memory access of E: always on NPerBlock dimension
865 if(!(e_nz_stride_ == 1 && e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
866 CDEBlockTransferScalarPerVector_NPerBlock ==
867 0))
868 {
869 return false;
870 }
871 }
872
873 return true;
874 }
875
876 // polymorphic
877 bool IsSupportedArgument(const BaseArgument* p_arg) override
878 {
879 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
880 }
881
882 static auto MakeArgument(std::vector<const void*> p_a_vec,
883 std::vector<const void*> p_b_vec,
884 std::vector<std::array<const void*, NumDTensor>> p_ds_vec,
885 std::vector<void*> p_e_vec,
886 std::vector<ContractionDesc<NumDTensor>> contraction_descs,
887 AElementwiseOperation a_element_op,
888 BElementwiseOperation b_element_op,
889 CDEElementwiseOperation cde_element_op)
890 {
891 return Argument{p_a_vec,
892 p_b_vec,
893 p_ds_vec,
894 p_e_vec,
895 contraction_descs,
896 a_element_op,
897 b_element_op,
898 cde_element_op};
899 }
900
901 static auto MakeInvoker() { return Invoker{}; }
902
903 // polymorphic
904 std::unique_ptr<BaseArgument>
905 MakeArgumentPointer(std::vector<const void*> p_a_vec,
906 std::vector<const void*> p_b_vec,
907 std::vector<std::array<const void*, NumDTensor>> p_ds_vec,
908 std::vector<void*> p_e_vec,
909 std::vector<ContractionDesc<NumDTensor>> contraction_descs,
910 AElementwiseOperation a_element_op,
911 BElementwiseOperation b_element_op,
912 CDEElementwiseOperation cde_element_op) override
913 {
914 return std::make_unique<Argument>(p_a_vec,
915 p_b_vec,
916 p_ds_vec,
917 p_e_vec,
918 contraction_descs,
919 a_element_op,
920 b_element_op,
921 cde_element_op);
922 }
923
924 // polymorphic
925 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
926 {
927 return std::make_unique<Invoker>(Invoker{});
928 }
929
930 // polymorphic
931 std::string GetTypeString() const override
932 {
933 auto str = std::stringstream();
934
935 // clang-format off
936 str << "DeviceGroupedContractionMultipleD_Xdl_CShuffle"
937 << "<"
938 << NumDimM << ", "
939 << NumDimN << ", "
940 << NumDimK << ", "
941 << BlockSize << ", "
942 << MPerBlock << ", "
943 << NPerBlock << ", "
944 << KPerBlock << ", "
945 << AK1 << ", "
946 << BK1 << ", "
947 << ABlockTransferSrcVectorDim << ", "
948 << BBlockTransferSrcVectorDim
949 << ">";
950 // clang-format on
951
952 return str.str();
953 }
954
955 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
956 {
957 return dynamic_cast<const Argument*>(p_arg)->group_count_ *
959 }
960};
961
962} // namespace device
963} // namespace tensor_operation
964} // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
TensorSpecialization
Definition tensor_specialization.hpp:11
@ Packed
Definition tensor_specialization.hpp:13
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__global__ void kernel_grouped_contraction_multiple_d_xdl_cshuffle(const void CK_CONSTANT_ADDRESS_SPACE *contraction_args, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:33
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/integral_constant.hpp:20
Definition utility/math.hpp:34
Definition functional2.hpp:33
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_contraction_multiple_d.hpp:17
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:490
index_t b_nz_stride_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:497
index_t a_mz_stride_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:495
index_t b_kz_stride_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:498
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:491
BGridDesc_N_K b_grid_desc_n_k_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:489
index_t a_kz_stride_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:496
std::array< index_t, NumDTensor > ds_nz_stride_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:499
AGridDesc_M_K a_grid_desc_m_k_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:488
index_t e_nz_stride_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:501
const BDataType * p_b_grid_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:468
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:477
GroupedContractionBlock2ETileMap block_2_etile_map_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:480
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:469
EDataType * p_e_grid_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:470
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:476
ck::index_t block_end_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:482
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:473
ck::index_t block_start_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:482
const ADataType * p_a_grid_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:467
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:474
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:449
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:441
__host__ bool CheckValidity(const EGridDesc_M_N &e_grid_desc_m_n) const
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:455
Block2ETileMap default_block_2_etile_map_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:460
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:430
GroupedContractionBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n, ck::index_t BlockStart)
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:433
ck::index_t block_start_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:461
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:506
CDEElementwiseOperation cde_element_op_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:665
BElementwiseOperation b_element_op_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:664
index_t grid_size_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:660
std::vector< ContractionMultiDKernelArg > contraction_multi_d_kernel_args_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:656
AElementwiseOperation a_element_op_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:663
Argument(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< std::array< const void *, NumDTensor > > p_ds_vec, std::vector< void * > p_e_vec, std::vector< ContractionDesc< NumDTensor > > contraction_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:507
std::vector< ContractionMultiDDeviceArg > contraction_multi_d_device_args_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:657
std::size_t group_count_
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:659
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:670
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:674
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:736
DeviceOp::Argument Argument
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:671
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:168
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:420
static constexpr auto I3
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:179
static constexpr auto I2
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:178
static constexpr auto I0
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:176
static auto MakeArgument(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< std::array< const void *, NumDTensor > > p_ds_vec, std::vector< void * > p_e_vec, std::vector< ContractionDesc< NumDTensor > > contraction_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:882
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:955
decltype(MakeEGridDescriptor_M_N({}, {})) EGridDesc_M_N
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:361
decltype(MakeAGridDescriptor_M_K({}, {})) AGridDesc_M_K
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:358
std::string GetTypeString() const override
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:931
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:417
static auto MakeDsGridDescriptor_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides_vec)
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:346
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:414
decltype(MakeBGridDescriptor_N_K({}, {})) BGridDesc_N_K
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:359
static constexpr auto I1
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:177
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_ms_ks_lengths_vec, const std::vector< index_t > &a_ms_ks_strides_vec)
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:185
static auto MakeEGridDescriptor_M_N(const std::vector< index_t > &e_ms_ns_lengths_vec, const std::vector< index_t > &e_ms_ns_strides_vec)
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:293
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< std::array< const void *, NumDTensor > > p_ds_vec, std::vector< void * > p_e_vec, std::vector< ContractionDesc< NumDTensor > > contraction_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:905
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:877
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:423
ADataType ComputeDataType
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:363
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:172
static constexpr auto matrix_padder
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:181
DeviceGroupedContractionMultipleD_Xdl_CShuffle DeviceOp
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:169
static auto MakeBGridDescriptor_N_K(const std::vector< index_t > &b_ns_ks_lengths_vec, const std::vector< index_t > &b_ns_ks_strides_vec)
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:239
static constexpr auto NXdlPerWave32
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:173
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:743
static auto MakeInvoker()
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:901
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:367
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:411
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:410
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:925
static constexpr index_t NumDTensor
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:174
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))> DsGridDesc_M_N
Definition device_grouped_contraction_multiple_d_xdl_cshuffle.hpp:360
Definition device_grouped_contraction_multiple_d.hpp:54
Definition matrix_padder.hpp:180