device_grouped_conv_fwd_multiple_abd.hpp Source File

device_grouped_conv_fwd_multiple_abd.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_multiple_abd.hpp Source File
device_grouped_conv_fwd_multiple_abd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#ifndef CK_CODE_GEN_RTC
7#include <array>
8#endif
9
13
14namespace ck {
15namespace tensor_operation {
16namespace device {
17
18#ifdef CK_CODE_GEN_RTC
19template <typename T>
20using is_tuple = decltype(ck::declval<T&>().IsTuple());
21#else
22template <typename T>
23using is_tuple = decltype(std::declval<T&>().IsTuple());
24#endif
25
53template <index_t NDimSpatial,
54 typename ALayout,
55 typename BLayout,
56 typename DsLayout,
57 typename ELayout,
58 typename ADataType,
59 typename BDataType,
60 typename DsDataType,
61 typename EDataType,
62 typename AElementwiseOperation,
63 typename BElementwiseOperation,
64 typename CDEElementwiseOperation,
65 typename AComputeType =
66 decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
68 ADataType>()), // AComputeType is InputType by default (first
69 // in tuple for MultiAB), unpack if tuple was
70 // passed
71 typename BComputeType = AComputeType>
73{
76
77 static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
78 static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
79 static constexpr index_t NumDTensor = DsDataType::Size();
80
81 static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
82#ifdef CK_CODE_GEN_RTC
85#else
86 // If DataType is tuple, user has to pass std::array with pointers.
87 using APointers =
89 using BPointers =
91#endif
92
93#ifndef CK_CODE_GEN_RTC
94
121 virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
122 APointers p_a,
123 BPointers p_b,
124 const std::array<const void*, NumDTensor>& p_ds,
125 void* p_e,
126 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
127 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
128 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
129 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
130 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
131 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
132 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
133 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
134 const std::array<index_t, NDimSpatial>& conv_filter_strides,
135 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
136 const std::array<index_t, NDimSpatial>& input_left_pads,
137 const std::array<index_t, NDimSpatial>& input_right_pads,
138 const AElementwiseOperation& a_element_op,
139 const BElementwiseOperation& b_element_op,
140 const CDEElementwiseOperation& cde_element_op) = 0;
141
142 virtual std::unique_ptr<BaseArgument>
144 BPointers p_b,
145 const std::array<const void*, NumDTensor>& p_ds,
146 void* p_e,
147 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
148 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
149 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
150 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
151 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
152 ds_g_n_k_wos_lengths,
153 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
154 ds_g_n_k_wos_strides,
155 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
156 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
157 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
158 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
159 const std::array<long_index_t, NDimSpatial>& input_left_pads,
160 const std::array<long_index_t, NDimSpatial>& input_right_pads,
161 const AElementwiseOperation& a_element_op,
162 const BElementwiseOperation& b_element_op,
163 const CDEElementwiseOperation& cde_element_op) = 0;
164
165 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
166#endif
167};
168
169} // namespace device
170} // namespace tensor_operation
171} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(APointers p_a, BPointers p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(APointers p_a, BPointers p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)=0
Make argument pointer for grouped conv fwd.
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0