device_sparse_embeddings_forward_layernorm.hpp Source File#
device_sparse_embeddings_forward_layernorm.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__global__ void kernel_sparse_embeddings_forward_layernorm(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > p_embs, const ck::Array< IndexType *, NumEmbeddings > p_indexes, const GammaDataType *p_gamma, const BetaDataType *p_beta, const OutGridDesc out_grid_desc, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
Definition ck/stream_config.hpp:10
Definition utility/array.hpp:14
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:57
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
BaseOperator()=default
Definition device_sparse_embeddings_forward_layernorm.hpp:48
const GammaDataType * p_gamma_
Definition device_sparse_embeddings_forward_layernorm.hpp:74
ck::index_t IndexLength_
Definition device_sparse_embeddings_forward_layernorm.hpp:77
ck::Array< EmbType *, NumEmbeddings > p_embs_
Definition device_sparse_embeddings_forward_layernorm.hpp:72
size_t grid_size_
Definition device_sparse_embeddings_forward_layernorm.hpp:81
OutType * p_out_
Definition device_sparse_embeddings_forward_layernorm.hpp:71
const BetaDataType * p_beta_
Definition device_sparse_embeddings_forward_layernorm.hpp:75
ck::index_t EmbeddingDim_
Definition device_sparse_embeddings_forward_layernorm.hpp:76
Argument(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > &p_embs, const ck::Array< IndexType *, NumEmbeddings > &p_indexs, const GammaDataType *p_gamma, const BetaDataType *p_beta, const ck::index_t EmbeddingDim, const ck::index_t IndexLength, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition device_sparse_embeddings_forward_layernorm.hpp:49
AccDataType epsilon_
Definition device_sparse_embeddings_forward_layernorm.hpp:78
ck::Array< IndexType *, NumEmbeddings > p_indexs_
Definition device_sparse_embeddings_forward_layernorm.hpp:73
EmbElementwiseOperation emb_elementwise_op_
Definition device_sparse_embeddings_forward_layernorm.hpp:79
Definition device_sparse_embeddings_forward_layernorm.hpp:125
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_sparse_embeddings_forward_layernorm.hpp:126
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_sparse_embeddings_forward_layernorm.hpp:158
Definition device_sparse_embeddings_forward_layernorm.hpp:41
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_out, const ck::Array< EmbType *, NumEmbeddings > &p_embs, const ck::Array< IndexType *, NumEmbeddings > &p_indexs, const void *p_gamma, const void *p_beta, ck::index_t EmbeddingDim, ck::index_t IndexLength, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition device_sparse_embeddings_forward_layernorm.hpp:85
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()
Definition device_sparse_embeddings_forward_layernorm.hpp:175
static bool IsSupportedArgument(const Argument *p_arg)
Definition device_sparse_embeddings_forward_layernorm.hpp:165
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_sparse_embeddings_forward_layernorm.hpp:170
std::string GetTypeString() const override
Definition device_sparse_embeddings_forward_layernorm.hpp:180
GridwiseSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, decltype(MakeOutputDescriptor(1, 1)), EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings > GridwiseSparseEmbedding
Definition device_sparse_embeddings_forward_layernorm.hpp:106
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
Definition device_sparse_embeddings_forward_layernorm.hpp:42