#include <grouped_flatmm_kernel.hpp>
|
| CK_TILE_HOST | GroupedFlatmmHostArgs ()=default |
| CK_TILE_HOST | GroupedFlatmmHostArgs (index_t group_count_, index_t *M_, index_t *N_, index_t *K_, const void **a_ptr_, index_t *stride_A_, const void **b_shuffle_ptr_, index_t *stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void **c_ptr_, index_t *stride_C_, index_t k_batch_, ScaleM *scale_m_=nullptr, ScaleN *scale_n_=nullptr) |
◆ GroupedFlatmmHostArgs() [1/2]
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ GroupedFlatmmHostArgs() [2/2]
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
| CK_TILE_HOST ck_tile::GroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor >::GroupedFlatmmHostArgs |
( |
index_t | group_count_, |
|
|
index_t * | M_, |
|
|
index_t * | N_, |
|
|
index_t * | K_, |
|
|
const void ** | a_ptr_, |
|
|
index_t * | stride_A_, |
|
|
const void ** | b_shuffle_ptr_, |
|
|
index_t * | stride_B_, |
|
|
const std::array< const void *, NumDTensor > & | ds_ptr_, |
|
|
const std::array< index_t, NumDTensor > & | stride_Ds_, |
|
|
void ** | c_ptr_, |
|
|
index_t * | stride_C_, |
|
|
index_t | k_batch_, |
|
|
ScaleM * | scale_m_ = nullptr, |
|
|
ScaleN * | scale_n_ = nullptr ) |
|
inline |
◆ [union]
◆ a_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ b_shuffle_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ c_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ ds_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ e_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ group_count
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ k_batch
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ scale_m
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ scale_n
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ stride_A
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ stride_B
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ stride_C
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ stride_Ds
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
The documentation for this struct was generated from the following file: