device_batched_gemm_e_permute_xdl.hpp Source File#
device_batched_gemm_e_permute_xdl.hpp
Go to the documentation of this file.
28 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
36 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
39 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
41 * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
44 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
45 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_batched_gemm_e_permute_xdl(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
Definition device_batched_gemm_e_permute_xdl.hpp:65
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_batched_gemm_e_permute.hpp:12
Definition device_batched_gemm_e_permute.hpp:27
Definition device_batched_gemm_e_permute_xdl.hpp:412
void Print() const
Definition device_batched_gemm_e_permute_xdl.hpp:460
EDataType * p_e_grid_
Definition device_batched_gemm_e_permute_xdl.hpp:471
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_e_permute_xdl.hpp:482
BGridDesc_N_K b_grid_desc_n_k_
Definition device_batched_gemm_e_permute_xdl.hpp:478
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_e_permute_xdl.hpp:483
CDEElementwiseOperation cde_element_op_
Definition device_batched_gemm_e_permute_xdl.hpp:496
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_gemm_e_permute_xdl.hpp:488
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_
Definition device_batched_gemm_e_permute_xdl.hpp:485
EGridDesc_M_N e_grid_desc_m_n_
Definition device_batched_gemm_e_permute_xdl.hpp:479
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, EDataType *p_e_grid, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_e_permute_xdl.hpp:413
const ADataType * p_a_grid_
Definition device_batched_gemm_e_permute_xdl.hpp:469
index_t BatchCount_
Definition device_batched_gemm_e_permute_xdl.hpp:474
const BDataType * p_b_grid_
Definition device_batched_gemm_e_permute_xdl.hpp:470
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_batched_gemm_e_permute_xdl.hpp:484
AGridDesc_M_K a_grid_desc_m_k_
Definition device_batched_gemm_e_permute_xdl.hpp:477
Block2ETileMap block_2_etile_map_
Definition device_batched_gemm_e_permute_xdl.hpp:491
BElementwiseOperation b_element_op_
Definition device_batched_gemm_e_permute_xdl.hpp:495
AElementwiseOperation a_element_op_
Definition device_batched_gemm_e_permute_xdl.hpp:494
Definition device_batched_gemm_e_permute_xdl.hpp:310
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_e_permute_xdl.hpp:320
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_e_permute_xdl.hpp:330
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, index_t Batchstride_B, EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
Definition device_batched_gemm_e_permute_xdl.hpp:311
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_e_permute_xdl.hpp:325
Definition device_batched_gemm_e_permute_xdl.hpp:501
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_e_permute_xdl.hpp:505
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_e_permute_xdl.hpp:574
DeviceOp::Argument Argument
Definition device_batched_gemm_e_permute_xdl.hpp:502
Definition device_batched_gemm_e_permute_xdl.hpp:178
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, ck::Tuple<>, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, AGridDesc_M_K, BGridDesc_N_K, Tuple<>, EGridDesc_M_N, NumPrefetch, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_batched_gemm_e_permute_xdl.hpp:348
decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)) EGridDesc_G0_G1_M_N
Definition device_batched_gemm_e_permute_xdl.hpp:307
static auto MakeInvoker()
Definition device_batched_gemm_e_permute_xdl.hpp:640
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_batched_gemm_e_permute_xdl.hpp:644
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_e_permute_xdl.hpp:210
std::string GetTypeString() const override
Definition device_batched_gemm_e_permute_xdl.hpp:684
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
Definition device_batched_gemm_e_permute_xdl.hpp:229
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_e_permute_xdl.hpp:183
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_batched_gemm_e_permute_xdl.hpp:304
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_batched_gemm_e_permute_xdl.hpp:192
static constexpr auto I1
Definition device_batched_gemm_e_permute_xdl.hpp:186
static constexpr auto matrix_padder
Definition device_batched_gemm_e_permute_xdl.hpp:189
DeviceBatchedGemmEPermuteXdl DeviceOp
Definition device_batched_gemm_e_permute_xdl.hpp:179
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_batched_gemm_e_permute_xdl.hpp:398
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_batched_gemm_e_permute_xdl.hpp:305
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_e_permute_xdl.hpp:182
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, EDataType *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_e_permute_xdl.hpp:607
static constexpr auto I0
Definition device_batched_gemm_e_permute_xdl.hpp:185
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_e_permute_xdl.hpp:581
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_batched_gemm_e_permute_xdl.hpp:401
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_e_permute_xdl.hpp:602
decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1)) EGridDesc_M_N
Definition device_batched_gemm_e_permute_xdl.hpp:306
static constexpr auto I2
Definition device_batched_gemm_e_permute_xdl.hpp:187
typename GridwiseGemm::DefaultBlock2ETileMap Block2ETileMap
Definition device_batched_gemm_e_permute_xdl.hpp:408
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{})) EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_batched_gemm_e_permute_xdl.hpp:405
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_e_permute_xdl.hpp:678
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_e_permute_xdl.hpp:587
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_e_permute_xdl.hpp:395
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, index_t G1, index_t MRaw, index_t NRaw, index_t stride_G0, index_t stride_G1, index_t stride_M, index_t stride_N)
Definition device_batched_gemm_e_permute_xdl.hpp:237
ADataType ComputeDataType
Definition device_batched_gemm_e_permute_xdl.hpp:344
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_e_permute_xdl.hpp:396
Definition matrix_padder.hpp:180