Medial Code Documentation
Loading...
Searching...
No Matches
MatrixVectorProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
11#define EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
12
13#include "../../InternalHeaderCheck.h"
14
15#if defined(__MMA__) && !EIGEN_ALTIVEC_DISABLE_MMA
16#if EIGEN_COMP_LLVM || (__GNUC__ > 10 || __GNUC_MINOR__ >= 3)
17#define USE_GEMV_MMA
18#endif
19
20#if !EIGEN_COMP_LLVM && (__GNUC__ == 10 && __GNUC_MINOR__ <= 3)
21// Only allow one vector_pair in buggy gcc - gcc 10.3 has a bug
22#define GCC_ONE_VECTORPAIR_BUG
23#endif
24#endif
25
26//#define USE_SLOWER_GEMV_MMA // MMA is currently not as fast as VSX in complex double GEMV (revisit when gcc is improved)
27
28//#define EIGEN_POWER_USE_GEMV_PREFETCH
29#ifdef EIGEN_POWER_USE_GEMV_PREFETCH
30#define EIGEN_POWER_GEMV_PREFETCH(p) prefetch(p)
31#else
32#define EIGEN_POWER_GEMV_PREFETCH(p)
33#endif
34
35#ifdef __has_builtin
36#if !__has_builtin(__builtin_vsx_assemble_pair)
37#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
38#endif
39#if !__has_builtin(__builtin_vsx_disassemble_pair)
40#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
41#endif
42#endif
43
44#if EIGEN_COMP_LLVM
45#define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
46 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
47#else
48#if (__GNUC__ <= 10)
49#if (__GNUC_MINOR__ > 3)
50#define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
51 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
52#else
53#define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
54 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
55#endif
56#else
57#define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
58 __builtin_vsx_build_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
59#endif
60#endif
61
62#define GEMV_IS_COMPLEX_COMPLEX ((sizeof(LhsPacket) == 16) && (sizeof(RhsPacket) == 16))
63#define GEMV_IS_FLOAT (ResPacketSize == (16 / sizeof(float)))
64#define GEMV_IS_SCALAR (sizeof(ResPacket) != 16)
65#define GEMV_IS_COMPLEX_FLOAT (ResPacketSize == (16 / sizeof(std::complex<float>)))
66
68template<typename ResPacket, typename ResScalar>
69EIGEN_ALWAYS_INLINE void storeMaddData(ResScalar* res, ResPacket& palpha, ResPacket& data)
70{
71 pstoreu(res, pmadd(data, palpha, ploadu<ResPacket>(res)));
72}
73
74template<typename ResScalar>
75EIGEN_ALWAYS_INLINE void storeMaddData(ResScalar* res, ResScalar& alpha, ResScalar& data)
76{
77 *res += (alpha * data);
78}
79
80#define GEMV_UNROLL(func, N) \
81 func(0, N) func(1, N) func(2, N) func(3, N) \
82 func(4, N) func(5, N) func(6, N) func(7, N)
83
84#define GEMV_UNROLL_HALF(func, N) \
85 func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
86
87#define GEMV_GETN(N) (((N) * ResPacketSize) >> 2)
88
89#define GEMV_LOADPACKET_COL(iter) \
90 lhs.template load<LhsPacket, LhsAlignment>(i + ((iter) * LhsPacketSize), j)
91
92#ifdef USE_GEMV_MMA
93#define GEMV_UNROLL3(func, N, which) \
94 func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) \
95 func(4, N, which) func(5, N, which) func(6, N, which) func(7, N, which)
96
97#define GEMV_UNUSED_VAR(iter, N, which) \
98 if (GEMV_GETN(N) <= iter) { \
99 EIGEN_UNUSED_VARIABLE(which##iter); \
100 }
101
102#define GEMV_UNUSED_EXTRA_VAR(iter, N, which) \
103 if (N <= iter) { \
104 EIGEN_UNUSED_VARIABLE(which##iter); \
105 }
106
107#define GEMV_UNUSED_EXTRA(N, which) \
108 GEMV_UNROLL3(GEMV_UNUSED_EXTRA_VAR, N, which)
109
110#define GEMV_UNUSED(N, which) \
111 GEMV_UNROLL3(GEMV_UNUSED_VAR, N, which)
112
113#define GEMV_INIT_MMA(iter, N) \
114 if (GEMV_GETN(N) > iter) { \
115 __builtin_mma_xxsetaccz(&e##iter); \
116 }
117
118#if EIGEN_COMP_LLVM
119#define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
120 GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_COL(iter2), GEMV_LOADPACKET_COL((iter2) + 1));
121#else
122#define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
123 const LhsScalar& src##iter1 = lhs(i + ((iter1 * 32) / sizeof(LhsScalar)), j); \
124 b##iter1 = *reinterpret_cast<__vector_pair *>(const_cast<LhsScalar *>(&src##iter1));
125#endif
126
127#define GEMV_LOAD1A_COL_MMA(iter, N) \
128 if (GEMV_GETN(N) > iter) { \
129 if (GEMV_IS_FLOAT) { \
130 g##iter = GEMV_LOADPACKET_COL(iter); \
131 EIGEN_UNUSED_VARIABLE(b##iter); \
132 } else { \
133 GEMV_LOADPAIR_COL_MMA(iter, iter << 1) \
134 EIGEN_UNUSED_VARIABLE(g##iter); \
135 } \
136 } else { \
137 EIGEN_UNUSED_VARIABLE(b##iter); \
138 EIGEN_UNUSED_VARIABLE(g##iter); \
139 }
140
141#define GEMV_WORK1A_COL_MMA(iter, N) \
142 if (GEMV_GETN(N) > iter) { \
143 if (GEMV_IS_FLOAT) { \
144 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, a0, g##iter); \
145 } else { \
146 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, b##iter, a0); \
147 } \
148 }
149
150#define GEMV_LOAD1B_COL_MMA(iter1, iter2, iter3, N) \
151 if (GEMV_GETN(N) > iter1) { \
152 if (GEMV_IS_FLOAT) { \
153 GEMV_LOADPAIR_COL_MMA(iter2, iter2) \
154 EIGEN_UNUSED_VARIABLE(b##iter3); \
155 } else { \
156 GEMV_LOADPAIR_COL_MMA(iter2, iter2 << 1) \
157 GEMV_LOADPAIR_COL_MMA(iter3, iter3 << 1) \
158 } \
159 } else { \
160 EIGEN_UNUSED_VARIABLE(b##iter2); \
161 EIGEN_UNUSED_VARIABLE(b##iter3); \
162 } \
163 EIGEN_UNUSED_VARIABLE(g##iter2); \
164 EIGEN_UNUSED_VARIABLE(g##iter3);
165
166#define GEMV_WORK1B_COL_MMA(iter1, iter2, iter3, N) \
167 if (GEMV_GETN(N) > iter1) { \
168 if (GEMV_IS_FLOAT) { \
169 LhsPacket h[2]; \
170 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(h), &b##iter2); \
171 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, a0, h[0]); \
172 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, a0, h[1]); \
173 } else { \
174 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, b##iter2, a0); \
175 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, b##iter3, a0); \
176 } \
177 }
178
179#if EIGEN_COMP_LLVM
180#define GEMV_LOAD_COL_MMA(N) \
181 if (GEMV_GETN(N) > 1) { \
182 GEMV_UNROLL_HALF(GEMV_LOAD1B_COL_MMA, (N >> 1)) \
183 } else { \
184 GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N) \
185 }
186
187#define GEMV_WORK_COL_MMA(N) \
188 if (GEMV_GETN(N) > 1) { \
189 GEMV_UNROLL_HALF(GEMV_WORK1B_COL_MMA, (N >> 1)) \
190 } else { \
191 GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N) \
192 }
193#else
194#define GEMV_LOAD_COL_MMA(N) \
195 GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N)
196
197#define GEMV_WORK_COL_MMA(N) \
198 GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N)
199#endif
200
201#define GEMV_DISASSEMBLE_MMA(iter, N) \
202 if (GEMV_GETN(N) > iter) { \
203 __builtin_mma_disassemble_acc(&result##iter.packet, &e##iter); \
204 if (!GEMV_IS_FLOAT) { \
205 result##iter.packet[0][1] = result##iter.packet[1][0]; \
206 result##iter.packet[2][1] = result##iter.packet[3][0]; \
207 } \
208 }
209
210#define GEMV_LOADPAIR2_COL_MMA(iter1, iter2) \
211 b##iter1 = *reinterpret_cast<__vector_pair *>(res + i + ((iter2) * ResPacketSize));
212
213#define GEMV_LOAD2_COL_MMA(iter1, iter2, iter3, N) \
214 if (GEMV_GETN(N) > iter1) { \
215 if (GEMV_IS_FLOAT) { \
216 GEMV_LOADPAIR2_COL_MMA(iter2, iter2); \
217 EIGEN_UNUSED_VARIABLE(b##iter3); \
218 } else { \
219 GEMV_LOADPAIR2_COL_MMA(iter2, iter2 << 1); \
220 GEMV_LOADPAIR2_COL_MMA(iter3, iter3 << 1); \
221 } \
222 } else { \
223 EIGEN_UNUSED_VARIABLE(b##iter2); \
224 EIGEN_UNUSED_VARIABLE(b##iter3); \
225 }
226
227#if EIGEN_COMP_LLVM
228#define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
229 ResPacket f##iter2[2]; \
230 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(f##iter2), &b##iter2); \
231 f##iter2[0] = pmadd(result##iter2.packet[0], palpha, f##iter2[0]); \
232 f##iter2[1] = pmadd(result##iter3.packet[(iter2 == iter3) ? 2 : 0], palpha, f##iter2[1]); \
233 GEMV_BUILDPAIR_MMA(b##iter2, f##iter2[0], f##iter2[1]);
234#else
235#define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
236 if (GEMV_IS_FLOAT) { \
237 __asm__ ("xvmaddasp %0,%x1,%x3\n\txvmaddasp %L0,%x2,%x3" : "+&d" (b##iter2) : "wa" (result##iter3.packet[0]), "wa" (result##iter2.packet[0]), "wa" (palpha)); \
238 } else { \
239 __asm__ ("xvmaddadp %0,%x1,%x3\n\txvmaddadp %L0,%x2,%x3" : "+&d" (b##iter2) : "wa" (result##iter2.packet[2]), "wa" (result##iter2.packet[0]), "wa" (palpha)); \
240 }
241#endif
242
243#define GEMV_WORK2_COL_MMA(iter1, iter2, iter3, N) \
244 if (GEMV_GETN(N) > iter1) { \
245 if (GEMV_IS_FLOAT) { \
246 GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter2); \
247 } else { \
248 GEMV_WORKPAIR2_COL_MMA(iter2, iter2, iter2 << 1); \
249 GEMV_WORKPAIR2_COL_MMA(iter3, iter3, iter3 << 1); \
250 } \
251 }
252
253#define GEMV_STOREPAIR2_COL_MMA(iter1, iter2) \
254 *reinterpret_cast<__vector_pair *>(res + i + ((iter2) * ResPacketSize)) = b##iter1;
255
256#define GEMV_STORE_COL_MMA(iter, N) \
257 if (GEMV_GETN(N) > iter) { \
258 if (GEMV_IS_FLOAT) { \
259 storeMaddData<ResPacket, ResScalar>(res + i + (iter * ResPacketSize), palpha, result##iter.packet[0]); \
260 } else { \
261 GEMV_LOADPAIR2_COL_MMA(iter, iter << 1) \
262 GEMV_WORKPAIR2_COL_MMA(iter, iter, iter << 1) \
263 GEMV_STOREPAIR2_COL_MMA(iter, iter << 1) \
264 } \
265 }
266
267#define GEMV_STORE2_COL_MMA(iter1, iter2, iter3, N) \
268 if (GEMV_GETN(N) > iter1) { \
269 if (GEMV_IS_FLOAT) { \
270 GEMV_STOREPAIR2_COL_MMA(iter2, iter2); \
271 } else { \
272 GEMV_STOREPAIR2_COL_MMA(iter2, iter2 << 1) \
273 GEMV_STOREPAIR2_COL_MMA(iter3, iter3 << 1) \
274 } \
275 }
276
277#define GEMV_PROCESS_COL_ONE_MMA(N) \
278 GEMV_UNROLL(GEMV_INIT_MMA, N) \
279 Index j = j2; \
280 __vector_pair b0, b1, b2, b3, b4, b5, b6, b7; \
281 do { \
282 LhsPacket g0, g1, g2, g3, g4, g5, g6, g7; \
283 RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
284 GEMV_UNROLL(GEMV_PREFETCH, N) \
285 GEMV_LOAD_COL_MMA(N) \
286 GEMV_WORK_COL_MMA(N) \
287 } while (++j < jend); \
288 GEMV_UNROLL(GEMV_DISASSEMBLE_MMA, N) \
289 if (GEMV_GETN(N) <= 1) { \
290 GEMV_UNROLL(GEMV_STORE_COL_MMA, N) \
291 } else { \
292 GEMV_UNROLL_HALF(GEMV_LOAD2_COL_MMA, (N >> 1)) \
293 GEMV_UNROLL_HALF(GEMV_WORK2_COL_MMA, (N >> 1)) \
294 GEMV_UNROLL_HALF(GEMV_STORE2_COL_MMA, (N >> 1)) \
295 } \
296 i += (ResPacketSize * N);
297#endif
298
299#define GEMV_INIT(iter, N) \
300 if (N > iter) { \
301 c##iter = pset1<ResPacket>(ResScalar(0)); \
302 } else { \
303 EIGEN_UNUSED_VARIABLE(c##iter); \
304 }
305
306#ifdef EIGEN_POWER_USE_GEMV_PREFETCH
307#define GEMV_PREFETCH(iter, N) \
308 if (GEMV_GETN(N) > ((iter >> 1) + ((N >> 1) * (iter & 1)))) { \
309 lhs.prefetch(i + (iter * LhsPacketSize) + prefetch_dist, j); \
310 }
311#else
312#define GEMV_PREFETCH(iter, N)
313#endif
314
315#define GEMV_WORK_COL(iter, N) \
316 if (N > iter) { \
317 c##iter = pcj.pmadd(GEMV_LOADPACKET_COL(iter), a0, c##iter); \
318 }
319
320#define GEMV_STORE_COL(iter, N) \
321 if (N > iter) { \
322 pstoreu(res + i + (iter * ResPacketSize), pmadd(c##iter, palpha, ploadu<ResPacket>(res + i + (iter * ResPacketSize)))); \
323 }
324
326#define GEMV_PROCESS_COL_ONE(N) \
327 GEMV_UNROLL(GEMV_INIT, N) \
328 Index j = j2; \
329 do { \
330 RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
331 GEMV_UNROLL(GEMV_PREFETCH, N) \
332 GEMV_UNROLL(GEMV_WORK_COL, N) \
333 } while (++j < jend); \
334 GEMV_UNROLL(GEMV_STORE_COL, N) \
335 i += (ResPacketSize * N);
336
337#ifdef USE_GEMV_MMA
338#define GEMV_PROCESS_COL(N) \
339 GEMV_PROCESS_COL_ONE_MMA(N)
340#else
341#define GEMV_PROCESS_COL(N) \
342 GEMV_PROCESS_COL_ONE(N)
343#endif
344
346#ifdef USE_GEMV_MMA
347template<typename LhsPacket, typename RhsPacket, bool accumulate>
348EIGEN_ALWAYS_INLINE void pger_vecMMA_acc(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b)
349{
350 if (accumulate)
351 {
352 __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
353 }
354 else
355 {
356 __builtin_mma_xvf32ger(acc, (__vector unsigned char)a, (__vector unsigned char)b);
357 }
358}
359
361template<typename LhsPacket, typename RhsPacket, bool accumulate>
362EIGEN_ALWAYS_INLINE void pger_vecMMA_acc(__vector_quad* acc, __vector_pair& a, const LhsPacket& b)
363{
364 if (accumulate)
365 {
366 __builtin_mma_xvf64gerpp(acc, a, (__vector unsigned char)b);
367 }
368 else
369 {
370 __builtin_mma_xvf64ger(acc, a, (__vector unsigned char)b);
371 }
372}
373#endif
374
375template<typename LhsScalar, typename LhsMapper, typename RhsScalar, typename RhsMapper, typename ResScalar>
376EIGEN_STRONG_INLINE void gemv_col(
377 Index rows, Index cols,
378 const LhsMapper& alhs,
379 const RhsMapper& rhs,
380 ResScalar* res, Index resIncr,
381 ResScalar alpha)
382{
383 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
384
385 typedef typename Traits::LhsPacket LhsPacket;
386 typedef typename Traits::RhsPacket RhsPacket;
387 typedef typename Traits::ResPacket ResPacket;
388
389 EIGEN_UNUSED_VARIABLE(resIncr);
390 eigen_internal_assert(resIncr == 1);
391
392 // The following copy tells the compiler that lhs's attributes are not modified outside this function
393 // This helps GCC to generate proper code.
394 LhsMapper lhs(alhs);
395 RhsMapper rhs2(rhs);
396
397 conj_helper<LhsScalar, RhsScalar, false, false> cj;
398 conj_helper<LhsPacket, RhsPacket, false, false> pcj;
399
400 const Index lhsStride = lhs.stride();
401 // TODO: for padded aligned inputs, we could enable aligned reads
402 enum {
403 LhsAlignment = Unaligned,
404 ResPacketSize = Traits::ResPacketSize,
405 LhsPacketSize = Traits::LhsPacketSize,
406 RhsPacketSize = Traits::RhsPacketSize,
407 };
408
409#ifndef GCC_ONE_VECTORPAIR_BUG
410 const Index n8 = rows - 8 * ResPacketSize + 1;
411 const Index n4 = rows - 4 * ResPacketSize + 1;
412 const Index n2 = rows - 2 * ResPacketSize + 1;
413#endif
414 const Index n1 = rows - 1 * ResPacketSize + 1;
415#ifdef EIGEN_POWER_USE_GEMV_PREFETCH
416 const Index prefetch_dist = 64 * LhsPacketSize;
417#endif
418
419 // TODO: improve the following heuristic:
420 const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(LhsScalar) < 16000 ? 16 : 8);
421 ResPacket palpha = pset1<ResPacket>(alpha);
422
423 for (Index j2 = 0; j2 < cols; j2 += block_cols)
424 {
425 Index jend = numext::mini(j2 + block_cols, cols);
426 Index i = 0;
427 ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
428#ifdef USE_GEMV_MMA
429 __vector_quad e0, e1, e2, e3, e4, e5, e6, e7;
430 PacketBlock<ResPacket, 4> result0, result1, result2, result3, result4, result5, result6, result7;
431 GEMV_UNUSED(8, e)
432 GEMV_UNUSED(8, result)
433 GEMV_UNUSED_EXTRA(1, c)
434#endif
435#ifndef GCC_ONE_VECTORPAIR_BUG
436 while (i < n8)
437 {
438 GEMV_PROCESS_COL(8)
439 }
440 if (i < n4)
441 {
442 GEMV_PROCESS_COL(4)
443 }
444 if (i < n2)
445 {
446 GEMV_PROCESS_COL(2)
447 }
448 if (i < n1)
449#else
450 while (i < n1)
451#endif
452 {
453 GEMV_PROCESS_COL_ONE(1)
454 }
455 for (;i < rows;++i)
456 {
457 ResScalar d0(0);
458 Index j = j2;
459 do {
460 d0 += cj.pmul(lhs(i, j), rhs2(j, 0));
461 } while (++j < jend);
462 res[i] += alpha * d0;
463 }
464 }
465}
466
467const Packet16uc p16uc_COMPLEX32_XORFLIP = { 0x44,0x55,0x66,0x77, 0x00,0x11,0x22,0x33, 0xcc,0xdd,0xee,0xff, 0x88,0x99,0xaa,0xbb };
468const Packet16uc p16uc_COMPLEX64_XORFLIP = { 0x88,0x99,0xaa,0xbb, 0xcc,0xdd,0xee,0xff, 0x00,0x11,0x22,0x33, 0x44,0x55,0x66,0x77 };
469
470#ifdef _BIG_ENDIAN
471const Packet16uc p16uc_COMPLEX32_CONJ_XOR = { 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00 };
472const Packet16uc p16uc_COMPLEX64_CONJ_XOR = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
473const Packet16uc p16uc_COMPLEX32_CONJ_XOR2 = { 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
474const Packet16uc p16uc_COMPLEX64_CONJ_XOR2 = { 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
475const Packet16uc p16uc_COMPLEX32_NEGATE = { 0x80,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x80,0x00,0x00,0x00 };
476const Packet16uc p16uc_COMPLEX64_NEGATE = { 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
477#else
478const Packet16uc p16uc_COMPLEX32_CONJ_XOR = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80 };
479const Packet16uc p16uc_COMPLEX64_CONJ_XOR = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80 };
480const Packet16uc p16uc_COMPLEX32_CONJ_XOR2 = { 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00 };
481const Packet16uc p16uc_COMPLEX64_CONJ_XOR2 = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
482const Packet16uc p16uc_COMPLEX32_NEGATE = { 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x80 };
483const Packet16uc p16uc_COMPLEX64_NEGATE = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80 };
484#endif
485
486#ifdef _BIG_ENDIAN
487#define COMPLEX_DELTA 0
488#else
489#define COMPLEX_DELTA 2
490#endif
491
493EIGEN_ALWAYS_INLINE Packet2cf pconj2(const Packet2cf& a) {
494 return Packet2cf(pxor(a.v, reinterpret_cast<Packet4f>(p16uc_COMPLEX32_CONJ_XOR)));
495}
496
497EIGEN_ALWAYS_INLINE Packet1cd pconj2(const Packet1cd& a) {
498 return Packet1cd(pxor(a.v, reinterpret_cast<Packet2d>(p16uc_COMPLEX64_CONJ_XOR)));
499}
500
502EIGEN_ALWAYS_INLINE Packet2cf pconjinv(const Packet2cf& a) {
503#ifdef __POWER8_VECTOR__
504 return Packet2cf(Packet4f(vec_neg(Packet2d(a.v))));
505#else
506 return Packet2cf(pxor(a.v, reinterpret_cast<Packet4f>(p16uc_COMPLEX32_CONJ_XOR2)));
507#endif
508}
509
510EIGEN_ALWAYS_INLINE Packet1cd pconjinv(const Packet1cd& a) {
511 return Packet1cd(pxor(a.v, reinterpret_cast<Packet2d>(p16uc_COMPLEX64_CONJ_XOR2)));
512}
513
514#if defined(_ARCH_PWR8) && (!EIGEN_COMP_LLVM || __clang_major__ >= 12)
515#define PERMXOR_GOOD // Clang had a bug with vec_permxor and endianness prior to version 12
516#endif
517
519EIGEN_ALWAYS_INLINE Packet2cf pcplxflipconj(Packet2cf a)
520{
521#ifdef PERMXOR_GOOD
522 return Packet2cf(Packet4f(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX32_CONJ_XOR, p16uc_COMPLEX32_XORFLIP)));
523#else
524 return pcplxflip(pconj2(a));
525#endif
526}
527
528EIGEN_ALWAYS_INLINE Packet1cd pcplxflipconj(Packet1cd a)
529{
530#ifdef PERMXOR_GOOD
531 return Packet1cd(Packet2d(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX64_CONJ_XOR, p16uc_COMPLEX64_XORFLIP)));
532#else
533 return pcplxflip(pconj2(a));
534#endif
535}
536
538EIGEN_ALWAYS_INLINE Packet2cf pcplxconjflip(Packet2cf a)
539{
540#ifdef PERMXOR_GOOD
541 return Packet2cf(Packet4f(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX32_CONJ_XOR2, p16uc_COMPLEX32_XORFLIP)));
542#else
543 return pconj2(pcplxflip(a));
544#endif
545}
546
547EIGEN_ALWAYS_INLINE Packet1cd pcplxconjflip(Packet1cd a)
548{
549#ifdef PERMXOR_GOOD
550 return Packet1cd(Packet2d(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX64_CONJ_XOR2, p16uc_COMPLEX64_XORFLIP)));
551#else
552 return pconj2(pcplxflip(a));
553#endif
554}
555
557EIGEN_ALWAYS_INLINE Packet2cf pnegate2(Packet2cf a)
558{
559#ifdef __POWER8_VECTOR__
560 return Packet2cf(vec_neg(a.v));
561#else
562 return Packet2cf(pxor(a.v, reinterpret_cast<Packet4f>(p16uc_COMPLEX32_NEGATE)));
563#endif
564}
565
566EIGEN_ALWAYS_INLINE Packet1cd pnegate2(Packet1cd a)
567{
568#ifdef __POWER8_VECTOR__
569 return Packet1cd(vec_neg(a.v));
570#else
571 return Packet1cd(pxor(a.v, reinterpret_cast<Packet2d>(p16uc_COMPLEX64_NEGATE)));
572#endif
573}
574
576EIGEN_ALWAYS_INLINE Packet2cf pcplxflipnegate(Packet2cf a)
577{
578#ifdef PERMXOR_GOOD
579 return Packet2cf(Packet4f(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX32_NEGATE, p16uc_COMPLEX32_XORFLIP)));
580#else
581 return pcplxflip(pnegate2(a));
582#endif
583}
584
585EIGEN_ALWAYS_INLINE Packet1cd pcplxflipnegate(Packet1cd a)
586{
587#ifdef PERMXOR_GOOD
588 return Packet1cd(Packet2d(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX64_NEGATE, p16uc_COMPLEX64_XORFLIP)));
589#else
590 return pcplxflip(pnegate2(a));
591#endif
592}
593
595EIGEN_ALWAYS_INLINE Packet2cf pcplxflip2(Packet2cf a)
596{
597 return Packet2cf(Packet4f(vec_perm(Packet16uc(a.v), Packet16uc(a.v), p16uc_COMPLEX32_XORFLIP)));
598}
599
600EIGEN_ALWAYS_INLINE Packet1cd pcplxflip2(Packet1cd a)
601{
602#ifdef EIGEN_VECTORIZE_VSX
603 return Packet1cd(__builtin_vsx_xxpermdi(a.v, a.v, 2));
604#else
605 return Packet1cd(Packet2d(vec_perm(Packet16uc(a.v), Packet16uc(a.v), p16uc_COMPLEX64_XORFLIP)));
606#endif
607}
608
610EIGEN_ALWAYS_INLINE Packet4f pload_complex_half(std::complex<float>* src)
611{
612 Packet4f t;
613#ifdef EIGEN_VECTORIZE_VSX
614 // Load float64/two float32 (doubleword alignment)
615 __asm__("lxsdx %x0,%y1" : "=wa" (t) : "Z" (*src));
616#else
617 *reinterpret_cast<std::complex<float>*>(reinterpret_cast<float*>(&t) + COMPLEX_DELTA) = *src;
618#endif
619 return t;
620}
621
623template<typename RhsScalar>
624EIGEN_ALWAYS_INLINE void pload_realimag(RhsScalar* src, Packet4f& r, Packet4f& i)
625{
626#ifdef _ARCH_PWR9
627 __asm__("lxvwsx %x0,%y1" : "=wa" (r) : "Z" (*(reinterpret_cast<float*>(src) + 0)));
628 __asm__("lxvwsx %x0,%y1" : "=wa" (i) : "Z" (*(reinterpret_cast<float*>(src) + 1)));
629#else
630 Packet4f t = pload_complex_half(src);
631 r = vec_splat(t, COMPLEX_DELTA + 0);
632 i = vec_splat(t, COMPLEX_DELTA + 1);
633#endif
634}
635
636template<typename RhsScalar>
637EIGEN_ALWAYS_INLINE void pload_realimag(RhsScalar* src, Packet2d& r, Packet2d& i)
638{
639#ifdef EIGEN_VECTORIZE_VSX
640 __asm__("lxvdsx %x0,%y1" : "=wa" (r) : "Z" (*(reinterpret_cast<double*>(src) + 0)));
641 __asm__("lxvdsx %x0,%y1" : "=wa" (i) : "Z" (*(reinterpret_cast<double*>(src) + 1)));
642#else
643 Packet2d t = ploadu<Packet2d>(reinterpret_cast<double*>(src));
644 r = vec_splat(t, 0);
645 i = vec_splat(t, 1);
646#endif
647}
648
649#ifndef __POWER8_VECTOR__
650const Packet16uc p16uc_MERGEE = { 0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13, 0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B };
651
652const Packet16uc p16uc_MERGEO = { 0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17, 0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F };
653#endif
654
656template<typename RhsScalar>
657EIGEN_ALWAYS_INLINE void pload_realimag_row(RhsScalar* src, Packet4f& r, Packet4f& i)
658{
659 Packet4f t = ploadu<Packet4f>(reinterpret_cast<float*>(src));
660#ifdef __POWER8_VECTOR__
661 r = vec_mergee(t, t);
662 i = vec_mergeo(t, t);
663#else
664 r = vec_perm(t, t, p16uc_MERGEE);
665 i = vec_perm(t, t, p16uc_MERGEO);
666#endif
667}
668
669template<typename RhsScalar>
670EIGEN_ALWAYS_INLINE void pload_realimag_row(RhsScalar* src, Packet2d& r, Packet2d& i)
671{
672 return pload_realimag(src, r, i);
673}
674
676EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine(std::complex<float>* src)
677{
678#ifdef EIGEN_VECTORIZE_VSX
679 Packet4f ret;
680 __asm__("lxvdsx %x0,%y1" : "=wa" (ret) : "Z" (*(reinterpret_cast<double*>(src) + 0)));
681 return ret;
682#else
683 return Packet4f(ploaddup<Packet2d>(reinterpret_cast<double *>(src)));
684#endif
685}
686
687EIGEN_ALWAYS_INLINE Packet2d pload_realimag_combine(std::complex<double>* src)
688{
689 return ploadu<Packet1cd>(src).v;
690}
691
693EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine_row(std::complex<float>* src)
694{
695 return ploadu<Packet2cf>(src).v;
696}
697
698EIGEN_ALWAYS_INLINE Packet2d pload_realimag_combine_row(std::complex<double>* src)
699{
700 return ploadu<Packet1cd>(src).v;
701}
702
704template<typename ResPacket>
705EIGEN_ALWAYS_INLINE Packet4f pload_complex(std::complex<float>* src)
706{
707 if (GEMV_IS_SCALAR) {
708 return pload_complex_half(src);
709 }
710 else
711 {
712 return ploadu<Packet4f>(reinterpret_cast<float*>(src));
713 }
714}
715
716template<typename ResPacket>
717EIGEN_ALWAYS_INLINE Packet2d pload_complex(std::complex<double>* src)
718{
719 return ploadu<Packet2d>(reinterpret_cast<double*>(src));
720}
721
723template<typename ResPacket>
724EIGEN_ALWAYS_INLINE Packet4f pload_complex(Packet2cf* src)
725{
726 return src->v;
727}
728
729template<typename ResPacket>
730EIGEN_ALWAYS_INLINE Packet2d pload_complex(Packet1cd* src)
731{
732 return src->v;
733}
734
736EIGEN_ALWAYS_INLINE Packet4f pload_complex_full(std::complex<float>* src)
737{
738 return Packet4f(ploaddup<Packet2d>(reinterpret_cast<double *>(src)));
739}
740
741EIGEN_ALWAYS_INLINE Packet2d pload_complex_full(std::complex<double>* src)
742{
743 return ploadu<Packet1cd>(src).v;
744}
745
747EIGEN_ALWAYS_INLINE Packet4f pload_complex_full_row(std::complex<float>* src)
748{
749 return ploadu<Packet2cf>(src).v;
750}
751
752EIGEN_ALWAYS_INLINE Packet2d pload_complex_full_row(std::complex<double>* src)
753{
754 return pload_complex_full(src);
755}
756
758EIGEN_ALWAYS_INLINE Packet4f pload_real(float* src)
759{
760 return pset1<Packet4f>(*src);
761}
762
763EIGEN_ALWAYS_INLINE Packet2d pload_real(double* src)
764{
765 return pset1<Packet2d>(*src);
766}
767
768EIGEN_ALWAYS_INLINE Packet4f pload_real(Packet4f& src)
769{
770 return src;
771}
772
773EIGEN_ALWAYS_INLINE Packet2d pload_real(Packet2d& src)
774{
775 return src;
776}
777
779EIGEN_ALWAYS_INLINE Packet4f pload_real_full(float* src)
780{
781 Packet4f ret = ploadu<Packet4f>(src);
782 return vec_mergeh(ret, ret);
783}
784
785EIGEN_ALWAYS_INLINE Packet2d pload_real_full(double* src)
786{
787 return pload_real(src);
788}
789
790EIGEN_ALWAYS_INLINE Packet4f pload_real_full(std::complex<float>* src)
791{
792 return pload_complex_full(src); // Just for compilation
793}
794
795EIGEN_ALWAYS_INLINE Packet2d pload_real_full(std::complex<double>* src)
796{
797 return pload_complex_full(src); // Just for compilation
798}
799
801template<typename ResPacket>
802EIGEN_ALWAYS_INLINE Packet4f pload_real_row(float* src)
803{
804 if (GEMV_IS_SCALAR) {
805 return pload_real_full(src);
806 }
807 else {
808 return ploadu<Packet4f>(src);
809 }
810}
811
812template<typename ResPacket>
813EIGEN_ALWAYS_INLINE Packet2d pload_real_row(double* src)
814{
815 return pload_real(src);
816}
817
818EIGEN_ALWAYS_INLINE Packet2cf padd(Packet2cf& a, std::complex<float>& b)
819{
820 EIGEN_UNUSED_VARIABLE(b);
821 return a; // Just for compilation
822}
823
824EIGEN_ALWAYS_INLINE Packet1cd padd(Packet1cd& a, std::complex<double>& b)
825{
826 EIGEN_UNUSED_VARIABLE(b);
827 return a; // Just for compilation
828}
829
831template<typename Scalar, typename ResScalar>
832EIGEN_ALWAYS_INLINE Scalar pset1_realimag(ResScalar& alpha, int which, int conj)
833{
834 return (which) ? ((conj) ? -alpha.real() : alpha.real()) : ((conj) ? -alpha.imag() : alpha.imag());
835}
836
838template<typename Scalar, typename ResScalar, typename ResPacket, int which>
839EIGEN_ALWAYS_INLINE Packet2cf pset1_complex(std::complex<float>& alpha)
840{
841 Packet2cf ret;
842 ret.v[COMPLEX_DELTA + 0] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x01), (which & 0x04));
843 ret.v[COMPLEX_DELTA + 1] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x02), (which & 0x08));
844 ret.v[2 - COMPLEX_DELTA] = ret.v[COMPLEX_DELTA + 0];
845 ret.v[3 - COMPLEX_DELTA] = ret.v[COMPLEX_DELTA + 1];
846 return ret;
847}
848
849template<typename Scalar, typename ResScalar, typename ResPacket, int which>
850EIGEN_ALWAYS_INLINE Packet1cd pset1_complex(std::complex<double>& alpha)
851{
852 Packet1cd ret;
853 ret.v[0] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x01), (which & 0x04));
854 ret.v[1] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x02), (which & 0x08));
855 return ret;
856}
857
859template<typename Packet>
860EIGEN_ALWAYS_INLINE Packet pset_zero()
861{
862 return pset1<Packet>(__UNPACK_TYPE__(Packet)(0));
863}
864
865template<>
866EIGEN_ALWAYS_INLINE Packet2cf pset_zero<Packet2cf>()
867{
868 return Packet2cf(pset1<Packet4f>(float(0)));
869}
870
871template<>
872EIGEN_ALWAYS_INLINE Packet1cd pset_zero<Packet1cd>()
873{
874 return Packet1cd(pset1<Packet2d>(double(0)));
875}
876
878template<typename Packet, typename LhsPacket, typename RhsPacket>
879EIGEN_ALWAYS_INLINE Packet pset_init(Packet& c1)
880{
881 if (GEMV_IS_COMPLEX_COMPLEX) {
882 EIGEN_UNUSED_VARIABLE(c1);
883 return pset_zero<Packet>();
884 }
885 else
886 {
887 return c1; // Intentionally left uninitialized
888 }
889}
890
891template<typename PResPacket, typename ResPacket, typename ResScalar, typename Scalar>
893{
895 separate.r = pset1_complex<Scalar, ResScalar, ResPacket, 0x3>(alpha);
896 separate.i = pset1_complex<Scalar, ResScalar, ResPacket, 0x0>(alpha);
897 }
898 struct ri {
899 PResPacket r;
900 PResPacket i;
901 } separate;
902};
903
905template<typename ScalarPacket, typename AlphaData>
906EIGEN_ALWAYS_INLINE ScalarPacket pmadd_complex(ScalarPacket& c0, ScalarPacket& c2, ScalarPacket& c4, AlphaData& b0)
907{
908 return pmadd(c2, b0.separate.i.v, pmadd(c0, b0.separate.r.v, c4));
909}
910
912template<typename Scalar, typename ScalarPacket, typename PResPacket, typename ResPacket, typename ResScalar, typename AlphaData>
913EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, AlphaData& b0, ResScalar* res)
914{
915 PResPacket c2 = pcplxflipconj(c0);
916 if (GEMV_IS_SCALAR) {
917 ScalarPacket c4 = ploadu<ScalarPacket>(reinterpret_cast<Scalar*>(res));
918 ScalarPacket c3 = pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0);
919 pstoreu(reinterpret_cast<Scalar*>(res), c3);
920 } else {
921 ScalarPacket c4 = pload_complex<ResPacket>(res);
922 PResPacket c3 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
923 pstoreu(res, c3);
924 }
925}
926
927template<typename ScalarPacket, typename PResPacket, typename ResPacket, typename ResScalar, typename AlphaData, Index ResPacketSize, Index iter2>
928EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, AlphaData& b0, ResScalar* res)
929{
930 PResPacket c2 = pcplxflipconj(c0);
931 PResPacket c3 = pcplxflipconj(c1);
932#if !defined(_ARCH_PWR10)
933 ScalarPacket c4 = pload_complex<ResPacket>(res + (iter2 * ResPacketSize));
934 ScalarPacket c5 = pload_complex<ResPacket>(res + ((iter2 + 1) * ResPacketSize));
935 PResPacket c6 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
936 PResPacket c7 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c5, b0));
937 pstoreu(res + (iter2 * ResPacketSize), c6);
938 pstoreu(res + ((iter2 + 1) * ResPacketSize), c7);
939#else
940 __vector_pair a = *reinterpret_cast<__vector_pair *>(res + (iter2 * ResPacketSize));
941#if EIGEN_COMP_LLVM
942 PResPacket c6[2];
943 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(c6), &a);
944 c6[0] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c6[0].v, b0));
945 c6[1] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c6[1].v, b0));
946 GEMV_BUILDPAIR_MMA(a, c6[0].v, c6[1].v);
947#else
948 if (GEMV_IS_COMPLEX_FLOAT) {
949 __asm__ ("xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" : "+&d" (a) : "wa" (b0.separate.r.v), "wa" (c0.v), "wa" (c1.v));
950 __asm__ ("xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" : "+&d" (a) : "wa" (b0.separate.i.v), "wa" (c2.v), "wa" (c3.v));
951 } else {
952 __asm__ ("xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" : "+&d" (a) : "wa" (b0.separate.r.v), "wa" (c0.v), "wa" (c1.v));
953 __asm__ ("xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" : "+&d" (a) : "wa" (b0.separate.i.v), "wa" (c2.v), "wa" (c3.v));
954 }
955#endif
956 *reinterpret_cast<__vector_pair *>(res + (iter2 * ResPacketSize)) = a;
957#endif
958}
959
961template<typename Scalar, typename LhsScalar, typename LhsMapper, typename LhsPacket>
962EIGEN_ALWAYS_INLINE LhsPacket loadLhsPacket(LhsMapper& lhs, Index i, Index j)
963{
964 if (sizeof(Scalar) == sizeof(LhsScalar)) {
965 const LhsScalar& src = lhs(i + 0, j);
966 return LhsPacket(pload_real_full(const_cast<LhsScalar*>(&src)));
967 }
968 return lhs.template load<LhsPacket, Unaligned>(i + 0, j);
969}
970
972template<typename ComplexPacket, typename RealPacket, bool ConjugateLhs, bool ConjugateRhs, bool Negate>
973EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_complex(RealPacket& a, RealPacket& b, RealPacket& c)
974{
975 if (ConjugateLhs && ConjugateRhs) {
976 return vec_madd(a, pconj2(ComplexPacket(b)).v, c);
977 }
978 else if (Negate && !ConjugateLhs && ConjugateRhs) {
979 return vec_nmsub(a, b, c);
980 }
981 else {
982 return vec_madd(a, b, c);
983 }
984}
985
987template<typename ComplexPacket, typename RealPacket, bool Conjugate>
988EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_real(RealPacket& a, RealPacket& b, RealPacket& c)
989{
990 if (Conjugate) {
991 return vec_madd(a, pconj2(ComplexPacket(b)).v, c);
992 }
993 else {
994 return vec_madd(a, b, c);
995 }
996}
997
998template<typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
999EIGEN_ALWAYS_INLINE void gemv_mult_generic(LhsPacket& a0, RhsScalar* b, PResPacket& c0)
1000{
1001 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
1002 RhsPacket b0;
1003 if (StorageOrder == ColMajor) {
1004 b0 = pset1<RhsPacket>(*b);
1005 }
1006 else {
1007 b0 = ploadu<RhsPacket>(b);
1008 }
1009 c0 = pcj.pmadd(a0, b0, c0);
1010}
1011
1013template<typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1014EIGEN_ALWAYS_INLINE void gemv_mult_complex_complex(LhsPacket& a0, RhsScalar* b, PResPacket& c0, ResPacket& c1)
1015{
1016 ScalarPacket br, bi;
1017 if (StorageOrder == ColMajor) {
1018 pload_realimag<RhsScalar>(b, br, bi);
1019 }
1020 else {
1021 pload_realimag_row<RhsScalar>(b, br, bi);
1022 }
1023 if (ConjugateLhs && !ConjugateRhs) a0 = pconj2(a0);
1024 LhsPacket a1 = pcplxflipconj(a0);
1025 ScalarPacket cr = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, false>(a0.v, br, c0.v);
1026 ScalarPacket ci = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, true>(a1.v, bi, c1.v);
1027 c1 = ResPacket(ci);
1028 c0 = PResPacket(cr);
1029}
1030
1032template<typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1033EIGEN_ALWAYS_INLINE void gemv_mult_real_complex(LhsPacket& a0, RhsScalar* b, PResPacket& c0)
1034{
1035 ScalarPacket b0;
1036 if (StorageOrder == ColMajor) {
1037 b0 = pload_complex_full(b);
1038 }
1039 else {
1040 b0 = pload_complex_full_row(b);
1041 }
1042 ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateRhs>(a0, b0, c0.v);
1043 c0 = PResPacket(cri);
1044}
1045
1047template<typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1048EIGEN_ALWAYS_INLINE void gemv_mult_complex_real(LhsPacket& a0, RhsScalar* b, PResPacket& c0)
1049{
1050 ScalarPacket a1 = pload_complex<ResPacket>(&a0);
1051 ScalarPacket b0;
1052 if (StorageOrder == ColMajor) {
1053 b0 = pload_real(b);
1054 }
1055 else {
1056 b0 = pload_real_row<ResPacket>(b);
1057 }
1058 ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateLhs>(a1, b0, c0.v);
1059 c0 = PResPacket(cri);
1060}
1061
1062#define GEMV_MULT_COMPLEX_COMPLEX(LhsType, RhsType, ResType) \
1063template<typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1064EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, ResType& c1) \
1065{ \
1066 gemv_mult_complex_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0, c1); \
1067}
1068
1069GEMV_MULT_COMPLEX_COMPLEX(Packet2cf, std::complex<float>, Packet2cf)
1070GEMV_MULT_COMPLEX_COMPLEX(Packet1cd, std::complex<double>, Packet1cd)
1071
1072#define GEMV_MULT_REAL_COMPLEX(LhsType, RhsType, ResType) \
1073template<typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1074EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, RhsType&) \
1075{ \
1076 gemv_mult_real_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1077}
1078
1079GEMV_MULT_REAL_COMPLEX(float, std::complex<float>, Packet2cf)
1080GEMV_MULT_REAL_COMPLEX(double, std::complex<double>, Packet1cd)
1081GEMV_MULT_REAL_COMPLEX(Packet4f, std::complex<float>, Packet2cf)
1082GEMV_MULT_REAL_COMPLEX(Packet2d, std::complex<double>, Packet1cd)
1083
1084#define GEMV_MULT_COMPLEX_REAL(LhsType, RhsType, ResType1, ResType2) \
1085template<typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1086EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType1& c0, ResType2&) \
1087{ \
1088 gemv_mult_complex_real<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1089}
1090
1091GEMV_MULT_COMPLEX_REAL(Packet2cf, float, Packet2cf, std::complex<float>)
1092GEMV_MULT_COMPLEX_REAL(Packet1cd, double, Packet1cd, std::complex<double>)
1093GEMV_MULT_COMPLEX_REAL(std::complex<float>, float, Packet2cf, std::complex<float>)
1094GEMV_MULT_COMPLEX_REAL(std::complex<double>, double, Packet1cd, std::complex<double>)
1095
1096#ifdef USE_GEMV_MMA
1098template<typename T>
1099EIGEN_ALWAYS_INLINE T convertReal(T a)
1100{
1101 return a;
1102}
1103
1104EIGEN_ALWAYS_INLINE Packet4f convertReal(Packet2cf a)
1105{
1106 return a.v;
1107}
1108
1109EIGEN_ALWAYS_INLINE Packet2d convertReal(Packet1cd a)
1110{
1111 return a.v;
1112}
1113
1115template<typename T>
1116EIGEN_ALWAYS_INLINE T convertComplex(T a)
1117{
1118 return a;
1119}
1120
1121EIGEN_ALWAYS_INLINE Packet2cf convertComplex(Packet4f a)
1122{
1123 return Packet2cf(a);
1124}
1125
1126EIGEN_ALWAYS_INLINE Packet1cd convertComplex(Packet2d a)
1127{
1128 return Packet1cd(a);
1129}
1130
1132template<typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename ResPacket>
1133EIGEN_ALWAYS_INLINE void pload_complex_MMA(SLhsPacket& a)
1134{
1135 a = SLhsPacket(pload_complex<ResPacket>(&a));
1136}
1137
1138template<typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename ResPacket>
1139EIGEN_ALWAYS_INLINE void pload_complex_MMA(__vector_pair&)
1140{
1141 // Pass thru
1142}
1143
1145template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
1146EIGEN_ALWAYS_INLINE void pger_vecMMA(__vector_quad* acc, RhsPacket& a, LhsPacket& b)
1147{
1148 if (NegativeAccumulate)
1149 {
1150 __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
1151 }
1152 else {
1153 __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
1154 }
1155}
1156
1158template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
1159EIGEN_ALWAYS_INLINE void pger_vecMMA(__vector_quad* acc, __vector_pair& a, Packet2d& b)
1160{
1161 if (NegativeAccumulate)
1162 {
1163 __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
1164 }
1165 else {
1166 __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
1167 }
1168}
1169
1170template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
1171EIGEN_ALWAYS_INLINE void pger_vecMMA(__vector_quad*, __vector_pair&, Packet4f&)
1172{
1173 // Just for compilation
1174}
1175
1177template<typename RealPacket, typename LhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool Negate>
1178EIGEN_ALWAYS_INLINE void pmadd_complex_complex_MMA(LhsPacket& a, RealPacket& b, __vector_quad* c)
1179{
1180 if (ConjugateLhs && ConjugateRhs) {
1181 RealPacket b2 = pconj2(convertComplex(b)).v;
1182 return pger_vecMMA<RealPacket, RealPacket, false>(c, b2, a.v);
1183 }
1184 else if (Negate && !ConjugateLhs && ConjugateRhs) {
1185 return pger_vecMMA<RealPacket, RealPacket, true>(c, b, a.v);
1186 }
1187 else {
1188 return pger_vecMMA<RealPacket, RealPacket, false>(c, b, a.v);
1189 }
1190}
1191
1192template<typename RealPacket, typename LhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool Negate>
1193EIGEN_ALWAYS_INLINE void pmadd_complex_complex_MMA(__vector_pair& a, RealPacket& b, __vector_quad* c)
1194{
1195 if (ConjugateLhs && ConjugateRhs) {
1196 RealPacket b2 = pconj2(convertComplex(b)).v;
1197 return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b2);
1198 }
1199 else if (Negate && !ConjugateLhs && ConjugateRhs) {
1200 return pger_vecMMA<RealPacket, __vector_pair, true>(c, a, b);
1201 }
1202 else {
1203 return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b);
1204 }
1205}
1206
1208template<typename RealPacket, typename LhsPacket, bool Conjugate, int StorageOrder>
1209EIGEN_ALWAYS_INLINE void pmadd_complex_real_MMA(LhsPacket& a, RealPacket& b, __vector_quad* c)
1210{
1211 RealPacket a2 = convertReal(a);
1212 if (Conjugate) {
1213 RealPacket b2 = pconj2(convertComplex(b)).v;
1214 if (StorageOrder == ColMajor) {
1215 return pger_vecMMA<RealPacket, RealPacket, false>(c, b2, a2);
1216 } else {
1217 return pger_vecMMA<RealPacket, RealPacket, false>(c, a2, b2);
1218 }
1219 }
1220 else {
1221 if (StorageOrder == ColMajor) {
1222 return pger_vecMMA<RealPacket, RealPacket, false>(c, b, a2);
1223 } else {
1224 return pger_vecMMA<RealPacket, RealPacket, false>(c, a2, b);
1225 }
1226 }
1227}
1228
1230template<typename RealPacket, typename LhsPacket, bool Conjugate, int StorageOrder>
1231EIGEN_ALWAYS_INLINE void pmadd_complex_real_MMA(__vector_pair& a, RealPacket& b, __vector_quad* c)
1232{
1233 if (Conjugate) {
1234 RealPacket b2 = pconj2(convertComplex(b)).v;
1235 return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b2);
1236 }
1237 else {
1238 return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b);
1239 }
1240}
1241
1243template<typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1244EIGEN_ALWAYS_INLINE void gemv_mult_complex_complex_MMA(SLhsPacket& a0, RhsScalar* b, __vector_quad* c0)
1245{
1246 ScalarPacket b0;
1247 if (StorageOrder == ColMajor) {
1248 b0 = pload_realimag_combine(b);
1249 } else {
1250 b0 = pload_realimag_combine_row(b);
1251 }
1252 pmadd_complex_complex_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ConjugateRhs, false>(a0, b0, c0);
1253}
1254
1256template<typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1257EIGEN_ALWAYS_INLINE void gemv_mult_complex_real_MMA(SLhsPacket& a0, RhsScalar* b, __vector_quad* c0)
1258{
1259 pload_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, ResPacket>(a0);
1260 ScalarPacket b0;
1261 if (StorageOrder == ColMajor) {
1262 b0 = pload_real(b);
1263 }
1264 else {
1265 b0 = pload_real_row<ResPacket>(b);
1266 }
1267 pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ColMajor>(a0, b0, c0);
1268}
1269
1271template<typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1272EIGEN_ALWAYS_INLINE void gemv_mult_real_complex_MMA(SLhsPacket& a0, RhsScalar* b, __vector_quad* c0)
1273{
1274 ScalarPacket b0;
1275 if (StorageOrder == ColMajor) {
1276 b0 = pload_complex_full(b);
1277 }
1278 else {
1279 b0 = pload_complex_full_row(b);
1280 }
1281 pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateRhs, (sizeof(RhsScalar) == sizeof(std::complex<float>)) ? StorageOrder : ColMajor>(a0, b0, c0);
1282}
1283
1284#define GEMV_MULT_COMPLEX_COMPLEX_MMA(LhsType, RhsType) \
1285template<typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1286EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) \
1287{ \
1288 gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1289}
1290
1291GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet2cf, std::complex<float>)
1292GEMV_MULT_COMPLEX_COMPLEX_MMA(__vector_pair, std::complex<float>)
1293GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet1cd, std::complex<double>)
1294
1295
1296template<typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1297EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(__vector_pair& a0, std::complex<double>* b, __vector_quad* c0)
1298{
1299 if (sizeof(LhsScalar) == 16) {
1300 gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0);
1301 }
1302 else {
1303 gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0);
1304 }
1305}
1306
1307#define GEMV_MULT_REAL_COMPLEX_MMA(LhsType, RhsType) \
1308template<typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1309EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) \
1310{ \
1311 gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1312}
1313
1314GEMV_MULT_REAL_COMPLEX_MMA(Packet4f, std::complex<float>)
1315GEMV_MULT_REAL_COMPLEX_MMA(Packet2d, std::complex<double>)
1316
1317#define GEMV_MULT_COMPLEX_REAL_MMA(LhsType, RhsType) \
1318template<typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1319EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) \
1320{ \
1321 gemv_mult_complex_real_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1322}
1323
1324GEMV_MULT_COMPLEX_REAL_MMA(Packet2cf, float)
1325GEMV_MULT_COMPLEX_REAL_MMA(Packet1cd, double)
1326GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair, float)
1327GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair, double)
1328
1329
1330template <typename Scalar, typename ScalarPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
1331EIGEN_ALWAYS_INLINE void disassembleResults2(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0)
1332{
1333 __builtin_mma_disassemble_acc(&result0.packet, c0);
1334 if (sizeof(LhsPacket) == 16) {
1335 if (sizeof(RhsPacket) == 16) {
1336 ScalarPacket tmp0, tmp2;
1337 tmp2 = vec_mergeh(result0.packet[2], result0.packet[3]);
1338 tmp0 = vec_mergeh(result0.packet[0], result0.packet[1]);
1339 result0.packet[3] = vec_mergel(result0.packet[3], result0.packet[2]);
1340 result0.packet[1] = vec_mergel(result0.packet[1], result0.packet[0]);
1341 result0.packet[2] = tmp2;
1342 result0.packet[0] = tmp0;
1343
1344 if (ConjugateLhs) {
1345 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1346 result0.packet[2] = pconj2(convertComplex(result0.packet[2])).v;
1347 } else if (ConjugateRhs) {
1348 result0.packet[1] = pconj2(convertComplex(result0.packet[1])).v;
1349 result0.packet[3] = pconj2(convertComplex(result0.packet[3])).v;
1350 } else {
1351 result0.packet[1] = pconjinv(convertComplex(result0.packet[1])).v;
1352 result0.packet[3] = pconjinv(convertComplex(result0.packet[3])).v;
1353 }
1354 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1355 result0.packet[2] = vec_add(result0.packet[2], result0.packet[3]);
1356 } else {
1357 result0.packet[0][1] = result0.packet[1][1];
1358 result0.packet[2][1] = result0.packet[3][1];
1359 }
1360 }
1361}
1362
1363template <typename Scalar, typename ScalarPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
1364EIGEN_ALWAYS_INLINE void disassembleResults4(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0)
1365{
1366 __builtin_mma_disassemble_acc(&result0.packet, c0);
1367 if (GEMV_IS_COMPLEX_COMPLEX) {
1368 if (ConjugateLhs) {
1369 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1370 result0.packet[1] = pcplxflip2(convertComplex(result0.packet[1])).v;
1371 } else {
1372 if (ConjugateRhs) {
1373 result0.packet[1] = pcplxconjflip(convertComplex(result0.packet[1])).v;
1374 } else {
1375 result0.packet[1] = pcplxflipconj(convertComplex(result0.packet[1])).v;
1376 }
1377 }
1378 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1379 } else if (sizeof(LhsPacket) == sizeof(std::complex<float>)) {
1380 if (ConjugateLhs) {
1381 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1382 }
1383 } else {
1384 result0.packet[0] = vec_mergee(result0.packet[0], result0.packet[1]);
1385 }
1386}
1387
1388template <typename Scalar, typename ScalarPacket, int ResPacketSize, typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
1389EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0)
1390{
1391 if (!GEMV_IS_COMPLEX_FLOAT) {
1392 disassembleResults2<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1393 } else {
1394 disassembleResults4<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1395 }
1396}
1397#endif
1398
1399#define GEMV_GETN_COMPLEX(N) (((N) * ResPacketSize) >> 1)
1400
1401#define GEMV_LOADPACKET_COL_COMPLEX(iter) \
1402 loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + ((iter) * ResPacketSize), j)
1403
1404#define GEMV_LOADPACKET_COL_COMPLEX_DATA(iter) \
1405 convertReal(GEMV_LOADPACKET_COL_COMPLEX(iter))
1406
1407#ifdef USE_GEMV_MMA
1408#define GEMV_INIT_COL_COMPLEX_MMA(iter, N) \
1409 if (GEMV_GETN_COMPLEX(N) > iter) { \
1410 __builtin_mma_xxsetaccz(&e0##iter); \
1411 }
1412
1413#if EIGEN_COMP_LLVM
1414#define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1415 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1)); \
1416 EIGEN_UNUSED_VARIABLE(f##iter1);
1417#else
1418#define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1419 if (sizeof(LhsPacket) == 16) { \
1420 const LhsScalar& src = lhs(i + ((32 * iter1) / sizeof(LhsScalar)), j); \
1421 a##iter1 = *reinterpret_cast<__vector_pair *>(const_cast<LhsScalar *>(&src)); \
1422 EIGEN_UNUSED_VARIABLE(f##iter1); \
1423 } else { \
1424 f##iter1 = lhs.template load<PLhsPacket, Unaligned>(i + ((iter2) * ResPacketSize), j); \
1425 GEMV_BUILDPAIR_MMA(a##iter1, vec_splat(convertReal(f##iter1), 0), vec_splat(convertReal(f##iter1), 1)); \
1426 }
1427#endif
1428
1429#define GEMV_LOAD1_COL_COMPLEX_MMA(iter, N) \
1430 if (GEMV_GETN_COMPLEX(N) > iter) { \
1431 if (GEMV_IS_COMPLEX_FLOAT) { \
1432 f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
1433 EIGEN_UNUSED_VARIABLE(a##iter); \
1434 } else { \
1435 GEMV_LOADPAIR_COL_COMPLEX_MMA(iter, iter << 1) \
1436 } \
1437 } else { \
1438 EIGEN_UNUSED_VARIABLE(a##iter); \
1439 EIGEN_UNUSED_VARIABLE(f##iter); \
1440 }
1441
1442#define GEMV_WORK1_COL_COMPLEX_MMA(iter, N) \
1443 if (GEMV_GETN_COMPLEX(N) > iter) { \
1444 if (GEMV_IS_COMPLEX_FLOAT) { \
1445 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(f##iter, b, &e0##iter); \
1446 } else { \
1447 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(a##iter, b, &e0##iter); \
1448 } \
1449 }
1450
1451#define GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter1, iter2) \
1452 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1));
1453
1454#define GEMV_LOAD2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1455 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1456 if (GEMV_IS_COMPLEX_FLOAT) { \
1457 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2); \
1458 EIGEN_UNUSED_VARIABLE(a##iter3) \
1459 } else { \
1460 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2 << 1); \
1461 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter3, iter3 << 1); \
1462 } \
1463 } else { \
1464 EIGEN_UNUSED_VARIABLE(a##iter2); \
1465 EIGEN_UNUSED_VARIABLE(a##iter3); \
1466 } \
1467 EIGEN_UNUSED_VARIABLE(f##iter2); \
1468 EIGEN_UNUSED_VARIABLE(f##iter3);
1469
1470#define GEMV_WORK2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1471 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1472 if (GEMV_IS_COMPLEX_FLOAT) { \
1473 PLhsPacket g[2]; \
1474 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(g), &a##iter2); \
1475 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(g[0], b, &e0##iter2); \
1476 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(g[1], b, &e0##iter3); \
1477 } else { \
1478 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(a##iter2, b, &e0##iter2); \
1479 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(a##iter3, b, &e0##iter3); \
1480 } \
1481 }
1482
1483#if EIGEN_COMP_LLVM
1484#define GEMV_LOAD_COL_COMPLEX_MMA(N) \
1485 if (GEMV_GETN_COMPLEX(N) > 1) { \
1486 GEMV_UNROLL_HALF(GEMV_LOAD2_COL_COMPLEX_MMA, (N >> 1)) \
1487 } else { \
1488 GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N) \
1489 }
1490
1491#define GEMV_WORK_COL_COMPLEX_MMA(N) \
1492 if (GEMV_GETN_COMPLEX(N) > 1) { \
1493 GEMV_UNROLL_HALF(GEMV_WORK2_COL_COMPLEX_MMA, (N >> 1)) \
1494 } else { \
1495 GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N) \
1496 }
1497#else
1498#define GEMV_LOAD_COL_COMPLEX_MMA(N) \
1499 GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N)
1500
1501#define GEMV_WORK_COL_COMPLEX_MMA(N) \
1502 GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N)
1503#endif
1504
1505#define GEMV_DISASSEMBLE_COMPLEX_MMA(iter) \
1506 disassembleResults<Scalar, ScalarPacket, ResPacketSize, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter, result0##iter);
1507
1508#define GEMV_STORE_COL_COMPLEX_MMA(iter, N) \
1509 if (GEMV_GETN_COMPLEX(N) > iter) { \
1510 GEMV_DISASSEMBLE_COMPLEX_MMA(iter); \
1511 c0##iter = PResPacket(result0##iter.packet[0]); \
1512 if (GEMV_IS_COMPLEX_FLOAT) { \
1513 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>(c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
1514 } else { \
1515 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>(c0##iter, alpha_data, res + i + ((iter << 1) * ResPacketSize)); \
1516 c0##iter = PResPacket(result0##iter.packet[2]); \
1517 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>(c0##iter, alpha_data, res + i + (((iter << 1) + 1) * ResPacketSize)); \
1518 } \
1519 }
1520
1521#define GEMV_STORE2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1522 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1523 GEMV_DISASSEMBLE_COMPLEX_MMA(iter2); \
1524 GEMV_DISASSEMBLE_COMPLEX_MMA(iter3); \
1525 c0##iter2 = PResPacket(result0##iter2.packet[0]); \
1526 if (GEMV_IS_COMPLEX_FLOAT) { \
1527 c0##iter3 = PResPacket(result0##iter3.packet[0]); \
1528 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2>(c0##iter2, c0##iter3, alpha_data, res + i); \
1529 } else { \
1530 c0##iter3 = PResPacket(result0##iter2.packet[2]); \
1531 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2 << 1>(c0##iter2, c0##iter3, alpha_data, res + i); \
1532 c0##iter2 = PResPacket(result0##iter3.packet[0]); \
1533 c0##iter3 = PResPacket(result0##iter3.packet[2]); \
1534 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter3 << 1>(c0##iter2, c0##iter3, alpha_data, res + i); \
1535 } \
1536 }
1537
1538#define GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
1539 GEMV_UNROLL(GEMV_INIT_COL_COMPLEX_MMA, N) \
1540 Index j = j2; \
1541 do { \
1542 const RhsScalar& b1 = rhs2(j, 0); \
1543 RhsScalar* b = const_cast<RhsScalar *>(&b1); \
1544 GEMV_UNROLL(GEMV_PREFETCH, N) \
1545 GEMV_LOAD_COL_COMPLEX_MMA(N) \
1546 GEMV_WORK_COL_COMPLEX_MMA(N) \
1547 } while (++j < jend); \
1548 if (GEMV_GETN(N) <= 2) { \
1549 GEMV_UNROLL(GEMV_STORE_COL_COMPLEX_MMA, N) \
1550 } else { \
1551 GEMV_UNROLL_HALF(GEMV_STORE2_COL_COMPLEX_MMA, (N >> 1)) \
1552 } \
1553 i += (ResPacketSize * N);
1554#endif
1555
1556#define GEMV_INIT_COMPLEX(iter, N) \
1557 if (N > iter) { \
1558 c0##iter = pset_zero<PResPacket>(); \
1559 c1##iter = pset_init<ResPacket, LhsPacket, RhsPacket>(c1##iter); \
1560 } else { \
1561 EIGEN_UNUSED_VARIABLE(c0##iter); \
1562 EIGEN_UNUSED_VARIABLE(c1##iter); \
1563 }
1564
1565#define GEMV_WORK_COL_COMPLEX(iter, N) \
1566 if (N > iter) { \
1567 f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
1568 gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(f##iter, b, c0##iter, c1##iter); \
1569 } else { \
1570 EIGEN_UNUSED_VARIABLE(f##iter); \
1571 }
1572
1573#define GEMV_STORE_COL_COMPLEX(iter, N) \
1574 if (N > iter) { \
1575 if (GEMV_IS_COMPLEX_COMPLEX) { \
1576 c0##iter = padd(c0##iter, c1##iter); \
1577 } \
1578 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>(c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
1579 }
1580
1582#define GEMV_PROCESS_COL_COMPLEX_ONE(N) \
1583 GEMV_UNROLL(GEMV_INIT_COMPLEX, N) \
1584 Index j = j2; \
1585 do { \
1586 const RhsScalar& b1 = rhs2(j, 0); \
1587 RhsScalar* b = const_cast<RhsScalar *>(&b1); \
1588 GEMV_UNROLL(GEMV_PREFETCH, N) \
1589 GEMV_UNROLL(GEMV_WORK_COL_COMPLEX, N) \
1590 } while (++j < jend); \
1591 GEMV_UNROLL(GEMV_STORE_COL_COMPLEX, N) \
1592 i += (ResPacketSize * N);
1593
1594#if defined(USE_GEMV_MMA) && (EIGEN_COMP_LLVM || defined(USE_SLOWER_GEMV_MMA))
1595#define USE_GEMV_COL_COMPLEX_MMA
1596#endif
1597
1598#ifdef USE_GEMV_COL_COMPLEX_MMA
1599#define GEMV_PROCESS_COL_COMPLEX(N) \
1600 GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N)
1601#else
1602#if defined(USE_GEMV_MMA) && (__GNUC__ > 10)
1603#define GEMV_PROCESS_COL_COMPLEX(N) \
1604 if (sizeof(Scalar) != sizeof(LhsPacket)) { \
1605 GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
1606 } else { \
1607 GEMV_PROCESS_COL_COMPLEX_ONE(N) \
1608 }
1609#else
1610#define GEMV_PROCESS_COL_COMPLEX(N) \
1611 GEMV_PROCESS_COL_COMPLEX_ONE(N)
1612#endif
1613#endif
1614
1615template<typename Scalar, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, bool LhsIsReal, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, bool RhsIsReal, typename ResScalar>
1616EIGEN_STRONG_INLINE void gemv_complex_col(
1617 Index rows, Index cols,
1618 const LhsMapper& alhs,
1619 const RhsMapper& rhs,
1620 ResScalar* res, Index resIncr,
1621 ResScalar alpha)
1622{
1623 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
1624
1625 typedef typename Traits::LhsPacket LhsPacket;
1626 typedef typename Traits::RhsPacket RhsPacket;
1627 typedef typename Traits::ResPacket ResPacket;
1628
1629 typedef typename packet_traits<Scalar>::type ScalarPacket;
1630 typedef typename packet_traits<LhsScalar>::type PLhsPacket;
1631 typedef typename packet_traits<ResScalar>::type PResPacket;
1632 typedef gemv_traits<ResPacket, ResPacket> PTraits;
1633
1634 EIGEN_UNUSED_VARIABLE(resIncr);
1635 eigen_internal_assert(resIncr == 1);
1636
1637 // The following copy tells the compiler that lhs's attributes are not modified outside this function
1638 // This helps GCC to generate proper code.
1639 LhsMapper lhs(alhs);
1640 RhsMapper rhs2(rhs);
1641
1642 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
1643
1644 const Index lhsStride = lhs.stride();
1645 // TODO: for padded aligned inputs, we could enable aligned reads
1646 enum {
1647 LhsAlignment = Unaligned,
1648 ResPacketSize = PTraits::ResPacketSize,
1649 LhsPacketSize = PTraits::LhsPacketSize,
1650 RhsPacketSize = PTraits::RhsPacketSize,
1651 };
1652#ifdef EIGEN_POWER_USE_GEMV_PREFETCH
1653 const Index prefetch_dist = 64 * LhsPacketSize;
1654#endif
1655
1656#ifndef GCC_ONE_VECTORPAIR_BUG
1657 const Index n8 = rows - 8 * ResPacketSize + 1;
1658 const Index n4 = rows - 4 * ResPacketSize + 1;
1659 const Index n2 = rows - 2 * ResPacketSize + 1;
1660#endif
1661 const Index n1 = rows - 1 * ResPacketSize + 1;
1662
1663 // TODO: improve the following heuristic:
1664 const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(LhsScalar) < 16000 ? 16 : 8);
1665
1667 AlphaData alpha_data(alpha);
1668
1669 for (Index j2 = 0; j2 < cols; j2 += block_cols)
1670 {
1671 Index jend = numext::mini(j2 + block_cols, cols);
1672 Index i = 0;
1673 PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
1674 ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
1675 PLhsPacket f0, f1, f2, f3, f4, f5, f6, f7;
1676#ifdef USE_GEMV_MMA
1677 __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
1678 __vector_pair a0, a1, a2, a3, a4, a5, a6, a7;
1679 PacketBlock<ScalarPacket, 4> result00, result01, result02, result03, result04, result05, result06, result07;
1680 GEMV_UNUSED(8, e0)
1681 GEMV_UNUSED(8, result0)
1682 GEMV_UNUSED(8, a)
1683 GEMV_UNUSED(8, f)
1684#if !defined(GCC_ONE_VECTORPAIR_BUG) && defined(USE_GEMV_COL_COMPLEX_MMA)
1685 if (GEMV_IS_COMPLEX_COMPLEX || !GEMV_IS_COMPLEX_FLOAT)
1686#endif
1687#endif
1688#ifndef GCC_ONE_VECTORPAIR_BUG
1689 {
1690 while (i < n8)
1691 {
1692 GEMV_PROCESS_COL_COMPLEX(8)
1693 }
1694 }
1695 while (i < n4)
1696 {
1697 GEMV_PROCESS_COL_COMPLEX(4)
1698 }
1699 if (i < n2)
1700 {
1701 GEMV_PROCESS_COL_COMPLEX(2)
1702 }
1703 if (i < n1)
1704#else
1705 while (i < n1)
1706#endif
1707 {
1708 GEMV_PROCESS_COL_COMPLEX_ONE(1)
1709 }
1710 for (;i < rows;++i)
1711 {
1712 ResScalar d0(0);
1713 Index j = j2;
1714 do {
1715 d0 += cj.pmul(lhs(i, j), rhs2(j, 0));
1716 } while (++j < jend);
1717 res[i] += alpha * d0;
1718 }
1719 }
1720}
1721
1722template <typename Scalar, int N> struct ScalarBlock {
1723 Scalar scalar[N];
1724};
1725
1726#ifdef USE_GEMV_MMA
1727static Packet16uc p16uc_ELEMENT_3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f, 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f };
1728
1730template<typename ResScalar, typename ResPacket>
1731EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0, __vector_quad* acc1)
1732{
1733 PacketBlock<ResPacket, 4> result0, result1;
1734 __builtin_mma_disassemble_acc(&result0.packet, acc0);
1735 __builtin_mma_disassemble_acc(&result1.packet, acc1);
1736 result0.packet[0] = vec_mergeh(result0.packet[0], result1.packet[0]);
1737 result0.packet[1] = vec_mergeo(result0.packet[1], result1.packet[1]);
1738 result0.packet[2] = vec_mergel(result0.packet[2], result1.packet[2]);
1739 result0.packet[3] = vec_perm(result0.packet[3], result1.packet[3], p16uc_ELEMENT_3);
1740 result0.packet[0] = vec_add(vec_add(result0.packet[0], result0.packet[2]), vec_add(result0.packet[1], result0.packet[3]));
1741 return *reinterpret_cast<ScalarBlock<ResScalar, 2> *>(&result0.packet[0]);
1742}
1743
1744template<>
1745EIGEN_ALWAYS_INLINE ScalarBlock<double, 2> predux_real<double, Packet2d>(__vector_quad* acc0, __vector_quad* acc1)
1746{
1747 PacketBlock<Packet2d, 4> result0, result1;
1748 __builtin_mma_disassemble_acc(&result0.packet, acc0);
1749 __builtin_mma_disassemble_acc(&result1.packet, acc1);
1750 result0.packet[0] = vec_add(vec_mergeh(result0.packet[0], result1.packet[0]), vec_mergel(result0.packet[1], result1.packet[1]));
1751 return *reinterpret_cast<ScalarBlock<double, 2> *>(&result0.packet[0]);
1752}
1753
1755template<typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
1756EIGEN_ALWAYS_INLINE ScalarBlock<std::complex<float>, 2> addComplexResults(PacketBlock<Packet4f, 4>& result0, PacketBlock<Packet4f, 4>& result1)
1757{
1759 result0.packet[0] = reinterpret_cast<Packet4f>(vec_mergeh(reinterpret_cast<Packet2d>(result0.packet[0]), reinterpret_cast<Packet2d>(result1.packet[0])));
1760 result0.packet[2] = reinterpret_cast<Packet4f>(vec_mergel(reinterpret_cast<Packet2d>(result0.packet[2]), reinterpret_cast<Packet2d>(result1.packet[2])));
1761 result0.packet[0] = vec_add(result0.packet[0], result0.packet[2]);
1762 if (GEMV_IS_COMPLEX_COMPLEX) {
1763 result0.packet[1] = reinterpret_cast<Packet4f>(vec_mergeh(reinterpret_cast<Packet2d>(result0.packet[1]), reinterpret_cast<Packet2d>(result1.packet[1])));
1764 result0.packet[3] = reinterpret_cast<Packet4f>(vec_mergel(reinterpret_cast<Packet2d>(result0.packet[3]), reinterpret_cast<Packet2d>(result1.packet[3])));
1765 result0.packet[1] = vec_add(result0.packet[1], result0.packet[3]);
1766 if (ConjugateLhs) {
1767 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1768 result0.packet[1] = pcplxflip2(convertComplex(result0.packet[1])).v;
1769 } else if (ConjugateRhs) {
1770 result0.packet[1] = pcplxconjflip(convertComplex(result0.packet[1])).v;
1771 } else {
1772 result0.packet[1] = pcplxflipconj(convertComplex(result0.packet[1])).v;
1773 }
1774 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1775 } else {
1776 if (ConjugateLhs && (sizeof(LhsPacket) == sizeof(std::complex<float>))) {
1777 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1778 }
1779 }
1780 cc0.scalar[0].real(result0.packet[0][0]);
1781 cc0.scalar[0].imag(result0.packet[0][1]);
1782 cc0.scalar[1].real(result0.packet[0][2]);
1783 cc0.scalar[1].imag(result0.packet[0][3]);
1784 return cc0;
1785}
1786
1787template<typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
1788EIGEN_ALWAYS_INLINE ScalarBlock<std::complex<double>, 2> addComplexResults(PacketBlock<Packet2d, 4>&, PacketBlock<Packet2d, 4>&)
1789{
1791 EIGEN_UNUSED_VARIABLE(cc0);
1792 return cc0; // Just for compilation
1793}
1794
1796template<typename ResScalar, typename ResPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
1797EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(__vector_quad* acc0, __vector_quad* acc1)
1798{
1799 PacketBlock<ResPacket, 4> result0, result1;
1800 __builtin_mma_disassemble_acc(&result0.packet, acc0);
1801 __builtin_mma_disassemble_acc(&result1.packet, acc1);
1802 return addComplexResults<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(result0, result1);
1803}
1804
1805template<typename ResScalar, typename ResPacket>
1806EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0)
1807{
1808 PacketBlock<ResPacket, 4> result0;
1809 __builtin_mma_disassemble_acc(&result0.packet, acc0);
1810 result0.packet[0] = vec_add(vec_mergeh(result0.packet[0], result0.packet[2]), vec_mergel(result0.packet[1], result0.packet[3]));
1811 return *reinterpret_cast<ScalarBlock<ResScalar, 2> *>(&result0.packet[0]);
1812}
1813
1814template<typename ResScalar, typename ResPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
1815EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(__vector_quad* acc0)
1816{
1818 PacketBlock<ResPacket, 4> result0;
1819 __builtin_mma_disassemble_acc(&result0.packet, acc0);
1820 if (GEMV_IS_COMPLEX_COMPLEX) {
1821 if (ConjugateLhs) {
1822 result0.packet[1] = pconjinv(convertComplex(result0.packet[1])).v;
1823 result0.packet[3] = pconjinv(convertComplex(result0.packet[3])).v;
1824 } else if (ConjugateRhs) {
1825 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1826 result0.packet[2] = pconj2(convertComplex(result0.packet[2])).v;
1827 } else {
1828 result0.packet[1] = pconj2(convertComplex(result0.packet[1])).v;
1829 result0.packet[3] = pconj2(convertComplex(result0.packet[3])).v;
1830 }
1831 result0.packet[0] = vec_add(result0.packet[0], __builtin_vsx_xxpermdi(result0.packet[1], result0.packet[1], 2));
1832 result0.packet[2] = vec_add(result0.packet[2], __builtin_vsx_xxpermdi(result0.packet[3], result0.packet[3], 2));
1833 } else {
1834 result0.packet[0] = __builtin_vsx_xxpermdi(result0.packet[0], result0.packet[1], 1);
1835 result0.packet[2] = __builtin_vsx_xxpermdi(result0.packet[2], result0.packet[3], 1);
1836 }
1837 cc0.scalar[0].real(result0.packet[0][0]);
1838 cc0.scalar[0].imag(result0.packet[0][1]);
1839 cc0.scalar[1].real(result0.packet[2][0]);
1840 cc0.scalar[1].imag(result0.packet[2][1]);
1841 return cc0;
1842}
1843#endif
1844
1845template<typename ResScalar, typename ResPacket>
1846EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(ResPacket& a, ResPacket& b)
1847{
1849 cc0.scalar[0] = predux(a);
1850 cc0.scalar[1] = predux(b);
1851 return cc0;
1852}
1853
1854template<typename ResScalar, typename ResPacket>
1855EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(ResPacket& a, ResPacket& b)
1856{
1857 return predux_real<ResScalar, ResPacket>(a, b);
1858}
1859
1860#define GEMV_UNROLL_ROW(func, N) \
1861 func(0, N) func(1, N) func(2, N) func(3, N) func(4, N) func(5, N) func(6, N) func(7, N)
1862
1863#define GEMV_UNROLL_ROW_HALF(func, N) \
1864 func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
1865
1866#define GEMV_LOADPACKET_ROW(iter) \
1867 lhs.template load<LhsPacket, Unaligned>(i + (iter), j)
1868
1869#ifdef USE_GEMV_MMA
1870#define GEMV_UNROLL3_ROW(func, N, which) \
1871 func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) \
1872 func(4, N, which) func(5, N, which) func(6, N, which) func(7, N, which)
1873
1874#define GEMV_UNUSED_ROW(N, which) \
1875 GEMV_UNROLL3_ROW(GEMV_UNUSED_VAR, N, which)
1876
1877#define GEMV_INIT_ROW(iter, N) \
1878 if (GEMV_GETN(N) > iter) { \
1879 __builtin_mma_xxsetaccz(&c##iter); \
1880 }
1881
1882#define GEMV_LOADPAIR_ROW(iter1, iter2) \
1883 GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_ROW(iter2), GEMV_LOADPACKET_ROW((iter2) + 1));
1884
1885#define GEMV_WORK_ROW(iter, N) \
1886 if (GEMV_GETN(N) > iter) { \
1887 if (GEMV_IS_FLOAT) { \
1888 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, a0, GEMV_LOADPACKET_ROW(iter)); \
1889 } else { \
1890 __vector_pair b##iter; \
1891 GEMV_LOADPAIR_ROW(iter, iter << 1) \
1892 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, b##iter, a0); \
1893 } \
1894 }
1895
1896#define GEMV_PREDUX2(iter1, iter2, iter3, N) \
1897 if (N > iter1) { \
1898 if (GEMV_IS_FLOAT) { \
1899 cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter2, &c##iter3); \
1900 } else { \
1901 cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter1); \
1902 } \
1903 } else { \
1904 EIGEN_UNUSED_VARIABLE(cc##iter1); \
1905 }
1906#else
1907#define GEMV_INIT_ROW(iter, N) \
1908 if (N > iter) { \
1909 c##iter = pset1<ResPacket>(ResScalar(0)); \
1910 } else { \
1911 EIGEN_UNUSED_VARIABLE(c##iter); \
1912 }
1913
1914#define GEMV_WORK_ROW(iter, N) \
1915 if (N > iter) { \
1916 c##iter = pcj.pmadd(GEMV_LOADPACKET_ROW(iter), a0, c##iter); \
1917 }
1918
1919#define GEMV_PREDUX2(iter1, iter2, iter3, N) \
1920 if (N > iter1) { \
1921 cc##iter1 = predux_real<ResScalar, ResPacket>(c##iter2, c##iter3); \
1922 } else { \
1923 EIGEN_UNUSED_VARIABLE(cc##iter1); \
1924 }
1925#endif
1926
1927#define GEMV_MULT(iter1, iter2, iter3, N) \
1928 if (N > iter1) { \
1929 cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), a0); \
1930 cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), a0); \
1931 }
1932
1933#define GEMV_STORE_ROW(iter1, iter2, iter3, N) \
1934 if (N > iter1) { \
1935 storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
1936 storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
1937 }
1938
1940#define GEMV_PROCESS_ROW(N) \
1941 for (; i < n##N; i += N) { \
1942 GEMV_UNROLL_ROW(GEMV_INIT_ROW, N) \
1943 Index j = 0; \
1944 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
1945 RhsPacket a0 = rhs2.template load<RhsPacket, Unaligned>(j); \
1946 GEMV_UNROLL_ROW(GEMV_WORK_ROW, N) \
1947 } \
1948 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX2, (N >> 1)) \
1949 for (; j < cols; ++j) { \
1950 RhsScalar a0 = rhs2(j); \
1951 GEMV_UNROLL_ROW_HALF(GEMV_MULT, (N >> 1)) \
1952 } \
1953 GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW, (N >> 1)) \
1954 }
1955
1956template<typename LhsScalar, typename LhsMapper, typename RhsScalar, typename RhsMapper, typename ResScalar>
1957EIGEN_STRONG_INLINE void gemv_row(
1958 Index rows, Index cols,
1959 const LhsMapper& alhs,
1960 const RhsMapper& rhs,
1961 ResScalar* res, Index resIncr,
1962 ResScalar alpha)
1963{
1964 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
1965
1966 typedef typename Traits::LhsPacket LhsPacket;
1967 typedef typename Traits::RhsPacket RhsPacket;
1968 typedef typename Traits::ResPacket ResPacket;
1969
1970 // The following copy tells the compiler that lhs's attributes are not modified outside this function
1971 // This helps GCC to generate proper code.
1972 LhsMapper lhs(alhs);
1973 typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
1974
1975 eigen_internal_assert(rhs.stride() == 1);
1976 conj_helper<LhsScalar, RhsScalar, false, false> cj;
1977 conj_helper<LhsPacket, RhsPacket, false, false> pcj;
1978
1979 // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
1980 // processing 8 rows at once might be counter productive wrt cache.
1981#ifndef GCC_ONE_VECTORPAIR_BUG
1982 const Index n8 = lhs.stride() * sizeof(LhsScalar) > 32000 ? (rows - 7) : (rows - 7);
1983 const Index n4 = rows - 3;
1984 const Index n2 = rows - 1;
1985#endif
1986
1987 // TODO: for padded aligned inputs, we could enable aligned reads
1988 enum {
1989 LhsAlignment = Unaligned,
1990 ResPacketSize = Traits::ResPacketSize,
1991 LhsPacketSize = Traits::LhsPacketSize,
1992 RhsPacketSize = Traits::RhsPacketSize,
1993 };
1994
1995 Index i = 0;
1996#ifdef USE_GEMV_MMA
1997 __vector_quad c0, c1, c2, c3, c4, c5, c6, c7;
1998 GEMV_UNUSED_ROW(8, c)
1999#else
2000 ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
2001#endif
2002#ifndef GCC_ONE_VECTORPAIR_BUG
2003 ScalarBlock<ResScalar, 2> cc0, cc1, cc2, cc3;
2004 GEMV_PROCESS_ROW(8)
2005 GEMV_PROCESS_ROW(4)
2006 GEMV_PROCESS_ROW(2)
2007#endif
2008 for (; i < rows; ++i)
2009 {
2010 ResPacket d0 = pset1<ResPacket>(ResScalar(0));
2011 Index j = 0;
2012 for (; j + LhsPacketSize <= cols; j += LhsPacketSize)
2013 {
2014 RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j);
2015
2016 d0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, d0);
2017 }
2018 ResScalar dd0 = predux(d0);
2019 for (; j < cols; ++j)
2020 {
2021 dd0 += cj.pmul(lhs(i, j), rhs2(j));
2022 }
2023 res[i * resIncr] += alpha * dd0;
2024 }
2025}
2026
2027#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(Scalar) \
2028template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2029struct general_matrix_vector_product<Index, Scalar, LhsMapper, ColMajor, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
2030{ \
2031 typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2032\
2033 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2034 Index rows, Index cols, \
2035 const LhsMapper& lhs, \
2036 const RhsMapper& rhs, \
2037 ResScalar* res, Index resIncr, \
2038 ResScalar alpha) { \
2039 gemv_col<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2040 } \
2041};
2042
2043#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(Scalar) \
2044template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2045struct general_matrix_vector_product<Index, Scalar, LhsMapper, RowMajor, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
2046{ \
2047 typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2048\
2049 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2050 Index rows, Index cols, \
2051 const LhsMapper& lhs, \
2052 const RhsMapper& rhs, \
2053 ResScalar* res, Index resIncr, \
2054 ResScalar alpha) { \
2055 gemv_row<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2056 } \
2057};
2058
2059EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(float)
2060EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(double)
2061EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float)
2062EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double)
2063
2064template<typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
2065EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1)
2066{
2067 if (GEMV_IS_COMPLEX_COMPLEX) {
2068 a0 = padd(a0, a1);
2069 b0 = padd(b0, b1);
2070 }
2071 return predux_complex<ResScalar, PResPacket>(a0, b0);
2072}
2073
2074#define GEMV_LOADPACKET_ROW_COMPLEX(iter) \
2075 loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + (iter), j)
2076
2077#define GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter) \
2078 convertReal(GEMV_LOADPACKET_ROW_COMPLEX(iter))
2079
2080#define GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(which, N) \
2081 j = 0; \
2082 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2083 const RhsScalar& b1 = rhs2(j); \
2084 RhsScalar* b = const_cast<RhsScalar *>(&b1); \
2085 GEMV_UNROLL_ROW(which, N) \
2086 }
2087
2088#define GEMV_PROCESS_END_ROW_COMPLEX(N) \
2089 for (; j < cols; ++j) { \
2090 RhsScalar b0 = rhs2(j); \
2091 GEMV_UNROLL_ROW_HALF(GEMV_MULT_COMPLEX, (N >> 1)) \
2092 } \
2093 GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW_COMPLEX, (N >> 1))
2094
2095#ifdef USE_GEMV_MMA
2096#define GEMV_INIT_ROW_COMPLEX_MMA(iter, N) \
2097 if (GEMV_GETN_COMPLEX(N) > iter) { \
2098 __builtin_mma_xxsetaccz(&e0##iter); \
2099 }
2100
2101#define GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter1, iter2) \
2102 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter2), GEMV_LOADPACKET_ROW_COMPLEX_DATA((iter2) + 1));
2103
2104#define GEMV_WORK_ROW_COMPLEX_MMA(iter, N) \
2105 if (GEMV_GETN_COMPLEX(N) > iter) { \
2106 if (GEMV_IS_COMPLEX_FLOAT) { \
2107 PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2108 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2109 } else { \
2110 __vector_pair a##iter; \
2111 GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter, iter << 1) \
2112 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2113 } \
2114 }
2115
2116#define GEMV_PREDUX4_COMPLEX_MMA(iter1, iter2, iter3, N) \
2117 if (N > iter1) { \
2118 if (GEMV_IS_COMPLEX_FLOAT) { \
2119 cc##iter1 = predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter2, &e0##iter3); \
2120 } else { \
2121 cc##iter1 = predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter1); \
2122 } \
2123 } else { \
2124 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2125 }
2126
2127#define GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2128 GEMV_UNROLL_ROW(GEMV_INIT_ROW_COMPLEX_MMA, N) \
2129 GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX_MMA, N)
2130
2131#define GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N) \
2132 for (; i < n##N; i += N) { \
2133 GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2134 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_MMA, (N >> 1)) \
2135 GEMV_PROCESS_END_ROW_COMPLEX(N); \
2136 }
2137#endif
2138
2139#define GEMV_WORK_ROW_COMPLEX(iter, N) \
2140 if (N > iter) { \
2141 PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2142 gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, c0##iter, c1##iter); \
2143 }
2144
2145#define GEMV_PREDUX4_COMPLEX(iter1, iter2, iter3, N) \
2146 if (N > iter1) { \
2147 cc##iter1 = predux_complex<ResScalar, PResPacket, ResPacket, LhsPacket, RhsPacket>(c0##iter2, c0##iter3, c1##iter2, c1##iter3); \
2148 } else { \
2149 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2150 }
2151
2152#define GEMV_MULT_COMPLEX(iter1, iter2, iter3, N) \
2153 if (N > iter1) { \
2154 cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), b0); \
2155 cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), b0); \
2156 }
2157
2158#define GEMV_STORE_ROW_COMPLEX(iter1, iter2, iter3, N) \
2159 if (N > iter1) { \
2160 storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
2161 storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
2162 }
2163
2164#define GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2165 GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX, N) \
2166 GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX, N)
2167
2169#define GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2170 for (; i < n##N; i += N) { \
2171 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2172 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX, (N >> 1)) \
2173 GEMV_PROCESS_END_ROW_COMPLEX(N); \
2174 }
2175
2176#define GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2177 if (GEMV_IS_COMPLEX_COMPLEX) { \
2178 c0##iter = padd(c0##iter, c1##iter); \
2179 } \
2180 dd0 = predux(c0##iter);
2181
2182#if EIGEN_COMP_LLVM
2183#define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) \
2184 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N)
2185
2186#define GEMV_PROCESS_ROW_COMPLEX_ONE(N) \
2187 GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N)
2188
2189#define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) \
2190 GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter)
2191#else
2192// gcc seems to be reading and writing registers unnecessarily to memory.
2193// Use the old way for complex double until it is fixed.
2194
2195#define GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter) \
2196 lhs.template load<LhsPacket, LhsAlignment>(i + (iter), j)
2197
2198#define GEMV_INIT_COMPLEX_OLD(iter, N) \
2199 EIGEN_UNUSED_VARIABLE(c0##iter); \
2200 if (N > iter) { \
2201 c1##iter = pset_zero<ResPacket>(); \
2202 } else { \
2203 EIGEN_UNUSED_VARIABLE(c1##iter); \
2204 }
2205
2206#define GEMV_WORK_ROW_COMPLEX_OLD(iter, N) \
2207 if (N > iter) { \
2208 LhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter); \
2209 c1##iter = pcj.pmadd(a##iter, b0, c1##iter); \
2210 }
2211
2212#define GEMV_PREDUX4_COMPLEX_OLD(iter1, iter2, iter3, N) \
2213 if (N > iter1) { \
2214 cc##iter1.scalar[0] = predux(c1##iter2); \
2215 cc##iter1.scalar[1] = predux(c1##iter3); \
2216 } else { \
2217 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2218 }
2219
2220#define GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2221 GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX_OLD, N) \
2222 j = 0; \
2223 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2224 RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j); \
2225 GEMV_UNROLL_ROW(GEMV_WORK_ROW_COMPLEX_OLD, N) \
2226 }
2227
2228#define GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2229 for (; i < n##N; i += N) { \
2230 GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2231 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_OLD, (N >> 1)) \
2232 GEMV_PROCESS_END_ROW_COMPLEX(N) \
2233 }
2234
2235#define GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) \
2236 dd0 = predux(c1##iter);
2237
2238#if (__GNUC__ > 10)
2239#define GEMV_PROCESS_ROW_COMPLEX_IS_NEW 1
2240#else
2241#define GEMV_PROCESS_ROW_COMPLEX_IS_NEW \
2242 (sizeof(Scalar) == sizeof(float)) || GEMV_IS_COMPLEX_COMPLEX
2243#endif
2244
2245#define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) \
2246 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2247 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2248 } else { \
2249 GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2250 }
2251
2252#define GEMV_PROCESS_ROW_COMPLEX_ONE(N) \
2253 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2254 GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2255 } else { \
2256 GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2257 }
2258
2259#define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) \
2260 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2261 GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2262 } else { \
2263 GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) \
2264 }
2265#endif
2266
2267#ifdef USE_GEMV_MMA
2268#define GEMV_PROCESS_ROW_COMPLEX(N) \
2269 GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N)
2270#else
2271#define GEMV_PROCESS_ROW_COMPLEX(N) \
2272 GEMV_PROCESS_ROW_COMPLEX_ONE(N)
2273#endif
2274
2275template<typename Scalar, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, bool LhsIsReal, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, bool RhsIsReal, typename ResScalar>
2276EIGEN_STRONG_INLINE void gemv_complex_row(
2277 Index rows, Index cols,
2278 const LhsMapper& alhs,
2279 const RhsMapper& rhs,
2280 ResScalar* res, Index resIncr,
2281 ResScalar alpha)
2282{
2283 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2284
2285 typedef typename Traits::LhsPacket LhsPacket;
2286 typedef typename Traits::RhsPacket RhsPacket;
2287 typedef typename Traits::ResPacket ResPacket;
2288
2289 typedef typename packet_traits<Scalar>::type ScalarPacket;
2290 typedef typename packet_traits<LhsScalar>::type PLhsPacket;
2291 typedef typename packet_traits<ResScalar>::type PResPacket;
2292 typedef gemv_traits<ResPacket, ResPacket> PTraits;
2293
2294 // The following copy tells the compiler that lhs's attributes are not modified outside this function
2295 // This helps GCC to generate proper code.
2296 LhsMapper lhs(alhs);
2297 typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
2298
2299 eigen_internal_assert(rhs.stride() == 1);
2300 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
2301#if !EIGEN_COMP_LLVM
2302 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
2303#endif
2304
2305 // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
2306 // processing 8 rows at once might be counter productive wrt cache.
2307#ifndef GCC_ONE_VECTORPAIR_BUG
2308 const Index n8 = lhs.stride() * sizeof(LhsScalar) > 32000 ? (rows - 7) : (rows - 7);
2309 const Index n4 = rows - 3;
2310 const Index n2 = rows - 1;
2311#endif
2312
2313 // TODO: for padded aligned inputs, we could enable aligned reads
2314 enum {
2315 LhsAlignment = Unaligned,
2316 ResPacketSize = PTraits::ResPacketSize,
2317 LhsPacketSize = PTraits::LhsPacketSize,
2318 RhsPacketSize = PTraits::RhsPacketSize,
2319 };
2320
2321 Index i = 0, j;
2322 PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
2323 ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
2324#ifdef USE_GEMV_MMA
2325 __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
2326 GEMV_UNUSED_ROW(8, e0)
2327 GEMV_UNUSED_EXTRA(1, c0)
2328 GEMV_UNUSED_EXTRA(1, c1)
2329#endif
2330 ResScalar dd0;
2331#ifndef GCC_ONE_VECTORPAIR_BUG
2332 ScalarBlock<ResScalar, 2> cc0, cc1, cc2, cc3;
2333#ifdef USE_GEMV_MMA
2334 if (!GEMV_IS_COMPLEX_COMPLEX)
2335#endif
2336 {
2337 GEMV_PROCESS_ROW_COMPLEX(8)
2338 }
2339 GEMV_PROCESS_ROW_COMPLEX(4)
2340 GEMV_PROCESS_ROW_COMPLEX(2)
2341#endif
2342 for (; i < rows; ++i)
2343 {
2344 GEMV_PROCESS_ROW_COMPLEX_SINGLE(1)
2345 GEMV_PROCESS_ROW_COMPLEX_PREDUX(0)
2346 for (; j < cols; ++j)
2347 {
2348 dd0 += cj.pmul(lhs(i, j), rhs2(j));
2349 }
2350 res[i * resIncr] += alpha * dd0;
2351 }
2352}
2353
2354#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(Scalar, LhsScalar, RhsScalar) \
2355template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2356struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
2357{ \
2358 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2359\
2360 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2361 Index rows, Index cols, \
2362 const LhsMapper& lhs, \
2363 const RhsMapper& rhs, \
2364 ResScalar* res, Index resIncr, \
2365 ResScalar alpha) { \
2366 gemv_complex_col<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2367 } \
2368};
2369
2370#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(Scalar, LhsScalar, RhsScalar) \
2371template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2372struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
2373{ \
2374 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2375\
2376 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2377 Index rows, Index cols, \
2378 const LhsMapper& lhs, \
2379 const RhsMapper& rhs, \
2380 ResScalar* res, Index resIncr, \
2381 ResScalar alpha) { \
2382 gemv_complex_row<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2383 } \
2384};
2385
2386EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, float, std::complex<float>)
2387EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, std::complex<float>, float)
2388EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, std::complex<float>, std::complex<float>)
2389EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, double, std::complex<double>)
2390EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, std::complex<double>, double)
2391EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, std::complex<double>, std::complex<double>)
2392EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, float, std::complex<float>)
2393EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, std::complex<float>, float)
2394EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, std::complex<float>, std::complex<float>)
2395EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, double, std::complex<double>)
2396EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, std::complex<double>, double)
2397EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, std::complex<double>, std::complex<double>)
2398
2399#endif // EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
2400
@ Unaligned
Data pointer has no specific alignment.
Definition Constants.h:233
@ ColMajor
Storage order is column major (see TopicStorageOrders).
Definition Constants.h:319
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74
Definition MatrixVectorProduct.h:1722
Definition MatrixVectorProduct.h:898
Definition MatrixVectorProduct.h:893