thread_group_tensor_slice_transfer_v7.hpp Source File

thread_group_tensor_slice_transfer_v7.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v7.hpp Source File
thread_group_tensor_slice_transfer_v7.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14// Thread-group level multi-source, multi-destination tensor slice data movement
15// Assume:
16// 1. All sources and destinations are DynamicBuffer
17// 2. Same VectorDim and ScalerPerVector for all sources and destinations
18// 3. DstInMemOps are per destination tensor
19// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
20// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
21//
22// Does following things to avoid scratch memory issue
23// 1. Pass tensor descritpors by reference (or tuple of references)
24// 2. Does not keep reference to tensor descriptor
25// 3. Does not construct new tensor coordinate when call Run()
26template <typename ThreadGroup,
27 typename SrcDatas,
28 typename DstDatas,
29 typename SrcDescs,
30 typename DstDescs,
31 typename ElementwiseOperation,
32 typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
33 typename SliceLengths,
34 typename ThreadClusterLengths,
35 typename ThreadClusterArrangeOrder,
36 typename DimAccessOrder,
37 index_t VectorDim,
38 index_t ScalarPerVector,
39 typename ThreadTransferSrcResetCoordinateAfterRunFlags,
40 typename ThreadTransferDstResetCoordinateAfterRunFlags>
42{
43 static constexpr index_t nDim =
45
48
50
51 static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
52
54 const SrcDescs& src_descs,
55 const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
56 const DstDescs& dst_descs,
57 const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
58 const ElementwiseOperation& element_op)
59 : threadwise_transfer_(src_descs,
61 dst_descs,
63 element_op)
64 {
65 static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
66 nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
67 nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
68 nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
69 "wrong!");
70
71 static_for<0, nSrc, 1>{}([&](auto i) {
72 static_assert(
73 nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
74 "wrong!");
75 });
76
77 static_for<0, nDst, 1>{}([&](auto i) {
78 static_assert(
79 nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
80 "wrong!");
81 });
82
83 static_assert(nDim == ThreadClusterLengths::Size() &&
84 nDim == ThreadClusterArrangeOrder::Size() &&
85 nDim == DimAccessOrder::Size(),
86 "wrong! nDim not consistent");
87
88 static_assert(
89 is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
90 "wrong! threads should be mapped to cover entire slicing window");
91
92 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
93 "wrong! ThreadGroup::GetNumOfThread() too small");
94
95 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
96 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
97 {
98 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
100
101 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
102
103 const auto src_thread_slice_origins = generate_tuple(
104 [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
105 Number<nSrc>{});
106
107 const auto dst_thread_slice_origins = generate_tuple(
108 [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
109 Number<nDst>{});
110
111 threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
112 threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
113 }
114 }
115
116 template <typename SrcBuffers, typename DstBuffers>
117 __device__ void Run(const SrcDescs& src_descs,
118 const SrcBuffers& src_bufs,
119 const DstDescs& dst_descs,
120 DstBuffers dst_bufs)
121 {
122 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
123 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
124 {
125 threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs);
126 }
127 }
128
129 template <index_t ISrc>
130 __device__ void
131 MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
132 {
133 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
134 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
135 {
136 threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
137 }
138 }
139
140 template <index_t IDst>
141 __device__ void
142 MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
143 {
144 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
145 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
146 {
147 threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
148 }
149 }
150
151 private:
152 static constexpr auto thread_cluster_desc_ =
153 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
154
155 using ThreadwiseTransfer =
156 ThreadwiseTensorSliceTransfer_v7<SrcDatas,
157 DstDatas,
158 SrcDescs,
159 DstDescs,
160 ElementwiseOperation,
161 DstInMemOps,
162 decltype(thread_slice_lengths),
163 DimAccessOrder,
164 VectorDim,
165 ScalarPerVector,
166 ThreadTransferSrcResetCoordinateAfterRunFlags,
167 ThreadTransferDstResetCoordinateAfterRunFlags>;
168
169 ThreadwiseTransfer threadwise_transfer_;
170};
171
172} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
static constexpr index_t nDst
Definition thread_group_tensor_slice_transfer_v7.hpp:47
static constexpr index_t nSrc
Definition thread_group_tensor_slice_transfer_v7.hpp:46
__device__ constexpr ThreadGroupTensorSliceTransfer_v7(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_block_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_block_slice_origins, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v7.hpp:53
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v7.hpp:49
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v7.hpp:51
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v7.hpp:43
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &step)
Definition thread_group_tensor_slice_transfer_v7.hpp:131
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &step)
Definition thread_group_tensor_slice_transfer_v7.hpp:142
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs)
Definition thread_group_tensor_slice_transfer_v7.hpp:117