Medial Code Documentation
Loading...
Searching...
No Matches
GenericPacketMathFunctions.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2007 Julien Pommier
5// Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com)
6// Copyright (C) 2009-2019 Gael Guennebaud <gael.guennebaud@inria.fr>
7//
8// This Source Code Form is subject to the terms of the Mozilla
9// Public License v. 2.0. If a copy of the MPL was not distributed
10// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11
12/* The exp and log functions of this file initially come from
13 * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
14 */
15
16#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
17#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
18
19namespace Eigen {
20namespace internal {
21
22// Creates a Scalar integer type with same bit-width.
23template<typename T> struct make_integer;
24template<> struct make_integer<float> { typedef numext::int32_t type; };
25template<> struct make_integer<double> { typedef numext::int64_t type; };
26template<> struct make_integer<half> { typedef numext::int16_t type; };
27template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };
28
29template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
30Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
31 typedef typename unpacket_traits<Packet>::type Scalar;
32 typedef typename unpacket_traits<Packet>::integer_packet PacketI;
33 enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1};
35}
36
37// Safely applies frexp, correctly handles denormals.
38// Assumes IEEE floating point format.
39template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
40Packet pfrexp_generic(const Packet& a, Packet& exponent) {
41 typedef typename unpacket_traits<Packet>::type Scalar;
42 typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
43 enum {
44 TotalBits = sizeof(Scalar) * CHAR_BIT,
45 MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
46 ExponentBits = int(TotalBits) - int(MantissaBits) - 1
47 };
48
49 EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
50 ~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
51 const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
52 const Packet half = pset1<Packet>(Scalar(0.5));
53 const Packet zero = pzero(a);
54 const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
55
56 // To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
57 const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
58 EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
59 // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
60 const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
61 const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
62 const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
63
64 // Determine exponent offset: -126 if normal, -126-24 if denormal
65 const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
66 Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
67 const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
68 exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
69
70 // Determine exponent and mantissa from normalized_a.
71 exponent = pfrexp_generic_get_biased_exponent(normalized_a);
72 // Zero, Inf and NaN return 'a' unmodified, exponent is zero
73 // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
74 const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255
75 const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
76 const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
77 const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
78 exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
79 return m;
80}
81
82// Safely applies ldexp, correctly handles overflows, underflows and denormals.
83// Assumes IEEE floating point format.
84template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
85Packet pldexp_generic(const Packet& a, const Packet& exponent) {
86 // We want to return a * 2^exponent, allowing for all possible integer
87 // exponents without overflowing or underflowing in intermediate
88 // computations.
89 //
90 // Since 'a' and the output can be denormal, the maximum range of 'exponent'
91 // to consider for a float is:
92 // -255-23 -> 255+23
93 // Below -278 any finite float 'a' will become zero, and above +278 any
94 // finite float will become inf, including when 'a' is the smallest possible
95 // denormal.
96 //
97 // Unfortunately, 2^(278) cannot be represented using either one or two
98 // finite normal floats, so we must split the scale factor into at least
99 // three parts. It turns out to be faster to split 'exponent' into four
100 // factors, since [exponent>>2] is much faster to compute that [exponent/3].
101 //
102 // Set e = min(max(exponent, -278), 278);
103 // b = floor(e/4);
104 // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
105 //
106 // This will avoid any intermediate overflows and correctly handle 0, inf,
107 // NaN cases.
108 typedef typename unpacket_traits<Packet>::integer_packet PacketI;
109 typedef typename unpacket_traits<Packet>::type Scalar;
110 typedef typename unpacket_traits<PacketI>::type ScalarI;
111 enum {
112 TotalBits = sizeof(Scalar) * CHAR_BIT,
113 MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
114 ExponentBits = int(TotalBits) - int(MantissaBits) - 1
115 };
116
117 const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278
118 const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
119 const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
120 PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
121 Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
122 Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
123 b = psub(psub(psub(e, b), b), b); // e - 3b
124 c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
125 out = pmul(out, c);
126 return out;
127}
128
129// Explicitly multiplies
130// a * (2^e)
131// clamping e to the range
132// [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()]
133//
134// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow
135// if 2^e doesn't fit into a normal floating-point Scalar.
136//
137// Assumes IEEE floating point format
138template<typename Packet>
140 typedef typename unpacket_traits<Packet>::integer_packet PacketI;
141 typedef typename unpacket_traits<Packet>::type Scalar;
143 enum {
144 TotalBits = sizeof(Scalar) * CHAR_BIT,
145 MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
146 ExponentBits = int(TotalBits) - int(MantissaBits) - 1
147 };
148
149 static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
150 Packet run(const Packet& a, const Packet& exponent) {
151 const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127
152 const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255
153 // restrict biased exponent between 0 and 255 for float.
154 const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
155 // return a * (2^e)
156 return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e)));
157 }
158};
159
160// Natural or base 2 logarithm.
161// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
162// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
163// be easily approximated by a polynomial centered on m=1 for stability.
164// TODO(gonnet): Further reduce the interval allowing for lower-degree
165// polynomial interpolants -> ... -> profit!
166template <typename Packet, bool base2>
167EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
168EIGEN_UNUSED
169Packet plog_impl_float(const Packet _x)
170{
171 Packet x = _x;
172
173 const Packet cst_1 = pset1<Packet>(1.0f);
174 const Packet cst_neg_half = pset1<Packet>(-0.5f);
175 // The smallest non denormalized float number.
176 const Packet cst_min_norm_pos = pset1frombits<Packet>( 0x00800000u);
177 const Packet cst_minus_inf = pset1frombits<Packet>( 0xff800000u);
178 const Packet cst_pos_inf = pset1frombits<Packet>( 0x7f800000u);
179
180 // Polynomial coefficients.
181 const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f);
182 const Packet cst_cephes_log_p0 = pset1<Packet>(7.0376836292E-2f);
183 const Packet cst_cephes_log_p1 = pset1<Packet>(-1.1514610310E-1f);
184 const Packet cst_cephes_log_p2 = pset1<Packet>(1.1676998740E-1f);
185 const Packet cst_cephes_log_p3 = pset1<Packet>(-1.2420140846E-1f);
186 const Packet cst_cephes_log_p4 = pset1<Packet>(+1.4249322787E-1f);
187 const Packet cst_cephes_log_p5 = pset1<Packet>(-1.6668057665E-1f);
188 const Packet cst_cephes_log_p6 = pset1<Packet>(+2.0000714765E-1f);
189 const Packet cst_cephes_log_p7 = pset1<Packet>(-2.4999993993E-1f);
190 const Packet cst_cephes_log_p8 = pset1<Packet>(+3.3333331174E-1f);
191
192 // Truncate input values to the minimum positive normal.
193 x = pmax(x, cst_min_norm_pos);
194
195 Packet e;
196 // extract significant in the range [0.5,1) and exponent
197 x = pfrexp(x,e);
198
199 // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
200 // and shift by -1. The values are then centered around 0, which improves
201 // the stability of the polynomial evaluation.
202 // if( x < SQRTHF ) {
203 // e -= 1;
204 // x = x + x - 1.0;
205 // } else { x = x - 1.0; }
206 Packet mask = pcmp_lt(x, cst_cephes_SQRTHF);
207 Packet tmp = pand(x, mask);
208 x = psub(x, cst_1);
209 e = psub(e, pand(cst_1, mask));
210 x = padd(x, tmp);
211
212 Packet x2 = pmul(x, x);
213 Packet x3 = pmul(x2, x);
214
215 // Evaluate the polynomial approximant of degree 8 in three parts, probably
216 // to improve instruction-level parallelism.
217 Packet y, y1, y2;
219 y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
221 y = pmadd(y, x, cst_cephes_log_p2);
222 y1 = pmadd(y1, x, cst_cephes_log_p5);
223 y2 = pmadd(y2, x, cst_cephes_log_p8);
224 y = pmadd(y, x3, y1);
225 y = pmadd(y, x3, y2);
226 y = pmul(y, x3);
227
228 y = pmadd(cst_neg_half, x2, y);
229 x = padd(x, y);
230
231 // Add the logarithm of the exponent back to the result of the interpolation.
232 if (base2) {
233 const Packet cst_log2e = pset1<Packet>(static_cast<float>(EIGEN_LOG2E));
234 x = pmadd(x, cst_log2e, e);
235 } else {
236 const Packet cst_ln2 = pset1<Packet>(static_cast<float>(EIGEN_LN2));
237 x = pmadd(e, cst_ln2, x);
238 }
239
240 Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
241 Packet iszero_mask = pcmp_eq(_x,pzero(_x));
242 Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf);
243 // Filter out invalid inputs, i.e.:
244 // - negative arg will be NAN
245 // - 0 will be -INF
246 // - +INF will be +INF
247 return pselect(iszero_mask, cst_minus_inf,
248 por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
249}
250
251template <typename Packet>
252EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
253EIGEN_UNUSED
254Packet plog_float(const Packet _x)
255{
256 return plog_impl_float<Packet, /* base2 */ false>(_x);
257}
258
259template <typename Packet>
260EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
261EIGEN_UNUSED
262Packet plog2_float(const Packet _x)
263{
264 return plog_impl_float<Packet, /* base2 */ true>(_x);
265}
266
267/* Returns the base e (2.718...) or base 2 logarithm of x.
268 * The argument is separated into its exponent and fractional parts.
269 * The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)],
270 * is approximated by
271 *
272 * log(1+x) = x - 0.5 x**2 + x**3 P(x)/Q(x).
273 *
274 * for more detail see: http://www.netlib.org/cephes/
275 */
276template <typename Packet, bool base2>
277EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
278EIGEN_UNUSED
279Packet plog_impl_double(const Packet _x)
280{
281 Packet x = _x;
282
283 const Packet cst_1 = pset1<Packet>(1.0);
284 const Packet cst_neg_half = pset1<Packet>(-0.5);
285 // The smallest non denormalized double.
286 const Packet cst_min_norm_pos = pset1frombits<Packet>( static_cast<uint64_t>(0x0010000000000000ull));
287 const Packet cst_minus_inf = pset1frombits<Packet>( static_cast<uint64_t>(0xfff0000000000000ull));
288 const Packet cst_pos_inf = pset1frombits<Packet>( static_cast<uint64_t>(0x7ff0000000000000ull));
289
290
291 // Polynomial Coefficients for log(1+x) = x - x**2/2 + x**3 P(x)/Q(x)
292 // 1/sqrt(2) <= x < sqrt(2)
293 const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
294 const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
295 const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
296 const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
297 const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
298 const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
299 const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
300
301 const Packet cst_cephes_log_q0 = pset1<Packet>(1.0);
302 const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
303 const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
304 const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
305 const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
306 const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);
307
308 // Truncate input values to the minimum positive normal.
309 x = pmax(x, cst_min_norm_pos);
310
311 Packet e;
312 // extract significant in the range [0.5,1) and exponent
313 x = pfrexp(x,e);
314
315 // Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
316 // and shift by -1. The values are then centered around 0, which improves
317 // the stability of the polynomial evaluation.
318 // if( x < SQRTHF ) {
319 // e -= 1;
320 // x = x + x - 1.0;
321 // } else { x = x - 1.0; }
322 Packet mask = pcmp_lt(x, cst_cephes_SQRTHF);
323 Packet tmp = pand(x, mask);
324 x = psub(x, cst_1);
325 e = psub(e, pand(cst_1, mask));
326 x = padd(x, tmp);
327
328 Packet x2 = pmul(x, x);
329 Packet x3 = pmul(x2, x);
330
331 // Evaluate the polynomial approximant , probably to improve instruction-level parallelism.
332 // y = x - 0.5*x^2 + x^3 * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) );
333 Packet y, y1, y_;
334 y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
335 y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
336 y = pmadd(y, x, cst_cephes_log_p2);
337 y1 = pmadd(y1, x, cst_cephes_log_p5);
338 y_ = pmadd(y, x3, y1);
339
340 y = pmadd(cst_cephes_log_q0, x, cst_cephes_log_q1);
341 y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4);
342 y = pmadd(y, x, cst_cephes_log_q2);
343 y1 = pmadd(y1, x, cst_cephes_log_q5);
344 y = pmadd(y, x3, y1);
345
346 y_ = pmul(y_, x3);
347 y = pdiv(y_, y);
348
349 y = pmadd(cst_neg_half, x2, y);
350 x = padd(x, y);
351
352 // Add the logarithm of the exponent back to the result of the interpolation.
353 if (base2) {
354 const Packet cst_log2e = pset1<Packet>(static_cast<double>(EIGEN_LOG2E));
355 x = pmadd(x, cst_log2e, e);
356 } else {
357 const Packet cst_ln2 = pset1<Packet>(static_cast<double>(EIGEN_LN2));
358 x = pmadd(e, cst_ln2, x);
359 }
360
361 Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
362 Packet iszero_mask = pcmp_eq(_x,pzero(_x));
363 Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf);
364 // Filter out invalid inputs, i.e.:
365 // - negative arg will be NAN
366 // - 0 will be -INF
367 // - +INF will be +INF
368 return pselect(iszero_mask, cst_minus_inf,
369 por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
370}
371
372template <typename Packet>
373EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
374EIGEN_UNUSED
375Packet plog_double(const Packet _x)
376{
377 return plog_impl_double<Packet, /* base2 */ false>(_x);
378}
379
380template <typename Packet>
381EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
382EIGEN_UNUSED
383Packet plog2_double(const Packet _x)
384{
385 return plog_impl_double<Packet, /* base2 */ true>(_x);
386}
387
391template<typename Packet>
392Packet generic_plog1p(const Packet& x)
393{
394 typedef typename unpacket_traits<Packet>::type ScalarType;
395 const Packet one = pset1<Packet>(ScalarType(1));
396 Packet xp1 = padd(x, one);
397 Packet small_mask = pcmp_eq(xp1, one);
398 Packet log1 = plog(xp1);
399 Packet inf_mask = pcmp_eq(xp1, log1);
400 Packet log_large = pmul(x, pdiv(log1, psub(xp1, one)));
401 return pselect(por(small_mask, inf_mask), x, log_large);
402}
403
407template<typename Packet>
408Packet generic_expm1(const Packet& x)
409{
410 typedef typename unpacket_traits<Packet>::type ScalarType;
411 const Packet one = pset1<Packet>(ScalarType(1));
412 const Packet neg_one = pset1<Packet>(ScalarType(-1));
413 Packet u = pexp(x);
414 Packet one_mask = pcmp_eq(u, one);
415 Packet u_minus_one = psub(u, one);
416 Packet neg_one_mask = pcmp_eq(u_minus_one, neg_one);
417 Packet logu = plog(u);
418 // The following comparison is to catch the case where
419 // exp(x) = +inf. It is written in this way to avoid having
420 // to form the constant +inf, which depends on the packet
421 // type.
422 Packet pos_inf_mask = pcmp_eq(logu, u);
423 Packet expm1 = pmul(u_minus_one, pdiv(x, logu));
424 expm1 = pselect(pos_inf_mask, u, expm1);
425 return pselect(one_mask,
426 x,
427 pselect(neg_one_mask,
428 neg_one,
429 expm1));
430}
431
432
433// Exponential function. Works by writing "x = m*log(2) + r" where
434// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
435// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).
436template <typename Packet>
437EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
438EIGEN_UNUSED
439Packet pexp_float(const Packet _x)
440{
441 const Packet cst_1 = pset1<Packet>(1.0f);
442 const Packet cst_half = pset1<Packet>(0.5f);
443 const Packet cst_exp_hi = pset1<Packet>( 88.723f);
444 const Packet cst_exp_lo = pset1<Packet>(-88.723f);
445
446 const Packet cst_cephes_LOG2EF = pset1<Packet>(1.44269504088896341f);
447 const Packet cst_cephes_exp_p0 = pset1<Packet>(1.9875691500E-4f);
448 const Packet cst_cephes_exp_p1 = pset1<Packet>(1.3981999507E-3f);
449 const Packet cst_cephes_exp_p2 = pset1<Packet>(8.3334519073E-3f);
450 const Packet cst_cephes_exp_p3 = pset1<Packet>(4.1665795894E-2f);
451 const Packet cst_cephes_exp_p4 = pset1<Packet>(1.6666665459E-1f);
452 const Packet cst_cephes_exp_p5 = pset1<Packet>(5.0000001201E-1f);
453
454 // Clamp x.
455 Packet x = pmax(pmin(_x, cst_exp_hi), cst_exp_lo);
456
457 // Express exp(x) as exp(m*ln(2) + r), start by extracting
458 // m = floor(x/ln(2) + 0.5).
459 Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half));
460
461 // Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is
462 // subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating
463 // truncation errors.
464 const Packet cst_cephes_exp_C1 = pset1<Packet>(-0.693359375f);
465 const Packet cst_cephes_exp_C2 = pset1<Packet>(2.12194440e-4f);
466 Packet r = pmadd(m, cst_cephes_exp_C1, x);
467 r = pmadd(m, cst_cephes_exp_C2, r);
468
469 Packet r2 = pmul(r, r);
470 Packet r3 = pmul(r2, r);
471
472 // Evaluate the polynomial approximant,improved by instruction-level parallelism.
473 Packet y, y1, y2;
474 y = pmadd(cst_cephes_exp_p0, r, cst_cephes_exp_p1);
475 y1 = pmadd(cst_cephes_exp_p3, r, cst_cephes_exp_p4);
476 y2 = padd(r, cst_1);
477 y = pmadd(y, r, cst_cephes_exp_p2);
478 y1 = pmadd(y1, r, cst_cephes_exp_p5);
479 y = pmadd(y, r3, y1);
480 y = pmadd(y, r2, y2);
481
482 // Return 2^m * exp(r).
483 // TODO: replace pldexp with faster implementation since y in [-1, 1).
484 return pmax(pldexp(y,m), _x);
485}
486
487template <typename Packet>
488EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
489EIGEN_UNUSED
490Packet pexp_double(const Packet _x)
491{
492 Packet x = _x;
493
494 const Packet cst_1 = pset1<Packet>(1.0);
495 const Packet cst_2 = pset1<Packet>(2.0);
496 const Packet cst_half = pset1<Packet>(0.5);
497
498 const Packet cst_exp_hi = pset1<Packet>(709.784);
499 const Packet cst_exp_lo = pset1<Packet>(-709.784);
500
501 const Packet cst_cephes_LOG2EF = pset1<Packet>(1.4426950408889634073599);
502 const Packet cst_cephes_exp_p0 = pset1<Packet>(1.26177193074810590878e-4);
503 const Packet cst_cephes_exp_p1 = pset1<Packet>(3.02994407707441961300e-2);
504 const Packet cst_cephes_exp_p2 = pset1<Packet>(9.99999999999999999910e-1);
505 const Packet cst_cephes_exp_q0 = pset1<Packet>(3.00198505138664455042e-6);
506 const Packet cst_cephes_exp_q1 = pset1<Packet>(2.52448340349684104192e-3);
507 const Packet cst_cephes_exp_q2 = pset1<Packet>(2.27265548208155028766e-1);
508 const Packet cst_cephes_exp_q3 = pset1<Packet>(2.00000000000000000009e0);
509 const Packet cst_cephes_exp_C1 = pset1<Packet>(0.693145751953125);
510 const Packet cst_cephes_exp_C2 = pset1<Packet>(1.42860682030941723212e-6);
511
512 Packet tmp, fx;
513
514 // clamp x
515 x = pmax(pmin(x, cst_exp_hi), cst_exp_lo);
516 // Express exp(x) as exp(g + n*log(2)).
517 fx = pmadd(cst_cephes_LOG2EF, x, cst_half);
518
519 // Get the integer modulus of log(2), i.e. the "n" described above.
520 fx = pfloor(fx);
521
522 // Get the remainder modulo log(2), i.e. the "g" described above. Subtract
523 // n*log(2) out in two steps, i.e. n*C1 + n*C2, C1+C2=log2 to get the last
524 // digits right.
525 tmp = pmul(fx, cst_cephes_exp_C1);
526 Packet z = pmul(fx, cst_cephes_exp_C2);
527 x = psub(x, tmp);
528 x = psub(x, z);
529
530 Packet x2 = pmul(x, x);
531
532 // Evaluate the numerator polynomial of the rational interpolant.
533 Packet px = cst_cephes_exp_p0;
534 px = pmadd(px, x2, cst_cephes_exp_p1);
535 px = pmadd(px, x2, cst_cephes_exp_p2);
536 px = pmul(px, x);
537
538 // Evaluate the denominator polynomial of the rational interpolant.
539 Packet qx = cst_cephes_exp_q0;
540 qx = pmadd(qx, x2, cst_cephes_exp_q1);
541 qx = pmadd(qx, x2, cst_cephes_exp_q2);
542 qx = pmadd(qx, x2, cst_cephes_exp_q3);
543
544 // I don't really get this bit, copied from the SSE2 routines, so...
545 // TODO(gonnet): Figure out what is going on here, perhaps find a better
546 // rational interpolant?
547 x = pdiv(px, psub(qx, px));
548 x = pmadd(cst_2, x, cst_1);
549
550 // Construct the result 2^n * exp(g) = e * x. The max is used to catch
551 // non-finite values in the input.
552 // TODO: replace pldexp with faster implementation since x in [-1, 1).
553 return pmax(pldexp(x,fx), _x);
554}
555
556// The following code is inspired by the following stack-overflow answer:
557// https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
558// It has been largely optimized:
559// - By-pass calls to frexp.
560// - Aligned loads of required 96 bits of 2/pi. This is accomplished by
561// (1) balancing the mantissa and exponent to the required bits of 2/pi are
562// aligned on 8-bits, and (2) replicating the storage of the bits of 2/pi.
563// - Avoid a branch in rounding and extraction of the remaining fractional part.
564// Overall, I measured a speed up higher than x2 on x86-64.
565inline float trig_reduce_huge (float xf, int *quadrant)
566{
567 using Eigen::numext::int32_t;
568 using Eigen::numext::uint32_t;
569 using Eigen::numext::int64_t;
570 using Eigen::numext::uint64_t;
571
572 const double pio2_62 = 3.4061215800865545e-19; // pi/2 * 2^-62
573 const uint64_t zero_dot_five = uint64_t(1) << 61; // 0.5 in 2.62-bit fixed-point foramt
574
575 // 192 bits of 2/pi for Payne-Hanek reduction
576 // Bits are introduced by packet of 8 to enable aligned reads.
577 static const uint32_t two_over_pi [] =
578 {
579 0x00000028, 0x000028be, 0x0028be60, 0x28be60db,
580 0xbe60db93, 0x60db9391, 0xdb939105, 0x9391054a,
581 0x91054a7f, 0x054a7f09, 0x4a7f09d5, 0x7f09d5f4,
582 0x09d5f47d, 0xd5f47d4d, 0xf47d4d37, 0x7d4d3770,
583 0x4d377036, 0x377036d8, 0x7036d8a5, 0x36d8a566,
584 0xd8a5664f, 0xa5664f10, 0x664f10e4, 0x4f10e410,
585 0x10e41000, 0xe4100000
586 };
587
588 uint32_t xi = numext::bit_cast<uint32_t>(xf);
589 // Below, -118 = -126 + 8.
590 // -126 is to get the exponent,
591 // +8 is to enable alignment of 2/pi's bits on 8 bits.
592 // This is possible because the fractional part of x as only 24 meaningful bits.
593 uint32_t e = (xi >> 23) - 118;
594 // Extract the mantissa and shift it to align it wrt the exponent
595 xi = ((xi & 0x007fffffu)| 0x00800000u) << (e & 0x7);
596
597 uint32_t i = e >> 3;
598 uint32_t twoopi_1 = two_over_pi[i-1];
599 uint32_t twoopi_2 = two_over_pi[i+3];
600 uint32_t twoopi_3 = two_over_pi[i+7];
601
602 // Compute x * 2/pi in 2.62-bit fixed-point format.
603 uint64_t p;
604 p = uint64_t(xi) * twoopi_3;
605 p = uint64_t(xi) * twoopi_2 + (p >> 32);
606 p = (uint64_t(xi * twoopi_1) << 32) + p;
607
608 // Round to nearest: add 0.5 and extract integral part.
609 uint64_t q = (p + zero_dot_five) >> 62;
610 *quadrant = int(q);
611 // Now it remains to compute "r = x - q*pi/2" with high accuracy,
612 // since we have p=x/(pi/2) with high accuracy, we can more efficiently compute r as:
613 // r = (p-q)*pi/2,
614 // where the product can be be carried out with sufficient accuracy using double precision.
615 p -= q<<62;
616 return float(double(int64_t(p)) * pio2_62);
617}
618
619template<bool ComputeSine,typename Packet>
620EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
621EIGEN_UNUSED
622#if EIGEN_GNUC_AT_LEAST(4,4) && EIGEN_COMP_GNUC_STRICT
623__attribute__((optimize("-fno-unsafe-math-optimizations")))
624#endif
625Packet psincos_float(const Packet& _x)
626{
627 typedef typename unpacket_traits<Packet>::integer_packet PacketI;
628
629 const Packet cst_2oPI = pset1<Packet>(0.636619746685028076171875f); // 2/PI
630 const Packet cst_rounding_magic = pset1<Packet>(12582912); // 2^23 for rounding
631 const PacketI csti_1 = pset1<PacketI>(1);
632 const Packet cst_sign_mask = pset1frombits<Packet>(0x80000000u);
633
634 Packet x = pabs(_x);
635
636 // Scale x by 2/Pi to find x's octant.
637 Packet y = pmul(x, cst_2oPI);
638
639 // Rounding trick:
640 Packet y_round = padd(y, cst_rounding_magic);
641 EIGEN_OPTIMIZATION_BARRIER(y_round)
642 PacketI y_int = preinterpret<PacketI>(y_round); // last 23 digits represent integer (if abs(x)<2^24)
643 y = psub(y_round, cst_rounding_magic); // nearest integer to x*4/pi
644
645 // Subtract y * Pi/2 to reduce x to the interval -Pi/4 <= x <= +Pi/4
646 // using "Extended precision modular arithmetic"
647 #if defined(EIGEN_VECTORIZE_FMA)
648 // This version requires true FMA for high accuracy.
649 // It provides a max error of 1ULP up to (with absolute_error < 5.9605e-08):
650 const float huge_th = ComputeSine ? 117435.992f : 71476.0625f;
651 x = pmadd(y, pset1<Packet>(-1.57079601287841796875f), x);
652 x = pmadd(y, pset1<Packet>(-3.1391647326017846353352069854736328125e-07f), x);
653 x = pmadd(y, pset1<Packet>(-5.390302529957764765544681040410068817436695098876953125e-15f), x);
654 #else
655 // Without true FMA, the previous set of coefficients maintain 1ULP accuracy
656 // up to x<15.7 (for sin), but accuracy is immediately lost for x>15.7.
657 // We thus use one more iteration to maintain 2ULPs up to reasonably large inputs.
658
659 // The following set of coefficients maintain 1ULP up to 9.43 and 14.16 for sin and cos respectively.
660 // and 2 ULP up to:
661 const float huge_th = ComputeSine ? 25966.f : 18838.f;
662 x = pmadd(y, pset1<Packet>(-1.5703125), x); // = 0xbfc90000
663 EIGEN_OPTIMIZATION_BARRIER(x)
664 x = pmadd(y, pset1<Packet>(-0.000483989715576171875), x); // = 0xb9fdc000
665 EIGEN_OPTIMIZATION_BARRIER(x)
666 x = pmadd(y, pset1<Packet>(1.62865035235881805419921875e-07), x); // = 0x342ee000
667 x = pmadd(y, pset1<Packet>(5.5644315544167710640977020375430583953857421875e-11), x); // = 0x2e74b9ee
668
669 // For the record, the following set of coefficients maintain 2ULP up
670 // to a slightly larger range:
671 // const float huge_th = ComputeSine ? 51981.f : 39086.125f;
672 // but it slightly fails to maintain 1ULP for two values of sin below pi.
673 // x = pmadd(y, pset1<Packet>(-3.140625/2.), x);
674 // x = pmadd(y, pset1<Packet>(-0.00048351287841796875), x);
675 // x = pmadd(y, pset1<Packet>(-3.13855707645416259765625e-07), x);
676 // x = pmadd(y, pset1<Packet>(-6.0771006282767103812147979624569416046142578125e-11), x);
677
678 // For the record, with only 3 iterations it is possible to maintain
679 // 1 ULP up to 3PI (maybe more) and 2ULP up to 255.
680 // The coefficients are: 0xbfc90f80, 0xb7354480, 0x2e74b9ee
681 #endif
682
683 if(predux_any(pcmp_le(pset1<Packet>(huge_th),pabs(_x))))
684 {
685 const int PacketSize = unpacket_traits<Packet>::size;
686 EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float vals[PacketSize];
687 EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float x_cpy[PacketSize];
688 EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) int y_int2[PacketSize];
689 pstoreu(vals, pabs(_x));
690 pstoreu(x_cpy, x);
691 pstoreu(y_int2, y_int);
692 for(int k=0; k<PacketSize;++k)
693 {
694 float val = vals[k];
695 if(val>=huge_th && (numext::isfinite)(val))
696 x_cpy[k] = trig_reduce_huge(val,&y_int2[k]);
697 }
698 x = ploadu<Packet>(x_cpy);
699 y_int = ploadu<PacketI>(y_int2);
700 }
701
702 // Compute the sign to apply to the polynomial.
703 // sin: sign = second_bit(y_int) xor signbit(_x)
704 // cos: sign = second_bit(y_int+1)
705 Packet sign_bit = ComputeSine ? pxor(_x, preinterpret<Packet>(plogical_shift_left<30>(y_int)))
706 : preinterpret<Packet>(plogical_shift_left<30>(padd(y_int,csti_1)));
707 sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit
708
709 // Get the polynomial selection mask from the second bit of y_int
710 // We'll calculate both (sin and cos) polynomials and then select from the two.
711 Packet poly_mask = preinterpret<Packet>(pcmp_eq(pand(y_int, csti_1), pzero(y_int)));
712
713 Packet x2 = pmul(x,x);
714
715 // Evaluate the cos(x) polynomial. (-Pi/4 <= x <= Pi/4)
716 Packet y1 = pset1<Packet>(2.4372266125283204019069671630859375e-05f);
717 y1 = pmadd(y1, x2, pset1<Packet>(-0.00138865201734006404876708984375f ));
718 y1 = pmadd(y1, x2, pset1<Packet>(0.041666619479656219482421875f ));
719 y1 = pmadd(y1, x2, pset1<Packet>(-0.5f));
720 y1 = pmadd(y1, x2, pset1<Packet>(1.f));
721
722 // Evaluate the sin(x) polynomial. (Pi/4 <= x <= Pi/4)
723 // octave/matlab code to compute those coefficients:
724 // x = (0:0.0001:pi/4)';
725 // A = [x.^3 x.^5 x.^7];
726 // w = ((1.-(x/(pi/4)).^2).^5)*2000+1; # weights trading relative accuracy
727 // c = (A'*diag(w)*A)\‍(A'*diag(w)*(sin(x)-x)); # weighted LS, linear coeff forced to 1
728 // printf('%.64f\n %.64f\n%.64f\n', c(3), c(2), c(1))
729 //
730 Packet y2 = pset1<Packet>(-0.0001959234114083702898469196984621021329076029360294342041015625f);
731 y2 = pmadd(y2, x2, pset1<Packet>( 0.0083326873655616851693794799871284340042620897293090820312500000f));
732 y2 = pmadd(y2, x2, pset1<Packet>(-0.1666666203982298255503735617821803316473960876464843750000000000f));
733 y2 = pmul(y2, x2);
734 y2 = pmadd(y2, x, x);
735
736 // Select the correct result from the two polynomials.
737 y = ComputeSine ? pselect(poly_mask,y2,y1)
738 : pselect(poly_mask,y1,y2);
739
740 // Update the sign and filter huge inputs
741 return pxor(y, sign_bit);
742}
743
744template<typename Packet>
745EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
746EIGEN_UNUSED
747Packet psin_float(const Packet& x)
748{
749 return psincos_float<true>(x);
750}
751
752template<typename Packet>
753EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
754EIGEN_UNUSED
755Packet pcos_float(const Packet& x)
756{
757 return psincos_float<false>(x);
758}
759
760template<typename Packet>
761EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
762EIGEN_UNUSED Packet pdiv_complex(const Packet& x, const Packet& y) {
763 typedef typename unpacket_traits<Packet>::as_real RealPacket;
764 // In the following we annotate the code for the case where the inputs
765 // are a pair length-2 SIMD vectors representing a single pair of complex
766 // numbers x = a + i*b, y = c + i*d.
767 const RealPacket y_abs = pabs(y.v); // |c|, |d|
768 const RealPacket y_abs_flip = pcplxflip(Packet(y_abs)).v; // |d|, |c|
769 const RealPacket y_max = pmax(y_abs, y_abs_flip); // max(|c|, |d|), max(|c|, |d|)
770 const RealPacket y_scaled = pdiv(y.v, y_max); // c / max(|c|, |d|), d / max(|c|, |d|)
771 // Compute scaled denominator.
772 const RealPacket y_scaled_sq = pmul(y_scaled, y_scaled); // c'**2, d'**2
773 const RealPacket denom = padd(y_scaled_sq, pcplxflip(Packet(y_scaled_sq)).v);
774 Packet result_scaled = pmul(x, pconj(Packet(y_scaled))); // a * c' + b * d', -a * d + b * c
775 // Divide elementwise by denom.
776 result_scaled = Packet(pdiv(result_scaled.v, denom));
777 // Rescale result
778 return Packet(pdiv(result_scaled.v, y_max));
779}
780
781template<typename Packet>
782EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
783EIGEN_UNUSED
784Packet psqrt_complex(const Packet& a) {
785 typedef typename unpacket_traits<Packet>::type Scalar;
786 typedef typename Scalar::value_type RealScalar;
787 typedef typename unpacket_traits<Packet>::as_real RealPacket;
788
789 // Computes the principal sqrt of the complex numbers in the input.
790 //
791 // For example, for packets containing 2 complex numbers stored in interleaved format
792 // a = [a0, a1] = [x0, y0, x1, y1],
793 // where x0 = real(a0), y0 = imag(a0) etc., this function returns
794 // b = [b0, b1] = [u0, v0, u1, v1],
795 // such that b0^2 = a0, b1^2 = a1.
796 //
797 // To derive the formula for the complex square roots, let's consider the equation for
798 // a single complex square root of the number x + i*y. We want to find real numbers
799 // u and v such that
800 // (u + i*v)^2 = x + i*y <=>
801 // u^2 - v^2 + i*2*u*v = x + i*v.
802 // By equating the real and imaginary parts we get:
803 // u^2 - v^2 = x
804 // 2*u*v = y.
805 //
806 // For x >= 0, this has the numerically stable solution
807 // u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
808 // v = 0.5 * (y / u)
809 // and for x < 0,
810 // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
811 // u = 0.5 * (y / v)
812 //
813 // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as
814 // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) ,
815
816 // In the following, without lack of generality, we have annotated the code, assuming
817 // that the input is a packet of 2 complex numbers.
818 //
819 // Step 1. Compute l = [l0, l0, l1, l1], where
820 // l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2)
821 // To avoid over- and underflow, we use the stable formula for each hypotenuse
822 // l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)),
823 // where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1.
824
825 RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|]
826 RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; // [|y0|, |x0|, |y1|, |x1|]
827 RealPacket a_max = pmax(a_abs, a_abs_flip);
828 RealPacket a_min = pmin(a_abs, a_abs_flip);
829 RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
830 RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
831 RealPacket r = pdiv(a_min, a_max);
832 const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
833 RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
834 // Set l to a_max if a_min is zero.
835 l = pselect(a_min_zero_mask, a_max, l);
836
837 // Step 2. Compute [rho0, *, rho1, *], where
838 // rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|))
839 // We don't care about the imaginary parts computed here. They will be overwritten later.
840 const RealPacket cst_half = pset1<RealPacket>(RealScalar(0.5));
841 Packet rho;
842 rho.v = psqrt(pmul(cst_half, padd(a_abs, l)));
843
844 // Step 3. Compute [rho0, eta0, rho1, eta1], where
845 // eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2.
846 // set eta = 0 of input is 0 + i0.
847 RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask);
848 RealPacket real_mask = peven_mask(a.v);
849 Packet positive_real_result;
850 // Compute result for inputs with positive real part.
851 positive_real_result.v = pselect(real_mask, rho.v, eta);
852
853 // Step 4. Compute solution for inputs with negative real part:
854 // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1]
855 const RealScalar neg_zero = RealScalar(numext::bit_cast<float>(0x80000000u));
856 const RealPacket cst_imag_sign_mask = pset1<Packet>(Scalar(RealScalar(0.0), neg_zero)).v;
857 RealPacket imag_signs = pand(a.v, cst_imag_sign_mask);
858 Packet negative_real_result;
859 // Notice that rho is positive, so taking it's absolute value is a noop.
860 negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs);
861
862 // Step 5. Select solution branch based on the sign of the real parts.
863 Packet negative_real_mask;
864 negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v));
865 negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v);
866 Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result);
867
868 // Step 6. Handle special cases for infinities:
869 // * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN
870 // * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN
871 // * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y
872 // * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y
873 const RealPacket cst_pos_inf = pset1<RealPacket>(NumTraits<RealScalar>::infinity());
874 Packet is_inf;
875 is_inf.v = pcmp_eq(a_abs, cst_pos_inf);
876 Packet is_real_inf;
877 is_real_inf.v = pand(is_inf.v, real_mask);
878 is_real_inf = por(is_real_inf, pcplxflip(is_real_inf));
879 // prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part.
880 Packet real_inf_result;
881 real_inf_result.v = pmul(a_abs, pset1<Packet>(Scalar(RealScalar(1.0), RealScalar(0.0))).v);
882 real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v);
883 // prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part.
884 Packet is_imag_inf;
885 is_imag_inf.v = pandnot(is_inf.v, real_mask);
886 is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf));
887 Packet imag_inf_result;
888 imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask));
889
890 return pselect(is_imag_inf, imag_inf_result,
891 pselect(is_real_inf, real_inf_result,result));
892}
893
894// TODO(rmlarsen): The following set of utilities for double word arithmetic
895// should perhaps be refactored as a separate file, since it would be generally
896// useful for special function implementation etc. Writing the algorithms in
897// terms if a double word type would also make the code more readable.
898
899// This function splits x into the nearest integer n and fractional part r,
900// such that x = n + r holds exactly.
901template<typename Packet>
902EIGEN_STRONG_INLINE
903void absolute_split(const Packet& x, Packet& n, Packet& r) {
904 n = pround(x);
905 r = psub(x, n);
906}
907
908// This function computes the sum {s, r}, such that x + y = s_hi + s_lo
909// holds exactly, and s_hi = fl(x+y), if |x| >= |y|.
910template<typename Packet>
911EIGEN_STRONG_INLINE
912void fast_twosum(const Packet& x, const Packet& y, Packet& s_hi, Packet& s_lo) {
913 s_hi = padd(x, y);
914 const Packet t = psub(s_hi, x);
915 s_lo = psub(y, t);
916}
917
918#ifdef EIGEN_VECTORIZE_FMA
919// This function implements the extended precision product of
920// a pair of floating point numbers. Given {x, y}, it computes the pair
921// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
922// p_hi = fl(x * y).
923template<typename Packet>
924EIGEN_STRONG_INLINE
925void twoprod(const Packet& x, const Packet& y,
926 Packet& p_hi, Packet& p_lo) {
927 p_hi = pmul(x, y);
928 p_lo = pmadd(x, y, pnegate(p_hi));
929}
930
931#else
932
933// This function implements the Veltkamp splitting. Given a floating point
934// number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds
935// exactly and that half of the significant of x fits in x_hi.
936// This is Algorithm 3 from Jean-Michel Muller, "Elementary Functions",
937// 3rd edition, Birkh\"auser, 2016.
938template<typename Packet>
939EIGEN_STRONG_INLINE
940void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) {
941 typedef typename unpacket_traits<Packet>::type Scalar;
942 EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2;
943 const Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr.
944 const Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x);
945 Packet rho = psub(x, gamma);
946 x_hi = padd(rho, gamma);
947 x_lo = psub(x, x_hi);
948}
949
950// This function implements Dekker's algorithm for products x * y.
951// Given floating point numbers {x, y} computes the pair
952// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
953// p_hi = fl(x * y).
954template<typename Packet>
955EIGEN_STRONG_INLINE
956void twoprod(const Packet& x, const Packet& y,
957 Packet& p_hi, Packet& p_lo) {
958 Packet x_hi, x_lo, y_hi, y_lo;
959 veltkamp_splitting(x, x_hi, x_lo);
960 veltkamp_splitting(y, y_hi, y_lo);
961
962 p_hi = pmul(x, y);
963 p_lo = pmadd(x_hi, y_hi, pnegate(p_hi));
964 p_lo = pmadd(x_hi, y_lo, p_lo);
965 p_lo = pmadd(x_lo, y_hi, p_lo);
966 p_lo = pmadd(x_lo, y_lo, p_lo);
967}
968
969#endif // EIGEN_VECTORIZE_FMA
970
971
972// This function implements Dekker's algorithm for the addition
973// of two double word numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
974// It returns the result as a pair {s_hi, s_lo} such that
975// x_hi + x_lo + y_hi + y_lo = s_hi + s_lo holds exactly.
976// This is Algorithm 5 from Jean-Michel Muller, "Elementary Functions",
977// 3rd edition, Birkh\"auser, 2016.
978template<typename Packet>
979EIGEN_STRONG_INLINE
980 void twosum(const Packet& x_hi, const Packet& x_lo,
981 const Packet& y_hi, const Packet& y_lo,
982 Packet& s_hi, Packet& s_lo) {
983 const Packet x_greater_mask = pcmp_lt(pabs(y_hi), pabs(x_hi));
984 Packet r_hi_1, r_lo_1;
985 fast_twosum(x_hi, y_hi,r_hi_1, r_lo_1);
986 Packet r_hi_2, r_lo_2;
987 fast_twosum(y_hi, x_hi,r_hi_2, r_lo_2);
988 const Packet r_hi = pselect(x_greater_mask, r_hi_1, r_hi_2);
989
990 const Packet s1 = padd(padd(y_lo, r_lo_1), x_lo);
991 const Packet s2 = padd(padd(x_lo, r_lo_2), y_lo);
992 const Packet s = pselect(x_greater_mask, s1, s2);
993
994 fast_twosum(r_hi, s, s_hi, s_lo);
995}
996
997// This is a version of twosum for double word numbers,
998// which assumes that |x_hi| >= |y_hi|.
999template<typename Packet>
1000EIGEN_STRONG_INLINE
1001 void fast_twosum(const Packet& x_hi, const Packet& x_lo,
1002 const Packet& y_hi, const Packet& y_lo,
1003 Packet& s_hi, Packet& s_lo) {
1004 Packet r_hi, r_lo;
1005 fast_twosum(x_hi, y_hi, r_hi, r_lo);
1006 const Packet s = padd(padd(y_lo, r_lo), x_lo);
1007 fast_twosum(r_hi, s, s_hi, s_lo);
1008}
1009
1010// This is a version of twosum for adding a floating point number x to
1011// double word number {y_hi, y_lo} number, with the assumption
1012// that |x| >= |y_hi|.
1013template<typename Packet>
1014EIGEN_STRONG_INLINE
1015void fast_twosum(const Packet& x,
1016 const Packet& y_hi, const Packet& y_lo,
1017 Packet& s_hi, Packet& s_lo) {
1018 Packet r_hi, r_lo;
1019 fast_twosum(x, y_hi, r_hi, r_lo);
1020 const Packet s = padd(y_lo, r_lo);
1021 fast_twosum(r_hi, s, s_hi, s_lo);
1022}
1023
1024// This function implements the multiplication of a double word
1025// number represented by {x_hi, x_lo} by a floating point number y.
1026// It returns the result as a pair {p_hi, p_lo} such that
1027// (x_hi + x_lo) * y = p_hi + p_lo hold with a relative error
1028// of less than 2*2^{-2p}, where p is the number of significand bit
1029// in the floating point type.
1030// This is Algorithm 7 from Jean-Michel Muller, "Elementary Functions",
1031// 3rd edition, Birkh\"auser, 2016.
1032template<typename Packet>
1033EIGEN_STRONG_INLINE
1034void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y,
1035 Packet& p_hi, Packet& p_lo) {
1036 Packet c_hi, c_lo1;
1037 twoprod(x_hi, y, c_hi, c_lo1);
1038 const Packet c_lo2 = pmul(x_lo, y);
1039 Packet t_hi, t_lo1;
1040 fast_twosum(c_hi, c_lo2, t_hi, t_lo1);
1041 const Packet t_lo2 = padd(t_lo1, c_lo1);
1042 fast_twosum(t_hi, t_lo2, p_hi, p_lo);
1043}
1044
1045// This function implements the multiplication of two double word
1046// numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
1047// It returns the result as a pair {p_hi, p_lo} such that
1048// (x_hi + x_lo) * (y_hi + y_lo) = p_hi + p_lo holds with a relative error
1049// of less than 2*2^{-2p}, where p is the number of significand bit
1050// in the floating point type.
1051template<typename Packet>
1052EIGEN_STRONG_INLINE
1053void twoprod(const Packet& x_hi, const Packet& x_lo,
1054 const Packet& y_hi, const Packet& y_lo,
1055 Packet& p_hi, Packet& p_lo) {
1056 Packet p_hi_hi, p_hi_lo;
1057 twoprod(x_hi, x_lo, y_hi, p_hi_hi, p_hi_lo);
1058 Packet p_lo_hi, p_lo_lo;
1059 twoprod(x_hi, x_lo, y_lo, p_lo_hi, p_lo_lo);
1060 fast_twosum(p_hi_hi, p_hi_lo, p_lo_hi, p_lo_lo, p_hi, p_lo);
1061}
1062
1063// This function computes the reciprocal of a floating point number
1064// with extra precision and returns the result as a double word.
1065template <typename Packet>
1066void doubleword_reciprocal(const Packet& x, Packet& recip_hi, Packet& recip_lo) {
1067 typedef typename unpacket_traits<Packet>::type Scalar;
1068 // 1. Approximate the reciprocal as the reciprocal of the high order element.
1069 Packet approx_recip = prsqrt(x);
1070 approx_recip = pmul(approx_recip, approx_recip);
1071
1072 // 2. Run one step of Newton-Raphson iteration in double word arithmetic
1073 // to get the bottom half. The NR iteration for reciprocal of 'a' is
1074 // x_{i+1} = x_i * (2 - a * x_i)
1075
1076 // -a*x_i
1077 Packet t1_hi, t1_lo;
1078 twoprod(pnegate(x), approx_recip, t1_hi, t1_lo);
1079 // 2 - a*x_i
1080 Packet t2_hi, t2_lo;
1081 fast_twosum(pset1<Packet>(Scalar(2)), t1_hi, t2_hi, t2_lo);
1082 Packet t3_hi, t3_lo;
1083 fast_twosum(t2_hi, padd(t2_lo, t1_lo), t3_hi, t3_lo);
1084 // x_i * (2 - a * x_i)
1085 twoprod(t3_hi, t3_lo, approx_recip, recip_hi, recip_lo);
1086}
1087
1088
1089// This function computes log2(x) and returns the result as a double word.
1090template <typename Scalar>
1092 template <typename Packet>
1093 EIGEN_STRONG_INLINE
1094 void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
1095 log2_x_hi = plog2(x);
1096 log2_x_lo = pzero(x);
1097 }
1098};
1099
1100// This specialization uses a more accurate algorithm to compute log2(x) for
1101// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10.
1102// This additional accuracy is needed to counter the error-magnification
1103// inherent in multiplying by a potentially large exponent in pow(x,y).
1104// The minimax polynomial used was calculated using the Sollya tool.
1105// See sollya.org.
1106template <>
1107struct accurate_log2<float> {
1108 template <typename Packet>
1109 EIGEN_STRONG_INLINE
1110 void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) {
1111 // The function log(1+x)/x is approximated in the interval
1112 // [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form
1113 // Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))),
1114 // where the degree 6 polynomial P(x) is evaluated in single precision,
1115 // while the remaining 4 terms of Q(x), as well as the final multiplication by x
1116 // to reconstruct log(1+x) are evaluated in extra precision using
1117 // double word arithmetic. C0 through C3 are extra precise constants
1118 // stored as double words.
1119 //
1120 // The polynomial coefficients were calculated using Sollya commands:
1121 // > n = 10;
1122 // > f = log2(1+x)/x;
1123 // > interval = [sqrt(0.5)-1;sqrt(2)-1];
1124 // > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating);
1125
1126 const Packet p6 = pset1<Packet>( 9.703654795885e-2f);
1127 const Packet p5 = pset1<Packet>(-0.1690667718648f);
1128 const Packet p4 = pset1<Packet>( 0.1720575392246f);
1129 const Packet p3 = pset1<Packet>(-0.1789081543684f);
1130 const Packet p2 = pset1<Packet>( 0.2050433009862f);
1131 const Packet p1 = pset1<Packet>(-0.2404672354459f);
1132 const Packet p0 = pset1<Packet>( 0.2885761857032f);
1133
1134 const Packet C3_hi = pset1<Packet>(-0.360674142838f);
1135 const Packet C3_lo = pset1<Packet>(-6.13283912543e-09f);
1136 const Packet C2_hi = pset1<Packet>(0.480897903442f);
1137 const Packet C2_lo = pset1<Packet>(-1.44861207474e-08f);
1138 const Packet C1_hi = pset1<Packet>(-0.721347510815f);
1139 const Packet C1_lo = pset1<Packet>(-4.84483164698e-09f);
1140 const Packet C0_hi = pset1<Packet>(1.44269502163f);
1141 const Packet C0_lo = pset1<Packet>(2.01711713999e-08f);
1142 const Packet one = pset1<Packet>(1.0f);
1143
1144 const Packet x = psub(z, one);
1145 // Evaluate P(x) in working precision.
1146 // We evaluate it in multiple parts to improve instruction level
1147 // parallelism.
1148 Packet x2 = pmul(x,x);
1149 Packet p_even = pmadd(p6, x2, p4);
1150 p_even = pmadd(p_even, x2, p2);
1151 p_even = pmadd(p_even, x2, p0);
1152 Packet p_odd = pmadd(p5, x2, p3);
1153 p_odd = pmadd(p_odd, x2, p1);
1154 Packet p = pmadd(p_odd, x, p_even);
1155
1156 // Now evaluate the low-order tems of Q(x) in double word precision.
1157 // In the following, due to the alternating signs and the fact that
1158 // |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use
1159 // fast_twosum instead of the slower twosum.
1160 Packet q_hi, q_lo;
1161 Packet t_hi, t_lo;
1162 // C3 + x * p(x)
1163 twoprod(p, x, t_hi, t_lo);
1164 fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo);
1165 // C2 + x * p(x)
1166 twoprod(q_hi, q_lo, x, t_hi, t_lo);
1167 fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo);
1168 // C1 + x * p(x)
1169 twoprod(q_hi, q_lo, x, t_hi, t_lo);
1170 fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo);
1171 // C0 + x * p(x)
1172 twoprod(q_hi, q_lo, x, t_hi, t_lo);
1173 fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo);
1174
1175 // log(z) ~= x * Q(x)
1176 twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo);
1177 }
1178};
1179
1180// This specialization uses a more accurate algorithm to compute log2(x) for
1181// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18.
1182// This additional accuracy is needed to counter the error-magnification
1183// inherent in multiplying by a potentially large exponent in pow(x,y).
1184// The minimax polynomial used was calculated using the Sollya tool.
1185// See sollya.org.
1186
1187template <>
1189 template <typename Packet>
1190 EIGEN_STRONG_INLINE
1191 void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
1192 // We use a transformation of variables:
1193 // r = c * (x-1) / (x+1),
1194 // such that
1195 // log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r).
1196 // The function f(r) can be approximated well using an odd polynomial
1197 // of the form
1198 // P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r,
1199 // For the implementation of log2<double> here, Q is of degree 6 with
1200 // coefficient represented in working precision (double), while C is a
1201 // constant represented in extra precision as a double word to achieve
1202 // full accuracy.
1203 //
1204 // The polynomial coefficients were computed by the Sollya script:
1205 //
1206 // c = 2 / log(2);
1207 // trans = c * (x-1)/(x+1);
1208 // itrans = (1+x/c)/(1-x/c);
1209 // interval=[trans(sqrt(0.5)); trans(sqrt(2))];
1210 // print(interval);
1211 // f = log2(itrans(x));
1212 // p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating);
1213 const Packet q12 = pset1<Packet>(2.87074255468000586e-9);
1214 const Packet q10 = pset1<Packet>(2.38957980901884082e-8);
1215 const Packet q8 = pset1<Packet>(2.31032094540014656e-7);
1216 const Packet q6 = pset1<Packet>(2.27279857398537278e-6);
1217 const Packet q4 = pset1<Packet>(2.31271023278625638e-5);
1218 const Packet q2 = pset1<Packet>(2.47556738444535513e-4);
1219 const Packet q0 = pset1<Packet>(2.88543873228900172e-3);
1220 const Packet C_hi = pset1<Packet>(0.0400377511598501157);
1221 const Packet C_lo = pset1<Packet>(-4.77726582251425391e-19);
1222 const Packet one = pset1<Packet>(1.0);
1223
1224 const Packet cst_2_log2e_hi = pset1<Packet>(2.88539008177792677);
1225 const Packet cst_2_log2e_lo = pset1<Packet>(4.07660016854549667e-17);
1226 // c * (x - 1)
1228 twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), num_hi, num_lo);
1229 // TODO(rmlarsen): Investigate if using the division algorithm by
1230 // Muller et al. is faster/more accurate.
1231 // 1 / (x + 1)
1233 doubleword_reciprocal(padd(x, one), denom_hi, denom_lo);
1234 // r = c * (x-1) / (x+1),
1235 Packet r_hi, r_lo;
1236 twoprod(num_hi, num_lo, denom_hi, denom_lo, r_hi, r_lo);
1237 // r2 = r * r
1239 twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo);
1240 // r4 = r2 * r2
1242 twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo);
1243
1244 // Evaluate Q(r^2) in working precision. We evaluate it in two parts
1245 // (even and odd in r^2) to improve instruction level parallelism.
1246 Packet q_even = pmadd(q12, r4_hi, q8);
1247 Packet q_odd = pmadd(q10, r4_hi, q6);
1248 q_even = pmadd(q_even, r4_hi, q4);
1249 q_odd = pmadd(q_odd, r4_hi, q2);
1250 q_even = pmadd(q_even, r4_hi, q0);
1251 Packet q = pmadd(q_odd, r2_hi, q_even);
1252
1253 // Now evaluate the low order terms of P(x) in double word precision.
1254 // In the following, due to the increasing magnitude of the coefficients
1255 // and r being constrained to [-0.5, 0.5] we can use fast_twosum instead
1256 // of the slower twosum.
1257 // Q(r^2) * r^2
1258 Packet p_hi, p_lo;
1259 twoprod(r2_hi, r2_lo, q, p_hi, p_lo);
1260 // Q(r^2) * r^2 + C
1262 fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo);
1263 // (Q(r^2) * r^2 + C) * r^2
1265 twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo);
1266 // ((Q(r^2) * r^2 + C) * r^2 + 1)
1268 fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo);
1269
1270 // log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r
1271 twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo);
1272 }
1273};
1274
1275// This function computes exp2(x) (i.e. 2**x).
1276template <typename Scalar>
1278 template <typename Packet>
1279 EIGEN_STRONG_INLINE
1280 Packet operator()(const Packet& x) {
1281 // TODO(rmlarsen): Add a pexp2 packetop.
1282 return pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), x));
1283 }
1284};
1285
1286// This specialization uses a faster algorithm to compute exp2(x) for floats
1287// in [-0.5;0.5] with a relative accuracy of 1 ulp.
1288// The minimax polynomial used was calculated using the Sollya tool.
1289// See sollya.org.
1290template <>
1291struct fast_accurate_exp2<float> {
1292 template <typename Packet>
1293 EIGEN_STRONG_INLINE
1294 Packet operator()(const Packet& x) {
1295 // This function approximates exp2(x) by a degree 6 polynomial of the form
1296 // Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in
1297 // single precision, and the remaining steps are evaluated with extra precision using
1298 // double word arithmetic. C is an extra precise constant stored as a double word.
1299 //
1300 // The polynomial coefficients were calculated using Sollya commands:
1301 // > n = 6;
1302 // > f = 2^x;
1303 // > interval = [-0.5;0.5];
1304 // > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating);
1305
1306 const Packet p4 = pset1<Packet>(1.539513905e-4f);
1307 const Packet p3 = pset1<Packet>(1.340007293e-3f);
1308 const Packet p2 = pset1<Packet>(9.618283249e-3f);
1309 const Packet p1 = pset1<Packet>(5.550328270e-2f);
1310 const Packet p0 = pset1<Packet>(0.2402264923f);
1311
1312 const Packet C_hi = pset1<Packet>(0.6931471825f);
1313 const Packet C_lo = pset1<Packet>(2.36836577e-08f);
1314 const Packet one = pset1<Packet>(1.0f);
1315
1316 // Evaluate P(x) in working precision.
1317 // We evaluate even and odd parts of the polynomial separately
1318 // to gain some instruction level parallelism.
1319 Packet x2 = pmul(x,x);
1320 Packet p_even = pmadd(p4, x2, p2);
1321 Packet p_odd = pmadd(p3, x2, p1);
1322 p_even = pmadd(p_even, x2, p0);
1323 Packet p = pmadd(p_odd, x, p_even);
1324
1325 // Evaluate the remaining terms of Q(x) with extra precision using
1326 // double word arithmetic.
1327 Packet p_hi, p_lo;
1328 // x * p(x)
1329 twoprod(p, x, p_hi, p_lo);
1330 // C + x * p(x)
1332 twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
1333 // x * (C + x * p(x))
1335 twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
1336 // 1 + x * (C + x * p(x))
1338 // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
1339 // for adding it to unity here.
1340 fast_twosum(one, q2_hi, q3_hi, q3_lo);
1341 return padd(q3_hi, padd(q2_lo, q3_lo));
1342 }
1343};
1344
1345// in [-0.5;0.5] with a relative accuracy of 1 ulp.
1346// The minimax polynomial used was calculated using the Sollya tool.
1347// See sollya.org.
1348template <>
1350 template <typename Packet>
1351 EIGEN_STRONG_INLINE
1352 Packet operator()(const Packet& x) {
1353 // This function approximates exp2(x) by a degree 10 polynomial of the form
1354 // Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in
1355 // single precision, and the remaining steps are evaluated with extra precision using
1356 // double word arithmetic. C is an extra precise constant stored as a double word.
1357 //
1358 // The polynomial coefficients were calculated using Sollya commands:
1359 // > n = 11;
1360 // > f = 2^x;
1361 // > interval = [-0.5;0.5];
1362 // > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating);
1363
1364 const Packet p9 = pset1<Packet>(4.431642109085495276e-10);
1365 const Packet p8 = pset1<Packet>(7.073829923303358410e-9);
1366 const Packet p7 = pset1<Packet>(1.017822306737031311e-7);
1367 const Packet p6 = pset1<Packet>(1.321543498017646657e-6);
1368 const Packet p5 = pset1<Packet>(1.525273342728892877e-5);
1369 const Packet p4 = pset1<Packet>(1.540353045780084423e-4);
1370 const Packet p3 = pset1<Packet>(1.333355814685869807e-3);
1371 const Packet p2 = pset1<Packet>(9.618129107593478832e-3);
1372 const Packet p1 = pset1<Packet>(5.550410866481961247e-2);
1373 const Packet p0 = pset1<Packet>(0.240226506959101332);
1374 const Packet C_hi = pset1<Packet>(0.693147180559945286);
1375 const Packet C_lo = pset1<Packet>(4.81927865669806721e-17);
1376 const Packet one = pset1<Packet>(1.0);
1377
1378 // Evaluate P(x) in working precision.
1379 // We evaluate even and odd parts of the polynomial separately
1380 // to gain some instruction level parallelism.
1381 Packet x2 = pmul(x,x);
1382 Packet p_even = pmadd(p8, x2, p6);
1383 Packet p_odd = pmadd(p9, x2, p7);
1384 p_even = pmadd(p_even, x2, p4);
1385 p_odd = pmadd(p_odd, x2, p5);
1386 p_even = pmadd(p_even, x2, p2);
1387 p_odd = pmadd(p_odd, x2, p3);
1388 p_even = pmadd(p_even, x2, p0);
1389 p_odd = pmadd(p_odd, x2, p1);
1390 Packet p = pmadd(p_odd, x, p_even);
1391
1392 // Evaluate the remaining terms of Q(x) with extra precision using
1393 // double word arithmetic.
1394 Packet p_hi, p_lo;
1395 // x * p(x)
1396 twoprod(p, x, p_hi, p_lo);
1397 // C + x * p(x)
1399 twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
1400 // x * (C + x * p(x))
1402 twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
1403 // 1 + x * (C + x * p(x))
1405 // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
1406 // for adding it to unity here.
1407 fast_twosum(one, q2_hi, q3_hi, q3_lo);
1408 return padd(q3_hi, padd(q2_lo, q3_lo));
1409 }
1410};
1411
1412// This function implements the non-trivial case of pow(x,y) where x is
1413// positive and y is (possibly) non-integer.
1414// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x.
1415// TODO(rmlarsen): We should probably add this as a packet up 'ppow', to make it
1416// easier to specialize or turn off for specific types and/or backends.x
1417template <typename Packet>
1418EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) {
1419 typedef typename unpacket_traits<Packet>::type Scalar;
1420 // Split x into exponent e_x and mantissa m_x.
1421 Packet e_x;
1422 Packet m_x = pfrexp(x, e_x);
1423
1424 // Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x).
1425 EIGEN_CONSTEXPR Scalar sqrt_half = Scalar(0.70710678118654752440);
1426 const Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(sqrt_half));
1427 m_x = pselect(m_x_scale_mask, pmul(pset1<Packet>(Scalar(2)), m_x), m_x);
1428 e_x = pselect(m_x_scale_mask, psub(e_x, pset1<Packet>(Scalar(1))), e_x);
1429
1430 // Compute log2(m_x) with 6 extra bits of accuracy.
1433
1434 // Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled
1435 // precision using double word arithmetic.
1437 twoprod(e_x, y, f1_hi, f1_lo);
1438 twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo);
1439 // Sum the two terms in f using double word arithmetic. We know
1440 // that |e_x| > |log2(m_x)|, except for the case where e_x==0.
1441 // This means that we can use fast_twosum(f1,f2).
1442 // In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any
1443 // accuracy by violating the assumption of fast_twosum, because
1444 // it's a no-op.
1445 Packet f_hi, f_lo;
1446 fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo);
1447
1448 // Split f into integer and fractional parts.
1449 Packet n_z, r_z;
1450 absolute_split(f_hi, n_z, r_z);
1451 r_z = padd(r_z, f_lo);
1452 Packet n_r;
1453 absolute_split(r_z, n_r, r_z);
1454 n_z = padd(n_z, n_r);
1455
1456 // We now have an accurate split of f = n_z + r_z and can compute
1457 // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}.
1458 // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy
1459 // using a specialized algorithm. Multiplication by the second factor can
1460 // be done exactly using pldexp(), since it is an integer power of 2.
1462 return pldexp(e_r, n_z);
1463}
1464
1465// Generic implementation of pow(x,y).
1466template <typename Packet>
1467EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Packet& x, const Packet& y) {
1468 typedef typename unpacket_traits<Packet>::type Scalar;
1469
1472 const Packet cst_zero = pset1<Packet>(Scalar(0));
1473 const Packet cst_one = pset1<Packet>(Scalar(1));
1475
1476 const Packet abs_x = pabs(x);
1477 // Predicates for sign and magnitude of x.
1478 const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_zero);
1479 const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
1480 const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
1481 const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
1482 const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
1483 const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
1484 const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
1485 const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
1486 const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
1487 const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
1488 const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
1489
1490 // Predicates for sign and magnitude of y.
1491 const Packet abs_y = pabs(y);
1492 const Packet y_is_one = pcmp_eq(y, cst_one);
1493 const Packet abs_y_is_zero = pcmp_eq(abs_y, cst_zero);
1494 const Packet y_is_neg = pcmp_lt(y, cst_zero);
1495 const Packet y_is_pos = pandnot(ptrue(y), por(abs_y_is_zero, y_is_neg));
1496 const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
1497 const Packet abs_y_is_inf = pcmp_eq(abs_y, cst_pos_inf);
1498 EIGEN_CONSTEXPR Scalar huge_exponent =
1500 const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y));
1501
1502 // Predicates for whether y is integer and/or even.
1503 const Packet y_is_int = pcmp_eq(pfloor(y), y);
1504 const Packet y_div_2 = pmul(y, pset1<Packet>(Scalar(0.5)));
1505 const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
1506
1507 // Predicates encoding special cases for the value of pow(x,y)
1508 const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), y_is_int), abs_y_is_inf);
1509 const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
1510 const Packet pow_is_one =
1512 const Packet pow_is_zero = por(por(por(pand(abs_x_is_zero, y_is_pos), pand(abs_x_is_inf, y_is_neg)),
1513 pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_pos)),
1514 pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_neg));
1515 const Packet pow_is_inf = por(por(por(pand(abs_x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)),
1516 pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_neg)),
1517 pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_pos));
1518 const Packet inf_val =
1519 pselect(pandnot(pand(por(pand(abs_x_is_inf, x_is_neg), pand(x_is_neg_zero, y_is_neg)), y_is_int), y_is_even),
1521
1522 // General computation of pow(x,y) for positive x or negative x and integer y.
1523 const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
1524 const Packet pow_abs = generic_pow_impl(abs_x, y);
1525 return pselect(
1526 y_is_one, x,
1527 pselect(pow_is_one, cst_one,
1528 pselect(pow_is_nan, cst_nan,
1529 pselect(pow_is_inf, inf_val,
1530 pselect(pow_is_zero, cst_zero, pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))));
1531}
1532
1533/* polevl (modified for Eigen)
1534 *
1535 * Evaluate polynomial
1536 *
1537 *
1538 *
1539 * SYNOPSIS:
1540 *
1541 * int N;
1542 * Scalar x, y, coef[N+1];
1543 *
1544 * y = polevl<decltype(x), N>( x, coef);
1545 *
1546 *
1547 *
1548 * DESCRIPTION:
1549 *
1550 * Evaluates polynomial of degree N:
1551 *
1552 * 2 N
1553 * y = C + C x + C x +...+ C x
1554 * 0 1 2 N
1555 *
1556 * Coefficients are stored in reverse order:
1557 *
1558 * coef[0] = C , ..., coef[N] = C .
1559 * N 0
1560 *
1561 * The function p1evl() assumes that coef[N] = 1.0 and is
1562 * omitted from the array. Its calling arguments are
1563 * otherwise the same as polevl().
1564 *
1565 *
1566 * The Eigen implementation is templatized. For best speed, store
1567 * coef as a const array (constexpr), e.g.
1568 *
1569 * const double coef[] = {1.0, 2.0, 3.0, ...};
1570 *
1571 */
1572template <typename Packet, int N>
1573struct ppolevl {
1574 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
1575 EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
1576 return pmadd(ppolevl<Packet, N-1>::run(x, coeff), x, pset1<Packet>(coeff[N]));
1577 }
1578};
1579
1580template <typename Packet>
1581struct ppolevl<Packet, 0> {
1582 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
1583 EIGEN_UNUSED_VARIABLE(x);
1584 return pset1<Packet>(coeff[0]);
1585 }
1586};
1587
1588/* chbevl (modified for Eigen)
1589 *
1590 * Evaluate Chebyshev series
1591 *
1592 *
1593 *
1594 * SYNOPSIS:
1595 *
1596 * int N;
1597 * Scalar x, y, coef[N], chebevl();
1598 *
1599 * y = chbevl( x, coef, N );
1600 *
1601 *
1602 *
1603 * DESCRIPTION:
1604 *
1605 * Evaluates the series
1606 *
1607 * N-1
1608 * - '
1609 * y = > coef[i] T (x/2)
1610 * - i
1611 * i=0
1612 *
1613 * of Chebyshev polynomials Ti at argument x/2.
1614 *
1615 * Coefficients are stored in reverse order, i.e. the zero
1616 * order term is last in the array. Note N is the number of
1617 * coefficients, not the order.
1618 *
1619 * If coefficients are for the interval a to b, x must
1620 * have been transformed to x -> 2(2x - b - a)/(b-a) before
1621 * entering the routine. This maps x from (a, b) to (-1, 1),
1622 * over which the Chebyshev polynomials are defined.
1623 *
1624 * If the coefficients are for the inverted interval, in
1625 * which (a, b) is mapped to (1/b, 1/a), the transformation
1626 * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity,
1627 * this becomes x -> 4a/x - 1.
1628 *
1629 *
1630 *
1631 * SPEED:
1632 *
1633 * Taking advantage of the recurrence properties of the
1634 * Chebyshev polynomials, the routine requires one more
1635 * addition per loop than evaluating a nested polynomial of
1636 * the same degree.
1637 *
1638 */
1639
1640template <typename Packet, int N>
1641struct pchebevl {
1642 EIGEN_DEVICE_FUNC
1643 static EIGEN_STRONG_INLINE Packet run(Packet x, const typename unpacket_traits<Packet>::type coef[]) {
1644 typedef typename unpacket_traits<Packet>::type Scalar;
1645 Packet b0 = pset1<Packet>(coef[0]);
1646 Packet b1 = pset1<Packet>(static_cast<Scalar>(0.f));
1647 Packet b2;
1648
1649 for (int i = 1; i < N; i++) {
1650 b2 = b1;
1651 b1 = b0;
1652 b0 = psub(pmadd(x, b1, pset1<Packet>(coef[i])), b2);
1653 }
1654
1655 return pmul(pset1<Packet>(static_cast<Scalar>(0.5f)), psub(b0, b2));
1656 }
1657};
1658
1659} // end namespace internal
1660} // end namespace Eigen
1661
1662#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
Base class for all dense matrices, vectors, and expressions.
Definition MatrixBase.h:50
Namespace containing all symbols from the Eigen library.
Definition LDLT.h:16
Holds information about the various numeric (i.e.
Definition NumTraits.h:236
Definition BFloat16.h:58
Definition Half.h:140
Definition GenericPacketMathFunctions.h:1091
Definition GenericPacketMathFunctions.h:1277
Definition GenericPacketMathFunctions.h:23
Definition GenericPacketMathFunctions.h:1641
Definition GenericPacketMathFunctions.h:139
Definition GenericPacketMathFunctions.h:1573
Definition GenericPacketMath.h:133
Definition PacketMath.h:47