arch.hpp Source File

arch.hpp Source File#

Composable Kernel: arch.hpp Source File
arch.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6// Address Space for AMDGCN
7// https://llvm.org/docs/AMDGPUUsage.html#address-space
8
15
16#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
17#define CK_TILE_VMCNT(cnt) \
18 ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
19 ((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
20#define CK_TILE_EXPCNT(cnt) \
21 ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
22#define CK_TILE_LGKMCNT(cnt) \
23 ([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
24
25namespace ck_tile {
26
27template <typename, bool>
29
30template <typename T>
31struct safe_underlying_type<T, true>
32{
33 using type = std::underlying_type_t<T>;
34};
35
36template <typename T>
37struct safe_underlying_type<T, false>
38{
39 using type = void;
40};
41
42template <typename T>
44
45enum struct address_space_enum : std::uint16_t
46{
47 generic = 0,
53};
54
55enum struct memory_operation_enum : std::uint16_t
56{
57 set = 0,
61};
62
64{
65#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
66 return 64;
67#else
68 return 32;
69#endif
70}
71
73{
74 hipDeviceProp_t props{};
75 int device;
76 auto status = hipGetDevice(&device);
77 if(status != hipSuccess)
78 {
79 return false;
80 }
81 status = hipGetDeviceProperties(&props, device);
82 if(status != hipSuccess)
83 {
84 return false;
85 }
86 return props.major > 9;
87}
88
89CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
90
91CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
92
93// TODO: deprecate these
95
96CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
97
98CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
99
100// Use these instead
101CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
102
103template <bool ReturnSgpr = true>
105{
106 const index_t warp_id = threadIdx.x / get_warp_size();
107 if constexpr(ReturnSgpr)
108 {
109 return amd_wave_read_first_lane(warp_id);
110 }
111 else
112 {
113 return warp_id;
114 }
115}
116
117CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
118
119CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
120
122{
123#ifdef __gfx12__
124 asm volatile("s_wait_loadcnt %0 \n"
125 "s_barrier_signal -1 \n"
126 "s_barrier_wait -1"
127 :
128 : "n"(cnt)
129 : "memory");
130#else
131 asm volatile("s_waitcnt vmcnt(%0) \n"
132 "s_barrier"
133 :
134 : "n"(cnt)
135 : "memory");
136#endif
137}
138
140{ // s_wait_loadcnt_dscnt: mem[13:8], ds[5:0]
141 CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // mem
142 CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // ds
143 CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
144
145 CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 8); }
146 CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 0); }
147 CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
148};
149
151{ // vm[15:10] (6), lgkm[9:4] (6), exp unused
152 CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F;
153 CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F;
154 CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
155
156 CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); }
157 CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); }
158 CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
159};
160
162{ // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV
163 CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2
164 CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8]
165 CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4]
166 CK_TILE_DEVICE static constexpr bool HAS_EXP = true;
167
169 {
170 c &= VM_MASK;
171 return ((c & 0xF) << 0) | ((c & 0x30) << 10);
172 }
173 CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 8); }
174 CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return ((c & EXP_MASK) << 4); }
175};
176
177// Select active layout
178#if defined(__gfx12__)
179using Waitcnt = WaitcntLayoutGfx12;
180#elif defined(__gfx11__)
181using Waitcnt = WaitcntLayoutGfx11;
182#else
184#endif
185
186//----------------------------------------------
187// Public API: only from_* (constexpr templates)
188//----------------------------------------------
190{
191 // kMax* exposed for callers; match field widths per-arch
192#if defined(__gfx12__) || defined(__gfx11__)
193 CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
194 CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
195 CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none
196#else
197 CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split)
198 CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits
199 CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits
200#endif
201
202 template <index_t cnt>
204 {
205 static_assert((cnt & ~Waitcnt::VM_MASK) == 0, "vmcnt out of range");
206 return Waitcnt::pack_vm(cnt);
207 }
208
209 template <index_t cnt>
211 {
212 static_assert((cnt & ~Waitcnt::LGKM_MASK) == 0, "lgkmcnt out of range");
213 return Waitcnt::pack_lgkm(cnt);
214 }
215
216 template <index_t cnt>
218 {
219 if constexpr(Waitcnt::HAS_EXP)
220 {
221 // EXP_MASK only exists on legacy
222#if !defined(__gfx12__) && !defined(__gfx11__)
223 static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range");
224 return Waitcnt::pack_exp(cnt);
225#else
226 (void)cnt;
227 return 0;
228#endif
229 }
230 else
231 {
232 static_assert(cnt == 0, "expcnt unsupported on this arch");
233 return 0;
234 }
235 }
236};
237
238template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
242{
243#if defined(__gfx12__)
244 // GFX12 do't use __builtin_amdgcn_s_waitcnt
245 constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
248
249 asm volatile("s_wait_loadcnt_dscnt %0" : : "n"(wait_mask) : "memory");
250#else
251 __builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
254#endif
255}
256
257template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
261{
262#if defined(__gfx12__)
263 // GFX12 optimization: Manual barrier implementation avoids performance penalty
264 // from __builtin_amdgcn_s_barrier which inserts extra s_wait_loadcnt_dscnt 0x0
265 constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
268
269 asm volatile("s_wait_loadcnt_dscnt %0\n"
270 "s_barrier_signal -1\n"
271 "s_barrier_wait -1"
272 :
273 : "n"(wait_mask)
274 : "memory");
275#else
277 __builtin_amdgcn_s_barrier();
278#endif
279}
280
281template <index_t lgkmcnt = 0>
286
287template <index_t vmcnt = 0>
292
294{
295#if 1
296 asm volatile("s_nop %0" : : "n"(cnt) :);
297#else
298 __builtin_amdgcn_sched_barrier(cnt);
299#endif
300}
301
302#define CK_CONSTANT_ADDRESS_SPACE \
303 __attribute__((address_space( \
304 static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
305
306template <typename T>
308{
309 // cast a pointer in "Constant" address space (4) to "Generic" address space (0)
310 // only c-style pointer cast seems be able to be compiled
311#pragma clang diagnostic push
312#pragma clang diagnostic ignored "-Wold-style-cast"
313 return (T*)(p); // NOLINT(old-style-cast)
314#pragma clang diagnostic pop
315}
316
317template <typename T>
319{
320 // cast a pointer in "Generic" address space (0) to "Constant" address space (4)
321 // only c-style pointer cast seems be able to be compiled;
322#pragma clang diagnostic push
323#pragma clang diagnostic ignored "-Wold-style-cast"
324 return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
325#pragma clang diagnostic pop
326}
327
329{
330#if defined(__gfx950__)
331 return 163840;
332#else
333 return 65536;
334#endif
335}
336
339{
340 switch(addr_space)
341 {
342 case address_space_enum::generic: return "generic";
343 case address_space_enum::global: return "global";
344 case address_space_enum::lds: return "lds";
345 case address_space_enum::sgpr: return "sgpr";
346 case address_space_enum::constant: return "constant";
347 case address_space_enum::vgpr: return "vgpr";
348 default: return "unknown";
349 }
350}
351
352// Architecture tags
353struct gfx9_t
354{
355};
357{
358};
360{
361};
363{
364};
366{
367};
369{
370};
371
372CK_TILE_DEVICE static constexpr auto get_device_arch()
373{
374// FIXME(0): on all devices except gfx11 it returns gfx12_t
375// FIXME(1): during the host compilation pass it returns gfx12_t
376#if defined(__gfx11__)
377 return gfx11_t{};
378#else
379 return gfx12_t{};
380#endif
381}
382
383CK_TILE_DEVICE static constexpr auto get_n_words_per_128b() { return 4; }
384
385namespace detail {
386CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx9_t) { return 32; }
387
388CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; }
389
390CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; }
391
392CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; }
393
394CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }
395
396CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; }
397
398CK_TILE_DEVICE static constexpr auto arch_tag_dispatch()
399{
400#if defined(__gfx103__)
401 return gfx103_t{};
402#elif defined(__gfx11__)
403 return gfx11_t{};
404#elif defined(__gfx12__)
405 return gfx12_t{};
406#elif defined(__gfx950__)
407 return gfx950_t{};
408#elif defined(__gfx9__)
409 return gfx9_t{};
410#else
411 return gfx_invalid_t{};
412#endif
413}
414} // namespace detail
415CK_TILE_DEVICE static constexpr auto get_n_lds_banks()
416{
417 return detail::get_n_lds_banks(detail::arch_tag_dispatch());
418}
419
421{
422 NONE = 0,
423 ALU = 1 << 0,
424 VALU = 1 << 1,
425 SALU = 1 << 2,
426 MFMA = 1 << 3,
427 VMEM = 1 << 4,
428 VMEM_READ = 1 << 5,
429 VMEM_WRITE = 1 << 6,
430 DS = 1 << 7,
431 DS_READ = 1 << 8,
432 DS_WRITE = 1 << 9,
433 ALL = (DS_WRITE << 1) - 1,
434};
435} // namespace ck_tile
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition arch.hpp:385
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename safe_underlying_type< T, std::is_enum< T >::value >::type safe_underlying_type_t
Definition arch.hpp:43
CK_TILE_DEVICE index_t get_lane_id()
Definition arch.hpp:101
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
memory_operation_enum
Definition arch.hpp:56
@ atomic_max
Definition arch.hpp:59
@ set
Definition arch.hpp:57
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void block_sync_lds_direct_load()
Definition arch.hpp:288
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition tile/core/arch/generic_memory_space_atomic.hpp:16
CK_TILE_DEVICE index_t get_block_1d_id()
Definition arch.hpp:98
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
Definition arch.hpp:328
CK_TILE_DEVICE void s_nop(index_t cnt=0)
Definition arch.hpp:293
CK_TILE_DEVICE index_t get_thread_local_1d_id()
Definition arch.hpp:94
CK_TILE_DEVICE void s_waitcnt()
Definition arch.hpp:241
CK_TILE_DEVICE index_t get_block_size()
Definition arch.hpp:91
int32_t int32_t
Definition integer.hpp:10
LLVMSchedGroupMask
Definition arch.hpp:421
@ SALU
Definition arch.hpp:425
@ VALU
Definition arch.hpp:424
@ ALL
Definition arch.hpp:433
@ MFMA
Definition arch.hpp:426
@ DS_READ
Definition arch.hpp:431
@ NONE
Definition arch.hpp:422
@ DS
Definition arch.hpp:430
@ VMEM_READ
Definition arch.hpp:428
@ VMEM_WRITE
Definition arch.hpp:429
@ ALU
Definition arch.hpp:423
@ VMEM
Definition arch.hpp:427
@ DS_WRITE
Definition arch.hpp:432
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition arch.hpp:307
CK_TILE_DEVICE index_t get_thread_id()
Definition arch.hpp:117
CK_TILE_DEVICE index_t get_thread_global_1d_id()
Definition arch.hpp:96
address_space_enum
Definition arch.hpp:46
@ generic
Definition arch.hpp:47
@ sgpr
Definition arch.hpp:50
@ constant
Definition arch.hpp:51
@ global
Definition arch.hpp:48
@ lds
Definition arch.hpp:49
@ vgpr
Definition arch.hpp:52
CK_TILE_DEVICE index_t get_block_id()
Definition arch.hpp:119
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition arch.hpp:318
WaitcntLayoutLegacy Waitcnt
Definition arch.hpp:183
CK_TILE_DEVICE index_t get_grid_size()
Definition arch.hpp:89
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt=0)
Definition arch.hpp:121
CK_TILE_DEVICE void s_waitcnt_barrier()
Definition arch.hpp:260
CK_TILE_HOST_DEVICE constexpr const char * address_space_to_string(address_space_enum addr_space)
Helper function to convert address space enum to string.
Definition arch.hpp:338
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition arch.hpp:151
static CK_TILE_DEVICE constexpr index_t pack_lgkm(index_t c)
Definition arch.hpp:157
static CK_TILE_DEVICE constexpr index_t LGKM_MASK
Definition arch.hpp:153
static CK_TILE_DEVICE constexpr index_t pack_exp(index_t)
Definition arch.hpp:158
static CK_TILE_DEVICE constexpr index_t VM_MASK
Definition arch.hpp:152
static CK_TILE_DEVICE constexpr bool HAS_EXP
Definition arch.hpp:154
static CK_TILE_DEVICE constexpr index_t pack_vm(index_t c)
Definition arch.hpp:156
Definition arch.hpp:140
static CK_TILE_DEVICE constexpr index_t VM_MASK
Definition arch.hpp:141
static CK_TILE_DEVICE constexpr index_t pack_vm(index_t c)
Definition arch.hpp:145
static CK_TILE_DEVICE constexpr index_t pack_lgkm(index_t c)
Definition arch.hpp:146
static CK_TILE_DEVICE constexpr index_t pack_exp(index_t)
Definition arch.hpp:147
static CK_TILE_DEVICE constexpr index_t LGKM_MASK
Definition arch.hpp:142
static CK_TILE_DEVICE constexpr bool HAS_EXP
Definition arch.hpp:143
Definition arch.hpp:162
static CK_TILE_DEVICE constexpr bool HAS_EXP
Definition arch.hpp:166
static CK_TILE_DEVICE constexpr index_t pack_lgkm(index_t c)
Definition arch.hpp:173
static CK_TILE_DEVICE constexpr index_t pack_vm(index_t c)
Definition arch.hpp:168
static CK_TILE_DEVICE constexpr index_t EXP_MASK
Definition arch.hpp:165
static CK_TILE_DEVICE constexpr index_t LGKM_MASK
Definition arch.hpp:164
static CK_TILE_DEVICE constexpr index_t VM_MASK
Definition arch.hpp:163
static CK_TILE_DEVICE constexpr index_t pack_exp(index_t c)
Definition arch.hpp:174
Definition tile/core/numeric/integral_constant.hpp:13
Definition arch.hpp:360
Definition arch.hpp:363
Definition arch.hpp:366
Definition arch.hpp:357
Definition arch.hpp:354
Definition arch.hpp:369
std::underlying_type_t< T > type
Definition arch.hpp:33
Definition arch.hpp:28
Definition arch.hpp:190
static CK_TILE_DEVICE constexpr index_t from_lgkmcnt()
Definition arch.hpp:210
static CK_TILE_DEVICE constexpr index_t kMaxVmCnt
Definition arch.hpp:197
static CK_TILE_DEVICE constexpr index_t from_vmcnt()
Definition arch.hpp:203
static CK_TILE_DEVICE constexpr index_t from_expcnt()
Definition arch.hpp:217
static CK_TILE_DEVICE constexpr index_t kMaxExpCnt
Definition arch.hpp:199
static CK_TILE_DEVICE constexpr index_t kMaxLgkmCnt
Definition arch.hpp:198