13#ifndef CK_USE_FNUZ_FP8
14#define CK_USE_FNUZ_FP8 0
18#define CK_USE_OCP_FP8 0
21#if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
22#define CK_FP8_CVT_FAST_PATH 1
24#define CK_FP8_CVT_FAST_PATH 0
27#if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
28#define CK_OCP_FP8_CVT_FAST_PATH 1
30#define CK_OCP_FP8_CVT_FAST_PATH 0
40 __host__ __device__
explicit constexpr f8_fnuz_t() =
default;
45 __host__ __device__
explicit constexpr operator data_type()
const {
return m_data; }
53 __host__ __device__
explicit constexpr bf8_fnuz_t() =
default;
58 __host__ __device__
explicit constexpr operator data_type()
const {
return m_data; }
61static_assert(1 ==
sizeof(f8_fnuz_t));
62static_assert(1 ==
sizeof(bf8_fnuz_t));
89typedef _Float16
half2_t __attribute__((ext_vector_type(2)));
90typedef ushort
ushortx2_t __attribute__((ext_vector_type(2)));
91typedef short shortx2_t __attribute__((ext_vector_type(2)));
92typedef float float2_t __attribute__((ext_vector_type(2)));
94__host__ __device__
static inline constexpr bool fnuz_f8_is_nan(
f8_fnuz_t a)
96 return static_cast<unsigned char>(
a) == 0x80;
98__host__ __device__
static inline constexpr bool fnuz_bf8_is_nan(
bf8_fnuz_t a)
100 return static_cast<unsigned char>(
a) == 0x80;
103__host__ __device__
static inline constexpr bool ocp_f8_is_nan(
fp8_storage_t a)
105 return (
a & 0x7f) == 0x7f;
107__host__ __device__
static inline constexpr bool ocp_bf8_is_nan(
fp8_storage_t a)
109 return (
a & 0x7f) > 0x7c;
115template <
typename T,
int wm,
int we,
bool is_fnuz,
bool clip = false>
116__host__ __device__
static inline T cast_from_f8(
fp8_storage_t x)
118 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
119 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
120 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
121 static_assert(is_half || is_float || is_double,
"only half, float and double are supported");
123 constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
124 constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
126 T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
127 if constexpr(is_half)
129 const unsigned short int ihInf = 0x7C00;
130 const unsigned short int ihNegInf = 0xFC00;
131 const unsigned short int ihNaN = 0x7C01;
132 const unsigned short int ihNeg0 = 0x8000;
134 const unsigned short int ifmax = 0x7B00;
135 const unsigned short int ifmin = 0xFB00;
144 else if constexpr(is_float)
146 const unsigned int ifInf = 0x7F800000;
147 const unsigned int ifNegInf = 0xFF800000;
148 const unsigned int ifNaN = 0x7F800001;
149 const unsigned int ifNeg0 = 0x80000000;
151 const unsigned int ifmax = 0x47600000;
152 const unsigned int ifmin = 0xC7600000;
161 else if constexpr(is_double)
163 const unsigned long long ifInf = 0x7FF0000000000000ull;
164 const unsigned long long ifNegInf = 0xFFF0000000000000ull;
165 const unsigned long long ifNaN = 0x7FF0000000000001ull;
166 const unsigned long long ifNeg0 = 0x8000000000000000ull;
168 const unsigned long long ifmax = 0x40EC000000000000ull;
169 const unsigned long long ifmin = 0xC0EC000000000000ull;
184 unsigned long long sign = x >> 7;
185 unsigned long long mantissa = x & ((1 << wm) - 1);
186 int exponent = (x & 0x7F) >> wm;
187 if constexpr(is_fnuz)
200 if constexpr(we == 4)
202 if((x & 0x7F) == 0x7F)
207 else if((x & 0x7C) == 0x7C)
213 return sign ? fmin : fmax;
215 return sign ? fNegInf : fInf;
227 if constexpr(we == 5 && is_half && !is_fnuz)
233 const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
238#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
240 int sh = 1 + __clz(mantissa) - (32 - wm);
242 int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
246 mantissa &= ((1ull << wm) - 1);
248 exponent += exp_low_cutoff - 1;
249 mantissa <<= wmo - wm;
254 mantissa |= 1 << wmo;
255 mantissa >>= 1 - exponent;
259 if constexpr(
sizeof(T) == 2)
260 retval = (sign << 15) | (exponent << 10) | mantissa;
261 else if constexpr(
sizeof(T) == 4)
262 retval = (sign << 31) | (exponent << 23) | mantissa;
264 retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
269#if CK_FP8_CVT_FAST_PATH
270template <ck_fp8_
interpretation_t
interpret>
271static __host__ __device__
float cast_to_f32_from_f8(
fp8_storage_t v)
276 unsigned char i8val[4];
284 "Only FNUZ and OCP interpretations are supported");
289 return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
293 return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
297template <ck_fp8_
interpretation_t
interpret>
306 "Only FNUZ and OCP interpretations are supported");
311 return __builtin_amdgcn_cvt_pk_f32_fp8(i16val,
false);
315 return __builtin_amdgcn_cvt_pk_f32_bf8(i16val,
false);
331 static constexpr unsigned int we = 4;
332 static constexpr unsigned int wm = 3;
336 return (
data == other.
data) && (fp8_impl::ocp_f8_is_nan(
data) ==
false);
340 __host__ __device__
explicit operator float() const
342 __host__
explicit operator float() const
345#if CK_OCP_FP8_CVT_FAST_PATH
346 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
348 return fp8_impl::cast_from_f8<float, wm, we, false>(
354 __host__ __device__
explicit operator _Float16() const
356 __host__
explicit operator _Float16() const
359#if CK_OCP_FP8_CVT_FAST_PATH
360 return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
362 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
377 static constexpr unsigned int we = 5;
378 static constexpr unsigned int wm = 2;
382 return (
data == other.
data) && (fp8_impl::ocp_bf8_is_nan(
data) ==
false);
386 __host__ __device__
explicit operator float() const
389 __host__
explicit operator float() const
392#if defined(__gfx950__) || defined(__gfx12__)
393 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
395 return fp8_impl::cast_from_f8<float, wm, we, false>(
401 __host__ __device__
explicit operator _Float16() const
403 __host__
explicit operator _Float16() const
406#if defined(__gfx950__) || defined(__gfx12__)
407 return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
409 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
416__host__ __device__
static inline constexpr bool fp8_is_nan(T);
419__host__ __device__
inline constexpr bool fp8_is_nan(
f8_ocp_t a)
421 return fp8_impl::ocp_f8_is_nan(
a.data);
424__host__ __device__
inline constexpr bool fp8_is_nan(
bf8_ocp_t a)
426 return fp8_impl::ocp_bf8_is_nan(
a.data);
429__host__ __device__
inline constexpr bool fp8_is_nan(
f8_fnuz_t a)
431 return fp8_impl::fnuz_f8_is_nan(
a);
436 return fp8_impl::fnuz_bf8_is_nan(
a);
443__host__ __device__
static inline constexpr bool fp8_is_inf(T)
448__host__ __device__
inline constexpr bool fp8_is_inf(
bf8_ocp_t a)
450 return (
a.data & 0x7f) == 0x7c;
456#define __fp8_impl_assert_ocp_support(interp) \
458 if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
459 interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
461 __hip_assert(false && "type is unsupported by current target device"); \
464#define __fp8_impl_assert_fnuz_support(interp) \
466 if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
467 interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
469 __hip_assert(false && "type is unsupported by current target device"); \
473__host__ __device__
static inline void
476#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
486#if defined(__gfx950__)
489 bool stochastic_rounding =
false,
492static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
501 constexpr unsigned int i32val = 0;
504 if constexpr(saturate)
506 if((val.i32val & 0x7FFF) != 0x7FFF)
508 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
513 __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, 1.f, 0);
520 bool stochastic_rounding =
false,
527 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
528 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
533 bool stochastic_rounding =
false,
536static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
545 constexpr unsigned int i32val = 0;
548 if constexpr(saturate)
550 if((val.i32val & 0x7FFF) != 0x7FFF)
552 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
557 __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, 1.f, 0);
564 bool stochastic_rounding =
false,
571 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
572 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
577 bool stochastic_rounding =
false,
580static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
595 if constexpr(saturate)
597 if((val.i32val & 0x7FFF) != 0x7FFF)
599 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
604 __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, 1.f, 0);
611 bool stochastic_rounding =
false,
616#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
618 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
619 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
633 if constexpr(saturate)
635 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
637 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
639 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
641 val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0);
646 __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, 1.f, 0);
654 bool stochastic_rounding =
false,
657static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
672 if constexpr(saturate)
674 if((val.i32val & 0x7FFF) != 0x7FFF)
676 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
681 __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, 1.f, 0);
688 bool stochastic_rounding =
false,
693#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
695 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
696 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
710 if constexpr(saturate)
712 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
714 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
716 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
718 val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0);
723 __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, 1.f, 0);
731 bool stochastic_rounding =
false,
734static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
743 constexpr unsigned int i32val = 0;
744 val.bhalf_vec[0] = v;
746 if constexpr(saturate)
748 if((val.i32val & 0x7FFF) != 0x7FFF)
757 val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16(
758 i32val, val.bhalf_vec[0], rng, 1.f, 0);
765 bool stochastic_rounding =
false,
772 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
773 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
778 bool stochastic_rounding =
false,
781static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
790 constexpr unsigned int i32val = 0;
791 val.bhalf_vec[0] = v;
793 if constexpr(saturate)
795 if((val.i32val & 0x7FFF) != 0x7FFF)
797 val.bhalf_vec[0] = ushort(
804 val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16(
805 i32val, val.bhalf_vec[0], rng, 1.f, 0);
812 bool stochastic_rounding =
false,
819 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
820 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
825 bool stochastic_rounding =
false,
828static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
841 val.bhalf_vec[0] = v;
843 if constexpr(saturate)
845 if((val.i32val & 0x7FFF) != 0x7FFF)
855 __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
862 bool stochastic_rounding =
false,
867#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
869 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
870 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
884 if constexpr(saturate)
886 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
893 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
903 __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
911 bool stochastic_rounding =
false,
914static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
927 val.bhalf_vec[0] = v;
929 if constexpr(saturate)
931 if((val.i32val & 0x7FFF) != 0x7FFF)
933 val.bhalf_vec[0] = ushort(
941 __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
948 bool stochastic_rounding =
false,
965 if constexpr(saturate)
967 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
969 val.bhalf_vec[0] = ushort(
974 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
976 val.bhalf_vec[1] = ushort(
984 __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
990#if CK_FP8_CVT_FAST_PATH
993template <ck_fp8_
interpretation_t
interpret,
bool saturate,
bool stochastic_rounding = false>
994static __device__
fp8_storage_t cast_to_f8_from_f32(
float v,
unsigned int rng = 0)
1000 unsigned int i32val;
1001 unsigned char i8val[4];
1004 unsigned int ival = 0;
1007 if constexpr(saturate)
1011 if((val.i32val & 0x7F800000) != 0x7F800000)
1013 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
1018 if((val.i32val & 0x7F800000) != 0x7F800000)
1020 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
1025 if((val.i32val & 0x7F800000) != 0x7F800000)
1027 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
1032 if constexpr(stochastic_rounding)
1036 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
1037 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0);
1039 i8data = val.i8val[0];
1045 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false)
1046 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
1051 i8data = val.i8val[0];
1056template <ck_fp8_
interpretation_t
interpret,
bool saturate,
bool stochastic_rounding = false>
1059 if constexpr(stochastic_rounding)
1063 cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[0], rng),
1064 cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[1], rng)};
1071 unsigned int i32val;
1072 unsigned char i8val[4];
1078 unsigned int ival = 0;
1080 if constexpr(saturate)
1084 if((val0.i32val & 0x7F800000) != 0x7F800000)
1086 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0);
1088 if((val1.i32val & 0x7F800000) != 0x7F800000)
1090 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0);
1095 if((val0.i32val & 0x7F800000) != 0x7F800000)
1097 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0);
1099 if((val1.i32val & 0x7F800000) != 0x7F800000)
1101 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0);
1106 if((val0.i32val & 0x7F800000) != 0x7F800000)
1108 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0);
1110 if((val1.i32val & 0x7F800000) != 0x7F800000)
1112 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0);
1121 ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival,
false);
1125 ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival,
false);
1138template <
typename T,
int wm,
int we,
bool is_fnuz,
bool clip = false,
bool stoch = false>
1139__host__ __device__
static inline fp8_storage_t cast_to_f8(T _x,
unsigned int rng = 0)
1141 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
1142 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
1143 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
1144 static_assert(is_half || is_float || is_double,
1145 "Only half, float and double can be cast to f8");
1147 constexpr int mfmt = (
sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
1155 unsigned long long x{x_bitwise};
1157 unsigned long long head, mantissa;
1160 unsigned long long fInf, mask;
1162 if constexpr(
sizeof(T) == 8)
1164 head = x & 0xFFF0000000000000ull;
1165 mantissa = x & 0xFFFFFFFFFFFFFull;
1166 exponent = (head >> 52) & 0x7FF;
1169 fInf = 0x7FF0000000000000ull;
1170 mask = 0x7FFFFFFFFFFFFFFFull;
1172 else if constexpr(
sizeof(T) == 4)
1174 head = x & 0xFF800000;
1175 mantissa = x & 0x7FFFFF;
1176 exponent = (head >> 23) & 0xFF;
1185 mantissa = x & 0x3FF;
1186 exponent = (head >> 10) & 0x1F;
1192 unsigned int signed_inf = 0;
1193 unsigned int nan = 0;
1194 if constexpr(is_fnuz)
1196 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
1201 if constexpr(we == 4)
1203 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
1207 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
1209 nan = (sign << 7) + 0x7f;
1212 unsigned long long ifmax = 0;
1213 if constexpr(
sizeof(T) == 8)
1215 if constexpr(we == 5)
1217 ifmax = 0x40EC000000000000ull;
1221 if constexpr(is_fnuz)
1223 ifmax = 0x406E000000000000ull;
1227 ifmax = 0x407C000000000000ull;
1231 else if(
sizeof(T) == 4)
1233 if constexpr(we == 5)
1239 if constexpr(is_fnuz)
1251 if constexpr(we == 5)
1257 if constexpr(is_fnuz)
1268 if((x & fInf) == fInf)
1270 if constexpr(is_fnuz)
1273 return mantissa != 0 ? nan : signed_inf;
1276 if((x & mask) > ifmax)
1294 const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
1295 const int f8_denormal_act_exponent = 1 - f8_bias;
1300 int act_exponent, f8_exponent, exponent_diff;
1311 act_exponent = exponent - bias + 1;
1312 exponent_diff = f8_denormal_act_exponent -
1317 act_exponent = exponent - bias;
1318 if(act_exponent <= f8_denormal_act_exponent)
1325 exponent_diff = f8_denormal_act_exponent - act_exponent;
1333 mantissa += (1ull << mfmt);
1336 bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
1337 (1ull << (mfmt - wm + exponent_diff - 1));
1345 if(exponent_diff > 0)
1346 mantissa >>= exponent_diff;
1347 else if(exponent_diff == -1)
1348 mantissa <<= -exponent_diff;
1349 bool implicit_one = mantissa & (1ull << mfmt);
1353 (act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
1356 unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
1358 mantissa & (1ull << (mfmt - wm));
1360 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
1363 if(f8_exponent == 0)
1365 if((1ull << mfmt) & mantissa)
1372 if((1ull << (mfmt + 1)) & mantissa)
1379 mantissa >>= (mfmt - wm);
1382 const int max_exp = (1 << we) - 1;
1383 if(f8_exponent > max_exp)
1387 mantissa = (1 << wm) - 1;
1388 f8_exponent = max_exp;
1396 if(f8_exponent == 0 && mantissa == 0)
1397 return is_fnuz ? 0 : (sign << 7);
1398 mantissa &= (1 << wm) - 1;
1399 return (sign << 7) | (f8_exponent << wm) | mantissa;
1413 bool stochastic_rounding =
false>
1414#if CK_FP8_CVT_FAST_PATH
1415__host__ __device__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
1417 __is_interpret_supported(interp);
1419 if constexpr(stochastic_rounding)
1421#if defined(__gfx950__)
1423 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1426 constexpr int seed = 1254739;
1427#ifndef CK_CODE_GEN_RTC
1434 return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1438__host__ __device__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
1441__host__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
1445 if constexpr(stochastic_rounding)
1447#if defined(__gfx950__)
1449 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1452 constexpr int seed = 1254739;
1453#ifndef CK_CODE_GEN_RTC
1463 return cast_to_f8<float,
1468 stochastic_rounding>(f, rng);
1472 return cast_to_f8<float,
1477 stochastic_rounding>(f, rng);
1481 return cast_to_f8<float,
1486 stochastic_rounding>(f, rng);
1490 return cast_to_f8<float,
1495 stochastic_rounding>(f, rng);
1499 __hip_assert(
false &&
"FP8 type is not supported by current target device");
1516 bool stochastic_rounding =
false>
1517#if CK_FP8_CVT_FAST_PATH
1520 __is_interpret_supported(interp);
1522 if constexpr(stochastic_rounding)
1524#if defined(__gfx950__)
1526 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1529 constexpr int seed = 1254739;
1530#ifndef CK_CODE_GEN_RTC
1537 return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1547 return fp8x2_storage_t{cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[0]),
1548 cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[1])};
1563 bool stochastic_rounding =
false>
1564#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1565__host__ __device__
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
1567__host__
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
1571 __is_interpret_supported(interp);
1573 if constexpr(stochastic_rounding)
1575#if defined(__gfx950__)
1577 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1580 constexpr int seed = 1254739;
1581#ifndef CK_CODE_GEN_RTC
1588#if defined(__gfx950__)
1589 return cast_to_f8_from_f16<interp,
1591 stochastic_rounding>(x, rng);
1594 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1595 static_cast<float>(x));
1611 bool stochastic_rounding =
false>
1612#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1619 __is_interpret_supported(interp);
1621 if constexpr(stochastic_rounding)
1623#if defined(__gfx950__)
1625 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1628 constexpr int seed = 1254739;
1629#ifndef CK_CODE_GEN_RTC
1636#if defined(__gfx950__)
1637 return cast_to_f8_from_f16<interp,
1639 stochastic_rounding>(x, rng);
1642 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1643 float2_t{
static_cast<float>(x[0]),
static_cast<float>(x[1])});
1659 bool stochastic_rounding =
false>
1660#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1661__host__ __device__
static inline fp8_storage_t cvt_bhalf_t_to_fp8(
const ushort x)
1663__host__
static inline fp8_storage_t cvt_bhalf_t_to_fp8(
const ushort x)
1667 __is_interpret_supported(interp);
1669 if constexpr(stochastic_rounding)
1671#if defined(__gfx950__)
1673 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1676 constexpr int seed = 1254739;
1677#ifndef CK_CODE_GEN_RTC
1679 static_cast<float>(x));
1685#if defined(__gfx950__)
1686 return cast_to_f8_from_bf16<interp,
1688 stochastic_rounding>(x, rng);
1691 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1708 bool stochastic_rounding =
false>
1709#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1715#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1716 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1721 __is_interpret_supported(interp);
1723 if constexpr(stochastic_rounding)
1725#if defined(__gfx950__)
1727 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1730 constexpr int seed = 1254739;
1731#ifndef CK_CODE_GEN_RTC
1733 static_cast<float>(x[0]));
1736 static_cast<float>(x[0]));
1740#if defined(__gfx950__)
1741 return cast_to_f8_from_bf16<interp,
1743 stochastic_rounding>(x, rng);
1746 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1757using f8_t = f8_ocp_t;
1758using bf8_t = bf8_ocp_t;
1759#define CK_FP8_TYPE_FNUZ 0
1760#define CK_FP8_TYPE_OCP 1
1764#define CK_FP8_TYPE_FNUZ 1
1765#define CK_FP8_TYPE_OCP 0
#define __fp8_impl_assert_fnuz_support(interp)
Definition amd_ck_fp8.hpp:464
#define __fp8_impl_assert_ocp_support(interp)
Definition amd_ck_fp8.hpp:456
Definition amd_ck_fp8.hpp:86
ushort ushortx2_t
Definition amd_ck_fp8.hpp:90
short shortx2_t
Definition amd_ck_fp8.hpp:91
float float2_t
Definition amd_ck_fp8.hpp:92
fp8_storage_t fp8x2_storage_t
Definition amd_ck_fp8.hpp:88
_Float16 half2_t
Definition amd_ck_fp8.hpp:89
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
ck_fp8_interpretation_t
Describes FP8 interpretation.
Definition amd_ck_fp8.hpp:70
@ CK_E4M3_OCP
Definition amd_ck_fp8.hpp:71
@ CK_E5M2_OCP
Definition amd_ck_fp8.hpp:72
@ CK_E5M2_FNUZ
Definition amd_ck_fp8.hpp:74
@ CK_E4M3_FNUZ
Definition amd_ck_fp8.hpp:73
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bf8_fnuz_t bf8_t
Definition amd_ck_fp8.hpp:1763
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed=seed_t)
Definition random_gen.hpp:19
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
ck_saturation_t
Describes saturation behavior.
Definition amd_ck_fp8.hpp:81
@ CK_SATFINITE
Definition amd_ck_fp8.hpp:83
@ CK_NOSAT
Definition amd_ck_fp8.hpp:82
unsigned char fp8_storage_t
Definition amd_ck_fp8.hpp:64
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
_W64 unsigned int uintptr_t
Definition stdint.h:164
unsigned int uint32_t
Definition stdint.h:126
Definition amd_ck_fp8.hpp:49
data_type m_data
Definition amd_ck_fp8.hpp:51
__host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const
Definition amd_ck_fp8.hpp:54
unsigned char data_type
Definition amd_ck_fp8.hpp:50
__host__ __device__ constexpr bf8_fnuz_t(data_type in_data)
Definition amd_ck_fp8.hpp:52
__host__ __device__ constexpr bf8_fnuz_t()=default
Definition amd_ck_fp8.hpp:369
static constexpr unsigned int wm
Definition amd_ck_fp8.hpp:378
static constexpr unsigned int we
Definition amd_ck_fp8.hpp:377
fp8_storage_t data_type
Definition amd_ck_fp8.hpp:370
data_type data
Definition amd_ck_fp8.hpp:371
static constexpr ck_fp8_interpretation_t default_interpret
Definition amd_ck_fp8.hpp:374
static constexpr ck_saturation_t default_saturation
Definition amd_ck_fp8.hpp:373
__host__ __device__ constexpr bool operator==(const bf8_ocp_t &other) const
Definition amd_ck_fp8.hpp:380
Definition amd_ck_fp8.hpp:36
__host__ __device__ constexpr f8_fnuz_t()=default
data_type m_data
Definition amd_ck_fp8.hpp:38
__host__ __device__ constexpr f8_fnuz_t(data_type in_data)
Definition amd_ck_fp8.hpp:39
__host__ __device__ bool constexpr operator==(f8_fnuz_t other) const
Definition amd_ck_fp8.hpp:41
unsigned char data_type
Definition amd_ck_fp8.hpp:37
Definition amd_ck_fp8.hpp:323
fp8_storage_t data_type
Definition amd_ck_fp8.hpp:324
data_type data
Definition amd_ck_fp8.hpp:325
static constexpr unsigned int we
Definition amd_ck_fp8.hpp:331
__host__ __device__ constexpr bool operator==(const f8_ocp_t &other) const
Definition amd_ck_fp8.hpp:334
static constexpr unsigned int wm
Definition amd_ck_fp8.hpp:332
static constexpr ck_fp8_interpretation_t default_interpret
Definition amd_ck_fp8.hpp:328
static constexpr ck_saturation_t default_saturation
Definition amd_ck_fp8.hpp:327