enable_if_t< WaveSize==32||WaveSize==64 > > Struct Template Reference#
ck::wmma_type< WmmaInstr::wmma_i32_16x16x16_iu8, WaveSize, typename std::enable_if_t< WaveSize==32||WaveSize==64 > > Struct Template Reference
#include <wmma_gemm.hpp>
Public Member Functions | |
| template<index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC, bool neg_a = true, bool neg_b = true, bool clamp = false> | |
| __device__ void | run (const FloatA &a, const FloatB &b, FloatC ®_c) const |
Static Public Attributes | |
| static constexpr index_t | m_per_wmma = 16 |
| static constexpr index_t | n_per_wmma = 16 |
| static constexpr index_t | k_per_wmma = 16 |
| static constexpr index_t | src_a_data_size = 2 |
| static constexpr index_t | src_b_data_size = 2 |
| static constexpr index_t | acc_data_size = 4 |
| static constexpr index_t | acc_pack_number = 1 |
| static constexpr index_t | num_thread_per_subgroups = n_per_wmma |
| static constexpr index_t | wave_size = Number<WaveSize>{} |
| static constexpr index_t | num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4 |
| static constexpr index_t | num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4 |
| static constexpr index_t | num_acc_vgprs_per_wave |
| static constexpr index_t | num_subgroups = wave_size / num_thread_per_subgroups |
Member Function Documentation
◆ run()
template<index_t WaveSize>
template<index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC, bool neg_a = true, bool neg_b = true, bool clamp = false>
|
inline |
Member Data Documentation
◆ acc_data_size
template<index_t WaveSize>
|
staticconstexpr |
◆ acc_pack_number
template<index_t WaveSize>
|
staticconstexpr |
◆ k_per_wmma
template<index_t WaveSize>
|
staticconstexpr |
◆ m_per_wmma
template<index_t WaveSize>
|
staticconstexpr |
◆ n_per_wmma
template<index_t WaveSize>
|
staticconstexpr |
◆ num_acc_vgprs_per_wave
template<index_t WaveSize>
|
staticconstexpr |
Initial value:
=
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:100
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:95
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:101
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:96
static constexpr index_t wave_size
Definition wmma_gemm.hpp:106
◆ num_src_a_vgprs_per_wave
template<index_t WaveSize>
|
staticconstexpr |
◆ num_src_b_vgprs_per_wave
template<index_t WaveSize>
|
staticconstexpr |
◆ num_subgroups
template<index_t WaveSize>
|
staticconstexpr |
◆ num_thread_per_subgroups
template<index_t WaveSize>
|
staticconstexpr |
◆ src_a_data_size
template<index_t WaveSize>
|
staticconstexpr |
◆ src_b_data_size
template<index_t WaveSize>
|
staticconstexpr |
◆ wave_size
template<index_t WaveSize>
|
staticconstexpr |
The documentation for this struct was generated from the following file: