tile_distribution_encoding.hpp Source File

tile_distribution_encoding.hpp Source File#

Composable Kernel: tile_distribution_encoding.hpp Source File
tile_distribution_encoding.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
17namespace ck_tile {
18
19template <typename RsLengths_, // sequence<...>
20 typename HsLengthss_, // tuple<sequence<...>, ...>
21 typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
22 typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
23 typename Ys2RHsMajor_, // sequence<...>
24 typename Ys2RHsMinor_> // sequence<...>
26{
33
34 static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
35 static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
36
37 static constexpr index_t NDimX = HsLengthss::size();
38 static constexpr index_t NDimP = Ps2RHssMajor::size();
39 static constexpr index_t NDimY = Ys2RHsMajor::size();
40 static constexpr index_t NDimR = RsLengths::size();
41
42 // FIXME: move into detail
43 static constexpr auto rs_lengths_ = RsLengths{};
44 static constexpr auto hs_lengthss_ = HsLengthss{};
45 static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
46 static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
47 static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
48 static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
49
50#if !CK_TILE_ENC_SUPPORT_Y_TO_R
51 static_assert(container_find(ys_to_rhs_major_, 0) == NDimY,
52 "do not support Y dim pointed to R dim");
53#endif
54
55 // redundant but useful info
56 // TODO: really bad code, should be over-hauled
57 struct detail
58 {
59 // ndim_rh_major_, ndim_span_mainor_
60 static constexpr index_t ndim_rh_major_ = NDimX + 1;
61 static constexpr index_t ndim_span_major_ = NDimX;
62
63 // ndims_rhs_minor_[ndim_rh_major_]
64 static constexpr auto ndims_rhs_minor_ = generate_array(
65 [](auto i) {
66 if constexpr(i.value == 0)
67 {
68 return rs_lengths_.size();
69 }
70 else
71 {
72 return hs_lengthss_[i - number<1>{}].size();
73 }
74 },
76
77 // max_ndim_rh_minor_
78 static constexpr index_t max_ndim_rh_minor_ =
80
81 // rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
82 static constexpr auto rhs_lengthss_ =
84
85 // ys_lengths_
86 static constexpr auto ys_lengths_ = [] {
87 array<index_t, NDimY> ys_lengths_tmp{-1};
88
89 for(index_t i = 0; i < NDimY; i++)
90 {
91 index_t rh_major = ys_to_rhs_major_[i];
92 index_t rh_minor = ys_to_rhs_minor_[i];
93
94 ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
95 }
96
97 return ys_lengths_tmp;
98 }();
99
100 // rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
101 static constexpr auto rhs_major_minor_to_ys_ = [] {
102 array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
103
104 static_for<0, NDimY, 1>{}([&](auto i) {
105 constexpr index_t rh_major = ys_to_rhs_major_[i];
106 constexpr index_t rh_minor = ys_to_rhs_minor_[i];
107
108 rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
109 });
110
111 return rhs_major_minor_to_ys_tmp;
112 }();
113
114 // ndims_span_minor_[NDimY]
115 static constexpr auto ndims_span_minor_ = [] {
116 array<index_t, NDimX> ndims_span_minor{0};
117
118 for(index_t i = 0; i < NDimY; i++)
119 {
120 const index_t span_major = ys_to_rhs_major_[i] - 1;
121
122 ndims_span_minor(span_major)++;
123 }
124
125 return ndims_span_minor;
126 }();
127
128 // max_ndim_span_minor_
129 static constexpr index_t max_ndim_span_minor_ =
131
132 // rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
133 static constexpr auto rhs_major_minor_to_span_minor_ = [] {
134 array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
135 {-1}};
136
137 static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
138 constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
139
140 index_t cnt_ndim_span_minor = 0;
141
142 static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
143 constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
144
145 if(idim_y >= 0)
146 {
147 rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
148
149 cnt_ndim_span_minor++;
150 }
151 });
152 });
153
154 return rhs_major_minor_to_span_minor;
155 }();
156
157 // ys_to_span_major_[NDimY]
158 static constexpr auto ys_to_span_major_ =
159 generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
160
161 // ys_to_span_minor_[NDimY]
162 static constexpr auto ys_to_span_minor_ = generate_array(
163 [](auto i) {
165 },
166 number<NDimY>{});
167
168 // distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
169 static constexpr auto distributed_spans_lengthss_ = [] {
171 distributed_spans_lengthss{{-1}};
172
173 static_for<0, NDimY, 1>{}([&](auto i) {
174 const index_t rh_major = ys_to_rhs_major_[i];
175 const index_t rh_minor = ys_to_rhs_minor_[i];
176
177 const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
178
179 const index_t span_major = rh_major - 1;
180 const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
181
182 distributed_spans_lengthss(span_major)(span_minor) = h_length;
183 });
184
185 return distributed_spans_lengthss;
186 }();
187
188 // ndims_distributed_spans_minor_[ndim_span_major_]
189 static constexpr auto ndims_distributed_spans_minor_ = [] {
190 array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
191
192 static_for<0, NDimY, 1>{}([&](auto i) {
193 const index_t span_major = ys_to_rhs_major_[i] - 1;
194
195 ndims_distributed_spans_minor(span_major)++;
196 });
197
198 return ndims_distributed_spans_minor;
199 }();
200
201 // does_p_own_r_[NDimP][NDimR]
202 static constexpr auto does_p_own_r_ = [] {
203 if constexpr(NDimR > 0)
204 {
205 array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
206
207 static_for<0, NDimP, 1>{}([&](auto idim_p) {
208 constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
209
210 static_for<0, ndim_low, 1>{}([&](auto idim_low) {
211 constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
212 constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
213
214 if constexpr(rh_major == 0)
215 {
216 does_p_own_r(idim_p)(rh_minor) = true;
217 }
218 });
219 });
220
221 return does_p_own_r;
222 }
223 else
224 {
226 }
227 }();
228
229 // ps_over_rs_derivative_[NDimP][NDimR]
230 static constexpr auto ps_over_rs_derivative_ = [] {
231 if constexpr(NDimR > 0)
232 {
233 array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
234
235 static_for<0, NDimP, 1>{}([&](auto idim_p) {
236 constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
237
238 index_t p_over_rh_derivative = 1;
239
240 static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
241 constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
242 constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
243
244 constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
245
246 if constexpr(rh_major == 0)
247 {
248 ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
249 }
250
251 p_over_rh_derivative *= rh_length;
252 });
253 });
254
255 return ps_over_rs_derivative;
256 }
257 else
258 {
260 }
261 }();
262
264 {
265 // e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
266 constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
267 [&](auto i) {
268 constexpr index_t size_ = HsLengthss{}[i].size();
269 return number<size_>{};
270 },
271 number<NDimX>{});
272 return uniformed_h_dim_lengths;
273 }
274
275 // note: this function only count the p dim length along h, not r
277 {
278 // e.g. tuple<seq<1, 4, 32>, seq<1, 2, 8, 4, 4>>
279 // Y P Y Y P Y P Y
280 // | | |
281 // v v v
282 // return : seq<4, 2 * 4> => seq<4, 8>
283 constexpr auto uniformed_ps_to_rhss_major_ =
284 unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
285 constexpr auto uniformed_ps_to_rhss_minor_ =
286 unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
287
288 constexpr auto p_len_ = [&]() {
289 array<index_t, NDimX> len_{1};
290 static_for<0, NDimX, 1>{}([&](auto idim_x_) {
291 constexpr auto major_ = number<idim_x_ + 1>{}; // RDim
292 static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) {
293 if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
294 {
295 constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
296 constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_];
297 len_[idim_x_] *= h_length_;
298 }
299 });
300 });
301 return len_;
302 }();
303 constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX);
304 return p_len_over_h_seq_;
305 }
306
307 //
308 // R: seq<3>, H: tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>>
309 // => return seq<1, 3, 5>
310 // R: seq<>, H: tuple<seq<2, 4>, seq<16, 8, 8>>
311 // => return seq<0, 2, 3>
313 {
314 constexpr auto uniformed_rh_dim_lengths =
316
317 return uniformed_rh_dim_lengths;
318 }
319
320 // e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
322 {
323 // <0, len_d0, len_d0+len_d1, ...>
324 // e.g. seq<3, 5> --> seq<0, 3, 8>
325 constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths());
326
327 return h_dim_prefix_sum;
328 }
329
331 {
332 // <0, len_d0, len_d0+len_d1, ...>
333 // e.g. seq<3, 5> --> seq<0, 3, 8>
334 constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths());
335
336 return rh_dim_prefix_sum;
337 }
338
340 {
341 // tuple<seq<xx..>, seq<yy..>> -> seq<xx..yy..>
342 constexpr auto uniformed_ps_to_rhss_major_ =
343 unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
344 constexpr auto uniformed_ps_to_rhss_minor_ =
345 unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
346
347 constexpr auto all_ps_2_rhss = transform_sequences(
348 [](auto major, auto minor) constexpr {
349 constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
350 return rh_dim_prefix_sum.at(major) + minor;
351 },
352 uniformed_ps_to_rhss_major_,
353 uniformed_ps_to_rhss_minor_);
354
355 return all_ps_2_rhss;
356 }
357
359 {
360 constexpr auto all_ys_2_rhss = transform_sequences(
361 [](auto major, auto minor) constexpr {
362 constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
363 return rh_dim_prefix_sum.at(major) + minor;
364 },
365 Ys2RHsMajor{},
366 Ys2RHsMinor{});
367
368 return all_ys_2_rhss;
369 }
370
372 {
373 // TODO: Y can't point to R
374 constexpr auto all_ys_2_rhss = transform_sequences(
375 [](auto major, auto minor) constexpr {
376 constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
377 return rh_dim_prefix_sum.at(major) + minor - NDimR;
378 },
379 Ys2RHsMajor{},
380 Ys2RHsMinor{});
381
382 return all_ys_2_rhss;
383 }
384
385 // return tuple of seq
386 CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks()
387 {
388 constexpr auto masks_ = generate_tuple(
389 [&](auto i) {
390 constexpr auto size_ = HsLengthss{}[i].size();
391 constexpr auto current_y_to_h_mask_ = [&]() {
393 // TODO: we loop over all y for each h dim
394 for(auto j = 0; j < NDimY; j++)
395 {
396 if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/)
397 {
398 m_[Ys2RHsMinor{}[j]] = 1;
399 }
400 }
401 return m_;
402 }();
403
404 return TO_SEQUENCE(current_y_to_h_mask_, size_);
405 },
406 number<NDimX>{});
407 return masks_;
408 }
409
410 // return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
411 template <typename IdxSeq, typename PrefixSumSeq>
412 CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
413 {
415
416 constexpr auto sorted_dims = typename sorted_idx::type{};
417 constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
418
419 constexpr auto sorted_histogram =
420 histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
421 constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
422
423 return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
424 }
425
426 // Note here y_to_h does not count R dim!
431 };
432};
433
434template <typename encoding, typename shuffle>
436template <typename encoding, index_t... shuffle>
438{
439 template <typename Ys2RHs>
440 using shuffled = sequence<(Ys2RHs::template get<shuffle>())...>;
441
442 public:
443 using type = tile_distribution_encoding<typename encoding::RsLengths,
444 typename encoding::HsLengthss,
445 typename encoding::Ps2RHssMajor,
446 typename encoding::Ps2RHssMinor,
447 shuffled<typename encoding::Ys2RHsMajor>,
448 shuffled<typename encoding::Ys2RHsMinor>>;
449};
450template <typename encoding, typename shuffle>
453
454namespace detail {
455
456template <typename OuterDstr, typename InnerDstr>
458{
459 static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
460
461 constexpr index_t NDimHMajor = OuterDstr::NDimX;
462
463 using RsLengths =
465
466 constexpr auto hs_lengthss = generate_tuple(
467 [&](auto i) {
468 return merge_sequences(typename OuterDstr::HsLengthss{}[i],
469 typename InnerDstr::HsLengthss{}[i]);
470 },
472
473 //
474 constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
475 array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
476
477 // R dimension
478 rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
479
480 // Hs dimensions
481 static_for<0, NDimHMajor, 1>{}([&](auto i) {
482 rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
483 });
484
485 return rhs_major_2_ndim_outer_rhs_minor_;
486 }();
487
488 // Ps2RHssMinor
489 constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
490 [&](auto p) {
491 constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
492 constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
493
494 constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
495
496 constexpr auto updated_inner_p_2_rhss_minor = [&]() {
497 array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
498
499 for(index_t i = 0; i < ndim_tmp; i++)
500 {
501 index_t rh_major = inner_p_2_rhss_major[i];
502
503 index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
504
505 updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
506 }
507
508 return updated_inner_p_2_rhss_minor_;
509 }();
510
511 return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
512 },
514
515 // Ys2RHsMinor
516 constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
517 constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
518 constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
519
520 constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
521
522 constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
523 array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
524
525 for(index_t i = 0; i < ndim_tmp; i++)
526 {
527 index_t rh_major = inner_ys_2_rhs_major[i];
528
529 index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
530
531 updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
532 }
533
534 return updated_inner_ys_2_rhs_minor__;
535 }();
536
537 return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
538 }();
539
540 //
541 constexpr auto ps_2_rhss_major =
542 container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
543
544 constexpr auto ps_2_rhss_minor =
545 container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
546
547 //
548 constexpr auto ys_2_rhs_major =
549 merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
550
551 constexpr auto ys_2_rhs_minor =
552 merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
553
554 return tile_distribution_encoding<RsLengths,
555 remove_cvref_t<decltype(hs_lengthss)>,
556 remove_cvref_t<decltype(ps_2_rhss_major)>,
557 remove_cvref_t<decltype(ps_2_rhss_minor)>,
558 remove_cvref_t<decltype(ys_2_rhs_major)>,
559 remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
560}
561
562template <typename InDstr, index_t... InReduceDimXs>
563CK_TILE_HOST_DEVICE constexpr auto
565{
566 constexpr auto I1 = number<1>{};
567
568 // FIXME: increase if fail
569 constexpr index_t max_ndim_r_out = 20;
570 constexpr index_t max_ndim_y_out = 20;
571
572 //
573 constexpr index_t ndim_p = InDstr::NDimP;
574 constexpr index_t ndim_x_in = InDstr::NDimX;
575 constexpr index_t ndim_y_in = InDstr::NDimY;
576 constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
577 constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
578 constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
579
580 // ndims_ps_low
581 constexpr auto ndims_ps_low = generate_array(
582 [&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
583
584 // is_rh_major_in_for_reduce
585 array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
586
587 for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
588 {
589 index_t rh_major = reduce_dim_xs_in[i] + 1;
590
591 is_rh_major_in_for_reduce(rh_major) = true;
592 }
593
594 // is_y_in_for_reduce
595 array<bool, ndim_y_in> is_y_in_for_reduce{false};
596
597 for(index_t i = 0; i < ndim_y_in; i++)
598 {
599 index_t rh_major = InDstr::ys_to_rhs_major_[i];
600
601 if(is_rh_major_in_for_reduce[rh_major])
602 {
603 is_y_in_for_reduce(i) = true;
604 }
605 }
606
607 // is_rh_minor_in_for_y_reduce
608 array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
609
610 static_for<0, ndim_y_in, 1>{}([&](auto i) {
611 index_t rh_major = InDstr::ys_to_rhs_major_[i];
612 index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
613
614 if(is_y_in_for_reduce[i])
615 {
616 is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
617 }
618 });
619
620 // in2out_rh_major
621 array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
622 index_t cnt_ndim_rh_major_out = 0;
623
624 for(index_t i = 0; i < ndim_rh_major_in; i++)
625 {
626 if(is_rh_major_in_for_reduce[i])
627 {
628 in2out_rh_major(i) = 0;
629 }
630 else
631 {
632 in2out_rh_major(i) = cnt_ndim_rh_major_out;
633
634 cnt_ndim_rh_major_out++;
635 }
636 }
637
638 // rs_lengths_out, in2out_rh_minor
639 array<index_t, max_ndim_r_out> rs_lengths_out{-1};
640 array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
641
642 // loop over input R dim
643 for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
644 {
645 // rs_lengths_out
646 rs_lengths_out(i) = InDstr::rs_lengths_[i];
647
648 // in2out_rh_minor
649 in2out_rh_minor(0)(i) = i;
650 }
651
652 // loop over input H Dim
653 index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
654
655 static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
656 constexpr auto h_major_in = rh_major_in - I1;
657
658 constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
659
660 if(is_rh_major_in_for_reduce[rh_major_in])
661 {
662 for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
663 {
664 if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
665 {
666 // rs_lengths_out
667 rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
668
669 // in2out_rh_minor
670 in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
671
672 cnt_ndim_r_out++;
673 }
674 }
675 }
676 else
677 {
678 for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
679 {
680 // in2out_rh_minor
681 in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
682 }
683 }
684 });
685
686 // ndim_r_out
687 const index_t ndim_r_out = cnt_ndim_r_out;
688
689 // ndims_hs_minor_out, hs_lengthss_out
690 array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
691 array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
692
693 index_t cnt_ndim_x_out = 0;
694
695 static_for<0, ndim_x_in, 1>{}([&](auto i) {
696 if(not is_rh_major_in_for_reduce[i + I1])
697 {
698 // ndims_hs_minor_out
699 ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
700
701 // hs_lengthss_out
702 static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
703 [&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
704
705 cnt_ndim_x_out++;
706 }
707 });
708
709 // ps_to_rhss_major_out, ps_to_rhss_minor_out
710 array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
711 array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
712
713 static_for<0, ndim_p, 1>{}([&](auto idim_p) {
714 static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
715 index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
716 index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
717
718 ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
719 ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
720 });
721 });
722
723 // ys_to_rhs_major_out, ys_to_rhs_minor_out
724 array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
725 array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
726
727 index_t cnt_ndim_y_out = 0;
728
729 static_for<0, ndim_y_in, 1>{}([&](auto i) {
730 if(not is_y_in_for_reduce[i])
731 {
732 index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
733 index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
734
735 ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
736 ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
737
738 cnt_ndim_y_out++;
739 }
740 });
741
742 // ndim_y_out
743 const index_t ndim_y_out = cnt_ndim_y_out;
744
745 //
746 return make_tuple(ndim_x_out,
747 ndim_p,
748 ndim_y_out,
749 ndim_r_out,
750 ndims_hs_minor_out,
751 ndims_ps_low,
752 rs_lengths_out,
753 hs_lengthss_out,
754 ps_to_rhss_major_out,
755 ps_to_rhss_minor_out,
756 ys_to_rhs_major_out,
757 ys_to_rhs_minor_out);
758}
759
760template <typename InDstr, index_t... InReduceDimXs>
761CK_TILE_HOST_DEVICE constexpr auto
763{
764 constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
765
766 constexpr index_t ndim_x = impl.template at<0>();
767 constexpr index_t ndim_p = impl.template at<1>();
768 constexpr index_t ndim_y = impl.template at<2>();
769 constexpr index_t ndim_r = impl.template at<3>();
770 constexpr auto ndims_hs_minor = impl.template at<4>();
771 constexpr auto ndims_ps_low = impl.template at<5>();
772 constexpr auto rs_lengths_impl = impl.template at<6>();
773 constexpr auto hs_lengthss_impl = impl.template at<7>();
774 constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
775 constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
776 constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
777 constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
778
779 constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
780 constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
781 constexpr auto ps_to_rhss_major =
782 TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
783 constexpr auto ps_to_rhss_minor =
784 TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
785 constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
786 constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
787
788 return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
789 remove_cvref_t<decltype(hs_lengthss)>,
790 remove_cvref_t<decltype(ps_to_rhss_major)>,
791 remove_cvref_t<decltype(ps_to_rhss_minor)>,
792 remove_cvref_t<decltype(ys_to_rhs_major)>,
793 remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
794}
795
796} // namespace detail
797
798// Free print function for tile_distribution_encoding::detail
799template <typename RsLengths_,
800 typename HsLengthss_,
801 typename Ps2RHssMajor_,
802 typename Ps2RHssMinor_,
803 typename Ys2RHsMajor_,
804 typename Ys2RHsMinor_>
806print(const typename tile_distribution_encoding<RsLengths_,
807 HsLengthss_,
808 Ps2RHssMajor_,
809 Ps2RHssMinor_,
810 Ys2RHsMajor_,
811 Ys2RHsMinor_>::detail& detail_obj)
812{
813 printf("tile_distribution_encoding::detail{");
814 printf("ndim_rh_major_: ");
815 print(detail_obj.ndim_rh_major_);
816 printf(", ");
817 printf("ndim_span_major_: ");
818 print(detail_obj.ndim_span_major_);
819 printf(", ");
820 printf("ndims_rhs_minor_: ");
821 print(detail_obj.ndims_rhs_minor_);
822 printf(", ");
823 printf("ndim_rh_major_: ");
824 print(detail_obj.ndim_rh_major_);
825 printf(", ");
826 printf("max_ndim_rh_minor_: ");
827 print(detail_obj.max_ndim_rh_minor_);
828 printf(", ");
829 printf("rhs_lengthss_: ");
830 print(detail_obj.rhs_lengthss_);
831 printf(", ");
832 printf("ys_lengths_: ");
833 print(detail_obj.ys_lengths_);
834 printf(", ");
835 printf("rhs_major_minor_to_ys_: ");
836 print(detail_obj.rhs_major_minor_to_ys_);
837 printf(", ");
838 printf("ndims_span_minor_: ");
839 print(detail_obj.ndims_span_minor_);
840 printf(", ");
841 printf("max_ndim_span_minor_: ");
842 print(detail_obj.max_ndim_span_minor_);
843 printf(", ");
844 printf("ys_to_span_major_: ");
845 print(detail_obj.ys_to_span_major_);
846 printf(", ");
847 printf("ys_to_span_minor_: ");
848 print(detail_obj.ys_to_span_minor_);
849 printf(", ");
850 printf("distributed_spans_lengthss_: ");
851 print(detail_obj.distributed_spans_lengthss_);
852 printf(", ");
853 printf("ndims_distributed_spans_minor_: ");
854 print(detail_obj.ndims_distributed_spans_minor_);
855 printf(", ");
856 printf("ps_over_rs_derivative_: ");
857 print(detail_obj.ps_over_rs_derivative_);
858 printf("}");
859}
860
861// Free print function for tile_distribution_encoding
862template <typename RsLengths_,
863 typename HsLengthss_,
864 typename Ps2RHssMajor_,
865 typename Ps2RHssMinor_,
866 typename Ys2RHsMajor_,
867 typename Ys2RHsMinor_>
869 HsLengthss_,
870 Ps2RHssMajor_,
871 Ps2RHssMinor_,
872 Ys2RHsMajor_,
873 Ys2RHsMinor_>& encoding)
874{
875 printf("tile_distribution_encoding{");
876
877 printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
878 printf("rs_lengths_: ");
879 print(encoding.rs_lengths_);
880 printf(", ");
881 printf("hs_lengthss_: ");
882 print(encoding.hs_lengthss_);
883 printf(", ");
884 printf("ps_to_rhss_major_: ");
885 print(encoding.ps_to_rhss_major_);
886 printf(", ");
887 printf("ps_to_rhss_minor_: ");
888 print(encoding.ps_to_rhss_minor_);
889 printf(", ");
890 printf("ys_to_rhs_major_: ");
891 print(encoding.ys_to_rhs_major_);
892 printf(", ");
893 printf("ys_to_rhs_minor_: ");
894 print(encoding.ys_to_rhs_minor_);
895 printf(", ");
896 printf("}");
897}
898
899} // namespace ck_tile
tile_distribution_encoding< typename encoding::RsLengths, typename encoding::HsLengthss, typename encoding::Ps2RHssMajor, typename encoding::Ps2RHssMinor, shuffled< typename encoding::Ys2RHsMajor >, shuffled< typename encoding::Ys2RHsMinor > > type
Definition tile_distribution_encoding.hpp:443
Definition tile_distribution_encoding.hpp:435
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition arch.hpp:385
CK_TILE_HOST_DEVICE constexpr auto make_reduce_tile_distribution_encoding_impl(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition tile_distribution_encoding.hpp:564
CK_TILE_HOST_DEVICE constexpr auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition tile_distribution_encoding.hpp:762
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/arch/amd_buffer_addressing.hpp:110
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1045
CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence< Xs... >)
Definition tile/core/container/sequence.hpp:832
typename tile_distribution_encoding_shuffle< encoding, shuffle >::type tile_distribution_encoding_shuffle_t
Definition tile_distribution_encoding.hpp:451
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X &x, const Ys &... ys)
Definition tile/core/container/container_helper.hpp:363
CK_TILE_HOST_DEVICE constexpr auto generate_array(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1115
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto unpack(F &&f, X &&x)
Definition tile/core/utility/functional.hpp:200
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition tile/core/container/container_helper.hpp:447
CK_TILE_HOST_DEVICE constexpr auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition tile/core/container/sequence.hpp:1102
typename sequence_merge< Seqs... >::type sequence_merge_t
Definition tile/core/container/sequence.hpp:1023
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition tile/core/container/tuple.hpp:630
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
constexpr auto prefix_sum_sequence(Seq)
Definition tile/core/container/sequence.hpp:908
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
static CK_TILE_HOST_DEVICE constexpr auto size()
Definition tile/core/container/array.hpp:97
Definition tile/core/numeric/math.hpp:329
Definition tile/core/numeric/math.hpp:122
Definition tile/core/container/sequence.hpp:593
Definition tile/core/container/sequence.hpp:49
static CK_TILE_HOST_DEVICE constexpr index_t size()
Definition tile/core/container/sequence.hpp:53
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:58
static constexpr index_t max_ndim_span_minor_
Definition tile_distribution_encoding.hpp:129
static CK_TILE_HOST_DEVICE constexpr auto get_uniformed_idx_p_to_h()
Definition tile_distribution_encoding.hpp:339
static constexpr auto rhs_lengthss_
Definition tile_distribution_encoding.hpp:82
static CK_TILE_HOST_DEVICE constexpr auto get_uniformed_p_dim_lengths_over_h()
Definition tile_distribution_encoding.hpp:276
static constexpr auto distributed_spans_lengthss_
Definition tile_distribution_encoding.hpp:169
static constexpr auto does_p_own_r_
Definition tile_distribution_encoding.hpp:202
static CK_TILE_HOST_DEVICE constexpr auto get_sorted_y_to_h_info()
Definition tile_distribution_encoding.hpp:427
static CK_TILE_HOST_DEVICE constexpr auto get_uniformed_h_dim_lengths()
Definition tile_distribution_encoding.hpp:263
static CK_TILE_HOST_DEVICE constexpr auto get_uniformed_rh_dim_lengths()
Definition tile_distribution_encoding.hpp:312
static constexpr index_t max_ndim_rh_minor_
Definition tile_distribution_encoding.hpp:78
static CK_TILE_HOST_DEVICE constexpr auto get_h_dim_lengths_prefix_sum()
Definition tile_distribution_encoding.hpp:321
static constexpr auto ys_to_span_major_
Definition tile_distribution_encoding.hpp:158
static CK_TILE_HOST_DEVICE constexpr auto get_uniformed_idx_y_to_rh()
Definition tile_distribution_encoding.hpp:358
static constexpr auto rhs_major_minor_to_span_minor_
Definition tile_distribution_encoding.hpp:133
static CK_TILE_HOST_DEVICE constexpr auto get_uniformed_idx_y_to_h()
Definition tile_distribution_encoding.hpp:371
static constexpr index_t ndim_span_major_
Definition tile_distribution_encoding.hpp:61
static constexpr auto ndims_span_minor_
Definition tile_distribution_encoding.hpp:115
static constexpr index_t ndim_rh_major_
Definition tile_distribution_encoding.hpp:60
static constexpr auto ps_over_rs_derivative_
Definition tile_distribution_encoding.hpp:230
static CK_TILE_HOST_DEVICE constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
Definition tile_distribution_encoding.hpp:412
static CK_TILE_HOST_DEVICE constexpr auto get_y_to_h_masks()
Definition tile_distribution_encoding.hpp:386
static constexpr auto ys_lengths_
Definition tile_distribution_encoding.hpp:86
static constexpr auto ndims_distributed_spans_minor_
Definition tile_distribution_encoding.hpp:189
static constexpr auto rhs_major_minor_to_ys_
Definition tile_distribution_encoding.hpp:101
static constexpr auto ndims_rhs_minor_
Definition tile_distribution_encoding.hpp:64
static constexpr auto ys_to_span_minor_
Definition tile_distribution_encoding.hpp:162
static CK_TILE_HOST_DEVICE constexpr auto get_rh_dim_lengths_prefix_sum()
Definition tile_distribution_encoding.hpp:330
Definition tile_distribution_encoding.hpp:26
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition tile/core/container/container_helper.hpp:486
#define TO_SEQUENCE(a, n)
Definition to_sequence.hpp:10