amd_buffer_addressing_builtins.hpp Source File

amd_buffer_addressing_builtins.hpp Source File#

Composable Kernel: amd_buffer_addressing_builtins.hpp Source File
utility/amd_buffer_addressing_builtins.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5#include "data_type.hpp"
6
7namespace ck {
8
9template <typename T>
11{
12 __device__ constexpr BufferResource() : content{} {}
13
14 // 128 bit SGPRs to supply buffer resource in buffer instructions
15 // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
20};
21
22template <typename T>
23__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
24{
25 BufferResource<T> wave_buffer_resource;
26
27 // wavewise base address (64 bit)
28 wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
29 // wavewise range (32 bit)
30 wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
31 // wavewise setting (32 bit)
32 wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
33
34 return wave_buffer_resource.content;
35}
36
37template <typename T>
39{
40 BufferResource<T> wave_buffer_resource;
41
42 // wavewise base address (64 bit)
43 wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
44 // wavewise range (32 bit)
45 wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range
46 // wavewise setting (32 bit)
47 wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
48
49 return wave_buffer_resource.content;
50}
51
52template <typename T>
53__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T* p_wave,
54 index_t element_space_size)
55{
56 // wavewise base address (64 bit)
57 auto p = const_cast<remove_cv_t<T>*>(p_wave);
58 int32_t stride = 0;
59 int32_t num = element_space_size * sizeof(T);
61
62 return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags);
63}
64
65template <typename T>
66__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T* p_wave)
67{
68 // wavewise base address (64 bit)
69 auto p = const_cast<remove_cv_t<T>*>(p_wave);
70 int32_t stride = 0;
71 int32_t num = 0xffffffff;
73
74 return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags);
75}
76
77// buffer atomic-add fp16
79 half2_t vdata,
80 int32x4_t rsrc,
81 index_t voffset,
82 index_t soffset,
83 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
84
85// buffer atomic-add i32
87 int32_t vdata,
88 int32x4_t rsrc,
89 index_t voffset,
90 index_t soffset,
91 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
92
93// buffer atomic-add fp32
95 float vdata,
96 int32x4_t rsrc,
97 index_t voffset,
98 index_t soffset,
99 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
100
101// buffer atomic-add fp32
102__device__ double
104 int32x4_t rsrc, // dst_wave_buffer_resource
105 int voffset, // dst_thread_addr_offset
106 int soffset, // dst_wave_addr_offset
107 int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
108
109// memory coherency bit for buffer store/load instruction
110// check ISA manual for each GFX target
111// e.g. for
112// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
113// page 67~68
115{
116 DefaultCoherence = 0, // default value
117 GLC = 1,
118 SLC = 2,
120 // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
121 // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
122 // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
131};
132
133template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
134__device__ typename vector_type<int8_t, N>::type
135amd_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
136 index_t src_thread_addr_offset,
137 index_t src_wave_addr_offset)
138{
139 static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
140 "wrong! not implemented");
141
142 if constexpr(N == 1)
143 {
144 return __builtin_amdgcn_raw_buffer_load_b8(src_wave_buffer_resource,
145 src_thread_addr_offset,
146 src_wave_addr_offset,
147 static_cast<index_t>(coherence));
148 }
149 else if constexpr(N == 2)
150 {
151
152 int16_t tmp = __builtin_amdgcn_raw_buffer_load_b16(src_wave_buffer_resource,
153 src_thread_addr_offset,
154 src_wave_addr_offset,
155 static_cast<index_t>(coherence));
156
157 return bit_cast<int8x2_t>(tmp);
158 }
159 else if constexpr(N == 4)
160 {
161 int32_t tmp = __builtin_amdgcn_raw_buffer_load_b32(src_wave_buffer_resource,
162 src_thread_addr_offset,
163 src_wave_addr_offset,
164 static_cast<index_t>(coherence));
165
166 return bit_cast<int8x4_t>(tmp);
167 }
168 else if constexpr(N == 8)
169 {
170 int32x2_t tmp = __builtin_amdgcn_raw_buffer_load_b64(src_wave_buffer_resource,
171 src_thread_addr_offset,
172 src_wave_addr_offset,
173 static_cast<index_t>(coherence));
174
175 return bit_cast<int8x8_t>(tmp);
176 }
177 else if constexpr(N == 16)
178 {
179 int32x4_t tmp = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
180 src_thread_addr_offset,
181 src_wave_addr_offset,
182 static_cast<index_t>(coherence));
183 return bit_cast<int8x16_t>(tmp);
184 }
185 else if constexpr(N == 32)
186 {
187 int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
188 src_thread_addr_offset,
189 src_wave_addr_offset,
190 static_cast<index_t>(coherence));
191 int32x4_t tmp1 =
192 __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
193 src_thread_addr_offset,
194 src_wave_addr_offset + 4 * sizeof(int32_t),
195 static_cast<index_t>(coherence));
197
198 tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
199 tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
200
201 return bit_cast<int8x32_t>(tmp);
202 }
203 else if constexpr(N == 64)
204 {
205 int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
206 src_thread_addr_offset,
207 src_wave_addr_offset,
208 static_cast<index_t>(coherence));
209 int32x4_t tmp1 =
210 __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
211 src_thread_addr_offset,
212 src_wave_addr_offset + 4 * sizeof(int32_t),
213 static_cast<index_t>(coherence));
214 int32x4_t tmp2 =
215 __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
216 src_thread_addr_offset,
217 src_wave_addr_offset + 8 * sizeof(int32_t),
218 static_cast<index_t>(coherence));
219 int32x4_t tmp3 =
220 __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
221 src_thread_addr_offset,
222 src_wave_addr_offset + 12 * sizeof(int32_t),
223 static_cast<index_t>(coherence));
224
226
227 tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
228 tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
229 tmp.AsType<int32x4_t>()(Number<2>{}) = tmp2;
230 tmp.AsType<int32x4_t>()(Number<3>{}) = tmp3;
231
232 return bit_cast<int8x64_t>(tmp);
233 }
234}
235
236template <typename T,
237 index_t N,
239__device__ typename vector_type<T, N>::type
240amd_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
241 index_t src_thread_addr_offset,
242 index_t src_wave_addr_offset)
243{
244 static_assert(
245 (is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
246 (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
247 (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
248 (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
249 (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
250 (is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
251 (is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
252 (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
253 (is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
254 (is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
255 "wrong! not implemented");
256
257 using r_t = typename vector_type<T, N>::type;
259 src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
260 return bit_cast<r_t>(raw_data);
261}
262
263template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
264__device__ void
266 __amdgpu_buffer_rsrc_t dst_wave_buffer_resource,
267 index_t dst_thread_addr_offset,
268 index_t dst_wave_addr_offset)
269{
270 static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
271 "wrong! not implemented");
272
273 if constexpr(N == 1)
274 {
275 __builtin_amdgcn_raw_buffer_store_b8(src_thread_data,
276 dst_wave_buffer_resource,
277 dst_thread_addr_offset,
278 dst_wave_addr_offset,
279 static_cast<index_t>(coherence));
280 }
281 else if constexpr(N == 2)
282 {
283
284 __builtin_amdgcn_raw_buffer_store_b16(bit_cast<int16_t>(src_thread_data),
285 dst_wave_buffer_resource,
286 dst_thread_addr_offset,
287 dst_wave_addr_offset,
288 static_cast<index_t>(coherence));
289 }
290 else if constexpr(N == 4)
291 {
292 __builtin_amdgcn_raw_buffer_store_b32(bit_cast<int32_t>(src_thread_data),
293 dst_wave_buffer_resource,
294 dst_thread_addr_offset,
295 dst_wave_addr_offset,
296 static_cast<index_t>(coherence));
297 }
298 else if constexpr(N == 8)
299 {
300 __builtin_amdgcn_raw_buffer_store_b64(bit_cast<int32x2_t>(src_thread_data),
301 dst_wave_buffer_resource,
302 dst_thread_addr_offset,
303 dst_wave_addr_offset,
304 static_cast<index_t>(coherence));
305 }
306 else if constexpr(N == 16)
307 {
308 __builtin_amdgcn_raw_buffer_store_b128(bit_cast<int32x4_t>(src_thread_data),
309 dst_wave_buffer_resource,
310 dst_thread_addr_offset,
311 dst_wave_addr_offset,
312 static_cast<index_t>(coherence));
313 }
314 else if constexpr(N == 32)
315 {
316 vector_type<int32_t, 8> tmp{bit_cast<int32x8_t>(src_thread_data)};
317
318 __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<0>{}],
319 dst_wave_buffer_resource,
320 dst_thread_addr_offset,
321 dst_wave_addr_offset,
322 static_cast<index_t>(coherence));
323
324 __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<1>{}],
325 dst_wave_buffer_resource,
326 dst_thread_addr_offset,
327 dst_wave_addr_offset + sizeof(int32_t) * 4,
328 static_cast<index_t>(coherence));
329 }
330 else if constexpr(N == 64)
331 {
333
334 __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<0>{}],
335 dst_wave_buffer_resource,
336 dst_thread_addr_offset,
337 dst_wave_addr_offset,
338 static_cast<index_t>(coherence));
339
340 __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<1>{}],
341 dst_wave_buffer_resource,
342 dst_thread_addr_offset,
343 dst_wave_addr_offset + sizeof(int32_t) * 4,
344 static_cast<index_t>(coherence));
345
346 __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<2>{}],
347 dst_wave_buffer_resource,
348 dst_thread_addr_offset,
349 dst_wave_addr_offset + sizeof(int32_t) * 8,
350 static_cast<index_t>(coherence));
351
352 __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<3>{}],
353 dst_wave_buffer_resource,
354 dst_thread_addr_offset,
355 dst_wave_addr_offset + sizeof(int32_t) * 12,
356 static_cast<index_t>(coherence));
357 }
358}
359
360template <typename T,
361 index_t N,
363__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
364 __amdgpu_buffer_rsrc_t dst_wave_buffer_resource,
365 index_t dst_thread_addr_offset,
366 index_t dst_wave_addr_offset)
367{
368 static_assert(
369 (is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
370 (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
371 (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
372 (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
373 (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
374 (is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
375 (is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
377 (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
378 (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
379 "wrong! not implemented");
380
381 using r_t = typename vector_type<int8_t, sizeof(T) * N>::type;
382
384 dst_wave_buffer_resource,
385 dst_thread_addr_offset,
386 dst_wave_addr_offset);
387}
388
389template <typename T, index_t N>
390__device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
391 T* addr)
392{
393 static_assert((is_same<T, bhalf_t>::value && (N == 2 || N == 4 || N == 8)) ||
394 (is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)),
395 "wrong! not implemented");
396
397 if constexpr(is_same<T, half_t>::value)
398 {
399 vector_type<half_t, N> tmp{src_thread_data};
400 static_for<0, N / 2, 1>{}([&](auto i) {
401 __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
402 tmp.template AsType<half2_t>()[i]);
403 });
404 }
405#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)
406 else if constexpr(is_same<T, bhalf_t>::value)
407 {
408 vector_type<bhalf_t, N> tmp{src_thread_data};
409 static_for<0, N / 2, 1>{}([&](auto i) {
410 __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
411 tmp.template AsType<bhalf2_t>()[i]);
412 });
413 }
414#endif
415}
416
417template <typename T, index_t N>
418__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
419 int32x4_t dst_wave_buffer_resource,
420 index_t dst_thread_addr_offset,
421 index_t dst_wave_addr_offset)
422{
423 static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
424 (is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)) ||
425 (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
426 "wrong! not implemented");
427
428 if constexpr(is_same<T, float>::value)
429 {
430 if constexpr(N == 1)
431 {
433 dst_wave_buffer_resource,
434 dst_thread_addr_offset,
435 dst_wave_addr_offset,
436 0);
437 }
438 else if constexpr(N == 2)
439 {
440 vector_type<float, 2> tmp{src_thread_data};
441
443 dst_wave_buffer_resource,
444 dst_thread_addr_offset,
445 dst_wave_addr_offset,
446 0);
447
449 dst_wave_buffer_resource,
450 dst_thread_addr_offset,
451 dst_wave_addr_offset + sizeof(float),
452 0);
453 }
454 else if constexpr(N == 4)
455 {
456 vector_type<float, 4> tmp{src_thread_data};
457
459 dst_wave_buffer_resource,
460 dst_thread_addr_offset,
461 dst_wave_addr_offset,
462 0);
463
465 dst_wave_buffer_resource,
466 dst_thread_addr_offset,
467 dst_wave_addr_offset + sizeof(float),
468 0);
469
471 dst_wave_buffer_resource,
472 dst_thread_addr_offset,
473 dst_wave_addr_offset + 2 * sizeof(float),
474 0);
475
477 dst_wave_buffer_resource,
478 dst_thread_addr_offset,
479 dst_wave_addr_offset + 3 * sizeof(float),
480 0);
481 }
482 }
483 else if constexpr(is_same<T, half_t>::value)
484 {
485 if constexpr(N == 2)
486 {
488 dst_wave_buffer_resource,
489 dst_thread_addr_offset,
490 dst_wave_addr_offset,
491 0);
492 }
493 else if constexpr(N == 4)
494 {
495 vector_type<half_t, 4> tmp{src_thread_data};
496
497 static_for<0, 2, 1>{}([&](auto i) {
499 dst_wave_buffer_resource,
500 dst_thread_addr_offset,
501 dst_wave_addr_offset + i * sizeof(half2_t),
502 0);
503 });
504 }
505 else if constexpr(N == 8)
506 {
507 vector_type<half_t, 8> tmp{src_thread_data};
508
509 static_for<0, 4, 1>{}([&](auto i) {
511 dst_wave_buffer_resource,
512 dst_thread_addr_offset,
513 dst_wave_addr_offset + i * sizeof(half2_t),
514 0);
515 });
516 }
517 }
518 else if constexpr(is_same<T, int32_t>::value)
519 {
520 if constexpr(N == 1)
521 {
523 dst_wave_buffer_resource,
524 dst_thread_addr_offset,
525 dst_wave_addr_offset,
526 0);
527 }
528 else if constexpr(N == 2)
529 {
530 vector_type<int32_t, 2> tmp{src_thread_data};
531
533 dst_wave_buffer_resource,
534 dst_thread_addr_offset,
535 dst_wave_addr_offset,
536 0);
537
539 dst_wave_buffer_resource,
540 dst_thread_addr_offset,
541 dst_wave_addr_offset + sizeof(int32_t),
542 0);
543 }
544 else if constexpr(N == 4)
545 {
546 vector_type<int32_t, 4> tmp{src_thread_data};
547
549 dst_wave_buffer_resource,
550 dst_thread_addr_offset,
551 dst_wave_addr_offset,
552 0);
553
555 dst_wave_buffer_resource,
556 dst_thread_addr_offset,
557 dst_wave_addr_offset + sizeof(int32_t),
558 0);
559
561 dst_wave_buffer_resource,
562 dst_thread_addr_offset,
563 dst_wave_addr_offset + 2 * sizeof(int32_t),
564 0);
565
567 dst_wave_buffer_resource,
568 dst_thread_addr_offset,
569 dst_wave_addr_offset + 3 * sizeof(int32_t),
570 0);
571 }
572 }
573}
574
575template <typename T, index_t N>
576__device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::type src_thread_data,
577 int32x4_t dst_wave_buffer_resource,
578 index_t dst_thread_addr_offset,
579 index_t dst_wave_addr_offset)
580{
581 static_assert((is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
582 "wrong! not implemented");
583 if constexpr(is_same<T, double>::value)
584 {
585 if constexpr(N == 1)
586 {
588 dst_wave_buffer_resource,
589 dst_thread_addr_offset,
590 dst_wave_addr_offset,
591 0);
592 }
593 else if constexpr(N == 2)
594 {
595 vector_type<double, 2> tmp{src_thread_data};
596
598 dst_wave_buffer_resource,
599 dst_thread_addr_offset,
600 dst_wave_addr_offset,
601 0);
602
604 dst_wave_buffer_resource,
605 dst_thread_addr_offset,
606 dst_wave_addr_offset + sizeof(double),
607 0);
608 }
609 else if constexpr(N == 4)
610 {
611 vector_type<double, 4> tmp{src_thread_data};
612
614 dst_wave_buffer_resource,
615 dst_thread_addr_offset,
616 dst_wave_addr_offset,
617 0);
618
620 dst_wave_buffer_resource,
621 dst_thread_addr_offset,
622 dst_wave_addr_offset + sizeof(double),
623 0);
624
626 dst_wave_buffer_resource,
627 dst_thread_addr_offset,
628 dst_wave_addr_offset + 2 * sizeof(double),
629 0);
630
632 dst_wave_buffer_resource,
633 dst_thread_addr_offset,
634 dst_wave_addr_offset + 3 * sizeof(double),
635 0);
636 }
637 }
638}
639
640// buffer_load requires:
641// 1) p_src_wave must point to global memory space
642// 2) p_src_wave must be a wavewise pointer.
643// It is user's responsibility to make sure that is true.
644template <typename T,
645 index_t N,
647__device__ typename vector_type_maker<T, N>::type::type
649 index_t src_thread_element_offset,
650 bool src_thread_element_valid,
651 index_t src_element_space_size)
652{
653 const __amdgpu_buffer_rsrc_t src_wave_buffer_resource =
654 make_wave_buffer_resource_new(p_src_wave, src_element_space_size);
655
656 index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
657
658 using vector_t = typename vector_type_maker<T, N>::type::type;
659 using scalar_t = typename scalar_type<vector_t>::type;
660
661 constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
662
663#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
664 uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
666 src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
667
668#else
669
671 src_wave_buffer_resource, src_thread_addr_offset, 0)};
672 return src_thread_element_valid ? tmp : vector_t(0);
673#endif
674}
675
676// buffer_load requires:
677// 1) p_src_wave must point to global memory space
678// 2) p_src_wave must be a wavewise pointer.
679// It is user's responsibility to make sure that is true.
680template <typename T,
681 index_t N,
683__device__ typename vector_type_maker<T, N>::type::type
685 index_t src_thread_element_offset,
686 bool src_thread_element_valid,
687 index_t src_element_space_size,
688 T customized_value)
689{
690 const __amdgpu_buffer_rsrc_t src_wave_buffer_resource =
691 make_wave_buffer_resource_new(p_src_wave, src_element_space_size);
692
693 index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
694
695 using vector_t = typename vector_type_maker<T, N>::type::type;
696 using scalar_t = typename scalar_type<vector_t>::type;
697
698 constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
699
701 src_wave_buffer_resource, src_thread_addr_offset, 0)};
702
703 return src_thread_element_valid ? tmp : vector_t(customized_value);
704}
705
706// buffer_store requires:
707// 1) p_dst_wave must point to global memory
708// 2) p_dst_wave must be a wavewise pointer.
709// It is user's responsibility to make sure that is true.
710template <typename T,
711 index_t N,
713__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
714 T* p_dst_wave,
715 const index_t dst_thread_element_offset,
716 const bool dst_thread_element_valid,
717 const index_t dst_element_space_size)
718{
719 const __amdgpu_buffer_rsrc_t dst_wave_buffer_resource =
720 make_wave_buffer_resource_new(p_dst_wave, dst_element_space_size);
721
722 index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
723
724 using vector_t = typename vector_type_maker<T, N>::type::type;
725 using scalar_t = typename scalar_type<vector_t>::type;
726 constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
727
728#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
729 uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
731 src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
732#else
733 if(dst_thread_element_valid)
734 {
736 src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
737 }
738#endif
739}
740
741// buffer_atomic_add requires:
742// 1) p_dst_wave must point to global memory
743// 2) p_dst_wave must be a wavewise pointer.
744// It is user's responsibility to make sure that is true.
745template <typename T, index_t N>
746__device__ void
748 T* p_dst_wave,
749 const index_t dst_thread_element_offset,
750 const bool dst_thread_element_valid,
751 const index_t dst_element_space_size)
752{
753 const int32x4_t dst_wave_buffer_resource =
754 make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
755
756 index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
757
758 using vector_t = typename vector_type_maker<T, N>::type::type;
759 using scalar_t = typename scalar_type<vector_t>::type;
760 constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
761
762 if constexpr(is_same<T, bhalf_t>::value)
763 {
764 if(dst_thread_element_valid)
765 {
767 src_thread_data, p_dst_wave + dst_thread_element_offset);
768 }
769 }
770 else
771 {
772#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
773 uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
774
776 src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
777#else
778 if(dst_thread_element_valid)
779 {
781 src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
782 }
783#endif
784 }
785}
786
787// buffer_atomic_max requires:
788// 1) p_dst_wave must point to global memory
789// 2) p_dst_wave must be a wavewise pointer.
790// It is user's responsibility to make sure that is true.
791template <typename T, index_t N>
792__device__ void
794 T* p_dst_wave,
795 const index_t dst_thread_element_offset,
796 const bool dst_thread_element_valid,
797 const index_t dst_element_space_size)
798{
799 const int32x4_t dst_wave_buffer_resource =
800 make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
801
802 index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
803
804 using vector_t = typename vector_type_maker<T, N>::type::type;
805 using scalar_t = typename scalar_type<vector_t>::type;
806 constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
807
808#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
809 uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
810
812 src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
813#else
814 if(dst_thread_element_valid)
815 {
817 src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
818 }
819#endif
820}
821
822// Direct loads from global to LDS.
823__device__ void
825 __attribute__((address_space(3))) uint32_t* lds_ptr,
826 index_t size,
827 index_t voffset,
828 index_t soffset,
829 index_t offset,
830 index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
831
832#ifndef __HIPCC_RTC__
833template <typename T, index_t NumElemsPerThread>
834__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
835 const index_t global_offset,
836 T* lds_base_ptr,
837 const index_t lds_offset,
838 const bool is_valid,
839 const index_t src_element_space_size)
840{
841 // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes).
842 // For gfx950: supports 1, 3, or 4 DWORDs per thread
843 // For gfx942: supports exactly 1 DWORD per thread
844 constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
845#if defined(__gfx950__)
846 constexpr auto dword_bytes = 4;
847 static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 ||
848 bytes_per_thread == dword_bytes * 4);
849#elif defined(__gfx942__)
850 constexpr auto dword_bytes = 4;
851 static_assert(bytes_per_thread == dword_bytes);
852#endif
853
854 const int32x4_t src_resource =
855 make_wave_buffer_resource(global_base_ptr, src_element_space_size);
856 const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
857
858#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
859 T* lds_ptr = lds_base_ptr + lds_offset;
860#ifndef CK_CODE_GEN_RTC
861 auto const lds_ptr_sgpr =
862 __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
863#else
864 auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<size_t>(lds_ptr)));
865#endif
866 asm volatile("s_mov_b32 m0, %0; \n\t"
867 "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
868 "v"(global_offset_bytes),
869 "s"(src_resource)
870 : "memory");
871#else
872 // LDS pointer must be attributed with the LDS address space.
873 __attribute__((address_space(3))) uint32_t* lds_ptr =
874#ifndef CK_CODE_GEN_RTC
875 reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
876 reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
877#else
878 reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
879 reinterpret_cast<size_t>(lds_base_ptr + lds_offset));
880#endif
881
883 src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
884#endif
885}
886#endif
887
888} // namespace ck
#define CK_BUFFER_RESOURCE_3RD_DWORD
Definition ck.hpp:80
Definition ck.hpp:268
__device__ int32x4_t make_wave_buffer_resource_with_default_range(T *p_wave)
Definition utility/amd_buffer_addressing.hpp:38
__device__ void amd_buffer_store(const typename vector_type_maker< T, N >::type::type src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition utility/amd_buffer_addressing.hpp:894
__device__ void amd_direct_load_global_to_lds(const T *global_base_ptr, const index_t global_offset, T *lds_base_ptr, const index_t lds_offset, const bool is_valid, const index_t src_element_space_size)
Definition utility/amd_buffer_addressing.hpp:1015
__device__ void amd_buffer_atomic_max(const typename vector_type_maker< T, N >::type::type src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition utility/amd_buffer_addressing.hpp:974
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__device__ void amd_buffer_store_impl(const typename vector_type< T, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:544
AmdBufferCoherenceEnum
Definition utility/amd_buffer_addressing.hpp:295
@ GLC
Definition utility/amd_buffer_addressing.hpp:297
@ SYSTEM_NT1
Definition utility/amd_buffer_addressing.hpp:310
@ WAVE_NT0
Definition utility/amd_buffer_addressing.hpp:303
@ GLC_SLC
Definition utility/amd_buffer_addressing.hpp:299
@ SLC
Definition utility/amd_buffer_addressing.hpp:298
@ DefaultCoherence
Definition utility/amd_buffer_addressing.hpp:296
@ DEVICE_NT1
Definition utility/amd_buffer_addressing.hpp:308
@ SYSTEM_NT0
Definition utility/amd_buffer_addressing.hpp:309
@ GROUP_NT1
Definition utility/amd_buffer_addressing.hpp:306
@ DEVICE_NT0
Definition utility/amd_buffer_addressing.hpp:307
@ GROUP_NT0
Definition utility/amd_buffer_addressing.hpp:305
@ WAVE_NT1
Definition utility/amd_buffer_addressing.hpp:304
__device__ int32x4_t make_wave_buffer_resource(T *p_wave, index_t element_space_size)
Definition utility/amd_buffer_addressing.hpp:23
__device__ void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, uint32_t *lds_ptr, index_t size, index_t voffset, index_t soffset, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds")
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ void amd_buffer_atomic_add_impl(const typename vector_type< T, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:599
__device__ vector_type_maker< T, N >::type::type amd_buffer_load_invalid_element_return_customized_value(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size, T customized_value)
Definition utility/amd_buffer_addressing.hpp:865
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(float vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32")
__device__ void amd_global_atomic_add_impl(const typename vector_type< T, N >::type src_thread_data, T *addr)
Definition utility/amd_buffer_addressing.hpp:571
typename vector_type< int32_t, 4 >::type int32x4_t
Definition dtype_vector.hpp:2168
__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(half2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16")
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T *p_wave, index_t element_space_size)
Definition utility/amd_buffer_addressing_builtins.hpp:53
typename vector_type< int32_t, 2 >::type int32x2_t
Definition dtype_vector.hpp:2167
__device__ vector_type< int8_t, N >::type amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:315
__device__ void amd_buffer_atomic_add(const typename vector_type_maker< T, N >::type::type src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition utility/amd_buffer_addressing.hpp:928
__device__ double llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int32x4_t rsrc, int voffset, int soffset, int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64")
typename vector_type< half_t, 2 >::type half2_t
Definition dtype_vector.hpp:2153
__device__ vector_type_maker< T, N >::type::type amd_buffer_load_invalid_element_return_zero(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size)
Definition utility/amd_buffer_addressing.hpp:829
__device__ void amd_buffer_atomic_max_impl(const typename vector_type< T, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:757
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(int32_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32")
__device__ vector_type< T, N >::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:419
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void amd_buffer_store_impl_raw(const typename vector_type< int8_t, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:446
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T *p_wave)
Definition utility/amd_buffer_addressing_builtins.hpp:66
signed short int16_t
Definition stdint.h:122
_W64 unsigned int uintptr_t
Definition stdint.h:164
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition data_type.hpp:39
Definition functional2.hpp:33
Definition dtype_vector.hpp:30
Definition dtype_vector.hpp:10
Definition utility/amd_buffer_addressing.hpp:11
int32x4_t content
Definition utility/amd_buffer_addressing.hpp:16
StaticallyIndexedArray< int32_t, 4 > config
Definition utility/amd_buffer_addressing.hpp:19
StaticallyIndexedArray< int32_t, 4 > range
Definition utility/amd_buffer_addressing.hpp:18
StaticallyIndexedArray< T *, 2 > address
Definition utility/amd_buffer_addressing.hpp:17
__device__ constexpr BufferResource()
Definition utility/amd_buffer_addressing_builtins.hpp:12