16#ifndef EIGEN_BFLOAT16_H
17#define EIGEN_BFLOAT16_H
19#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
21 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
22 PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
23 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
30namespace bfloat16_impl {
39EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
__bfloat16_raw raw_uint16_to_bfloat16(
unsigned short value);
40template <
bool AssumeArgumentIsNormalOrInfinityOrZero>
41EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
__bfloat16_raw float_to_bfloat16_rtne(
float ff);
48EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(
__bfloat16_raw h);
62 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
bfloat16() {}
66 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
bfloat16(
bool b)
70 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
bfloat16(T val)
73 explicit EIGEN_DEVICE_FUNC
bfloat16(
float f)
78 template<
typename RealScalar>
79 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
bfloat16(
const std::complex<RealScalar>& val)
82 EIGEN_DEVICE_FUNC
operator float()
const {
83 return bfloat16_impl::bfloat16_to_float(*
this);
90struct numeric_limits<
Eigen::bfloat16> {
91 static const bool is_specialized =
true;
92 static const bool is_signed =
true;
93 static const bool is_integer =
false;
94 static const bool is_exact =
false;
95 static const bool has_infinity =
true;
96 static const bool has_quiet_NaN =
true;
97 static const bool has_signaling_NaN =
true;
98 static const float_denorm_style has_denorm = std::denorm_absent;
99 static const bool has_denorm_loss =
false;
100 static const std::float_round_style round_style = numeric_limits<float>::round_style;
101 static const bool is_iec559 =
false;
102 static const bool is_bounded =
true;
103 static const bool is_modulo =
false;
104 static const int digits = 8;
105 static const int digits10 = 2;
106 static const int max_digits10 = 4;
107 static const int radix = 2;
108 static const int min_exponent = numeric_limits<float>::min_exponent;
109 static const int min_exponent10 = numeric_limits<float>::min_exponent10;
110 static const int max_exponent = numeric_limits<float>::max_exponent;
111 static const int max_exponent10 = numeric_limits<float>::max_exponent10;
112 static const bool traps = numeric_limits<float>::traps;
113 static const bool tinyness_before = numeric_limits<float>::tinyness_before;
115 static Eigen::bfloat16 (min)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
116 static Eigen::bfloat16 lowest() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
117 static Eigen::bfloat16 (max)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
118 static Eigen::bfloat16 epsilon() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
120 static Eigen::bfloat16 infinity() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
121 static Eigen::bfloat16 quiet_NaN() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
122 static Eigen::bfloat16 signaling_NaN() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
123 static Eigen::bfloat16 denorm_min() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
131struct numeric_limits<const
Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
133struct numeric_limits<volatile
Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
135struct numeric_limits<const volatile
Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
140namespace bfloat16_impl {
145#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
147#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
149#pragma push_macro("EIGEN_DEVICE_FUNC")
150#undef EIGEN_DEVICE_FUNC
151#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
152#define EIGEN_DEVICE_FUNC __host__
154#define EIGEN_DEVICE_FUNC __host__ __device__
161EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (
const bfloat16& a,
const bfloat16& b) {
162 return bfloat16(
float(a) +
float(b));
164EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (
const bfloat16& a,
const int& b) {
165 return bfloat16(
float(a) +
static_cast<float>(b));
167EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (
const int& a,
const bfloat16& b) {
168 return bfloat16(
static_cast<float>(a) +
float(b));
170EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (
const bfloat16& a,
const bfloat16& b) {
171 return bfloat16(
float(a) *
float(b));
173EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (
const bfloat16& a,
const bfloat16& b) {
174 return bfloat16(
float(a) -
float(b));
176EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (
const bfloat16& a,
const bfloat16& b) {
177 return bfloat16(
float(a) /
float(b));
179EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (
const bfloat16& a) {
181 result.value = a.value ^ 0x8000;
184EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a,
const bfloat16& b) {
185 a = bfloat16(
float(a) +
float(b));
188EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a,
const bfloat16& b) {
189 a = bfloat16(
float(a) *
float(b));
192EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a,
const bfloat16& b) {
193 a = bfloat16(
float(a) -
float(b));
196EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a,
const bfloat16& b) {
197 a = bfloat16(
float(a) /
float(b));
200EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
204EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
208EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a,
int) {
209 bfloat16 original_value = a;
211 return original_value;
213EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a,
int) {
214 bfloat16 original_value = a;
216 return original_value;
218EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator == (
const bfloat16& a,
const bfloat16& b) {
219 return numext::equal_strict(
float(a),
float(b));
221EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator != (
const bfloat16& a,
const bfloat16& b) {
222 return numext::not_equal_strict(
float(a),
float(b));
224EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator < (
const bfloat16& a,
const bfloat16& b) {
225 return float(a) < float(b);
227EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator <= (
const bfloat16& a,
const bfloat16& b) {
228 return float(a) <= float(b);
230EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator > (
const bfloat16& a,
const bfloat16& b) {
231 return float(a) > float(b);
233EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator >= (
const bfloat16& a,
const bfloat16& b) {
234 return float(a) >= float(b);
237#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
238#pragma pop_macro("EIGEN_DEVICE_FUNC")
244EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (
const bfloat16& a,
Index b) {
245 return bfloat16(
static_cast<float>(a) /
static_cast<float>(b));
248EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(
const float v) {
249 __bfloat16_raw output;
250 if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
251 output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
254 output.value =
static_cast<numext::uint16_t
>(numext::bit_cast<numext::uint32_t>(v) >> 16);
258EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
259 return __bfloat16_raw(value);
262EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(
const __bfloat16_raw& bf) {
269EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(
float ff) {
270#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
273 __bfloat16_raw output;
275 if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
281 output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
432 output = float_to_bfloat16_rtne<true>(ff);
443EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(
float ff) {
444#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
447 numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
448 __bfloat16_raw output;
451 numext::uint32_t lsb = (input >> 16) & 1;
452 numext::uint32_t rounding_bias = 0x7fff + lsb;
453 input += rounding_bias;
454 output.value =
static_cast<numext::uint16_t
>(input >> 16);
459EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(__bfloat16_raw h) {
460 return numext::bit_cast<float>(
static_cast<numext::uint32_t
>(h.value) << 16);
464EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(
const bfloat16& a) {
465 EIGEN_USING_STD(isinf);
466 return (isinf)(float(a));
468EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(
const bfloat16& a) {
469 EIGEN_USING_STD(isnan);
470 return (isnan)(float(a));
472EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(
const bfloat16& a) {
473 return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
476EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(
const bfloat16& a) {
478 result.value = a.
value & 0x7FFF;
481EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(
const bfloat16& a) {
482 return bfloat16(::expf(
float(a)));
484EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(
const bfloat16& a) {
485 return bfloat16(numext::expm1(
float(a)));
487EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(
const bfloat16& a) {
488 return bfloat16(::logf(
float(a)));
490EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(
const bfloat16& a) {
491 return bfloat16(numext::log1p(
float(a)));
493EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(
const bfloat16& a) {
494 return bfloat16(::log10f(
float(a)));
496EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(
const bfloat16& a) {
497 return bfloat16(
static_cast<float>(EIGEN_LOG2E) * ::logf(
float(a)));
499EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(
const bfloat16& a) {
500 return bfloat16(::sqrtf(
float(a)));
502EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(
const bfloat16& a,
const bfloat16& b) {
503 return bfloat16(::powf(
float(a),
float(b)));
505EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(
const bfloat16& a) {
506 return bfloat16(::sinf(
float(a)));
508EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(
const bfloat16& a) {
509 return bfloat16(::cosf(
float(a)));
511EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(
const bfloat16& a) {
512 return bfloat16(::tanf(
float(a)));
514EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(
const bfloat16& a) {
515 return bfloat16(::asinf(
float(a)));
517EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(
const bfloat16& a) {
518 return bfloat16(::acosf(
float(a)));
520EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(
const bfloat16& a) {
521 return bfloat16(::atanf(
float(a)));
523EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(
const bfloat16& a) {
524 return bfloat16(::sinhf(
float(a)));
526EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(
const bfloat16& a) {
527 return bfloat16(::coshf(
float(a)));
529EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(
const bfloat16& a) {
530 return bfloat16(::tanhf(
float(a)));
532#if EIGEN_HAS_CXX11_MATH
533EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(
const bfloat16& a) {
534 return bfloat16(::asinhf(
float(a)));
536EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(
const bfloat16& a) {
537 return bfloat16(::acoshf(
float(a)));
539EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(
const bfloat16& a) {
540 return bfloat16(::atanhf(
float(a)));
543EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(
const bfloat16& a) {
544 return bfloat16(::floorf(
float(a)));
546EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(
const bfloat16& a) {
547 return bfloat16(::ceilf(
float(a)));
549EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(
const bfloat16& a) {
550 return bfloat16(::rintf(
float(a)));
552EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(
const bfloat16& a) {
553 return bfloat16(::roundf(
float(a)));
555EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(
const bfloat16& a,
const bfloat16& b) {
556 return bfloat16(::fmodf(
float(a),
float(b)));
559EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(
const bfloat16& a,
const bfloat16& b) {
560 const float f1 =
static_cast<float>(a);
561 const float f2 =
static_cast<float>(b);
562 return f2 < f1 ? b : a;
564EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(
const bfloat16& a,
const bfloat16& b) {
565 const float f1 =
static_cast<float>(a);
566 const float f2 =
static_cast<float>(b);
567 return f1 < f2 ? b : a;
570EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(
const bfloat16& a,
const bfloat16& b) {
571 const float f1 =
static_cast<float>(a);
572 const float f2 =
static_cast<float>(b);
573 return bfloat16(::fminf(f1, f2));
575EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(
const bfloat16& a,
const bfloat16& b) {
576 const float f1 =
static_cast<float>(a);
577 const float f2 =
static_cast<float>(b);
578 return bfloat16(::fmaxf(f1, f2));
582EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os,
const bfloat16& v) {
583 os << static_cast<float>(v);
616 RequireInitialization =
false
619 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 epsilon() {
620 return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
622 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 dummy_precision() {
623 return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D);
626 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 highest() {
627 return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
629 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 lowest() {
630 return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
632 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 infinity() {
633 return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
635 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 quiet_NaN() {
636 return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
646EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
648 return (bfloat16_impl::isnan)(h);
652EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
654 return (bfloat16_impl::isinf)(h);
658EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
660 return (bfloat16_impl::isfinite)(h);
664EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(
const uint16_t& src) {
665 return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
669EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(
const Eigen::bfloat16& src) {
670 return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
676#if EIGEN_HAS_STD_HASH
679struct hash<
Eigen::bfloat16> {
680 EIGEN_STRONG_INLINE std::size_t operator()(
const Eigen::bfloat16& a)
const {
681 return static_cast<std::size_t
>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
EIGEN_DEVICE_FUNC CoeffReturnType value() const
Definition DenseBase.h:526
Base class for all dense matrices, vectors, and expressions.
Definition MatrixBase.h:50
Namespace containing all symbols from the Eigen library.
Definition LDLT.h:16
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74
Definition NumTraits.h:156
Holds information about the various numeric (i.e.
Definition NumTraits.h:236
Definition MathFunctions.h:806