BatchedContractionHostArgs< NumDTensor > Struct Template Reference#
#include <batched_contraction_kernel.hpp>
Public Member Functions | |
| CK_TILE_HOST | BatchedContractionHostArgs (const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, ck_tile::index_t k_batch_, const std::vector< ck_tile::index_t > &A_dims_, const std::vector< ck_tile::index_t > &B_dims_, const std::array< std::vector< ck_tile::index_t >, NumDTensor > &Ds_dims_, const std::vector< ck_tile::index_t > &E_dims_, const std::vector< ck_tile::index_t > &A_strides_, const std::vector< ck_tile::index_t > &B_strides_, const std::array< std::vector< ck_tile::index_t >, NumDTensor > &Ds_strides_, const std::vector< ck_tile::index_t > &E_strides_) |
| Constructor for batched contraction host arguments. | |
Public Attributes | |
| const void * | a_ptr |
| Pointer to input tensor A. | |
| const void * | b_ptr |
| Pointer to input tensor B. | |
| std::array< const void *, NumDTensor > | ds_ptr |
| Array of pointers to auxiliary input tensors D. | |
| void * | e_ptr |
| Pointer to output tensor E. | |
| ck_tile::index_t | k_batch |
| Number of k-splits for split-K batching. | |
| const std::vector< ck_tile::index_t > | A_dims |
| Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...]. | |
| const std::vector< ck_tile::index_t > | B_dims |
| Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...]. | |
| const std::array< std::vector< ck_tile::index_t >, NumDTensor > | Ds_dims |
| Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...]. | |
| const std::vector< ck_tile::index_t > | E_dims |
| Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...]. | |
| const std::vector< ck_tile::index_t > | A_strides |
| Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...]. | |
| const std::vector< ck_tile::index_t > | B_strides |
| Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...]. | |
| const std::array< std::vector< ck_tile::index_t >, NumDTensor > | Ds_strides |
| Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...]. | |
| const std::vector< ck_tile::index_t > | E_strides |
| Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...]. | |
Constructor & Destructor Documentation
◆ BatchedContractionHostArgs()
|
inline |
Constructor for batched contraction host arguments.
- Parameters
-
a_ptr_ Pointer to input tensor A b_ptr_ Pointer to input tensor B ds_ptr_ Array of pointers to auxiliary input tensors D e_ptr_ Pointer to output tensor E k_batch_ Number of k-splits for split-K batching A_dims_ Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] B_dims_ Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] Ds_dims_ Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] E_dims_ Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...] A_strides_ Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] B_strides_ Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] Ds_strides_ Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] E_strides_ Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...]
Member Data Documentation
◆ A_dims
| const std::vector<ck_tile::index_t> BatchedContractionHostArgs< NumDTensor >::A_dims |
Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...].
◆ a_ptr
| const void* BatchedContractionHostArgs< NumDTensor >::a_ptr |
Pointer to input tensor A.
◆ A_strides
| const std::vector<ck_tile::index_t> BatchedContractionHostArgs< NumDTensor >::A_strides |
Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...].
◆ B_dims
| const std::vector<ck_tile::index_t> BatchedContractionHostArgs< NumDTensor >::B_dims |
Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...].
◆ b_ptr
| const void* BatchedContractionHostArgs< NumDTensor >::b_ptr |
Pointer to input tensor B.
◆ B_strides
| const std::vector<ck_tile::index_t> BatchedContractionHostArgs< NumDTensor >::B_strides |
Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...].
◆ Ds_dims
| const std::array<std::vector<ck_tile::index_t>, NumDTensor> BatchedContractionHostArgs< NumDTensor >::Ds_dims |
Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
◆ ds_ptr
| std::array<const void*, NumDTensor> BatchedContractionHostArgs< NumDTensor >::ds_ptr |
Array of pointers to auxiliary input tensors D.
◆ Ds_strides
| const std::array<std::vector<ck_tile::index_t>, NumDTensor> BatchedContractionHostArgs< NumDTensor >::Ds_strides |
Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
◆ E_dims
| const std::vector<ck_tile::index_t> BatchedContractionHostArgs< NumDTensor >::E_dims |
Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
◆ e_ptr
| void* BatchedContractionHostArgs< NumDTensor >::e_ptr |
Pointer to output tensor E.
◆ E_strides
| const std::vector<ck_tile::index_t> BatchedContractionHostArgs< NumDTensor >::E_strides |
Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
◆ k_batch
| ck_tile::index_t BatchedContractionHostArgs< NumDTensor >::k_batch |
Number of k-splits for split-K batching.
The documentation for this struct was generated from the following file: