device_elementwise_normalization.hpp Source File

device_elementwise_normalization.hpp Source File#

Composable Kernel: device_elementwise_normalization.hpp Source File
device_elementwise_normalization.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15template <typename InDataTypeTuple,
16 typename GammaDataType,
17 typename BetaDataType,
18 typename AccDataType,
19 typename YDataType,
20 typename XElementwiseOperation,
21 typename YElementwiseOperation,
22 index_t Rank,
23 index_t NumReduceDim>
25{
26 static constexpr int NumInput = InDataTypeTuple::Size();
27
28 virtual std::unique_ptr<BaseArgument>
29 MakeArgumentPointer(const std::vector<index_t> lengths,
30 const std::array<std::vector<index_t>, NumInput> inStridesArray,
31 const std::vector<index_t> gammaStrides,
32 const std::vector<index_t> betaStrides,
33 const std::vector<index_t> yStrides,
34 const std::vector<index_t> reduceDims,
35 double epsilon,
36 const std::array<const void*, NumInput> in_dev_buffers,
37 const void* p_gamma,
38 const void* p_beta,
39 void* p_y,
40 XElementwiseOperation x_elementwise_op,
41 YElementwiseOperation y_elementwise_op) = 0;
42
43 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
44};
45
46template <typename InDataTypeTuple,
47 typename GammaDataType,
48 typename BetaDataType,
49 typename AccDataType,
50 typename YDataType,
51 typename XElementwiseOperation,
52 typename YElementwiseOperation,
53 index_t Rank,
54 index_t NumReduceDim>
56 std::unique_ptr<DeviceElementwiseNormalization<InDataTypeTuple,
57 GammaDataType,
58 BetaDataType,
59 AccDataType,
60 YDataType,
61 XElementwiseOperation,
62 YElementwiseOperation,
63 Rank,
64 NumReduceDim>>;
65
66} // namespace device
67} // namespace tensor_operation
68} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceElementwiseNormalization< InDataTypeTuple, GammaDataType, BetaDataType, AccDataType, YDataType, XElementwiseOperation, YElementwiseOperation, Rank, NumReduceDim > > DeviceElementwiseNormalizationPtr
Definition device_elementwise_normalization.hpp:55
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_elementwise_normalization.hpp:25
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr int NumInput
Definition device_elementwise_normalization.hpp:26
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::array< std::vector< index_t >, NumInput > inStridesArray, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > reduceDims, double epsilon, const std::array< const void *, NumInput > in_dev_buffers, const void *p_gamma, const void *p_beta, void *p_y, XElementwiseOperation x_elementwise_op, YElementwiseOperation y_elementwise_op)=0