float8.hpp Source File

float8.hpp Source File#

Composable Kernel: float8.hpp Source File
float8.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
12#include <stdint.h>
13#include <type_traits>
14
15#pragma once
16
17#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
18#define CK_TILE_FP8_CVT_DEVICE 1
19#else
20#define CK_TILE_FP8_CVT_DEVICE 0
21#endif
22
23namespace ck_tile {
24
25// fp8 rounding modes
26// use standard for rounding to nearest, the faster one
27// use stochastic for stochastic rounding, helps to avoid error accumulation
33
38{
39 E4M3_OCP = 0, // OCP FP8 E4M3
40 E5M2_OCP = 1, // OCP BF8 E5M2
41 E4M3_FNUZ = 2, // FNUZ FP8 E4M3
42 E5M2_FNUZ = 3, // FNUZ BF8 E5M2
43};
44
45/*
46 * ______________FNUZ_________________ | ______________OCP________________
47 * e4m3 e5m2 | e4m3 e5m2
48 * bias : 8 16 | 7 15
49 * inf : N/A N/A | N/A s.11111.00
50 * Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
51 * zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
52 * Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
53 * Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
54 * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
55 * Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
56 * 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
57 * Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
58 * 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05)
59 */
60
61template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
63
64template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
66
69
70#if CK_TILE_USE_CUSTOM_DATA_TYPE
71struct alignas(1) float8_e4m3_t
72{
73 static constexpr int exponent = 4;
74 static constexpr int mantissa = 3;
75#if CK_TILE_USE_OCP_FP8
76 static constexpr int bias = 7; // OCP
77#else
78 static constexpr int bias = 8; // FNUZ
79#endif
80 using raw_type = uint8_t;
81 raw_type data;
82
84 static constexpr float8_e4m3_t bit_cast(raw_type x)
85 {
86 float8_e4m3_t y;
87 y.data = x;
88 return y;
89 }
90
91 // constructor
92 constexpr float8_e4m3_t() : data() {}
93
94 // construct from float
96 explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {}
97
98 // construct from int
100 explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast<float>(x)))
101 {
102 }
103
104 // construct from unsigned int
106 explicit constexpr float8_e4m3_t(const unsigned int& x)
107 : data(float_to_fp8_raw(static_cast<float>(x)))
108 {
109 }
110
111 // cast to float
113 explicit constexpr operator float() const { return fp8_to_float_raw(data); }
114
115 // cast to int
117 explicit constexpr operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
118
119 // internal access
121 constexpr raw_type& get() { return data; }
122
124 constexpr raw_type get() const { return data; }
125};
126using fp8_t = float8_e4m3_t;
127using fp8_raw_t = typename fp8_t::raw_type;
128
129struct alignas(1) float8_e5m2_t
130{
131 static constexpr int exponent = 5;
132 static constexpr int mantissa = 2;
133#if CK_TILE_USE_OCP_FP8
134 static constexpr int bias = 15; // OCP
135#else
136 static constexpr int bias = 16; // FNUZ
137#endif
138 using raw_type = uint8_t;
139 raw_type data;
140
142 static constexpr float8_e5m2_t bit_cast(raw_type x)
143 {
144 float8_e5m2_t y;
145 y.data = x;
146 return y;
147 }
148
149 // constructor
150 constexpr float8_e5m2_t() : data() {}
151
152 // construct from float
154 explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {}
155
156 // construct from int
158 explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast<float>(x)))
159 {
160 }
161
162 // construct from unsigned int
164 explicit constexpr float8_e5m2_t(const unsigned int& x)
165 : data(float_to_bf8_raw(static_cast<float>(x)))
166 {
167 }
168
169 // cast to float
171 explicit constexpr operator float() const { return bf8_to_float_raw(data); }
172
173 // cast to int
175 explicit constexpr operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
176
177 // internal access
179 constexpr raw_type& get() { return data; }
180
182 constexpr raw_type get() const { return data; }
183};
184using bf8_t = float8_e5m2_t;
185using bf8_raw_t = typename bf8_t::raw_type;
186
187template <typename>
188struct native_t;
189
190template <>
191struct native_t<fp8_t>
192{
193 using type = _BitInt(8);
194};
195
196template <>
197struct native_t<bf8_t>
198{
199 using type = unsigned _BitInt(8);
200};
201
202#else
203
204using fp8_t = _BitInt(8);
206using bf8_t = unsigned _BitInt(8);
208#endif
209
210template <>
212{
214
215 static constexpr int exp = 4;
216 static constexpr int mant = 3;
217#if CK_TILE_USE_OCP_FP8
218 static constexpr int bias = 7;
220#else
221 static constexpr int bias = 8;
223#endif
224 static constexpr uint8_t abs_mask = 0x7F;
225 static constexpr int PackedSize = 1;
226};
227
228template <>
230{
232
233 static constexpr int exp = 5;
234 static constexpr int mant = 2;
235#if CK_TILE_USE_OCP_FP8
236 static constexpr int bias = 15;
238#else
239 static constexpr int bias = 16;
241#endif
242 static constexpr uint8_t abs_mask = 0x7F;
243 static constexpr int PackedSize = 1;
244};
245
246// below is sw fp8 conversion, not utilizing hw instruction
247namespace impl {
248
249template <typename SrcT, typename DstT, bool clip = true, bool stoch = false>
250CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
251{
252 static_assert(std::is_same<DstT, fp8_t>::value || std::is_same<DstT, bf8_t>::value,
253 "DstT type must be fp8 or bf8.");
254
255 constexpr bool is_half = std::is_same<SrcT, half_t>::value;
256 constexpr bool is_float = std::is_same<SrcT, float>::value;
257 static_assert(is_half || is_float, "Only half and float can be cast to f8");
258
259 // fp8/bf8 type exponent/mantissa layout
260 constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
261 constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
262 constexpr int DstT_bias = numeric_traits<DstT>::bias;
263 constexpr bool is_fnuz =
266
267 constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
268 constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
269 constexpr int bias = numeric_traits<SrcT>::bias;
270 constexpr unsigned int fInf = numeric_traits<SrcT>::Inf;
271 constexpr unsigned int abs_mask = numeric_traits<SrcT>::abs_mask;
272
273 using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
274 SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
275
276 unsigned int head, mantissa;
277 int exponent;
278 unsigned int sign;
279
280 head = src_bitwise & numeric_traits<SrcT>::head_mask;
281 mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
282 exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
283 sign = head >> (SrcT_exp + SrcT_mant);
284
285 unsigned int signed_inf = 0;
286 unsigned int nan = 0;
287 if constexpr(is_fnuz)
288 {
289 signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
290 nan = 0x80;
291 }
292 else
293 {
294 if constexpr(DstT_exp == 4)
295 { // e4m3
296 signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
297 }
298 else
299 { // e5m2
300 signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
301 }
302 nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
303 }
304 // Max values
305 unsigned int ifmax = 0;
306 if constexpr(is_float)
307 {
308 if constexpr(DstT_exp == 5)
309 {
310 ifmax = 0x47600000;
311 }
312 else
313 {
314 if constexpr(is_fnuz)
315 {
316 ifmax = 0x43700000;
317 }
318 else
319 {
320 ifmax = 0x43E00000;
321 }
322 }
323 }
324 else if constexpr(is_half)
325 {
326 if constexpr(DstT_exp == 5)
327 {
328 ifmax = 0x7B00;
329 }
330 else
331 {
332 if constexpr(is_fnuz)
333 {
334 ifmax = 0x5B80;
335 }
336 else
337 {
338 ifmax = 0x5F00;
339 }
340 }
341 }
342
343 // Deal with inf and NaNs
344 if((src_bitwise & fInf) == fInf)
345 {
346 return mantissa != 0 ? nan : signed_inf;
347 }
348
349 if((src_bitwise & abs_mask) > ifmax)
350 {
351 return signed_inf;
352 }
353
354 // First need to check if it is normal or denorm as there is a difference of
355 // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
356 // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
357 // to mantissa and truncate. And for RNE, no need to add rng. Then probably
358 // need to check whether there is carry and adjust exponent and mantissa again
359
360 // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
361 // bits
362 constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal
363 // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
364 // f8_exponent is the converted f8 exponent with bias encoding
365 // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
366 // the difference needs to be adjusted and mantissa shifted
367 int act_exponent, f8_exponent, exponent_diff;
368
369 if(exponent == 0)
370 { // fp32/fp16 is in denormal.
371 /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
372 mostly concern fp16 here. In this case, f8 is usually in denormal. But there
373 could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
374 exponent bias 16. It means that there are some numbers in fp16 denormal but they
375 are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
376 where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
377 (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
378 act_exponent = exponent - bias + 1;
379 exponent_diff = f8_denormal_act_exponent -
380 act_exponent; // actual exponent is exponent-bias+1 as it is denormal
381 }
382 else
383 { // fp32/fp16 is normal with implicit 1
384 act_exponent = exponent - bias;
385 if(act_exponent <= f8_denormal_act_exponent)
386 {
387 /* This is the case where fp32/fp16 is normal but it is in f8 denormal
388 range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
389 actual exponent is -7, it is actually larger due to the implicit 1,
390 Therefore it needs to be adjust to -6 and mantissa shift right by 1.
391 So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
392 exponent_diff = f8_denormal_act_exponent - act_exponent;
393 }
394 else
395 { // both fp32/fp16 and f8 are in normal range
396 exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
397 // for this case, act_exponent could be larger. Just
398 // that it does not need shift mantissa
399 }
400 mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
401 }
402 // The value is <= than min f8 denormal/2 and results in zero (the early exit also prevents
403 // an undefined behavior of bit shifts >= type width).
404 if(exponent_diff > DstT_mant + 1)
405 {
406 return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
407 }
408 bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
409 (1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
410 /* This part is a bit tricky. The judgment of whether it is a tie needs to be
411 done before we shift right as shift right could rip off some residual part and
412 make something not midpoint look like midpoint. For example, the fp16 number
413 0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
414 by 4 bits, it would look like midpoint.
415 */
416
417 if(exponent_diff > 0)
418 mantissa >>= exponent_diff;
419 else if(exponent_diff == -1)
420 mantissa <<= -exponent_diff;
421 bool implicit_one = mantissa & (1u << SrcT_mant);
422 // if there is no implicit 1, it means the f8 is denormal and need to adjust
423 // to denorm exponent
424 f8_exponent =
425 (act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1);
426
427 // Now we have the exponent and mantissa adjusted
428 unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
429 bool odd =
430 mantissa &
431 (1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1
432 mantissa +=
433 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
434
435 // Now we deal with overflow
436 if(f8_exponent == 0)
437 {
438 if((1u << SrcT_mant) & mantissa)
439 {
440 f8_exponent = 1; // denormal overflow to become normal, promote exponent
441 }
442 }
443 else
444 {
445 if((1u << (SrcT_mant + 1)) & mantissa)
446 {
447 mantissa >>= 1;
448 f8_exponent++;
449 }
450 }
451
452 mantissa >>= (SrcT_mant - DstT_mant);
453
454 // above range: quantize to maximum possible float of the same sign
455 const int max_exp = (1 << DstT_exp) - 1;
456 if(f8_exponent > max_exp)
457 {
458 if constexpr(clip)
459 {
460 mantissa = (1 << DstT_mant) - 1;
461 f8_exponent = max_exp;
462 }
463 else
464 {
465 return signed_inf;
466 }
467 }
468
469 if(f8_exponent == 0 && mantissa == 0)
470 return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
471 mantissa &= (1 << DstT_mant) - 1;
472 return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
473}
474
475template <typename SrcT, typename DstT, bool clip = true>
477{
478 static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
479 "SrcT type must be fp8 or bf8.");
480 constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
481 constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
482 constexpr uint8_t SrcT_abs_mask = numeric_traits<SrcT>::abs_mask;
483 constexpr bool is_fnuz =
486
487 constexpr bool is_half = std::is_same<DstT, half_t>::value;
488 constexpr bool is_float = std::is_same<DstT, float>::value;
489 static_assert(is_half || is_float, "DstT type must be half_t or float.");
490
491 // destination type exponent/mantissa layout
492 constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
493 constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
494
495 constexpr DstT fInf = bit_cast<DstT>(numeric_traits<DstT>::Inf);
496 constexpr DstT fNegInf = bit_cast<DstT>(numeric_traits<DstT>::NegInf);
497 constexpr DstT fNaN = bit_cast<DstT>(numeric_traits<DstT>::NaN);
498 constexpr DstT fNeg0 = bit_cast<DstT>(numeric_traits<DstT>::Neg0);
499
500 DstT fmax{0}, fmin{0};
501 // Max number in e5m2 57344
502 if constexpr(is_half)
503 {
504 fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x7B00));
505 fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xFB00));
506 }
507 else if constexpr(is_float)
508 {
509 fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x47600000));
510 fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xC7600000));
511 }
512
513 if(x == 0)
514 {
515 return 0;
516 }
517
518 unsigned int sign = x >> (SrcT_exp + SrcT_mant);
519 unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
520 int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
521 if constexpr(is_fnuz)
522 {
523 if((x & 0xff) == 0x80)
524 {
525 return fNaN;
526 }
527 }
528 else
529 {
530 if(x == SrcT(0x80))
531 {
532 return fNeg0;
533 }
534 if constexpr(SrcT_exp == 4)
535 { // e4m3
536 if((x & 0x7F) == 0x7F)
537 {
538 return fNaN;
539 }
540 }
541 else if((x & 0x7C) == 0x7C)
542 { // e5m2
543 if((x & 0x3) == 0)
544 {
545 if constexpr(clip)
546 {
547 return sign ? fmin : fmax;
548 }
549 return sign ? fNegInf : fInf;
550 }
551 return fNaN;
552 }
553 }
554
556
557 if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
558 {
559 retval = static_cast<typename numeric_traits<DstT>::bitwise_type>(x) << 8;
560 return bit_cast<DstT>(retval);
561 }
562
563 const int exp_low_cutoff =
564 (1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
565
566 // subnormal input
567 if(exponent == 0)
568 {
569 int sh = 1 + clz(mantissa) - (32 - SrcT_mant);
570 mantissa <<= sh;
571 exponent += 1 - sh;
572 mantissa &= ((1ull << SrcT_mant) - 1);
573 }
574 exponent += exp_low_cutoff - 1;
575 mantissa <<= DstT_mant - SrcT_mant;
576
577 // subnormal output (occurs when DstT is half_t, we=5, is_fnuz=true)
578 if(exponent <= 0)
579 {
580 mantissa |= 1 << DstT_mant;
581 mantissa >>= 1 - exponent;
582 exponent = 0;
583 }
584
585 retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
586
587 return bit_cast<DstT>(retval);
588}
589
590template <typename X, typename Y, bool clip, bool stoch>
595
596#if CK_TILE_FP8_CVT_DEVICE
600template <fp8_interpretation interpret, bool saturate, bool stochastic_rounding = false>
601CK_TILE_DEVICE uint8_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
602{
603 uint8_t i8data;
604 union
605 {
606 float fval;
607 unsigned int i32val;
608 unsigned char i8val[4]; // NOTE: not endian independent
609 } val;
610
611 unsigned int ival = 0;
612 val.fval = v;
613
614 if constexpr(saturate)
615 {
616 if constexpr(interpret == fp8_interpretation::E4M3_FNUZ)
617 {
618 if((val.i32val & 0x7F800000) != 0x7F800000)
619 {
620 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
621 }
622 }
623 else if constexpr(interpret == fp8_interpretation::E4M3_OCP)
624 { // OCP type
625 if((val.i32val & 0x7F800000) != 0x7F800000)
626 {
627 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
628 }
629 }
630 else
631 {
632 if((val.i32val & 0x7F800000) != 0x7F800000)
633 {
634 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
635 }
636 }
637 }
638
639 if constexpr(stochastic_rounding)
640 {
641 ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
642 (interpret == fp8_interpretation::E4M3_OCP)
643 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
644 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
645 val.i32val = ival;
646 i8data = val.i8val[0]; // little endian
647 }
648 else
649 { // RNE CVT
650 ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
651 (interpret == fp8_interpretation::E4M3_OCP)
652 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
653 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
654 val.fval,
655 ival,
656 false); // false -> WORD0
657 val.i32val = ival;
658 i8data = val.i8val[0];
659 }
660 return i8data;
661}
662#endif // CK_TILE_FP8_CVT_DEVICE
663
664} // namespace impl
665
679template <typename SrcT, typename DstT>
681{
682 constexpr bool clip = true;
683 constexpr int seed = 42;
684 uint32_t rng = prand_generator_t<SrcT, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
685#if CK_TILE_FP8_CVT_DEVICE
686 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, true>(x, rng);
687#else
690#endif
691}
692
705template <typename SrcT, typename DstT>
707{
708 constexpr bool clip = true;
709#if CK_TILE_FP8_CVT_DEVICE
710 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, false>(x, 0);
711#else
714#endif
715}
716
717template <fp8_rounding_mode rounding>
719{
720 if constexpr(rounding == fp8_rounding_mode::standard)
721 {
723 }
724 else if constexpr(rounding == fp8_rounding_mode::stochastic)
725 {
727 }
728 else
729 {
730 return fp8_raw_t{0};
731 }
732}
733
734template <fp8_rounding_mode rounding>
736{
737 if constexpr(rounding == fp8_rounding_mode::standard)
738 {
740 }
741 else if constexpr(rounding == fp8_rounding_mode::stochastic)
742 {
744 }
745 else
746 {
747 return bf8_raw_t{0};
748 }
749}
750
752{
753#if CK_TILE_FP8_CVT_DEVICE
754 float fval;
755 uint32_t i32val = static_cast<uint32_t>(x);
756 fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
757 // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
758 return fval;
759#else
761#endif
762}
763
765{
766#if CK_TILE_FP8_CVT_DEVICE
767 float fval;
768 uint32_t i32val = static_cast<uint32_t>(x);
769 fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
770 // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
771 return fval;
772#else
774#endif
775}
776
777template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
782
783template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
788
790
792
793template <class T>
794struct numeric;
795
796#if CK_TILE_USE_OCP_FP8
797template <>
798struct numeric<fp8_t>
799{
800 // minimum finite value, or minimum positive normal value
801 CK_TILE_HOST_DEVICE static constexpr fp8_t min()
802 {
803 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08)); // 0b00001000 = 2^-6
804 }
805
806 // minumum finite value
807 CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
808 {
809 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xfe)); // 0b11111110 = -448
810 }
811
812 // maximum finite value
813 CK_TILE_HOST_DEVICE static constexpr fp8_t max()
814 {
815 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7e)); // 0b01111110 = 448
816 }
817
818 // difference between 1.0 and next representable f8 value (1.125)
819 // returns fp8_t(0.125)
820 CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
821 {
822 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20)); // 0.125
823 }
824
825 // rounding error (0.0625)
826 // half of epsilon
827 CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
828 {
829 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x18)); // 0.0625
830 }
831
832 // quiet NaN
833 CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
834 {
835 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7F)); // 0b01111111
836 }
837
838 // signaling NaN
839 CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
840 {
841 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xFF)); // 0b11111111
842 }
843
844 // smallest positive subnormal value
845 CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
846 {
847 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
848 }
849
850 CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
851 {
852 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
853 }
854};
855
856template <>
857struct numeric<bf8_t>
858{
859 // minimum finite value, or minimum positive normalized value for float
860 CK_TILE_HOST_DEVICE static constexpr bf8_t min()
861 {
862 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04)); // 0b00000100 = 2^-14
863 }
864
865 // minumum finite value
866 CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
867 {
868 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xfb)); // 0b11111011 = -57344
869 }
870
871 // maximum finite value
872 CK_TILE_HOST_DEVICE static constexpr bf8_t max()
873 {
874 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7b)); // 0b01111011 = 57344
875 }
876
877 // difference between 1.0 and next representable bf8 value (1.25)
878 CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
879 {
880 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34)); // 0.25
881 }
882
883 // rounding error (0.125)
884 // half of epsilon
885 CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
886 {
887 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x30)); // 0.125
888 }
889
890 // positive infinity value
891 CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
892 {
893 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7c)); // 0b01111100
894 }
895
896 // quiet NaN
897 CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
898 {
899 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7F)); // 0b01111111
900 }
901
902 // signaling NaN
903 CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
904 {
905 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xFF));
906 }
907
908 // smallest positive subnormal value
909 CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
910 {
911 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
912 }
913
914 CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
915 {
916 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
917 }
918};
919#else
920template <>
922{
923 // minimum finite value, or minimum positive normalized value for float
924 CK_TILE_HOST_DEVICE static constexpr fp8_t min()
925 {
926 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08));
927 }
928
929 // minumum finite value
931 {
932 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xff));
933 }
934
935 // maximum finite value
936 CK_TILE_HOST_DEVICE static constexpr fp8_t max()
937 {
938 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7f));
939 }
940
941 // difference between 1.0 and next value representable by float
943 {
944 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20));
945 }
946
947 // maximum rounding error
948 // bin : 7 6543 210
949 // bits: s eeee mmm
950 // 0 0110 000 (0.5)
951 //
953 {
954 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x30));
955 }
956
957 // positive infinity value
959 {
960 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
961 }
962
963 // quiet NaN
965 {
966 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
967 }
968
969 // signaling NaN
971 {
972 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
973 }
974
975 // smallest positive subnormal value
977 {
978 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
979 }
980
981 CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
982 {
983 return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
984 }
985};
986
987template <>
989{
990 // minimum finite value, or minimum positive normalized value for float
991 CK_TILE_HOST_DEVICE static constexpr bf8_t min()
992 {
993 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04));
994 }
995
996 // minumum finite value
998 {
999 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xff));
1000 }
1001
1002 // maximum finite value
1003 CK_TILE_HOST_DEVICE static constexpr bf8_t max()
1004 {
1005 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7f));
1006 }
1007
1008 // difference between 1.0 and next value representable by float
1010 {
1011 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34));
1012 }
1013
1014 // maximum rounding error
1015 // bin : 7 65432 10
1016 // bits: s eeeee mm
1017 // 0 01110 00 (0.5)
1018 //
1020 {
1021 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x38));
1022 }
1023
1024 // positive infinity value
1026 {
1027 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1028 }
1029
1030 // quiet NaN
1032 {
1033 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1034 }
1035
1036 // signaling NaN
1038 {
1039 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1040 }
1041
1042 // smallest positive subnormal value
1044 {
1045 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
1046 }
1047
1048 CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
1049 {
1050 return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
1051 }
1052};
1053#endif
1054
1055#if CK_TILE_USE_CUSTOM_DATA_TYPE
1058#endif
1059
1060// math
1061template <typename T>
1063{
1064 static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
1065 "Only fp8_t and bf8_t are supported");
1067}
1068
1070bool isnan(const fp8_t& x)
1071{
1073
1074#if CK_TILE_USE_OCP_FP8
1075 return (xx & 0x7f) == 0x7f;
1076#else
1077 return xx == 0x80;
1078#endif
1079}
1080#if CK_TILE_USE_CUSTOM_DATA_TYPE
1082fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
1083
1085fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__ocml_exp_f32(static_cast<float>(x))); };
1086
1088fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
1089
1091fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
1092#endif
1093
1095bool isnan(const bf8_t& x)
1096{
1098
1099#if CK_TILE_USE_OCP_FP8
1100 return (xx & 0x7f) > 0x7c;
1101#else
1102 return xx == 0x80;
1103#endif
1104}
1105
1106#if CK_TILE_USE_CUSTOM_DATA_TYPE
1108bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
1109
1111bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__ocml_exp_f32(static_cast<float>(x))); };
1112
1114bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
1115
1117bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
1118#endif
1119
1120} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_FLOAT_TO_FP8_DEFAULT
Definition config.hpp:79
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/arch/amd_buffer_addressing.hpp:110
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition float8.hpp:250
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition float8.hpp:476
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
Definition float8.hpp:591
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_rtn_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with rounding to nearest ev...
Definition float8.hpp:706
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
fp8_interpretation
FP8 interpretation used in conversion algorithms.
Definition float8.hpp:38
@ E4M3_OCP
Definition float8.hpp:39
@ E5M2_OCP
Definition float8.hpp:40
@ E5M2_FNUZ
Definition float8.hpp:42
@ E4M3_FNUZ
Definition float8.hpp:41
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant< rounding >={})
Definition float8.hpp:778
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition float8.hpp:751
_BitInt(8) fp8_t
Definition float8.hpp:204
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition float8.hpp:764
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
fp8_rounding_mode
Definition float8.hpp:29
@ stochastic
Definition float8.hpp:31
@ standard
Definition float8.hpp:30
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant< rounding >={})
Definition float8.hpp:718
uint8_t fp8_raw_t
Definition float8.hpp:205
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
Definition float8.hpp:791
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_HOST int clz(uint32_t x)
Definition tile/core/numeric/math.hpp:264
@ standard
Definition bfloat16.hpp:20
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
uint8_t bf8_raw_t
Definition float8.hpp:207
@ constant
Definition arch.hpp:51
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant< rounding >={})
Definition float8.hpp:784
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition bfloat16.hpp:406
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
Definition float8.hpp:789
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant< rounding >={})
Definition float8.hpp:735
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_sr_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with stochastic rounding.
Definition float8.hpp:680
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
_W64 unsigned int uintptr_t
Definition stdint.h:164
unsigned int uint32_t
Definition stdint.h:126
unsigned char uint8_t
Definition stdint.h:124
Definition tile/core/numeric/integral_constant.hpp:13
Definition vector_type.hpp:26
remove_cvref_t< T > type
Definition vector_type.hpp:27
static CK_TILE_HOST_DEVICE constexpr bf8_t infinity()
Definition float8.hpp:1025
static CK_TILE_HOST_DEVICE constexpr bf8_t zero()
Definition float8.hpp:1048
static CK_TILE_HOST_DEVICE constexpr bf8_t quiet_NaN()
Definition float8.hpp:1031
static CK_TILE_HOST_DEVICE constexpr bf8_t max()
Definition float8.hpp:1003
static CK_TILE_HOST_DEVICE constexpr bf8_t round_error()
Definition float8.hpp:1019
static CK_TILE_HOST_DEVICE constexpr bf8_t min()
Definition float8.hpp:991
static CK_TILE_HOST_DEVICE constexpr bf8_t lowest()
Definition float8.hpp:997
static CK_TILE_HOST_DEVICE constexpr bf8_t signaling_NaN()
Definition float8.hpp:1037
static CK_TILE_HOST_DEVICE constexpr bf8_t epsilon()
Definition float8.hpp:1009
static CK_TILE_HOST_DEVICE constexpr bf8_t denorm_min()
Definition float8.hpp:1043
static CK_TILE_HOST_DEVICE constexpr fp8_t round_error()
Definition float8.hpp:952
static CK_TILE_HOST_DEVICE constexpr fp8_t signaling_NaN()
Definition float8.hpp:970
static CK_TILE_HOST_DEVICE constexpr fp8_t epsilon()
Definition float8.hpp:942
static CK_TILE_HOST_DEVICE constexpr fp8_t max()
Definition float8.hpp:936
static CK_TILE_HOST_DEVICE constexpr fp8_t infinity()
Definition float8.hpp:958
static CK_TILE_HOST_DEVICE constexpr fp8_t zero()
Definition float8.hpp:981
static CK_TILE_HOST_DEVICE constexpr fp8_t quiet_NaN()
Definition float8.hpp:964
static CK_TILE_HOST_DEVICE constexpr fp8_t min()
Definition float8.hpp:924
static CK_TILE_HOST_DEVICE constexpr fp8_t lowest()
Definition float8.hpp:930
static CK_TILE_HOST_DEVICE constexpr fp8_t denorm_min()
Definition float8.hpp:976
static constexpr uint8_t abs_mask
Definition float8.hpp:242
static constexpr int PackedSize
Definition float8.hpp:243
static constexpr fp8_interpretation f8_interpret
Definition float8.hpp:240
static constexpr int exp
Definition float8.hpp:233
static constexpr int mant
Definition float8.hpp:234
static constexpr int bias
Definition float8.hpp:239
bf8_raw_t bitwise_type
Definition float8.hpp:231
static constexpr int bias
Definition float8.hpp:221
fp8_raw_t bitwise_type
Definition float8.hpp:213
static constexpr uint8_t abs_mask
Definition float8.hpp:224
static constexpr fp8_interpretation f8_interpret
Definition float8.hpp:222
static constexpr int mant
Definition float8.hpp:216
static constexpr int exp
Definition float8.hpp:215
static constexpr int PackedSize
Definition float8.hpp:225
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/numeric/numeric.hpp:18
static CK_TILE_HOST_DEVICE constexpr T quiet_NaN()
Definition tile/core/numeric/numeric.hpp:41
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
static CK_TILE_HOST_DEVICE constexpr T lowest()
Definition tile/core/numeric/numeric.hpp:23
static CK_TILE_HOST_DEVICE constexpr T zero()
Definition tile/core/numeric/numeric.hpp:58
static CK_TILE_HOST_DEVICE constexpr T denorm_min()
Definition tile/core/numeric/numeric.hpp:53
static CK_TILE_HOST_DEVICE constexpr T round_error()
Definition tile/core/numeric/numeric.hpp:32
static CK_TILE_HOST_DEVICE constexpr T signaling_NaN()
Definition tile/core/numeric/numeric.hpp:47
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
static CK_TILE_HOST_DEVICE constexpr T min()
Definition tile/core/numeric/numeric.hpp:20
static CK_TILE_HOST_DEVICE constexpr T epsilon()
Definition tile/core/numeric/numeric.hpp:29
Definition random.hpp:17
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition tile/core/numeric/numeric.hpp:106