device_grouped_gemm_multi_abd_fixed_nk.hpp Source File

device_grouped_gemm_multi_abd_fixed_nk.hpp Source File#

Composable Kernel: device_grouped_gemm_multi_abd_fixed_nk.hpp Source File
device_grouped_gemm_multi_abd_fixed_nk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <array>
8
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
17{
18 std::array<const void*, NumATensor> p_as_grid;
19 std::array<const void*, NumBTensor> p_bs_grid;
20 std::array<const void*, NumDTensor> p_ds_grid;
21 void* p_e_grid;
22
26
27 std::array<index_t, NumATensor> StrideAs;
28 std::array<index_t, NumBTensor> StrideBs;
29 std::array<index_t, NumDTensor> StrideDs;
31};
32
33/*
34 * \brief Grouped Gemm Multi ABD Fixed NK
35 *
36 * C = a_op(A, A1...) * b_op(B, B1...)
37 * E = cde_op(C, D0, D1, ...)
38 *
39 * \tparam AsLayout A layouts (tuple).
40 * \tparam BsLayout B layouts (tuple).
41 * \tparam DsLayout Ds layouts (tuple).
42 * \tparam ELayout Output layout.
43 * \tparam AsDataType A data types (tuple).
44 * \tparam BsDataType B data types (tuple).
45 * \tparam DsDataType D data types (tuple).
46 * \tparam EDataType Output data type.
47 * \tparam AElementwiseOperation A elementwise operation.
48 * \tparam BElementwiseOperation B elementwise operation.
49 * \tparam CDEElementwiseOperation C elementwise operation.
50 */
51template <typename AsLayout,
52 typename BsLayout,
53 typename DsLayout,
54 typename ELayout,
55 typename AsDataType,
56 typename BsDataType,
57 typename DsDataType,
58 typename EDataType,
59 typename AElementwiseOperation,
60 typename BElementwiseOperation,
61 typename CElementwiseOperation>
63 BsLayout,
64 DsLayout,
65 ELayout,
66 AsDataType,
67 BsDataType,
68 DsDataType,
69 EDataType,
70 AElementwiseOperation,
71 BElementwiseOperation,
72 CElementwiseOperation>
73{
74 virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
75 virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
76 virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
77};
78
79} // namespace device
80} // namespace tensor_operation
81} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_base.hpp:197
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:73
virtual void SetKBatch(BaseArgument *p_arg, index_t k_batch) const =0
virtual size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const =0
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, const void *kernel_args) const =0
Definition device_grouped_gemm_multi_abd.hpp:56
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:17
std::array< index_t, NumBTensor > StrideBs
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:28
index_t M
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:23
std::array< const void *, NumBTensor > p_bs_grid
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:19
index_t N
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:24
std::array< const void *, NumATensor > p_as_grid
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:18
std::array< index_t, NumATensor > StrideAs
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:27
std::array< const void *, NumDTensor > p_ds_grid
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:20
index_t K
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:25
std::array< index_t, NumDTensor > StrideDs
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:29
void * p_e_grid
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:21
index_t StrideE
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:30