device_softmax_impl.hpp Source File

device_softmax_impl.hpp Source File#

Composable Kernel: device_softmax_impl.hpp Source File
device_softmax_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
16
17namespace ck {
18namespace tensor_operation {
19namespace device {
20
21template <typename InDataType,
22 typename AccDataType,
23 typename OutDataType,
24 typename InElementwiseOp,
25 typename AccElementwiseOp,
26 index_t Rank,
27 index_t NumReduceDim,
28 index_t BlockSize,
29 index_t MThreadClusterSize,
30 index_t KThreadClusterSize,
31 index_t MThreadSliceSize,
32 index_t KThreadSliceSize,
33 index_t InSrcVectorDim,
34 index_t InSrcVectorSize,
35 index_t OutDstVectorSize>
36struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
37 AccDataType,
38 OutDataType,
39 InElementwiseOp,
40 AccElementwiseOp,
41 Rank,
42 NumReduceDim>
43{
44 static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
45
46 static constexpr index_t NumSrcDim = Rank;
47 static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
48 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
49
50 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
51 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
52
53 static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
54 const std::vector<index_t>& inStrides,
55 int blkGroupSize,
56 int numBlockTileIteration)
57 {
58 const auto tupleSrcLengths =
59 generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
60 const auto tupleSrcStrides =
61 generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
62
63 const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
64
65 const auto in_grid_desc_m_k = [&]() {
66 if constexpr(reduceAllDim)
67 {
68 const auto one_dim_inDesc = transform_tensor_descriptor(
69 inDesc,
70 make_tuple(make_merge_transform(tupleSrcLengths)),
73
74 return transform_tensor_descriptor(one_dim_inDesc,
76 1, one_dim_inDesc.GetLength(Number<0>{})))),
79 }
80 else
81 {
82 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
84
85 const auto reduceDimLengths = generate_tuple(
86 [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
87 const auto invariantDimLengths =
88 generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
89
91 inDesc,
92 make_tuple(make_merge_transform(invariantDimLengths),
93 make_merge_transform(reduceDimLengths)),
94 make_tuple(InvariantDims{}, ReduceDims{}),
96 }
97 }();
98
99 const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
100 const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
101
102 const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
103 const auto inPad_M =
104 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
105 const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
106
107 auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
108 in_grid_desc_m_k,
109 make_tuple(make_right_pad_transform(invariantLength, inPad_M),
110 make_right_pad_transform(reduceLength, inPad_K)),
113
114 return (in_grid_desc_m_k_padded);
115 };
116
117 using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
118
120 OutDataType,
121 AccDataType,
123 BlockSize,
124 MThreadClusterSize,
125 KThreadClusterSize,
126 MThreadSliceSize,
127 KThreadSliceSize,
128 InSrcVectorDim,
129 InSrcVectorSize,
130 OutDstVectorSize,
131 false>;
132
134 OutDataType,
135 AccDataType,
137 BlockSize,
138 MThreadClusterSize,
139 KThreadClusterSize,
140 MThreadSliceSize,
141 KThreadSliceSize,
142 InSrcVectorDim,
143 InSrcVectorSize,
144 OutDstVectorSize,
145 true>;
146
147 struct Argument : public BaseArgument
148 {
149 Argument(const std::vector<index_t> inLengths,
150 const std::vector<index_t> inStrides,
151 const std::vector<index_t> reduceDims,
152 double alpha,
153 double beta,
154 const InDataType* in_dev,
155 OutDataType* out_dev,
156 InElementwiseOp in_elementwise_op,
157 AccElementwiseOp acc_elementwise_op)
158 : in_dev_{in_dev},
159 out_dev_{out_dev},
160 in_elementwise_op_{in_elementwise_op},
161 acc_elementwise_op_{acc_elementwise_op}
162 {
163 alpha_ = static_cast<AccDataType>(alpha);
164 beta_ = static_cast<AccDataType>(beta);
165
166 if(Rank != inLengths.size() || Rank != inStrides.size() ||
167 NumReduceDim != reduceDims.size())
168 {
169 throw std::runtime_error(
170 "One of inLengths/inStrides/reduceDims has invalid size!"
171 "\nExpected size inLengths: " +
172 std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
173 ", reduceDims: " + std::to_string(NumReduceDim) +
174 "\nBut have inLengths: " + std::to_string(inLengths.size()) +
175 ", inStrides: " + std::to_string(inStrides.size()) +
176 ", reduceDims: " + std::to_string(reduceDims.size()));
177 }
178
179 for(std::size_t i = 0; i < reduceDims.size(); ++i)
180 {
181 if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
182 {
183 throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
184 "\nHave reduceDims[" +
185 std::to_string(i) +
186 "]: " + std::to_string(reduceDims[i]));
187 }
188 }
189
192
193 long_index_t invariant_total_length;
194 long_index_t reduce_total_length;
195
196 std::tie(invariant_total_length, reduce_total_length) =
198
199 if constexpr(NumInvariantDim == 0)
201 else
203
204 blkGroupSize = 1;
205 numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
206
207 gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
209 }
210
211 std::vector<index_t> inLengths_;
212 std::vector<index_t> inStrides_;
213
214 AccDataType alpha_;
215 AccDataType beta_;
216
217 const InDataType* in_dev_;
218 OutDataType* out_dev_;
219
220 InElementwiseOp in_elementwise_op_;
221 AccElementwiseOp acc_elementwise_op_;
222
224
227 size_t gridSize;
228 };
229
230 struct Invoker : public BaseInvoker
231 {
232 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
233 {
234 const auto in_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
236 const auto out_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
238
239 bool sweep_once =
240 in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
241
242 const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
243 InDataType,
244 OutDataType,
245 AccDataType,
248 InDataType,
249 OutDataType,
250 AccDataType,
252
253 float avg_time = 0;
254
255 avg_time += launch_and_time_kernel(stream_config,
256 kernel_main,
257 dim3(arg.gridSize),
258 dim3(BlockSize),
259 0,
260 in_grid_desc_m_k,
261 out_grid_desc_m_k,
262 arg.blkGroupSize,
264 arg.alpha_,
265 arg.in_dev_,
266 arg.beta_,
267 arg.out_dev_);
268
269 return (avg_time);
270 };
271
272 float Run(const BaseArgument* p_arg,
273 const StreamConfig& stream_config = StreamConfig{}) override
274 {
275 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
276 };
277 };
278
279 static bool IsSupportedArgument(const Argument& arg)
280 {
281 if constexpr(InSrcVectorDim == 0)
282 {
283 if constexpr(NumInvariantDim == 0)
284 {
285 return false;
286 }
287 else
288 {
289 if(arg.inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
290 {
291 return false;
292 }
293 if(arg.invariant_lowest_length_ % InSrcVectorSize != 0)
294 {
295 return false;
296 }
297 }
298 }
299 else
300 {
301 if(arg.inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
302 {
303 return false;
304 }
305 if(arg.inLengths_[Rank - 1] % InSrcVectorSize != 0)
306 {
307 return false;
308 }
309 }
310
311 // To improve
312 if(NumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
313 {
314 return false;
315 }
316
317 if(arg.inLengths_[Rank - 1] % OutDstVectorSize != 0)
318 {
319 return false;
320 }
321
322 return true;
323 };
324
325 bool IsSupportedArgument(const BaseArgument* p_arg) override
326 {
327 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
328 }
329
330 static auto MakeArgument(const std::vector<index_t> inLengths,
331 const std::vector<index_t> inStrides,
332 const std::vector<int> reduceDims,
333 double alpha,
334 double beta,
335 const InDataType* in_dev,
336 OutDataType* out_dev,
337 InElementwiseOp in_elementwise_op,
338 AccElementwiseOp acc_elementwise_op)
339 {
340 return Argument{inLengths,
341 inStrides,
342 reduceDims,
343 alpha,
344 beta,
345 in_dev,
346 out_dev,
347 in_elementwise_op,
348 acc_elementwise_op};
349 };
350
351 //
352 // @brief Makes a pointer to Argument class.
353 //
354 // @param[in] inLengths Input tensor extent(s) from high to low dimension
355 // @param[in] inStrides Input tensor stride(s) from high to low dimension
356 // @param[in] reduceDims The dimension(s) the normalization operation is applied
357 // @param[in] alpha Typeless pointer in host memory storing the alpha scaling
358 // value as type AccDataType
359 // @param[in] beta Typeless pointer in host memory storing the beta scaling
360 // value as type AccDataType
361 // @param[in] in_dev Typeless const pointer in device memory storing the input
362 // tensor
363 // @param out_dev Typeless pointer in device memory storing the output tensor
364 // @param[in] in_elementwise_op The input elementwise operation.
365 // @param[in] acc_elementwise_op The accumulation elementwise operation.
366 //
367 // @return Unique pointer to the Argument class.
368 //
369 std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
370 const std::vector<index_t> inStrides,
371 const std::vector<int> reduceDims,
372 double alpha,
373 double beta,
374 const void* in_dev,
375 void* out_dev,
376 InElementwiseOp in_elementwise_op,
377 AccElementwiseOp acc_elementwise_op) override
378 {
379 return std::make_unique<Argument>(inLengths,
380 inStrides,
381 reduceDims,
382 alpha,
383 beta,
384 static_cast<const InDataType*>(in_dev),
385 static_cast<OutDataType*>(out_dev),
386 in_elementwise_op,
387 acc_elementwise_op);
388 };
389
390 static auto MakeInvoker() { return Invoker{}; }
391
392 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
393 {
394 return std::make_unique<Invoker>();
395 };
396
397 std::string GetTypeString() const override
398 {
399 auto str = std::stringstream();
400
401 // clang-format off
402 str << "DeviceReduceSoftmax<"
403 << Rank << "," << NumReduceDim << "," << BlockSize << ","
404 << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","
405 << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","
406 << "InSrcVectorDim_" << InSrcVectorDim
407 << "_InSrcVectorSize_" << InSrcVectorSize
408 << "_OutDstVectorSize_" << OutDstVectorSize << ">";
409 // clang-format on
410
411 return str.str();
412 }
413};
414
415} // namespace device
416} // namespace tensor_operation
417} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, const GridDesc_M_K out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_softmax.hpp:22
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_softmax.hpp:55
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
Definition device_softmax.hpp:24
Definition device_softmax_impl.hpp:148
std::vector< index_t > inLengths_
Definition device_softmax_impl.hpp:211
AccDataType alpha_
Definition device_softmax_impl.hpp:214
index_t invariant_lowest_length_
Definition device_softmax_impl.hpp:223
AccElementwiseOp acc_elementwise_op_
Definition device_softmax_impl.hpp:221
const InDataType * in_dev_
Definition device_softmax_impl.hpp:217
AccDataType beta_
Definition device_softmax_impl.hpp:215
size_t gridSize
Definition device_softmax_impl.hpp:227
Argument(const std::vector< index_t > inLengths, const std::vector< index_t > inStrides, const std::vector< index_t > reduceDims, double alpha, double beta, const InDataType *in_dev, OutDataType *out_dev, InElementwiseOp in_elementwise_op, AccElementwiseOp acc_elementwise_op)
Definition device_softmax_impl.hpp:149
int blkGroupSize
Definition device_softmax_impl.hpp:225
InElementwiseOp in_elementwise_op_
Definition device_softmax_impl.hpp:220
OutDataType * out_dev_
Definition device_softmax_impl.hpp:218
int numBlockTileIteration
Definition device_softmax_impl.hpp:226
std::vector< index_t > inStrides_
Definition device_softmax_impl.hpp:212
Definition device_softmax_impl.hpp:231
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_softmax_impl.hpp:232
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_softmax_impl.hpp:272
Definition device_softmax_impl.hpp:43
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int blkGroupSize, int numBlockTileIteration)
Definition device_softmax_impl.hpp:53
static constexpr index_t NumInvariantDim
Definition device_softmax_impl.hpp:44
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_softmax_impl.hpp:392
static auto MakeInvoker()
Definition device_softmax_impl.hpp:390
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > inLengths, const std::vector< index_t > inStrides, const std::vector< int > reduceDims, double alpha, double beta, const void *in_dev, void *out_dev, InElementwiseOp in_elementwise_op, AccElementwiseOp acc_elementwise_op) override
Definition device_softmax_impl.hpp:369
GridwiseSoftmax_mk_to_mk< InDataType, OutDataType, AccDataType, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, false > GridwiseSoftmaxGeneric
Definition device_softmax_impl.hpp:119
static bool IsSupportedArgument(const Argument &arg)
Definition device_softmax_impl.hpp:279
static constexpr index_t NumSrcDim
Definition device_softmax_impl.hpp:46
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_softmax_impl.hpp:325
static constexpr index_t M_BlockTileSize
Definition device_softmax_impl.hpp:50
decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)) GridDesc_M_K
Definition device_softmax_impl.hpp:117
static auto MakeArgument(const std::vector< index_t > inLengths, const std::vector< index_t > inStrides, const std::vector< int > reduceDims, double alpha, double beta, const InDataType *in_dev, OutDataType *out_dev, InElementwiseOp in_elementwise_op, AccElementwiseOp acc_elementwise_op)
Definition device_softmax_impl.hpp:330
static constexpr index_t NumDstDim
Definition device_softmax_impl.hpp:47
static constexpr index_t K_BlockTileSize
Definition device_softmax_impl.hpp:51
std::string GetTypeString() const override
Definition device_softmax_impl.hpp:397
GridwiseSoftmax_mk_to_mk< InDataType, OutDataType, AccDataType, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, true > GridwiseSoftmaxSweepOnce
Definition device_softmax_impl.hpp:133
static constexpr bool reduceAllDim
Definition device_softmax_impl.hpp:48