unary_element_wise_operation.hpp Source File

unary_element_wise_operation.hpp Source File#

Composable Kernel: unary_element_wise_operation.hpp Source File
tile/ops/elementwise/unary_element_wise_operation.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7#include <cstdint>
8#include <type_traits>
9
10#define CONSTEXPR_LOOKUP_TABLE_FOR_BF16 1
11#define CONSTEXPR_LOOKUP_TABLE_FOR_FP8 0
12#define CONSTEXPR_LOOKUP_TABLE_FOR_BF8 0
13
14namespace ck_tile {
15namespace element_wise {
16
17// Generalized constexpr lookup table generator
18template <typename T, std::size_t N, typename F, std::size_t... Is>
19constexpr std::array<T, N> make_lookup_table_impl(F&& func, std::index_sequence<Is...>)
20{
21 return {func(Is)...};
22}
23
24template <typename T, std::size_t N, typename F>
25constexpr std::array<T, N> make_lookup_table(F&& func)
26{
27 return make_lookup_table_impl<T, N>(std::forward<F>(func), std::make_index_sequence<N>{});
28}
29
48{
49 const int LO = 0x000f000f;
50 const int HI = 0x00f000f0;
51 const int EX = 0x64006400;
52
53 int lo;
54 int hi;
55 // Extract the two int4 at low bit and create two fp16 number.
56 asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
57 // Extract the two int4 at hight bit and create two fp16 number.
58 asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
59
60 const int SUB = 0xE408E408; // half2 {-1032, -1032}
61 const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
62 const int ADD = 0xd480d480; // half2 {-72, -72}
63
64 fp16x4_t res;
65
66 // for two fp16 from lowbit, subtract 1032 to get correct fp16 value
67 asm volatile("v_pk_add_f16 %0, %1, %2"
68 : "=v"(res.lo)
69 : "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
70
71 // for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
72 asm volatile(
73 "v_pk_fma_f16 %0, %1, %2, %3"
74 : "=v"(res.hi)
75 : "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
76
77 return res;
78}
79
93{
94 const int LO = 0x000f000f;
95 const int HI = 0x00f000f0;
96 const int EX = 0x64006400;
97
98 int lo;
99 int hi;
100 // Extract the two int4 at low bit and create two fp16 number.
101 asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
102 // Extract the two int4 at hight bit and create two fp16 number.
103 asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
104
105 const int SUB = 0xE408E408; // half2 {-1032, -1032}
106 const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
107 const int ADD = 0xd480d480; // half2 {-72, -72}
108
109 fp16x4_t res;
110
111 asm volatile("v_pk_add_f16 %0, %1, %2"
112 : "=v"(res.lo)
113 : "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
114
115 asm volatile(
116 "v_pk_fma_f16 %0, %1, %2, %3"
117 : "=v"(res.hi)
118 : "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
119
120 asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.lo) : "v"(res.lo), "v"(scale));
121
122 asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.hi) : "v"(res.hi), "v"(scale));
123
124 return res;
125}
126
140{
141#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF16
142 // This approach fails validation in GEMM tests.
143 uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
144
145 static constexpr uint32_t fp32_base = 0x4B000000;
146
147 float fp32_intermediates[4];
148
149 uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
150
151 fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
152 fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
153 fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
154 fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
155
156 fp32_intermediates[0] -= 8388616.f;
157 fp32_intermediates[1] -= 8388616.f;
158 fp32_intermediates[2] -= 8388616.f;
159 fp32_intermediates[3] -= 8388616.f;
160
161 bf16x4_t res;
162 res.lo = bit_cast<bf16x2_t>(
163 __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
164 res.hi = bit_cast<bf16x2_t>(
165 __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
166
167 return res;
168#else
169 // Lookup table for bf16_t values corresponding to int4 values -8 to 7
170 constexpr auto bf16_lookup_table = make_lookup_table<bf16_t, 16>(
171 [](int i) { return bit_cast<bf16_t>(float_to_bf16_rtn_raw(i - 8)); });
172
173 return bf16x4_t{bf16_lookup_table[(q >> 0) & 0xf],
174 bf16_lookup_table[(q >> 16) & 0xf],
175 bf16_lookup_table[(q >> 4) & 0xf],
176 bf16_lookup_table[(q >> 20) & 0xf]};
177#endif
178}
179
180#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
194{
195#if CK_TILE_USE_OCP_FP8
196 // register values [3, 2, 1, 0]
197 static constexpr uint32_t reg0 = 0xcaccced0;
198 // register values [7, 6, 5, 4]
199 static constexpr uint32_t reg1 = 0xb8c0c4c8;
200 // register values [-1, -2, -3, -4]
201 static constexpr uint32_t reg2 = 0x44403800;
202 // register values [-5, -6, -7, -8]
203 static constexpr uint32_t reg3 = 0x4e4c4a48;
204#else
205 // register values [3, 2, 1, 0]
206 static constexpr uint32_t reg0 = 0xd2d4d6d8;
207 // register values [7, 6, 5, 4]
208 static constexpr uint32_t reg1 = 0xc0c8ccd0;
209 // register values [-1, -2, -3, -4]
210 static constexpr uint32_t reg2 = 0x4C484000;
211 // register values [-5, -6, -7, -8]
212 static constexpr uint32_t reg3 = 0x56545250;
213#endif
214
215 uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
216
217 uint32_t dict_sel = a & 0x07070707;
218 uint32_t sign = a >> 1;
219 asm volatile("v_and_or_b32 %0, %1, %2, %3"
220 : "=v"(final_sel)
221 : "v"(sign), "v"(0x04040404), "v"(0x03020100));
222
223 tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
224 tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
225 tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
226
227 a >>= 4;
228 dict_sel = a & 0x07070707;
229 sign = a >> 1;
230 asm volatile("v_and_or_b32 %0, %1, %2, %3"
231 : "=v"(final_sel)
232 : "v"(sign), "v"(0x04040404), "v"(0x03020100));
233
234 tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
235 tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
236 tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
237 auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200);
238 auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301);
239
240 return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
241}
242#else
243CK_TILE_DEVICE fp8x4_t i4_to_fp8x4(int q)
244{
245 // The approach below can be used once this compiler issue is resolved:
246 // "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
247 // Lookup table for fp8_t values corresponding to int4 values -8 to 7
248 constexpr auto fp8_lookup_table = make_lookup_table<fp8_t, 16>(
249 [](int i) { return impl::cast_to_f8<float, fp8_t, true, false>(i - 8, 0); });
250
251 return fp8x4_t{fp8_lookup_table[(q >> 0) & 0xf],
252 fp8_lookup_table[(q >> 16) & 0xf],
253 fp8_lookup_table[(q >> 4) & 0xf],
254 fp8_lookup_table[(q >> 20) & 0xf]};
255}
256#endif
257
259{
260 float res;
261 asm volatile("v_cvt_f32_fp8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
262 return res;
263}
264
266{
267 float res;
268 asm volatile("v_cvt_f32_bf8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
269 return res;
270}
271
272#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
286{
287#if CK_TILE_USE_OCP_FP8
288 // register values [3, 2, 1, 0]
289 static constexpr uint32_t reg0 = 0Xc5c6c7c8;
290 // register values [7, 6, 5, 4]
291 static constexpr uint32_t reg1 = 0Xbcc0c2c4;
292 // register values [11, 10, 9, 8]
293 static constexpr uint32_t reg2 = 0X42403c00;
294 // register values [15, 14, 13, 12]
295 static constexpr uint32_t reg3 = 0X47464544;
296#else
297 // register values [3, 2, 1, 0]
298 static constexpr uint32_t reg0 = 0Xc9cacbcc;
299 // register values [7, 6, 5, 4]
300 static constexpr uint32_t reg1 = 0Xc0c4c6c8;
301 // register values [11, 10, 9, 8]
302 static constexpr uint32_t reg2 = 0X46444000;
303 // register values [15, 14, 13, 12]
304 static constexpr uint32_t reg3 = 0X4b4a4948;
305#endif
306
307 uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
308
309 uint32_t dict_sel = a & 0x07070707;
310 uint32_t sign = a >> 1;
311 asm volatile("v_and_or_b32 %0, %1, %2, %3"
312 : "=v"(final_sel)
313 : "v"(sign), "v"(0x04040404), "v"(0x03020100));
314
315 tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
316 tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
317 tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
318
319 a >>= 4;
320 dict_sel = a & 0x07070707;
321 sign = a >> 1;
322 asm volatile("v_and_or_b32 %0, %1, %2, %3"
323 : "=v"(final_sel)
324 : "v"(sign), "v"(0x04040404), "v"(0x03020100));
325
326 tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
327 tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
328 tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
329 auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200);
330 auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301);
331
332 return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
333}
334#else
335CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
336{
337 // The approach below can be used once this compiler issue is resolved:
338 // "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
339 // Lookup table for bf8_t values corresponding to int4 values -8 to 7
340 constexpr auto bf8_lookup_table = make_lookup_table<bf8_t, 16>(
341 [](int i) { return impl::cast_to_f8<float, bf8_t, true, false>(i - 8, 0); });
342
343 return bf8x4_t{bf8_lookup_table[(q >> 0) & 0xf],
344 bf8_lookup_table[(q >> 16) & 0xf],
345 bf8_lookup_table[(q >> 4) & 0xf],
346 bf8_lookup_table[(q >> 20) & 0xf]};
347}
348#endif
349
351{
352 static constexpr const char* name = "PassThroughPack8";
353
354 template <typename Y, typename X>
355 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
356
357 CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const
358 {
359 y.lo = i4_to_half4(bit_cast<int>(x));
360 y.hi = i4_to_half4(bit_cast<int>(x) >> 8);
361 }
362
363 CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
364 {
365 y.lo = i4_to_bhalf4(bit_cast<int>(x));
366 y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8);
367 }
368
369 CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
370 {
371#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
373#else
374 y.lo = i4_to_fp8x4(bit_cast<int>(x));
375 y.hi = i4_to_fp8x4(bit_cast<int>(x) >> 8);
376#endif
377 }
378
379 CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
380 {
381#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
383#else
384 y.lo = i4_to_bf8x4(bit_cast<int>(x));
385 y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
386#endif
387 }
388 constexpr const static bool is_pack8_invocable = true;
389};
390
392{
393 static constexpr const char* name = "DequantPack8";
394
395 template <typename Y, typename X, typename Z>
396 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
397
398 CK_TILE_HOST_DEVICE constexpr void
399 operator()(fp16x8_t& y, const pk_int4x4_t& x, const fp16x2_t& z) const
400 {
401 y.lo = i4_to_half4_scale(bit_cast<int>(x), z);
402 y.hi = i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
403 }
404
405 constexpr const static bool is_pack8_invocable = true;
406};
407
409{
410 static constexpr const char* name = "PassThroughPack2";
411
412 template <typename Y, typename X>
413 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
414
415#if 0
416 CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
417 {
418 auto t = type_convert<float2_t>(x);
420 }
421#endif
422
423 CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
424 {
425 uint8_t x_u8 = bit_cast<uint8_t>(x);
426 uint8_t x_l = (x_u8 & 0x0f) >> 0;
427 uint8_t x_h = (x_u8 & 0xf0) >> 4;
428
429 y.lo = type_convert<half_t>(x_l);
430 y.hi = type_convert<half_t>(x_h);
431 }
432
433 constexpr const static bool is_pack2_invocable = true;
434};
435
437{
438 static constexpr const char* name = "PassThrough";
439
440 template <class T>
441 using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
442
443 template <class Y, class X>
444 CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const
445 {
446 /* Only do the assignment when
447 - y is an *l-value* and
448 - y is *not* const */
449 if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
450 {
452 }
453 /* otherwise (r-value or const) → do nothing */
454 }
455
456 template <typename E, typename C, typename... Ds>
457 CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&...) const -> void
458 {
459 // Just assign e with c
460 if constexpr(std::is_same_v<E, C>)
461 {
462 e = c;
463 }
464 else
465 {
467 }
468 }
469};
470
472{
473 static constexpr const char* name = "AddScale";
474
475 template <typename E, typename... As>
476 CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const
477 {
478 // Start with the base value c
479 float result = ck_tile::type_convert<float>(0.0f);
480
481 // Add by each D parameter using fold expression
482 ((result += ck_tile::type_convert<float>(as)), ...);
483
484 a = ck_tile::type_convert<E>(scale * result);
485 }
486
487 float scale = 1.0;
488};
489
491{
492 static constexpr const char* name = "MultiDMultiply";
493
494 template <typename E, typename C, typename... Ds>
495 CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
496 {
497 // Start with the base value c
498 float result = ck_tile::type_convert<float>(c);
499
500 // Multiply by each D parameter using fold expression
501 ((result *= ck_tile::type_convert<float>(ds)), ...);
502
503 e = ck_tile::type_convert<E>(result);
504 }
505};
506
508{
509 static constexpr const char* name = "MultiDAdd";
510
511 template <typename E, typename C, typename... Ds>
512 CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
513 {
514 // Start with the base value c
515 float result = ck_tile::type_convert<float>(c);
516
517 // Add by each D parameter using fold expression
518 ((result += ck_tile::type_convert<float>(ds)), ...);
519
520 e = ck_tile::type_convert<E>(result);
521 }
522};
523
525{
526 static constexpr const char* name = "UnaryConvert";
527
528 template <typename Y, typename X>
529 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
530 {
531 y = type_convert<Y>(x);
532 }
533};
534
535#if 0
536struct ConvertBF16RTN
537{
538 // convert to bf16 using round to nearest (rtn)
539 template <typename Y, typename X>
540 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
541 {
542 // check Y datatype
543 static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
544
545 // check X datatype
546 static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
547 "Data type is not supported by this operation!");
548
549 y = bf16_convert_rtn<Y>(x);
550 }
551};
552
553struct ConvertF8SR
554{
555 // convert to fp8 using stochastic rounding (SR)
556 template <typename Y, typename X>
557 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
558 {
559 // check Y datatype
560 static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
561 "Data type is not supported by this operation!");
562
563 // check X datatype
564 static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
565 "Data type is not supported by this operation!");
566
567 y = f8_convert_sr<Y>(x);
568 }
569};
570
571struct ConvertF8RNE
572{
573 // convert to fp8 using rounding to nearest even
574 template <typename Y, typename X>
575 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
576 {
577 // check Y datatype
578 static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
579 "Data type is not supported by this operation!");
580
581 // check X datatype
582 static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
583 "Data type is not supported by this operation!");
584
585 y = f8_convert_rne<Y>(x);
586 }
587};
588#endif
589
590struct Scale
591{
592 static constexpr const char* name = "Scale";
593
594 CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {}
595
596 template <typename Y, typename X>
597 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
598 {
600 }
601
602 template <>
608
609 template <>
612 {
613 const float x_tmp = ck_tile::type_convert<float>(x);
614 const float y_tmp = scale_ * x_tmp;
616 };
617
618 template <>
619 CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
620 {
621 y = scale_ * x;
622 };
623
624 template <>
625 CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
626 {
627 y = scale_ * x;
628 };
629
630 template <>
631 CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
632 {
634 };
635
636 float scale_;
637};
638
640{
641 static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity";
642
644
645 template <typename Y, typename X>
646 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
647
648 template <>
649 CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
650 {
652 };
653
654 float scale_;
655};
656
658{
659 static constexpr const char* name = "UnaryDivide";
660
661 CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
662
663 template <typename T>
664 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
665 {
666 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
667 std::is_same_v<T, int32_t>,
668 "Data type is not supported by this operation!");
669
670 y = x / type_convert<T>(divider_);
671 };
672
674};
675
677{
678 static constexpr const char* name = "UnarySquare";
679
680 template <typename Y, typename X>
681 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
682 {
683 static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t> ||
684 std::is_same_v<X, double> || std::is_same_v<X, int32_t> ||
685 std::is_same_v<X, int8_t>
686#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
687 || std::is_same_v<X, int4_t>
688#endif
689 ,
690 "Data type is not supported by this operation!");
691 y = x * x;
692 };
693};
694
696{
697 static constexpr const char* name = "UnaryAbs";
698
699 template <typename T>
700 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
701 {
702 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
703 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
704 std::is_same_v<T, int8_t>,
705 "Data type is not supported by this operation!");
706
707 y = ck_tile::abs(x);
708 };
709};
710
712{
713 static constexpr const char* name = "UnarySqrt";
714
715 template <typename T>
716 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
717 {
718 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
719 "Data type is not supported by this operation!");
720
721 y = ck_tile::sqrt(x);
722 };
723};
724
725struct Relu
726{
727 static constexpr const char* name = "Relu";
728
729 template <typename T>
730 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
731 {
732 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
733 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
734 std::is_same_v<T, int8_t>,
735 "Data type is not supported by this operation!");
736 y = x > 0 ? x : 0;
737 }
738
739 template <>
741 {
742 float x_f32 = ck_tile::type_convert<float>(x);
743 float y_f32 = x_f32 > 0 ? x_f32 : 0;
745 }
746};
747
748// Fast GeLU
749// https://paperswithcode.com/method/gelu
750// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
751// host code use higher accuracy "exp" and "div"
752// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
754{
755 static constexpr const char* name = "FastGelu";
756
757 template <typename Y, typename X>
758 CK_TILE_HOST void operator()(Y& y, const X& x) const;
759
760 template <typename Y, typename X>
761 CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
762
763 template <>
764 CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
765 {
766 // const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
767 const float c1 = -2.0 * 0.035677f;
768 const float c2 = -2.0 * 0.797885f;
769 const float u = x * (c1 * x * x + c2);
770 const float emu = exp(u);
771 y = x / (1.f + emu);
772 }
773
774 // device code, use lower precision "__ocml_exp_f32" and "rcp"
775 template <>
776 CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
777 {
778 // const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
779 const float c1 = -2.0 * 0.035677f;
780 const float c2 = -2.0 * 0.797885f;
781 const float u = x * (c1 * x * x + c2);
782 const float emu = __ocml_exp_f32(u);
783
784 y = x * ck_tile::rcp(1.f + emu);
785 }
786
787 template <>
789 const ck_tile::fp16_t& x) const
790 {
791 float y_f;
792
793 this->operator()<float, float>(y_f, type_convert<float>(x));
794
796 }
797
798 template <>
800 const ck_tile::fp16_t& x) const
801 {
802 float y_f;
803
804 this->operator()<float, float>(y_f, type_convert<float>(x));
805
807 }
808
809 template <>
810 CK_TILE_HOST void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
811 {
812 float y_f;
813
814 this->operator()<float, float>(y_f, x);
815
817 }
818
819 template <>
820 CK_TILE_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
821 {
822 float y_f;
823
824 this->operator()<float, float>(y_f, x);
825
827 }
828
829 template <>
830 CK_TILE_HOST void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
831 {
832 float y_f;
833
834 this->operator()<float, float>(y_f, x);
835
837 }
838
839 template <>
840 CK_TILE_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
841 {
842 float y_f;
843
844 this->operator()<float, float>(y_f, x);
845
847 }
848
849 template <>
851 const ck_tile::bf16_t& x) const
852 {
853 float y_f;
854
855 this->operator()<float, float>(y_f, type_convert<float>(x));
856
858 }
859
860 template <>
862 const ck_tile::bf16_t& x) const
863 {
864 float y_f;
865
866 this->operator()<float, float>(y_f, type_convert<float>(x));
867
869 }
870};
871
873{
874 static constexpr const char* name = "FastGeluAsm";
875
876 template <typename Y, typename X>
877 CK_TILE_HOST void operator()(Y& y, const X& x) const;
878
879 template <typename Y, typename X>
880 CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
881
882 template <>
883 CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
884 {
885 // const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
886 const float c1 = -2.0 * 0.035677f;
887 const float c2 = -2.0 * 0.797885f;
888 const float u = x * (c1 * x * x + c2);
889 const float emu = exp(u);
890 y = x / (1.f + emu);
891 }
892
893 // device code, use lower precision "__ocml_exp_f32" and "rcp"
894 template <>
895 CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
896 {
897 const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
898 const float c2 = -2.0 * 0.797885f;
899 const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
900 float tmp;
901
902 asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
903 "v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
904 "v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
905 "v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
906 "v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
907 "s_nop 0 ; hazard for exp\n"
908 "v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
909 "v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
910 "s_nop 0 ; hazard for rcp \n"
911 "v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
912 : [v_y] "=v"(y), [v_tmp] "+v"(tmp)
913 : [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
914 :);
915 }
916
917 template <>
918 CK_TILE_HOST void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
919 {
920 const float c1 = -2.0 * 0.035677f;
921 const float c2 = -2.0 * 0.797885f;
922 const float u0 = x.x * (c1 * x.x * x.x + c2);
923 const float emu0 = exp(u0);
924 y.x = x.x / (1.f + emu0);
925 const float u1 = x.y * (c1 * x.y * x.y + c2);
926 const float emu1 = exp(u1);
927 y.y = x.y / (1.f + emu1);
928 }
929
930 // this is packed verion to remove data hazard for trans
931 template <>
932 CK_TILE_DEVICE void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
933 {
934 const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
935 float c2 = -2.0 * 0.797885f;
936 const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
937 float tmp0, tmp1;
938 float y0 = x.x, y1 = x.y;
939
940 asm volatile(
941 "v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x\n"
942 "v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x\n"
943 "v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
944 "v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
945 "v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)\n"
946 "v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)\n"
947 "v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
948 "v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
949 "v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
950 "v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
951 "v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n"
952 "v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n"
953 "v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n"
954 "v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n"
955 "v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)\n"
956 "v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)\n"
957 : [v_y0] "+v"(y0),
958 [v_y1] "+v"(y1),
959 [v_c2] "+v"(c2),
960 // NOTE! it is totally possible that c2/y0/y1 share same register, they are all local
961 // tmp variables we need to expicitly hint compiler they may read+write, to allow
962 // allocate different register , the side effect is c2=** may issue for every such
963 // inline asm block
964 [v_tmp0] "+v"(tmp0),
965 [v_tmp1] "+v"(tmp1)
966 : [s_c1] "s"(c1), [s_log2e] "s"(log2e_)
967 :);
968 y.x = y0;
969 y.y = y1;
970 }
971};
972
973// https://paperswithcode.com/method/gelu
974// y = 0.5*x*(1+erf(x/sqrt(2)))
975struct Gelu
976{
977 static constexpr const char* name = "Gelu";
978
979 template <typename Y, typename X>
980 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
981
982 template <>
983 CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
984 {
985 y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
986 }
987
988 template <>
991 {
992 y = ck_tile::fp16_t(0.5) * x *
993 (ck_tile::fp16_t(1) + ck_tile::fp16_t(erf(float(0.70710678118f * x))));
994 }
995};
996
998{
999 static constexpr const char* name = "Sigmoid";
1000
1001 template <typename T>
1002 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1003 {
1004 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1005 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1006 std::is_same_v<T, int32_t>,
1007 "Data type is not supported by this operation!");
1008 constexpr T one = type_convert<T>(1);
1009 y = one / (one + ck_tile::exp(-x));
1010 };
1011};
1012
1013struct Silu
1014{
1015 static constexpr const char* name = "Silu";
1016
1017 template <typename T>
1018 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1019 {
1020 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1021 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1022 std::is_same_v<T, int32_t>,
1023 "Data type is not supported by this operation!");
1024 constexpr T one = type_convert<T>(1);
1025 y = x * (one / (one + ck_tile::exp(-x)));
1026 };
1027
1028 template <>
1029 CK_TILE_HOST_DEVICE void operator()<fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
1030 {
1031 constexpr auto one = type_convert<float>(1);
1032 y[0] = x[0] * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x[0]));
1033 y[1] = x[1] * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x[1]));
1034 };
1035};
1036
1037#if 0
1038// Silu, the formular is not so good to do inline asm (dependency)
1039// we put the code here purposely if in the future ppl want to try
1040struct SiluAsm
1041{
1042 template <typename T>
1043 CK_TILE_HOST void operator()(T& y, T& x) const
1044 {
1045 static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
1046 constexpr T one = type_convert<T>(1);
1047 y = x * (one / (one + ck_tile::exp(-x)));
1048 };
1049
1050 template <typename T>
1051 CK_TILE_DEVICE void operator()(T& y, T& x) const
1052 {
1053 static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
1054
1055 const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
1056
1057 // NOTE: x/y can't be same register before inline asm
1058 // "+v" as y, "v" as x is not enought, x/y stil maybe put to same register
1059 T tmp = x;
1060 asm volatile("v_mul_f32 %[v_y], %[s_log2e], %[v_x]\n"
1061 "v_exp_f32 %[v_y], %[v_y]\n"
1062 "s_nop 0 ; hazard for exp\n"
1063 "v_add_f32 %[v_y], %[v_y], 1.0\n"
1064 "v_rcp_f32 %[v_y], %[v_y]\n"
1065 "s_nop 0 ; hazard for rcp\n"
1066 "v_mul_f32 %[v_y], %[v_x], %[v_y]\n"
1067 : [v_y] "+v"(y), [v_x] "+v"(tmp)
1068 : [s_log2e] "s"(log2e_neg_)
1069 :);
1070 };
1071
1072 template <>
1073 CK_TILE_HOST void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
1074 {
1075 constexpr auto one = type_convert<float>(1);
1076 y[0] = x[0] * (one / (one + ck_tile::exp(-x[0])));
1077 y[1] = x[1] * (one / (one + ck_tile::exp(-x[1])));
1078 };
1079
1080 template <>
1081 CK_TILE_DEVICE void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
1082 {
1083 const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
1084
1085 // NOTE: x/y can't be same register before inline asm
1086 // float tmp0 = x[0], tmp1 = x[1];
1087 asm volatile("v_mul_f32 %[v_y0], %[s_log2e], %[v_x0]\n"
1088 "v_mul_f32 %[v_y1], %[s_log2e], %[v_x1]\n"
1089 "v_exp_f32 %[v_y0], %[v_y0]\n"
1090 "v_exp_f32 %[v_y1], %[v_y1]\n"
1091 "v_add_f32 %[v_y0], %[v_y0], 1.0\n"
1092 "v_add_f32 %[v_y1], %[v_y1], 1.0\n"
1093 "v_rcp_f32 %[v_y0], %[v_y0]\n"
1094 "v_rcp_f32 %[v_y1], %[v_y1]\n"
1095 "v_mul_f32 %[v_y0], %[v_x0], %[v_y0]\n"
1096 "v_mul_f32 %[v_y1], %[v_x1], %[v_y1]\n"
1097 : [v_y0] "+v"(y[0]), [v_y1] "+v"(y[1]), [v_x0] "+v"(x[0]), [v_x1] "+v"(x[1])
1098 : [s_log2e] "s"(log2e_neg_)
1099 :);
1100 };
1101};
1102#endif
1103
1104struct TanH
1105{
1106 static constexpr const char* name = "TanH";
1107
1108 template <typename T>
1109 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1110 {
1111 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1112 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1113 std::is_same_v<T, int32_t>,
1114 "Data type is not supported by this operation!");
1115
1116 y = ck_tile::tanh(x);
1117 };
1118};
1119
1120struct ACos
1121{
1122 static constexpr const char* name = "ACos";
1123
1124 template <typename T>
1125 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1126 {
1127 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1128 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1129 std::is_same_v<T, int32_t>,
1130 "Data type is not supported by this operation!");
1131
1132 y = ck_tile::acos(x);
1133 };
1134};
1135
1136struct Neg
1137{
1138 static constexpr const char* name = "Neg";
1139
1140 template <typename T>
1141 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1142 {
1143 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1144 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1145 std::is_same_v<T, int32_t>,
1146 "Data type is not supported by this operation!");
1147
1148 y = ck_tile::neg(x);
1149 };
1150};
1151
1152struct ATan
1153{
1154 static constexpr const char* name = "ATan";
1155
1156 template <typename T>
1157 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1158 {
1159 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1160 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1161 std::is_same_v<T, int32_t>,
1162 "Data type is not supported by this operation!");
1163
1164 y = ck_tile::atan(x);
1165 };
1166};
1167
1168struct Sin
1169{
1170 static constexpr const char* name = "Sin";
1171
1172 template <typename T>
1173 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1174 {
1175 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1176 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1177 std::is_same_v<T, int32_t>,
1178 "Data type is not supported by this operation!");
1179
1180 y = ck_tile::sin(x);
1181 };
1182};
1183
1184struct ASinH
1185{
1186 static constexpr const char* name = "ASinH";
1187
1188 template <typename T>
1189 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1190 {
1191 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1192 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1193 std::is_same_v<T, int32_t>,
1194 "Data type is not supported by this operation!");
1195
1196 y = ck_tile::asinh(x);
1197 };
1198};
1199
1200struct Cos
1201{
1202 static constexpr const char* name = "Cos";
1203
1204 template <typename T>
1205 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1206 {
1207 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1208 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1209 std::is_same_v<T, int32_t>,
1210 "Data type is not supported by this operation!");
1211
1212 y = ck_tile::cos(x);
1213 };
1214};
1215
1216struct ACosH
1217{
1218 static constexpr const char* name = "ACosH";
1219
1220 template <typename T>
1221 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1222 {
1223 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1224 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1225 std::is_same_v<T, int32_t>,
1226 "Data type is not supported by this operation!");
1227
1228 y = ck_tile::acosh(x);
1229 };
1230};
1231
1232struct Tan
1233{
1234 static constexpr const char* name = "Tan";
1235
1236 template <typename T>
1237 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1238 {
1239 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1240 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1241 std::is_same_v<T, int32_t>,
1242 "Data type is not supported by this operation!");
1243
1244 y = ck_tile::tan(x);
1245 };
1246};
1247
1248struct ATanH
1249{
1250 static constexpr const char* name = "ATanH";
1251
1252 template <typename T>
1253 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1254 {
1255 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1256 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1257 std::is_same_v<T, int32_t>,
1258 "Data type is not supported by this operation!");
1259
1260 y = ck_tile::atanh(x);
1261 };
1262};
1263
1264struct SinH
1265{
1266 static constexpr const char* name = "SinH";
1267
1268 template <typename T>
1269 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1270 {
1271 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1272 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1273 std::is_same_v<T, int32_t>,
1274 "Data type is not supported by this operation!");
1275
1276 y = ck_tile::sinh(x);
1277 };
1278};
1279
1280struct Ceil
1281{
1282 static constexpr const char* name = "Ceil";
1283
1284 template <typename T>
1285 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1286 {
1287 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1288 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1289 std::is_same_v<T, int32_t>,
1290 "Data type is not supported by this operation!");
1291
1292 y = ck_tile::ceil(x);
1293 };
1294};
1295
1296struct Exp
1297{
1298 static constexpr const char* name = "Exp";
1299
1300 template <typename T>
1301 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1302 {
1303 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1304 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1305 std::is_same_v<T, int32_t>,
1306 "Data type is not supported by this operation!");
1307
1308 y = ck_tile::exp(x);
1309 };
1310};
1311
1312struct CosH
1313{
1314 static constexpr const char* name = "CosH";
1315
1316 template <typename T>
1317 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1318 {
1319 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1320 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1321 std::is_same_v<T, int32_t>,
1322 "Data type is not supported by this operation!");
1323
1324 y = ck_tile::cosh(x);
1325 };
1326};
1327
1328struct Floor
1329{
1330 static constexpr const char* name = "Floor";
1331
1332 template <typename T>
1333 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1334 {
1335 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1336 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1337 std::is_same_v<T, int32_t>,
1338 "Data type is not supported by this operation!");
1339
1340 y = ck_tile::floor(x);
1341 };
1342};
1343
1344struct Log
1345{
1346 static constexpr const char* name = "Log";
1347
1348 template <typename T>
1349 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1350 {
1351 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1352 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1353 std::is_same_v<T, int32_t>,
1354 "Data type is not supported by this operation!");
1355
1356 y = ck_tile::log(x);
1357 };
1358};
1359
1360struct ASin
1361{
1362 static constexpr const char* name = "ASin";
1363
1364 template <typename T>
1365 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1366 {
1367 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1368 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1369 std::is_same_v<T, int32_t>,
1370 "Data type is not supported by this operation!");
1371
1372 y = ck_tile::asin(x);
1373 };
1374};
1375
1376struct Rcp
1377{
1378 static constexpr const char* name = "Rcp";
1379
1380 template <typename T>
1381 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1382 {
1383 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1384 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
1385 std::is_same_v<T, int32_t>,
1386 "Data type is not supported by this operation!");
1387
1388 y = ck_tile::rcp(x);
1389 };
1390};
1391
1392struct Swish
1393{
1394 static constexpr const char* name = "Swish";
1395
1396 Swish(float beta = 1.0f) : beta_(beta) {}
1397
1398 template <typename Y, typename X>
1399 CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
1400 {
1401 static_assert(std::is_same_v<X, float> || std::is_same_v<X, double> ||
1402 std::is_same_v<X, ck_tile::fp16_t>,
1403 "Data type is not supported by this operation!");
1404
1405 static_assert(std::is_same_v<Y, float> || std::is_same_v<Y, double> ||
1406 std::is_same_v<Y, ck_tile::fp16_t>,
1407 "Data type is not supported by this operation!");
1408
1409 float bx = -beta_ * type_convert<float>(x);
1410 y = type_convert<Y>(x / (1.f + ck_tile::exp(bx)));
1411 };
1412
1413 const float beta_;
1414};
1415
1417{
1418 static constexpr const char* name = "SoftRelu";
1419
1420 SoftRelu(float alpha = 1.f) : alpha_(alpha){};
1421
1422 template <typename T>
1423 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1424 {
1425 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1426 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
1427 std::is_same_v<T, int8_t>,
1428 "Data type is not supported by this operation!");
1429 T casted_alpha = type_convert<T>(alpha_);
1430 constexpr T one = type_convert<T>(1);
1431 y = ck_tile::log(one + ck_tile::exp(x * casted_alpha)) / casted_alpha;
1432 }
1433 const float alpha_;
1434};
1435
1436struct Power
1437{
1438 static constexpr const char* name = "Power";
1439
1440 Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
1441 : alpha_(alpha), beta_(beta), gamma_(gamma){};
1442
1443 template <typename T>
1444 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1445 {
1446 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1447 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
1448 std::is_same_v<T, int8_t>,
1449 "Data type is not supported by this operation!");
1450 T casted_alpha = type_convert<T>(alpha_);
1451 T casted_beta = type_convert<T>(beta_);
1452 T casted_gamma = type_convert<T>(gamma_);
1453 T shifted_scaled_x = casted_alpha + casted_beta * x;
1454 y = ck_tile::pow(shifted_scaled_x, casted_gamma);
1455 }
1456 const float alpha_;
1457 const float beta_;
1458 const float gamma_;
1459};
1460
1462{
1463 static constexpr const char* name = "ClippedRelu";
1464
1465 ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
1466
1467 template <typename T>
1468 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1469 {
1470 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1471 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
1472 std::is_same_v<T, int8_t>,
1473 "Data type is not supported by this operation!");
1474 T casted_alpha = type_convert<T>(alpha_);
1475 T casted_beta = type_convert<T>(beta_);
1476 y = ck_tile::min(casted_beta, ck_tile::max(casted_alpha, x));
1477 }
1478 const float alpha_;
1479 const float beta_;
1480};
1481
1483{
1484 static constexpr const char* name = "LeakyRelu";
1485
1486 LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
1487
1488 template <typename T>
1489 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1490 {
1491 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1492 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
1493 std::is_same_v<T, int8_t>,
1494 "Data type is not supported by this operation!");
1495 T casted_alpha = type_convert<T>(alpha_);
1496 y = x >= 0 ? x : x * casted_alpha;
1497 }
1498 const float alpha_;
1499};
1500
1501struct Elu
1502{
1503 static constexpr const char* name = "Elu";
1504
1505 Elu(float alpha = 1.f) : alpha_(alpha){};
1506
1507 template <typename T>
1508 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1509 {
1510 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1511 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
1512 std::is_same_v<T, int8_t>,
1513 "Data type is not supported by this operation!");
1514 T casted_alpha = type_convert<T>(alpha_);
1515 y = x > 0 ? x : casted_alpha * ck_tile::expm1(x);
1516 }
1517 const float alpha_;
1518};
1519
1521{
1522 static constexpr const char* name = "Logistic";
1523
1524 Logistic(float alpha = 1.f) : alpha_(alpha){};
1525
1526 template <typename T>
1527 CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
1528 {
1529 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
1530 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
1531 std::is_same_v<T, int8_t>,
1532 "Data type is not supported by this operation!");
1533 T casted_alpha = type_convert<T>(alpha_);
1534 constexpr T one = type_convert<T>(1);
1535 y = casted_alpha / (one + ck_tile::exp(-x) * casted_alpha);
1536 }
1537 const float alpha_;
1538};
1539
1540struct Clamp
1541{
1542 CK_TILE_HOST_DEVICE Clamp(float lower = std::numeric_limits<float>::lowest(),
1543 float upper = std::numeric_limits<float>::max())
1544 : lower_(lower), upper_(upper) {};
1545
1546 template <typename T>
1547 CK_TILE_HOST_DEVICE constexpr void operator()(T& y, const T& x) const
1548 {
1551 y = ck_tile::clamp(x, lower, upper);
1552 }
1553
1555};
1556
1558{
1559 static constexpr const char* name = "ConvInvscale";
1560
1562 ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
1563 : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
1564 {
1565 }
1566
1567 template <typename E, typename C>
1568 CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
1569
1570 template <>
1572 const float& c) const
1573 {
1575 };
1576
1580};
1581
1583{
1584 static constexpr const char* name = "ConvScale";
1585
1587 ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
1588 : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
1589 {
1590 }
1591
1592 template <typename E, typename C>
1593 CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
1594
1595 template <>
1597 const float& c) const
1598 {
1600 };
1601
1605};
1606
1608{
1609 static constexpr const char* name = "ConvScaleRelu";
1610
1612 ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
1613 : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
1614 {
1615 }
1616
1617 template <typename E, typename C>
1618 CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
1619
1620 template <>
1622 const float& c) const
1623 {
1624 float x;
1625 Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
1627 };
1628
1632};
1633
1634template <typename DstType, typename SrcType>
1635struct Cast
1636{
1637 static constexpr const char* name = "Cast";
1638
1639 template <typename T>
1640 CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const
1641 {
1643 };
1644};
1645
1660template <typename FuncA, typename FuncB, bool FuncADs = false, bool FuncBDs = false>
1662{
1663 static_assert(!(FuncADs && FuncBDs), "Only one composed function may use the Ds tensor.");
1664
1665 CK_TILE_HOST_DEVICE Compose(FuncA func_a_ = FuncA{}, FuncB func_b_ = FuncB{})
1666 : func_a(func_a_), func_b(func_b_)
1667 {
1668 }
1669
1670 template <typename AIn, typename BOut, typename AOut = AIn, typename... ADs>
1671 CK_TILE_HOST_DEVICE constexpr void operator()(BOut& y, const AIn& x, const ADs&... ds) const
1672 {
1673 AOut tmp;
1674 if constexpr(FuncADs)
1675 {
1676 func_a(tmp, x, ds...);
1677 func_b(y, tmp);
1678 }
1679 else if constexpr(FuncBDs)
1680 {
1681 func_a(tmp, x);
1682 func_b(y, tmp, ds...);
1683 }
1684 else
1685 {
1686 func_a(tmp, x);
1687 func_b(y, tmp);
1688 }
1689 }
1690
1691 const FuncA func_a;
1692 const FuncB func_b;
1693};
1694
1695// support fastconvert of int8 to fp16
1696#if 0
1697template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
1698struct FastNumericArrayConverter
1699{
1700};
1701
1702template <>
1703struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4>
1704{
1705 using InputArray = vector_type<uint8_t, 4>;
1706 using OutputArray = vector_type<ck_tile::fp16_t, 4>;
1707
1708 CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
1709 {
1710 OutputArray Output;
1711
1712 uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
1713 uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
1714
1715 static constexpr uint32_t byte_selector_01 = 0x05010500;
1716 static constexpr uint32_t byte_selector_23 = 0x05030502;
1717 static constexpr uint32_t fp16_adder = 0x64646464;
1718 half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
1719 half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
1720
1721 static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
1722 asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
1723 : "=v"(half_2[0])
1724 : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
1725 asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
1726 : "=v"(half_2[1])
1727 : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
1728
1729 return Output;
1730 }
1731
1732 CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
1733};
1734
1735template <index_t N>
1736struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
1737{
1738 static constexpr int VEC_WIDTH = 4;
1739 static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
1740
1741 using InputArray = vector_type<uint8_t, N>;
1742 using OutputArray = vector_type<ck_tile::fp16_t, N>;
1743
1744 CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
1745 {
1746 FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4> converter;
1747
1748 OutputArray Output;
1749
1750 using Vec_InputArray = vector_type<uint8_t, 4>;
1751 using Vec_OutputArray = vector_type<ck_tile::fp16_t, 4>;
1752
1753 Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
1754 Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
1755
1756 static_for<0, N / VEC_WIDTH, 1>{}(
1757 [&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
1758
1759 return Output;
1760 }
1761
1762 CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
1763};
1764#endif
1765
1766} // namespace element_wise
1767} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
constexpr std::array< T, N > make_lookup_table(F &&func)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:25
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t &scale)
This function dequantizes 4 int4 values into 4 fp16 values and applies scaling.
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:92
CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
This function converts 8 packed 4-bit integers into 8 fp8 values.
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:193
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:258
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
This function converts 4 4-bit integers into 4 bf16 values.
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:139
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
Fast int4x4 to fp16x8_t data type conversion based on paper "Who Says Elephants Can't Run: Bringing L...
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:47
constexpr std::array< T, N > make_lookup_table_impl(F &&func, std::index_sequence< Is... >)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:19
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
This function converts 8 packed 4-bit integers into 8 bf8 values.
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:285
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:265
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
Definition float8.hpp:591
Definition tile/core/algorithm/cluster_descriptor.hpp:13
_Float16 fp16x2_t
Definition half.hpp:385
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
CK_TILE_HOST T acos(T x)
Definition tile/core/numeric/math.hpp:632
fp8_t fp8x4_t
Definition vector_type.hpp:228
int8_t pk_int4x4_t
Definition vector_type.hpp:247
int8_t int8_t
Definition int8.hpp:20
bfloat16_t bf16_t
Definition bfloat16.hpp:113
CK_TILE_HOST T cos(T x)
Definition tile/core/numeric/math.hpp:752
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
CK_TILE_HOST_DEVICE constexpr T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition tile/core/numeric/math.hpp:259
CK_TILE_HOST T ceil(T x)
Definition tile/core/numeric/math.hpp:842
CK_TILE_HOST T acosh(T x)
Definition tile/core/numeric/math.hpp:770
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST T expm1(T x)
Definition tile/core/numeric/math.hpp:956
CK_TILE_HOST T tanh(T x)
Definition tile/core/numeric/math.hpp:614
_Float16 fp16x4_t
Definition vector_type.hpp:137
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
CK_TILE_HOST T atan(T x)
Definition tile/core/numeric/math.hpp:680
CK_TILE_HOST T sin(T x)
Definition tile/core/numeric/math.hpp:698
bf8_t bf8x8_t
Definition vector_type.hpp:238
CK_TILE_HOST T floor(T x)
Definition tile/core/numeric/math.hpp:878
CK_TILE_HOST T sinh(T x)
Definition tile/core/numeric/math.hpp:824
bfloat16_t bf16x4_t
Definition vector_type.hpp:146
CK_TILE_HOST T asin(T x)
Definition tile/core/numeric/math.hpp:716
CK_TILE_HOST T asinh(T x)
Definition tile/core/numeric/math.hpp:734
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
_Float16 fp16x8_t
Definition vector_type.hpp:138
bfloat16_t bf16x8_t
Definition vector_type.hpp:147
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
CK_TILE_HOST T atanh(T x)
Definition tile/core/numeric/math.hpp:806
CK_TILE_HOST T neg(T x)
Definition tile/core/numeric/math.hpp:650
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_rtn_raw(float f)
Definition bfloat16.hpp:118
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
float fp32x2_t
Definition pk_fp4.hpp:22
bf8_t bf8x4_t
Definition vector_type.hpp:237
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition bfloat16.hpp:406
CK_TILE_HOST T tan(T x)
Definition tile/core/numeric/math.hpp:788
fp8_t fp8x8_t
Definition vector_type.hpp:229
CK_TILE_HOST T cosh(T x)
Definition tile/core/numeric/math.hpp:860
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST T pow(T x, T gamma)
Definition tile/core/numeric/math.hpp:938
CK_TILE_HOST T rcp(T x)
Definition tile/core/numeric/math.hpp:896
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
unsigned int uint32_t
Definition stdint.h:126
unsigned char uint8_t
Definition stdint.h:124
unsigned __int64 uint64_t
Definition stdint.h:136
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1217
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1218
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1221
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1121
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1122
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1125
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1185
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1186
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1189
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1361
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1365
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1362
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1249
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1250
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1253
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1153
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1157
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1154
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:472
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:473
float scale
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:487
CK_TILE_HOST_DEVICE constexpr void operator()(E &a, const As &... as) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:476
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1636
CK_TILE_HOST_DEVICE void operator()(DstType &y, const SrcType &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1640
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1637
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1281
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1282
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1285
float lower_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1554
CK_TILE_HOST_DEVICE Clamp(float lower=std::numeric_limits< float >::lowest(), float upper=std::numeric_limits< float >::max())
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1542
CK_TILE_HOST_DEVICE constexpr void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1547
float upper_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1554
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1463
ClippedRelu(float alpha=0.f, float beta=1.f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1465
const float alpha_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1478
const float beta_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1479
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1468
CK_TILE_HOST_DEVICE constexpr void operator()(BOut &y, const AIn &x, const ADs &... ds) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1671
const FuncA func_a
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1691
const FuncB func_b
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1692
CK_TILE_HOST_DEVICE Compose(FuncA func_a_=FuncA{}, FuncB func_b_=FuncB{})
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1665
float scale_wei_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1578
float scale_in_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1577
CK_TILE_HOST_DEVICE void operator()(E &e, const C &c) const
CK_TILE_HOST_DEVICE ConvInvscale(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1562
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1559
float scale_out_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1579
float scale_wei_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1603
float scale_in_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1602
CK_TILE_HOST_DEVICE ConvScale(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1587
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1584
CK_TILE_HOST_DEVICE void operator()(E &e, const C &c) const
float scale_out_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1604
CK_TILE_HOST_DEVICE void operator()(E &e, const C &c) const
float scale_wei_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1630
float scale_out_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1631
CK_TILE_HOST_DEVICE ConvScaleRelu(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1612
float scale_in_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1629
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1609
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1313
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1314
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1317
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1201
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1205
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1202
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:392
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:393
constexpr static const bool is_pack8_invocable
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:405
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t &y, const pk_int4x4_t &x, const fp16x2_t &z) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:399
CK_TILE_HOST_DEVICE void operator()(Y &y, const X &x, const Z &z) const
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1508
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1503
Elu(float alpha=1.f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1505
const float alpha_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1517
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1297
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1301
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1298
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:873
CK_TILE_HOST void operator()(Y &y, const X &x) const
CK_TILE_DEVICE void operator()(Y &y, const X &x) const
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:874
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:754
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:755
CK_TILE_DEVICE void operator()(Y &y, const X &x) const
CK_TILE_HOST void operator()(Y &y, const X &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1329
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1330
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1333
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:976
CK_TILE_HOST_DEVICE void operator()(Y &y, const X &x) const
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:977
const float alpha_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1498
LeakyRelu(float alpha=0.01f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1486
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1489
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1484
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1345
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1346
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1349
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1527
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1522
Logistic(float alpha=1.f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1524
const float alpha_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1537
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:508
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:509
CK_TILE_HOST_DEVICE auto operator()(E &e, const C &c, const Ds &... ds) const -> void
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:512
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:491
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:492
CK_TILE_HOST_DEVICE auto operator()(E &e, const C &c, const Ds &... ds) const -> void
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:495
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1137
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1141
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1138
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:437
CK_TILE_HOST_DEVICE void operator()(Y &&y, const X &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:444
std::remove_cv_t< std::remove_reference_t< T > > raw_t
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:441
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:438
CK_TILE_HOST_DEVICE auto operator()(E &e, const C &c, const Ds &...) const -> void
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:457
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:409
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:410
CK_TILE_HOST_DEVICE void operator()(Y &y, const X &x) const
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t &y, const pk_int4_t &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:423
constexpr static const bool is_pack2_invocable
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:433
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:351
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:352
CK_TILE_HOST_DEVICE void operator()(Y &y, const X &x) const
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t &y, const pk_int4x4_t &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:363
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t &y, const pk_int4x4_t &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:357
constexpr static const bool is_pack8_invocable
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:388
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t &y, const pk_int4x4_t &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:369
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t &y, const pk_int4x4_t &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:379
const float beta_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1457
const float gamma_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1458
const float alpha_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1456
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1444
Power(float alpha=0.f, float beta=1.f, float gamma=2.f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1440
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1438
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1377
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1378
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1381
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:726
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:727
CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t &y, const ck_tile::bf16_t &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:740
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:730
CK_TILE_HOST_DEVICE void operator()(Y &y, const X &x) const
float scale_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:654
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:641
CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:643
CK_TILE_HOST_DEVICE void operator()(Y &y, const X &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:597
float scale_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:636
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:592
CK_TILE_HOST_DEVICE Scale(float scale=1.f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:594
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:998
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:999
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1002
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1014
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1015
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1018
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1265
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1269
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1266
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1169
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1170
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1173
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1423
const float alpha_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1433
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1418
SoftRelu(float alpha=1.f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1420
CK_TILE_HOST_DEVICE void operator()(Y &y, const X &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1399
const float beta_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1413
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1394
Swish(float beta=1.0f)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1396
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1105
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1109
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1106
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1233
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1234
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1237
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:696
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:700
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:697
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:525
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:526
CK_TILE_HOST_DEVICE void operator()(Y &y, const X &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:529
CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider=1)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:661
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:659
int32_t divider_
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:673
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:664
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:712
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:713
CK_TILE_HOST_DEVICE void operator()(T &y, const T &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:716
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:677
CK_TILE_HOST_DEVICE void operator()(Y &y, const X &x) const
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:681
static constexpr const char * name
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:678
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition pk_int4.hpp:21