Medial Code Documentation
Loading...
Searching...
No Matches
BFloat16.h
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef EIGEN_BFLOAT16_H
17#define EIGEN_BFLOAT16_H
18
19#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
20 template <> \
21 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
22 PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
23 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
24 }
25
26namespace Eigen {
27
28struct bfloat16;
29
30namespace bfloat16_impl {
31
32// Make our own __bfloat16_raw definition.
34 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
35 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
36 unsigned short value;
37};
38
39EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
40template <bool AssumeArgumentIsNormalOrInfinityOrZero>
41EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
42// Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
43// > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
44template <>
45EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
46template <>
47EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
48EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
49
51 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
52 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
53};
54
55} // namespace bfloat16_impl
56
57// Class definition.
59
61
62 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
63
64 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
65
66 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
67 : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
68
69 template<class T>
70 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
71 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
72
73 explicit EIGEN_DEVICE_FUNC bfloat16(float f)
74 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
75
76 // Following the convention of numpy, converting between complex and
77 // float will lead to loss of imag value.
78 template<typename RealScalar>
79 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
80 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
81
82 EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
83 return bfloat16_impl::bfloat16_to_float(*this);
84 }
85};
86} // namespace Eigen
87
88namespace std {
89template<>
90struct numeric_limits<Eigen::bfloat16> {
91 static const bool is_specialized = true;
92 static const bool is_signed = true;
93 static const bool is_integer = false;
94 static const bool is_exact = false;
95 static const bool has_infinity = true;
96 static const bool has_quiet_NaN = true;
97 static const bool has_signaling_NaN = true;
98 static const float_denorm_style has_denorm = std::denorm_absent;
99 static const bool has_denorm_loss = false;
100 static const std::float_round_style round_style = numeric_limits<float>::round_style;
101 static const bool is_iec559 = false;
102 static const bool is_bounded = true;
103 static const bool is_modulo = false;
104 static const int digits = 8;
105 static const int digits10 = 2;
106 static const int max_digits10 = 4;
107 static const int radix = 2;
108 static const int min_exponent = numeric_limits<float>::min_exponent;
109 static const int min_exponent10 = numeric_limits<float>::min_exponent10;
110 static const int max_exponent = numeric_limits<float>::max_exponent;
111 static const int max_exponent10 = numeric_limits<float>::max_exponent10;
112 static const bool traps = numeric_limits<float>::traps;
113 static const bool tinyness_before = numeric_limits<float>::tinyness_before;
114
115 static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
116 static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
117 static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
118 static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
119 static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
120 static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
121 static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
122 static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
123 static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
124};
125
126// If std::numeric_limits<T> is specialized, should also specialize
127// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
128// std::numeric_limits<const volatile T>
129// https://stackoverflow.com/a/16519653/
130template<>
131struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
132template<>
133struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
134template<>
135struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
136} // namespace std
137
138namespace Eigen {
139
140namespace bfloat16_impl {
141
142// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
143// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
144// of the functions, while the latter can only deal with one of them.
145#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
146
147#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
148// We need to provide emulated *host-side* BF16 operators for clang.
149#pragma push_macro("EIGEN_DEVICE_FUNC")
150#undef EIGEN_DEVICE_FUNC
151#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
152#define EIGEN_DEVICE_FUNC __host__
153#else // both host and device need emulated ops.
154#define EIGEN_DEVICE_FUNC __host__ __device__
155#endif
156#endif
157
158// Definitions for CPUs, mostly working through conversion
159// to/from fp32.
160
161EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
162 return bfloat16(float(a) + float(b));
163}
164EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
165 return bfloat16(float(a) + static_cast<float>(b));
166}
167EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
168 return bfloat16(static_cast<float>(a) + float(b));
169}
170EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
171 return bfloat16(float(a) * float(b));
172}
173EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
174 return bfloat16(float(a) - float(b));
175}
176EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
177 return bfloat16(float(a) / float(b));
178}
179EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
180 bfloat16 result;
181 result.value = a.value ^ 0x8000;
182 return result;
183}
184EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
185 a = bfloat16(float(a) + float(b));
186 return a;
187}
188EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
189 a = bfloat16(float(a) * float(b));
190 return a;
191}
192EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
193 a = bfloat16(float(a) - float(b));
194 return a;
195}
196EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
197 a = bfloat16(float(a) / float(b));
198 return a;
199}
200EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
201 a += bfloat16(1);
202 return a;
203}
204EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
205 a -= bfloat16(1);
206 return a;
207}
208EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
209 bfloat16 original_value = a;
210 ++a;
211 return original_value;
212}
213EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
214 bfloat16 original_value = a;
215 --a;
216 return original_value;
217}
218EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
219 return numext::equal_strict(float(a),float(b));
220}
221EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
222 return numext::not_equal_strict(float(a), float(b));
223}
224EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
225 return float(a) < float(b);
226}
227EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
228 return float(a) <= float(b);
229}
230EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
231 return float(a) > float(b);
232}
233EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
234 return float(a) >= float(b);
235}
236
237#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
238#pragma pop_macro("EIGEN_DEVICE_FUNC")
239#endif
240#endif // Emulate support for bfloat16 floats
241
242// Division by an index. Do it in full float precision to avoid accuracy
243// issues in converting the denominator to bfloat16.
244EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
245 return bfloat16(static_cast<float>(a) / static_cast<float>(b));
246}
247
248EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
249 __bfloat16_raw output;
250 if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
251 output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
252 return output;
253 }
254 output.value = static_cast<numext::uint16_t>(numext::bit_cast<numext::uint32_t>(v) >> 16);
255 return output;
256}
257
258EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
259 return __bfloat16_raw(value);
260}
261
262EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) {
263 return bf.value;
264}
265
266// float_to_bfloat16_rtne template specialization that does not make any
267// assumption about the value of its function argument (ff).
268template <>
269EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
270#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
271 // Nothing to do here
272#else
273 __bfloat16_raw output;
274
275 if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
276 // If the value is a NaN, squash it to a qNaN with msb of fraction set,
277 // this makes sure after truncation we don't end up with an inf.
278 //
279 // qNaN magic: All exponent bits set + most significant bit of fraction
280 // set.
281 output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
282 } else {
283 // Fast rounding algorithm that rounds a half value to nearest even. This
284 // reduces expected error when we convert a large number of floats. Here
285 // is how it works:
286 //
287 // Definitions:
288 // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
289 // with the following tags:
290 //
291 // Sign | Exp (8 bits) | Frac (23 bits)
292 // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
293 //
294 // S: Sign bit.
295 // E: Exponent bits.
296 // F: First 6 bits of fraction.
297 // L: Least significant bit of resulting bfloat16 if we truncate away the
298 // rest of the float32. This is also the 7th bit of fraction
299 // R: Rounding bit, 8th bit of fraction.
300 // T: Sticky bits, rest of fraction, 15 bits.
301 //
302 // To round half to nearest even, there are 3 cases where we want to round
303 // down (simply truncate the result of the bits away, which consists of
304 // rounding bit and sticky bits) and two cases where we want to round up
305 // (truncate then add one to the result).
306 //
307 // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
308 // 1s) as the rounding bias, adds the rounding bias to the input, then
309 // truncates the last 16 bits away.
310 //
311 // To understand how it works, we can analyze this algorithm case by case:
312 //
313 // 1. L = 0, R = 0:
314 // Expect: round down, this is less than half value.
315 //
316 // Algorithm:
317 // - Rounding bias: 0x7fff + 0 = 0x7fff
318 // - Adding rounding bias to input may create any carry, depending on
319 // whether there is any value set to 1 in T bits.
320 // - R may be set to 1 if there is a carry.
321 // - L remains 0.
322 // - Note that this case also handles Inf and -Inf, where all fraction
323 // bits, including L, R and Ts are all 0. The output remains Inf after
324 // this algorithm.
325 //
326 // 2. L = 1, R = 0:
327 // Expect: round down, this is less than half value.
328 //
329 // Algorithm:
330 // - Rounding bias: 0x7fff + 1 = 0x8000
331 // - Adding rounding bias to input doesn't change sticky bits but
332 // adds 1 to rounding bit.
333 // - L remains 1.
334 //
335 // 3. L = 0, R = 1, all of T are 0:
336 // Expect: round down, this is exactly at half, the result is already
337 // even (L=0).
338 //
339 // Algorithm:
340 // - Rounding bias: 0x7fff + 0 = 0x7fff
341 // - Adding rounding bias to input sets all sticky bits to 1, but
342 // doesn't create a carry.
343 // - R remains 1.
344 // - L remains 0.
345 //
346 // 4. L = 1, R = 1:
347 // Expect: round up, this is exactly at half, the result needs to be
348 // round to the next even number.
349 //
350 // Algorithm:
351 // - Rounding bias: 0x7fff + 1 = 0x8000
352 // - Adding rounding bias to input doesn't change sticky bits, but
353 // creates a carry from rounding bit.
354 // - The carry sets L to 0, creates another carry bit and propagate
355 // forward to F bits.
356 // - If all the F bits are 1, a carry then propagates to the exponent
357 // bits, which then creates the minimum value with the next exponent
358 // value. Note that we won't have the case where exponents are all 1,
359 // since that's either a NaN (handled in the other if condition) or inf
360 // (handled in case 1).
361 //
362 // 5. L = 0, R = 1, any of T is 1:
363 // Expect: round up, this is greater than half.
364 //
365 // Algorithm:
366 // - Rounding bias: 0x7fff + 0 = 0x7fff
367 // - Adding rounding bias to input creates a carry from sticky bits,
368 // sets rounding bit to 0, then create another carry.
369 // - The second carry sets L to 1.
370 //
371 // Examples:
372 //
373 // Exact half value that is already even:
374 // Input:
375 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
376 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
377 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
378 //
379 // This falls into case 3. We truncate the rest of 16 bits and no
380 // carry is created into F and L:
381 //
382 // Output:
383 // Sign | Exp (8 bit) | Frac (first 7 bit)
384 // S E E E E E E E E F F F F F F L
385 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
386 //
387 // Exact half value, round to next even number:
388 // Input:
389 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
390 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
391 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
392 //
393 // This falls into case 4. We create a carry from R and T,
394 // which then propagates into L and F:
395 //
396 // Output:
397 // Sign | Exp (8 bit) | Frac (first 7 bit)
398 // S E E E E E E E E F F F F F F L
399 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
400 //
401 //
402 // Max denormal value round to min normal value:
403 // Input:
404 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
405 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
406 // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
407 //
408 // This falls into case 4. We create a carry from R and T,
409 // propagate into L and F, which then propagates into exponent
410 // bits:
411 //
412 // Output:
413 // Sign | Exp (8 bit) | Frac (first 7 bit)
414 // S E E E E E E E E F F F F F F L
415 // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
416 //
417 // Max normal value round to Inf:
418 // Input:
419 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
420 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
421 // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
422 //
423 // This falls into case 4. We create a carry from R and T,
424 // propagate into L and F, which then propagates into exponent
425 // bits:
426 //
427 // Sign | Exp (8 bit) | Frac (first 7 bit)
428 // S E E E E E E E E F F F F F F L
429 // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
430
431 // At this point, ff must be either a normal float, or +/-infinity.
432 output = float_to_bfloat16_rtne<true>(ff);
433 }
434 return output;
435#endif
436}
437
438// float_to_bfloat16_rtne template specialization that assumes that its function
439// argument (ff) is either a normal floating point number, or +/-infinity, or
440// zero. Used to improve the runtime performance of conversion from an integer
441// type to bfloat16.
442template <>
443EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
444#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
445 // Nothing to do here
446#else
447 numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
448 __bfloat16_raw output;
449
450 // Least significant bit of resulting bfloat.
451 numext::uint32_t lsb = (input >> 16) & 1;
452 numext::uint32_t rounding_bias = 0x7fff + lsb;
453 input += rounding_bias;
454 output.value = static_cast<numext::uint16_t>(input >> 16);
455 return output;
456#endif
457}
458
459EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
460 return numext::bit_cast<float>(static_cast<numext::uint32_t>(h.value) << 16);
461}
462// --- standard functions ---
463
464EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
465 EIGEN_USING_STD(isinf);
466 return (isinf)(float(a));
467}
468EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
469 EIGEN_USING_STD(isnan);
470 return (isnan)(float(a));
471}
472EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
473 return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
474}
475
476EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
477 bfloat16 result;
478 result.value = a.value & 0x7FFF;
479 return result;
480}
481EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
482 return bfloat16(::expf(float(a)));
483}
484EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
485 return bfloat16(numext::expm1(float(a)));
486}
487EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
488 return bfloat16(::logf(float(a)));
489}
490EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
491 return bfloat16(numext::log1p(float(a)));
492}
493EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
494 return bfloat16(::log10f(float(a)));
495}
496EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
497 return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
498}
499EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
500 return bfloat16(::sqrtf(float(a)));
501}
502EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
503 return bfloat16(::powf(float(a), float(b)));
504}
505EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
506 return bfloat16(::sinf(float(a)));
507}
508EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
509 return bfloat16(::cosf(float(a)));
510}
511EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
512 return bfloat16(::tanf(float(a)));
513}
514EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
515 return bfloat16(::asinf(float(a)));
516}
517EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
518 return bfloat16(::acosf(float(a)));
519}
520EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
521 return bfloat16(::atanf(float(a)));
522}
523EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
524 return bfloat16(::sinhf(float(a)));
525}
526EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
527 return bfloat16(::coshf(float(a)));
528}
529EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
530 return bfloat16(::tanhf(float(a)));
531}
532#if EIGEN_HAS_CXX11_MATH
533EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
534 return bfloat16(::asinhf(float(a)));
535}
536EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
537 return bfloat16(::acoshf(float(a)));
538}
539EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
540 return bfloat16(::atanhf(float(a)));
541}
542#endif
543EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
544 return bfloat16(::floorf(float(a)));
545}
546EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
547 return bfloat16(::ceilf(float(a)));
548}
549EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) {
550 return bfloat16(::rintf(float(a)));
551}
552EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) {
553 return bfloat16(::roundf(float(a)));
554}
555EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
556 return bfloat16(::fmodf(float(a), float(b)));
557}
558
559EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
560 const float f1 = static_cast<float>(a);
561 const float f2 = static_cast<float>(b);
562 return f2 < f1 ? b : a;
563}
564EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
565 const float f1 = static_cast<float>(a);
566 const float f2 = static_cast<float>(b);
567 return f1 < f2 ? b : a;
568}
569
570EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
571 const float f1 = static_cast<float>(a);
572 const float f2 = static_cast<float>(b);
573 return bfloat16(::fminf(f1, f2));
574}
575EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
576 const float f1 = static_cast<float>(a);
577 const float f2 = static_cast<float>(b);
578 return bfloat16(::fmaxf(f1, f2));
579}
580
581#ifndef EIGEN_NO_IO
582EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
583 os << static_cast<float>(v);
584 return os;
585}
586#endif
587
588} // namespace bfloat16_impl
589
590namespace internal {
591
592template<>
594{
595 static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
596 {
597 return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
598 }
599 static inline bfloat16 run()
600 {
601 return run(bfloat16(-1.f), bfloat16(1.f));
602 }
603};
604
605template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
606
607} // namespace internal
608
609template<> struct NumTraits<Eigen::bfloat16>
610 : GenericNumTraits<Eigen::bfloat16>
611{
612 enum {
613 IsSigned = true,
614 IsInteger = false,
615 IsComplex = false,
616 RequireInitialization = false
617 };
618
619 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
620 return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
621 }
622 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
623 return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f);
624
625 }
626 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
627 return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
628 }
629 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
630 return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
631 }
632 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
633 return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
634 }
635 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
636 return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
637 }
638};
639
640} // namespace Eigen
641
642namespace Eigen {
643namespace numext {
644
645template<>
646EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
647bool (isnan)(const Eigen::bfloat16& h) {
648 return (bfloat16_impl::isnan)(h);
649}
650
651template<>
652EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
653bool (isinf)(const Eigen::bfloat16& h) {
654 return (bfloat16_impl::isinf)(h);
655}
656
657template<>
658EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
659bool (isfinite)(const Eigen::bfloat16& h) {
660 return (bfloat16_impl::isfinite)(h);
661}
662
663template <>
664EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
665 return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
666}
667
668template <>
669EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
670 return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
671}
672
673} // namespace numext
674} // namespace Eigen
675
676#if EIGEN_HAS_STD_HASH
677namespace std {
678template <>
679struct hash<Eigen::bfloat16> {
680 EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
681 return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
682 }
683};
684} // namespace std
685#endif
686
687
688#endif // EIGEN_BFLOAT16_H
EIGEN_DEVICE_FUNC CoeffReturnType value() const
Definition DenseBase.h:526
Base class for all dense matrices, vectors, and expressions.
Definition MatrixBase.h:50
Namespace containing all symbols from the Eigen library.
Definition LDLT.h:16
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74
Definition BFloat16.h:88
Definition NumTraits.h:156
Holds information about the various numeric (i.e.
Definition NumTraits.h:236
Definition BFloat16.h:33
Definition BFloat16.h:50
Definition BFloat16.h:58
Definition Meta.h:133
Definition Meta.h:162
Definition MathFunctions.h:806