gridwise_normalization_splitk_1st.hpp Source File

gridwise_normalization_splitk_1st.hpp Source File#

Composable Kernel: gridwise_normalization_splitk_1st.hpp Source File
gridwise_normalization_splitk_1st.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/math.hpp"
12
13namespace ck {
14
15template <typename XDataType,
16 typename ComputeDataType,
17 typename MeanVarDataType,
18 typename XGridDesc_M_K,
19 typename MeanVarGridDesc_M_KBlock,
20 index_t BlockSize,
21 index_t MThreadClusterSize,
22 index_t KThreadClusterSize,
23 index_t MThreadSliceSize,
24 index_t KThreadSliceSize,
25 index_t XSrcVectorDim,
26 index_t XSrcVectorSize>
28{
29 static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
30 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
31 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
32
33 static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
34
35 static constexpr auto I0 = Number<0>{};
36 static constexpr auto I1 = Number<1>{};
37 static constexpr auto I2 = Number<2>{};
38
40
43
46
47 static constexpr auto thread_cluster_desc =
49
53
55 static constexpr auto thread_buffer_desc_m_1 =
57
62
65
66 using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
67 BlockSize,
70 false>;
71
73
74 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
75 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
76 static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
77
78 static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
79
80 __device__ static int
81 GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
82 {
83 bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
84
85 if(is_rightmost_block)
86 {
87 int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
88 int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
89 int kPerThread = kRightmostBlock < K_BlockTileSize
90 ? 0
91 : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
92 int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
93
94 if(kPerBlockTail > 0)
95 {
97 int thread_max_len =
98 (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
99 int delta = thread_max_len - kPerBlockTail;
100 delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
101 kPerThread += XSrcVectorSize - delta;
102 });
103 }
104
105 return kPerThread;
106 }
107 else
108 {
109 int kPerBlock = math::integer_divide_ceil(k, kGridSize);
110 return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
111 }
112 }
113
114 // Calculate mean and variance by welford along k dimension
115 __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
116 const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock,
117 index_t num_k_block_tile_iteration,
118 const XDataType* const __restrict__ p_x_global,
119 MeanVarDataType* const p_mean_global,
120 MeanVarDataType* const p_variance_global,
121 int32_t* const p_welford_count_global)
122 {
123 auto x_thread_buf = generate_tuple(
124 [&](auto) {
126 ComputeDataType,
127 MThreadSliceSize * XSrcVectorSize,
128 true>{};
129 },
131
133 mean_thread_buf;
135 var_thread_buf;
136
137 const index_t thread_local_id = get_thread_local_1d_id();
138 const index_t block_global_id = get_block_1d_id();
139
140 const index_t k_grid_size = mean_var_grid_desc_m_kblock.GetLength(I1);
141 const index_t block_m_cluster_id = block_global_id / k_grid_size;
142 const index_t block_k_cluster_id = block_global_id % k_grid_size;
143
144 const auto thread_cluster_idx =
145 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
146
147 const auto thread_m_cluster_id = thread_cluster_idx[I0];
148 const auto thread_k_cluster_id = thread_cluster_idx[I1];
149
150 const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
151
152 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
153 ComputeDataType,
154 XGridDesc_M_K,
155 decltype(thread_buffer_desc_m_k),
158 XSrcVectorDim,
159 XSrcVectorSize,
160 1,
161 true>(
162 x_grid_desc_m_k,
164 block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
165 block_k_cluster_id * reduceSizePerBlock + thread_k_cluster_id * XSrcVectorSize));
166
167 auto mean_var_count_store_index = make_multi_index(
168 block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
169 block_k_cluster_id);
170
171 auto threadwise_welford_mean_var_store =
173 MeanVarDataType,
174 decltype(thread_buffer_desc_m_1),
175 MeanVarGridDesc_M_KBlock,
179 1,
180 1,
182 1,
183 true>(
184 mean_var_grid_desc_m_kblock, mean_var_count_store_index, PassThroughOp{});
185
186 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
187
188 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
189 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
190
191 auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
192 p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
193
194 auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
195 p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
196
197 auto threadwise_welford = ThreadwiseWelford();
198 int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
199 threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
200 kRaw,
201 k_grid_size,
202 block_k_cluster_id,
203 thread_k_cluster_id);
204
206 mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
207 var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
208 });
209
210 for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
211 {
213 threadwise_x_load.Run(x_grid_desc_m_k,
214 x_global_val_buf,
216 make_tuple(I0, I0),
217 x_thread_buf(i));
218 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
219 threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
220 });
221 }
222
223 int welford_count = 0;
225 if constexpr(I > 0)
227
228 int count = threadwise_welford.cur_count_;
229 BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
230
231 // The value of count is same for all I
232 if constexpr(I == MThreadSliceSize - 1)
233 welford_count = count;
234 });
235
236 if(thread_k_cluster_id == 0)
237 {
238 threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
239 make_tuple(I0, I0),
240 mean_thread_buf,
241 mean_var_grid_desc_m_kblock,
242 mean_global_val_buf);
243
244 threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
245 make_tuple(I0, I0),
246 var_thread_buf,
247 mean_var_grid_desc_m_kblock,
248 var_global_val_buf);
249
250 if(block_m_cluster_id == 0 && thread_m_cluster_id == 0)
251 p_welford_count_global[block_k_cluster_id] = welford_count;
252 }
253 }
254};
255
256} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition utility/math.hpp:148
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
signed int int32_t
Definition stdint.h:123
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_normalization_splitk_1st.hpp:28
static __device__ void Run(const XGridDesc_M_K &x_grid_desc_m_k, const MeanVarGridDesc_M_KBlock &mean_var_grid_desc_m_kblock, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x_global, MeanVarDataType *const p_mean_global, MeanVarDataType *const p_variance_global, int32_t *const p_welford_count_global)
Definition gridwise_normalization_splitk_1st.hpp:115
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< XSrcVectorSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_normalization_splitk_1st.hpp:58
static __device__ int GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
Definition gridwise_normalization_splitk_1st.hpp:81
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340