threadwise_tensor_slice_transfer_v3r1_gather.hpp Source File

threadwise_tensor_slice_transfer_v3r1_gather.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v3r1_gather.hpp Source File
threadwise_tensor_slice_transfer_v3r1_gather.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
14
15namespace ck {
16
17// Assume:
18// 1. src_desc and dst_desc are not known at compile-time
19// 2. SrcBuffer and DstBuffer are DynamicBuffer
20// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
21// 4. Use thread buffer
22template <typename SliceLengths,
23 typename SrcElementwiseOperation,
24 typename DstElementwiseOperation,
26 typename SrcData,
27 typename DstData,
28 typename SrcDesc,
29 typename DstDesc,
30 typename SrcDimAccessOrder,
31 typename DstDimAccessOrder,
32 index_t SrcVectorDim,
33 index_t DstVectorDim,
34 index_t SrcScalarPerVector_,
35 index_t DstScalarPerVector_,
36 index_t SrcScalarStrideInVector,
37 index_t DstScalarStrideInVector,
38 bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
39 // RunRead(), will be fused with MoveSrcSliceWindow to
40 // save addr computation
41 bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
42 // RunWrite(), will be fused with MoveDstSliceWindow to
43 // save addr computation
44 typename IndexType,
45 index_t GatherDim = 1,
46 index_t NumThreadScratch = 1>
48{
49 static constexpr index_t nDim = SliceLengths::Size();
51
52 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
53 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
54
55 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
56 using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
57
58 static constexpr auto I0 = Number<0>{};
59 static constexpr auto I1 = Number<1>{};
60 static constexpr auto I2 = Number<2>{};
61 static constexpr auto I3 = Number<3>{};
62 static constexpr auto I4 = Number<4>{};
63 static constexpr auto I5 = Number<5>{};
64 static constexpr auto I6 = Number<6>{};
65 static constexpr auto I7 = Number<7>{};
66 static constexpr auto I8 = Number<8>{};
67 static constexpr auto I10 = Number<10>{};
68 static constexpr auto I12 = Number<12>{};
69 static constexpr auto I13 = Number<13>{};
70 static constexpr auto I14 = Number<14>{};
71 static constexpr auto I16 = Number<16>{};
72
73 static constexpr index_t PackedSize = []() {
75 return 2;
76 else
77 return 1;
78 }();
79
80 static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
81 static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
82
83 static constexpr index_t gather_num = SliceLengths{}.At(Number<GatherDim>{});
84
86 const SrcDesc& src_desc,
87 const Index& src_slice_origin,
88 const SrcElementwiseOperation& src_element_op,
89 const DstDesc& dst_desc,
90 const Index& dst_slice_origin,
91 const DstElementwiseOperation& dst_element_op,
93 : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
94 dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
95 src_element_op_(src_element_op),
96 dst_element_op_(dst_element_op),
97 gather_offsets_(gather_offsets)
98 {
99 if constexpr((packed_size_v<SrcData>) > 1)
100 {
102 "SrcData != DstData");
103
104 static_assert(
105 SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
106 "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
107
108 static_assert(SrcVectorDim == DstVectorDim,
109 "Packed data type does not support transpose");
110 }
111 }
112
113 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
114 {
115
116 auto adjusted_origin_idx = [&]() {
117 Index idx;
118 static_for<0, nDim, 1>{}([&](auto i) {
119 idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number<i>{}];
120 });
121 return idx;
122 }();
123 src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
124 }
125
126 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
127 {
128 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
129 }
130
131 template <typename SrcBuffer, index_t ThreadScratchId = 0>
132 __device__ void RunRead(const SrcDesc& src_desc,
133 const SrcBuffer& src_buf,
135 {
136 static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
137 SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
138 "wrong!");
139
140 static_assert(
142 "wrong! SrcBuffer and SrcData data type are inconsistent");
143
144 // scalar per access on each dim
145 // TODO: don't use lambda_scalar_per_access
146 constexpr auto src_scalar_per_access = generate_sequence(
148
149 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
150
151 static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
152 "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
153
154 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
155 constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim];
156 constexpr auto ordered_src_access_lengths =
157 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
158
159 // make forward steps
160 const auto src_forward_steps = generate_tuple(
161 [&](auto i) {
162 Index forward_step_idx;
163
164 static_for<0, nDim, 1>{}([&](auto j) {
165 forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
166 });
167
168 return make_tensor_coordinate_step(src_desc, forward_step_idx);
169 },
170 Number<nDim>{});
171
172 // make backward steps
173 const auto src_backward_steps = generate_tuple(
174 [&](auto i) {
175 Index backward_step_idx;
176
177 static_for<0, nDim, 1>{}([&](auto j) {
178 backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
179 });
180
181 return make_tensor_coordinate_step(src_desc, backward_step_idx);
182 },
183 Number<nDim>{});
184
185 // loop over tensor and copy
186 static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
187 // judge move forward or move backward
188 constexpr auto forward_sweep = [&]() {
190
191 forward_sweep_(I0) = true;
192
193 static_for<1, nDim, 1>{}([&](auto i) {
194 index_t tmp = ordered_src_access_idx[I0];
195
196 static_for<1, i, 1>{}([&](auto j) {
197 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
198 });
199
200 forward_sweep_(i) = tmp % 2 == 0;
201 });
202
203 return forward_sweep_;
204 }();
205
206 // calculate src data index
207 constexpr auto src_data_idx = [&]() {
208 Index ordered_idx;
209
210 static_for<0, nDim, 1>{}([&](auto i) {
211 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
212 : ordered_src_access_lengths[i] - 1 -
213 ordered_src_access_idx[i];
214 });
215
216 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
217 src_scalar_per_access;
218 }();
219
220 constexpr auto src_data_idx_seq = generate_sequence_v2(
221 [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
222
223 auto gather_offset =
224 gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
225
226 const IndexType ld_offset = src_coord_.GetOffset() / PackedSize + gather_offset;
227 src_oob_thread_scratch_tuple_(thread_scratch_id)
228 .template SetAsType<bool>(src_data_idx_seq, true);
229
231 using src_vector_t = typename src_vector_type::type;
232
233 auto src_vector_container =
234 src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)};
235
237 using dst_vector_t = typename dst_vector_type::type;
238 dst_vector_type op_r_v;
239
240 constexpr auto get_elem_op_vec_len = []() {
241 if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
242 {
243 if constexpr(decltype(src_element_op_)::is_pack8_invocable)
244 return math::min(8, SrcScalarPerVector);
245 }
246 else if constexpr(is_detected<is_pack4_invocable_t,
247 decltype(src_element_op_)>::value)
248 {
249 if constexpr(decltype(src_element_op_)::is_pack4_invocable)
250 return math::min(4, SrcScalarPerVector);
251 }
252 else if constexpr(is_detected<is_pack2_invocable_t,
253 decltype(src_element_op_)>::value)
254 {
255 if constexpr(decltype(src_element_op_)::is_pack2_invocable)
256 return math::min(2, SrcScalarPerVector);
257 }
258 else
259 {
260 return 1;
261 }
262 };
263
264 constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
265
266 using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
267 using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
268
269 static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
270 // apply the src elementwise op and convert to DstData under the hood if needed
271 src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
272 src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
273 });
274
275 // copy data from src_vector_container into src_thread_scratch_
276 src_thread_scratch_tuple_(thread_scratch_id)
277 .template SetAsType<dst_vector_t>(src_data_idx_seq,
278 op_r_v.template AsType<dst_vector_t>()[I0]);
279
280 auto move_on_dim = [&]() constexpr {
282
283 static_for<0, nDim, 1>{}([&](auto i) {
284 move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
285
286 static_for<i + 1, nDim, 1>{}([&](auto j) {
287 move_on_dim_(i) &=
288 ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
289 });
290 move_on_dim_(i) &= i.value != ordered_gather_dim;
291 });
292
293 return move_on_dim_;
294 }();
295 // move src coord
296 static_for<0, nDim, 1>{}([&](auto i) {
297 if(move_on_dim[i])
298 {
299 if constexpr(forward_sweep[i])
300 {
302 src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
303 }
304 else
305 {
307 src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
308 }
309 }
310 });
311 });
312
313 // move src coordinate back to slice origin (or not)
314 if constexpr(SrcResetCoordinateAfterRun)
315 {
316 const auto src_reset_step =
318
319 move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
320 }
321 }
322
323 template <typename SeqIdx, index_t ThreadScratchId = 0>
324 __device__ constexpr auto
326 {
328 return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
329 }
330
331 template <index_t ThreadScratchId>
332 __device__ void
334 {
335#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
336 static_ford<SliceLengths>{}([&](auto idx) {
337 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
338 });
339#else
340
341 // OOB Check
342 constexpr auto src_scalar_per_access = generate_sequence(
344
345 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
346
347 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
348
349 constexpr auto ordered_src_access_lengths =
350 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
351
352 // loop over tensor and copy
353 static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
354 // judge move forward or move backward
355 constexpr auto forward_sweep = [&]() {
357
358 forward_sweep_(I0) = true;
359
360 static_for<1, nDim, 1>{}([&](auto i) {
361 index_t tmp = ordered_src_access_idx[I0];
362
363 static_for<1, i, 1>{}([&](auto j) {
364 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
365 });
366
367 forward_sweep_(i) = tmp % 2 == 0;
368 });
369
370 return forward_sweep_;
371 }();
372
373 // calculate src data index
374 constexpr auto src_data_idx = [&]() {
375 Index ordered_idx;
376
377 static_for<0, nDim, 1>{}([&](auto i) {
378 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
379 : ordered_src_access_lengths[i] - 1 -
380 ordered_src_access_idx[i];
381 });
382
383 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
384 src_scalar_per_access;
385 }();
386
387 constexpr auto src_data_idx_seq = generate_sequence_v2(
388 [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
389
391
392 auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
393 .template GetAsType<vector_t>(src_data_idx_seq);
394
395 auto op_r_v = op_r;
396
397 src_thread_scratch_tuple_(thread_scratch_id)
398 .template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
399 });
400
401 // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
402 // TODO make this logic more generic for more sub-dword datatype
403 if constexpr(SrcVectorDim != DstVectorDim &&
405 SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
407 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
409 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
410 {
412 "in-register transpose is not supported for pk_i4_t");
413 // each transpose does
414 // DstScalarPerVector # of src vectors in src_thread_scratch_
415 // SrcScalarPerVector # of dst vectors in dst_thread_scratch_
416 constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
417 constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
418
419 // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
420 // TODO: make this logic generic for all scenario
421 static_assert(SrcVectorDim != DstVectorDim, "wrong");
422
423 constexpr auto src_scalar_step_in_vector = generate_sequence(
425
426 constexpr auto dst_scalar_step_in_vector = generate_sequence(
428
429 constexpr auto scalar_per_access = generate_sequence(
432 DstVectorDim,
434 Number<nDim>{});
435
436 constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
437
438 static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
439 constexpr auto data_idx = access_idx * scalar_per_access;
440
441 constexpr auto data_idx_seq = generate_sequence_v2(
442 [&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
443
446
447 // get DstScalarPerVector # of read-only references to src vectors from
448 // src_thread_scratch_
449 const auto src_vector_refs = generate_tie(
450 [&](auto i) -> const src_vector_t& {
451 // i increment corresponds to movement in DstVectorDim
452 return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
453 data_idx_seq + i * dst_scalar_step_in_vector);
454 },
456
457 // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
458 auto dst_vector_refs = generate_tie(
459 [&](auto i) -> dst_vector_t& {
460 // i increment corresponds to movement in SrcVectorDim
461 return dst_thread_scratch_.GetVectorTypeReference(
462 data_idx_seq + i * src_scalar_step_in_vector);
463 },
465
466 // do data transpose
468 src_vector_refs, dst_vector_refs);
469 });
470 }
471 else
472 {
473 constexpr auto packed_per_access = generate_sequence(
475
476 constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
477
478 static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
479 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
480 });
481 }
482#endif
483 }
484
485 template <typename DstBuffer, index_t ThreadScratchId = 0>
486 __device__ void RunWrite(const DstDesc& dst_desc,
487 DstBuffer& dst_buf,
489 {
490 // if there is transpose, it's done here
491 // if there is oob check, it's done here
492 // TODO move this elsewhere
494
495 static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
496 DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
497 "wrong!");
498
499 static_assert(
501 "wrong! SrcBuffer or DstBuffer data type is wrong");
502
503 // src scalar per access on each dim
504 // TODO: don't use this
505 constexpr auto dst_scalar_per_access = generate_sequence(
507
508 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
509
510 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
511
512 constexpr auto ordered_dst_access_lengths =
513 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
514
515 // make forward steps
516 const auto dst_forward_steps = generate_tuple(
517 [&](auto i) {
518 Index forward_step_idx;
519
520 static_for<0, nDim, 1>{}([&](auto j) {
521 forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
522 });
523
524 return make_tensor_coordinate_step(dst_desc, forward_step_idx);
525 },
526 Number<nDim>{});
527
528 // make backward steps
529 const auto dst_backward_steps = generate_tuple(
530 [&](auto i) {
531 Index backward_step_idx;
532
533 static_for<0, nDim, 1>{}([&](auto j) {
534 backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
535 });
536
537 return make_tensor_coordinate_step(dst_desc, backward_step_idx);
538 },
539 Number<nDim>{});
540
541 // loop over tensor and copy
542 static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
543 // judge move forward or move backward
544 constexpr auto forward_sweep = [&]() {
546
547 forward_sweep_(I0) = true;
548
549 static_for<1, nDim, 1>{}([&](auto i) {
550 index_t tmp = ordered_dst_access_idx[I0];
551
552 static_for<1, i, 1>{}([&](auto j) {
553 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
554 });
555
556 forward_sweep_(i) = tmp % 2 == 0;
557 });
558
559 return forward_sweep_;
560 }();
561
562 // calculate dst data index
563 constexpr auto dst_data_idx = [&]() {
564 Index ordered_idx;
565
566 static_for<0, nDim, 1>{}([&](auto i) {
567 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
568 : ordered_dst_access_lengths[i] - 1 -
569 ordered_dst_access_idx[i];
570 });
571
572 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
573 dst_scalar_per_access;
574 }();
575
576 constexpr auto dst_data_idx_seq = generate_sequence_v2(
577 [&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
578
579 const bool is_dst_valid =
581
583 using dst_vector_t = typename dst_vector_type::type;
584
585 // copy data from dst_thread_scratch_ into dst_vector_container
586 auto dst_vector_container = dst_vector_type{
587 dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
588
589 static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
590 DstData dst_v;
591
592 // apply DstElementwiseOperation
593 dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
594
595 dst_vector_container.template AsType<DstData>()(i) = dst_v;
596 });
597
598 // copy data from dst_vector_container to dst_buf
599 dst_buf.template Set<dst_vector_t>(
600 dst_coord_.GetOffset() / PackedSize,
601 is_dst_valid,
602 dst_vector_container.template AsType<dst_vector_t>()[I0]);
603
604 constexpr auto move_on_dim = [&]() constexpr {
606
607 static_for<0, nDim, 1>{}([&](auto i) {
608 move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
609
610 static_for<i + 1, nDim, 1>{}([&](auto j) {
611 move_on_dim_(i) &=
612 ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
613 });
614 });
615
616 return move_on_dim_;
617 }();
618
619 // move dst coord
620 static_for<0, nDim, 1>{}([&](auto i) {
621 if constexpr(move_on_dim[i])
622 {
623 if constexpr(forward_sweep[i])
624 {
626 dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
627 }
628 else
629 {
631 dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
632 }
633 }
634 });
635 });
636
637 // move dst coordinate back to slice origin (or not)
638 if constexpr(DstResetCoordinateAfterRun)
639 {
640 const auto dst_reset_step =
642
643 move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
644 }
645 }
646
647 __device__ static constexpr auto GetSrcCoordinateResetStep()
648 {
649 // scalar per access on each dim
650 // TODO: don't use lambda_scalar_per_access
651 constexpr auto src_scalar_per_access = generate_sequence(
653
654 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
655
656 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
657
658 constexpr auto ordered_src_access_lengths =
659 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
660
661 // judge move forward or move backward during the last iteration
662 constexpr auto forward_sweep = [&]() {
664
665 forward_sweep_(I0) = true;
666
667 static_for<1, nDim, 1>{}([&](auto i) {
668 index_t tmp = ordered_src_access_lengths[I0] - 1;
669
670 static_for<1, i, 1>{}([&](auto j) {
671 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
672 });
673
674 forward_sweep_(i) = tmp % 2 == 0;
675 });
676
677 return forward_sweep_;
678 }();
679
680 // calculate src data index after last iteration in RunRead(), if it has not being reset by
681 // RunRead()
682 constexpr auto src_data_idx = [&]() {
683 Index ordered_idx;
684
685 static_for<0, nDim, 1>{}([&](auto i) {
686 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
687 });
688
689 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
690 src_scalar_per_access;
691 }();
692
693 //
694 constexpr auto reset_src_data_step = [&]() {
695 Index reset_src_data_step_;
696
697 static_for<0, nDim, 1>{}([&](auto i) {
698 reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i];
699 });
700
701 return reset_src_data_step_;
702 }();
703 return reset_src_data_step;
704 }
705
706 __device__ static constexpr auto GetDstCoordinateResetStep()
707 {
708 // scalar per access on each dim
709 // TODO: don't use lambda_scalar_per_access
710 constexpr auto dst_scalar_per_access = generate_sequence(
712
713 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
714
715 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
716
717 constexpr auto ordered_dst_access_lengths =
718 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
719
720 // judge move forward or move backward during the last iteration
721 constexpr auto forward_sweep = [&]() {
723
724 forward_sweep_(I0) = true;
725
726 static_for<1, nDim, 1>{}([&](auto i) {
727 index_t tmp = ordered_dst_access_lengths[I0] - 1;
728
729 static_for<1, i, 1>{}([&](auto j) {
730 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
731 });
732
733 forward_sweep_(i) = tmp % 2 == 0;
734 });
735
736 return forward_sweep_;
737 }();
738
739 // calculate dst data index after last iteration in RunWrite(), if it has not being reset by
740 // RunWrite()
741 constexpr auto dst_data_idx = [&]() {
742 Index ordered_idx;
743
744 static_for<0, nDim, 1>{}([&](auto i) {
745 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
746 });
747
748 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
749 dst_scalar_per_access;
750 }();
751
752 //
753 constexpr auto reset_dst_data_step = [&]() {
754 Index reset_dst_data_step_;
755
756 static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
757
758 return reset_dst_data_step_;
759 }();
760
761 return reset_dst_data_step;
762 }
763
764 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
765 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
766 const Index& src_slice_origin_step_idx)
767 {
768 // if src coord was not reset by RunRead(), then need to adjust the step here
769 const auto adjusted_step_idx =
770 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
771 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
772 // is it OK to construct a new step every time?
773 const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
774
775 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
776 }
777
778 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
779 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
780 const Index& dst_slice_origin_step_idx)
781 {
782 // if dst coord was not reset by RunWrite(), then need to adjust the step here
783 const auto adjusted_step_idx =
784 DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
785 : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
786
787 // is it OK to construct a new step every time?
788 const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
789
790 move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
791 }
792
793 __device__ static constexpr auto GetSrcThreadScratchDescriptor()
794 {
795 constexpr auto src_scalar_per_access = generate_sequence(
797
798 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
799
800 constexpr auto src_access_lengths_and_vector_length = container_push_back(
802
803 // 1st stage of transforms
804 constexpr auto desc0 =
805 make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
806
807 // 2nd stage of transforms
808 constexpr auto transforms = generate_tuple(
809 [&](auto i) {
810 if constexpr(i == SrcVectorDim)
811 {
813 make_tuple(src_access_lengths_and_vector_length[i],
814 src_access_lengths_and_vector_length[Number<nDim>{}]));
815 }
816 else
817 {
818 return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
819 }
820 },
821 Number<nDim>{});
822
823 constexpr auto low_dim_idss = generate_tuple(
824 [&](auto i) {
825 if constexpr(i == SrcVectorDim)
826 {
827 return Sequence<i.value, nDim>{};
828 }
829 else
830 {
831 return Sequence<i.value>{};
832 }
833 },
834 Number<nDim>{});
835
836 constexpr auto up_dim_idss =
837 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
838
839 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
840 }
841
842 __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
843 {
844 constexpr auto src_scalar_per_access = generate_sequence(
846
847 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
848
849 return make_naive_tensor_descriptor_packed(src_access_lengths);
850 }
851
852 __device__ static constexpr auto GetDstThreadScratchDescriptor()
853 {
854 // 1st stage of transforms
855 constexpr auto dst_scalar_per_access = generate_sequence(
857
858 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
859
860 constexpr auto dst_access_lengths_and_vector_length = container_push_back(
862
863 constexpr auto desc0 =
864 make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
865
866 // 2nd stage of transforms
867 constexpr auto transforms = generate_tuple(
868 [&](auto i) {
869 if constexpr(i == DstVectorDim)
870 {
872 make_tuple(dst_access_lengths_and_vector_length[i],
873 dst_access_lengths_and_vector_length[Number<nDim>{}]));
874 }
875 else
876 {
877 return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
878 }
879 },
880 Number<nDim>{});
881
882 constexpr auto low_dim_idss = generate_tuple(
883 [&](auto i) {
884 if constexpr(i == DstVectorDim)
885 {
886 return Sequence<i.value, nDim>{};
887 }
888 else
889 {
890 return Sequence<i.value>{};
891 }
892 },
893 Number<nDim>{});
894
895 constexpr auto up_dim_idss =
896 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
897
898 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
899 }
900
901 private:
902 static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
903 static constexpr auto src_oob_thread_scratch_desc_ =
904 decltype(GetSrcThreadScratchDescriptor()){};
905 static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
906
907 using SrcThreadScratch =
908 StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
909 DstData, // apply data_convert with SrcThreadScratch
911 decltype(src_thread_scratch_desc_),
912 true>;
913
914 using SrcOOBThreadScratch =
915 StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
916 bool, // apply data_convert with SrcThreadScratch
917 1,
918 decltype(src_oob_thread_scratch_desc_),
919 true>;
920
921 using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
922 DstData,
924 decltype(dst_thread_scratch_desc_),
925 true>;
926
929
930 DstThreadScratch dst_thread_scratch_;
931
932 SrcCoord src_coord_;
933 DstCoord dst_coord_;
934 const SrcElementwiseOperation src_element_op_;
935 const DstElementwiseOperation dst_element_op_;
937};
938
939} // namespace ck
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
decltype(ck::declval< T & >().is_pack8_invocable) is_pack8_invocable_t
Definition is_detected.hpp:43
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
decltype(ck::declval< T & >().is_pack4_invocable) is_pack4_invocable_t
Definition is_detected.hpp:40
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Lds
Definition amd_address_space.hpp:18
@ Global
Definition amd_address_space.hpp:17
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
decltype(ck::declval< T & >().is_pack2_invocable) is_pack2_invocable_t
Definition is_detected.hpp:37
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
constexpr index_t packed_size_v
Definition data_type.hpp:411
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition utility/sequence.hpp:43
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:126
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:647
static __device__ constexpr auto GetDstThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:852
static __device__ constexpr auto GetSrcOOBThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:842
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch(Number< ThreadScratchId > thread_scratch_id)
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:333
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_gather(const SrcDesc &src_desc, const Index &src_slice_origin, const SrcElementwiseOperation &src_element_op, const DstDesc &dst_desc, const Index &dst_slice_origin, const DstElementwiseOperation &dst_element_op, const StaticallyIndexedArray< IndexType, gather_num > &gather_offsets)
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:85
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:765
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:113
__device__ constexpr auto GetSrcThreadScratchIdx(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:325
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:486
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:779
static __device__ constexpr auto GetSrcThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:793
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:706
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1_gather.hpp:132
Definition threadwise_tensor_slice_transfer_util.hpp:43
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition threadwise_tensor_slice_transfer_util.hpp:29
Definition type.hpp:177
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition functional3.hpp:97
Definition utility/transpose_vectors.hpp:16
Definition dtype_vector.hpp:30