Medial Code Documentation
Loading...
Searching...
No Matches
MatrixProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
5// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
13
14#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
15#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
16#endif
17
18#include "MatrixProductCommon.h"
19
20#if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
21#define EIGEN_ALTIVEC_DISABLE_MMA 0
22#endif
23
24// Check for MMA builtin support.
25#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__has_builtin)
26#if __has_builtin(__builtin_mma_assemble_acc)
27 #define EIGEN_ALTIVEC_MMA_SUPPORT
28#endif
29#endif
30
31// Check if and how we should actually use MMA if supported.
32#if defined(EIGEN_ALTIVEC_MMA_SUPPORT)
33
34#if !defined(EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH)
35#define EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH 0
36#endif
37
38// Check if we want to enable dynamic dispatch. Not supported by LLVM.
39#if EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH && !EIGEN_COMP_LLVM
40#define EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH 1
41// Otherwise, use MMA by default if available.
42#elif defined(__MMA__)
43#define EIGEN_ALTIVEC_MMA_ONLY 1
44#endif
45
46#endif // EIGEN_ALTIVEC_MMA_SUPPORT
47
48#if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
49 #include "MatrixProductMMA.h"
50#endif
51
52/**************************************************************************************************
53 * TODO *
54 * - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). *
55 * - Check the possibility of transposing as GETREAL and GETIMAG when needed. *
56 **************************************************************************************************/
57namespace Eigen {
58
59namespace internal {
60
61/**************************
62 * Constants and typedefs *
63 **************************/
64template<typename Scalar>
66{
67 typedef typename packet_traits<Scalar>::type vectortype;
69 typedef vectortype rhstype;
70 enum
71 {
72 vectorsize = packet_traits<Scalar>::size,
73 size = 4,
74 rows = 4
75 };
76};
77
78template<>
80{
81 typedef Packet2d vectortype;
84 enum
85 {
86 vectorsize = packet_traits<double>::size,
87 size = 2,
88 rows = 4
89 };
90};
91
92// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
93// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
94// are responsible to extract from convert between Eigen's and MatrixProduct approach.
95
96const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3,
97 8, 9, 10, 11,
98 16, 17, 18, 19,
99 24, 25, 26, 27};
100
101const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7,
102 12, 13, 14, 15,
103 20, 21, 22, 23,
104 28, 29, 30, 31};
105const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7,
106 16, 17, 18, 19, 20, 21, 22, 23};
107
108//[a,ai],[b,bi] = [ai,bi]
109const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15,
110 24, 25, 26, 27, 28, 29, 30, 31};
111
112/*********************************************
113 * Single precision real and complex packing *
114 * *******************************************/
115
130template<typename Scalar, typename Index, int StorageOrder>
131EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt)
132{
133 std::complex<Scalar> v;
134 if(i < j)
135 {
136 v.real( dt(j,i).real());
137 v.imag(-dt(j,i).imag());
138 } else if(i > j)
139 {
140 v.real( dt(i,j).real());
141 v.imag( dt(i,j).imag());
142 } else {
143 v.real( dt(i,j).real());
144 v.imag((Scalar)0.0);
145 }
146 return v;
147}
148
149template<typename Scalar, typename Index, int StorageOrder, int N>
150EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
151{
152 const Index depth = k2 + rows;
153 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride);
154 const Index vectorSize = N*quad_traits<Scalar>::vectorsize;
155 const Index vectorDelta = vectorSize * rows;
156 Scalar* blockBf = reinterpret_cast<Scalar *>(blockB);
157
158 Index rir = 0, rii, j = 0;
159 for(; j + vectorSize <= cols; j+=vectorSize)
160 {
161 rii = rir + vectorDelta;
162
163 for(Index i = k2; i < depth; i++)
164 {
165 for(Index k = 0; k < vectorSize; k++)
166 {
167 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j + k, rhs);
168
169 blockBf[rir + k] = v.real();
170 blockBf[rii + k] = v.imag();
171 }
172 rir += vectorSize;
173 rii += vectorSize;
174 }
175
176 rir += vectorDelta;
177 }
178
179 for(; j < cols; j++)
180 {
181 rii = rir + rows;
182
183 for(Index i = k2; i < depth; i++)
184 {
185 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j, rhs);
186
187 blockBf[rir] = v.real();
188 blockBf[rii] = v.imag();
189
190 rir += 1;
191 rii += 1;
192 }
193
194 rir += rows;
195 }
196}
197
198template<typename Scalar, typename Index, int StorageOrder>
199EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs, Index lhsStride, Index cols, Index rows)
200{
201 const Index depth = cols;
202 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride);
203 const Index vectorSize = quad_traits<Scalar>::vectorsize;
204 const Index vectorDelta = vectorSize * depth;
205 Scalar* blockAf = (Scalar *)(blockA);
206
207 Index rir = 0, rii, j = 0;
208 for(; j + vectorSize <= rows; j+=vectorSize)
209 {
210 rii = rir + vectorDelta;
211
212 for(Index i = 0; i < depth; i++)
213 {
214 for(Index k = 0; k < vectorSize; k++)
215 {
216 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(j+k, i, lhs);
217
218 blockAf[rir + k] = v.real();
219 blockAf[rii + k] = v.imag();
220 }
221 rir += vectorSize;
222 rii += vectorSize;
223 }
224
225 rir += vectorDelta;
226 }
227
228 if (j < rows)
229 {
230 rii = rir + ((rows - j) * depth);
231
232 for(Index i = 0; i < depth; i++)
233 {
234 Index k = j;
235 for(; k < rows; k++)
236 {
237 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(k, i, lhs);
238
239 blockAf[rir] = v.real();
240 blockAf[rii] = v.imag();
241
242 rir += 1;
243 rii += 1;
244 }
245 }
246 }
247}
248
249template<typename Scalar, typename Index, int StorageOrder, int N>
250EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
251{
252 const Index depth = k2 + rows;
253 const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
254 const Index vectorSize = quad_traits<Scalar>::vectorsize;
255
256 Index ri = 0, j = 0;
257 for(; j + N*vectorSize <= cols; j+=N*vectorSize)
258 {
259 Index i = k2;
260 for(; i < depth; i++)
261 {
262 for(Index k = 0; k < N*vectorSize; k++)
263 {
264 if(i <= j+k)
265 blockB[ri + k] = rhs(j+k, i);
266 else
267 blockB[ri + k] = rhs(i, j+k);
268 }
269 ri += N*vectorSize;
270 }
271 }
272
273 for(; j < cols; j++)
274 {
275 for(Index i = k2; i < depth; i++)
276 {
277 if(j <= i)
278 blockB[ri] = rhs(i, j);
279 else
280 blockB[ri] = rhs(j, i);
281 ri += 1;
282 }
283 }
284}
285
286template<typename Scalar, typename Index, int StorageOrder>
287EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
288{
289 const Index depth = cols;
290 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
291 const Index vectorSize = quad_traits<Scalar>::vectorsize;
292
293 Index ri = 0, j = 0;
294 for(; j + vectorSize <= rows; j+=vectorSize)
295 {
296 Index i = 0;
297
298 for(; i < depth; i++)
299 {
300 for(Index k = 0; k < vectorSize; k++)
301 {
302 if(i <= j+k)
303 blockA[ri + k] = lhs(j+k, i);
304 else
305 blockA[ri + k] = lhs(i, j+k);
306 }
307 ri += vectorSize;
308 }
309 }
310
311 if (j < rows)
312 {
313 for(Index i = 0; i < depth; i++)
314 {
315 Index k = j;
316 for(; k < rows; k++)
317 {
318 if(i <= k)
319 blockA[ri] = lhs(k, i);
320 else
321 blockA[ri] = lhs(i, k);
322 ri += 1;
323 }
324 }
325 }
326}
327
328template<typename Index, int nr, int StorageOrder>
329struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder>
330{
331 void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
332 {
334 }
335};
336
337template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
338struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder>
339{
340 void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols, Index rows)
341 {
343 }
344};
345
346// *********** symm_pack std::complex<float64> ***********
347
348template<typename Index, int nr, int StorageOrder>
349struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder>
350{
351 void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
352 {
354 }
355};
356
357template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
358struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder>
359{
360 void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols, Index rows)
361 {
363 }
364};
365
366// *********** symm_pack float32 ***********
367template<typename Index, int nr, int StorageOrder>
368struct symm_pack_rhs<float, Index, nr, StorageOrder>
369{
370 void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
371 {
373 }
374};
375
376template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
377struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder>
378{
379 void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows)
380 {
382 }
383};
384
385// *********** symm_pack float64 ***********
386template<typename Index, int nr, int StorageOrder>
387struct symm_pack_rhs<double, Index, nr, StorageOrder>
388{
389 void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
390 {
392 }
393};
394
395template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
397{
398 void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows)
399 {
401 }
402};
403
415template<typename Scalar, typename Packet, typename Index, int N>
416EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
417{
418 const Index size = 16 / sizeof(Scalar);
419 pstore<Scalar>(to + (0 * size), block.packet[0]);
420 pstore<Scalar>(to + (1 * size), block.packet[1]);
421 if (N > 2) {
422 pstore<Scalar>(to + (2 * size), block.packet[2]);
423 }
424 if (N > 3) {
425 pstore<Scalar>(to + (3 * size), block.packet[3]);
426 }
427}
428
429// General template for lhs & rhs complex packing.
430template<typename Scalar, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs>
431struct dhs_cpack {
432 EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
433 {
435 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
436 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
437 Scalar* blockAt = reinterpret_cast<Scalar *>(blockA);
438 Index j = 0;
439
440 for(; j + vectorSize <= rows; j+=vectorSize)
441 {
442 Index i = 0;
443
444 rii = rir + vectorDelta;
445
446 for(; i + vectorSize <= depth; i+=vectorSize)
447 {
450
451 if (UseLhs) {
453 } else {
455 }
456
457 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
458 blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32);
459 blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32);
460 blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32);
461
462 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32);
463 blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32);
464 blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32);
465 blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32);
466
467 if(Conjugate)
468 {
469 blocki.packet[0] = -blocki.packet[0];
470 blocki.packet[1] = -blocki.packet[1];
471 blocki.packet[2] = -blocki.packet[2];
472 blocki.packet[3] = -blocki.packet[3];
473 }
474
475 if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs)))
476 {
477 ptranspose(blockr);
478 ptranspose(blocki);
479 }
480
483
484 rir += 4*vectorSize;
485 rii += 4*vectorSize;
486 }
487 for(; i < depth; i++)
488 {
491
492 if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs)))
493 {
494 if (UseLhs) {
495 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
496 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 2, i);
497 } else {
498 cblock.packet[0] = lhs.template loadPacket<PacketC>(i, j + 0);
499 cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2);
500 }
501 } else {
502 if (UseLhs) {
503 cblock.packet[0] = pload2(lhs(j + 0, i), lhs(j + 1, i));
504 cblock.packet[1] = pload2(lhs(j + 2, i), lhs(j + 3, i));
505 } else {
506 cblock.packet[0] = pload2(lhs(i, j + 0), lhs(i, j + 1));
507 cblock.packet[1] = pload2(lhs(i, j + 2), lhs(i, j + 3));
508 }
509 }
510
511 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32);
512 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32);
513
514 if(Conjugate)
515 {
516 blocki.packet[0] = -blocki.packet[0];
517 }
518
519 pstore<Scalar>(blockAt + rir, blockr.packet[0]);
520 pstore<Scalar>(blockAt + rii, blocki.packet[0]);
521
522 rir += vectorSize;
523 rii += vectorSize;
524 }
525
526 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
527 }
528
529 if (!UseLhs)
530 {
531 if(PanelMode) rir -= (offset*(vectorSize - 1));
532
533 for(; j < rows; j++)
534 {
535 rii = rir + ((PanelMode) ? stride : depth);
536
537 for(Index i = 0; i < depth; i++)
538 {
539 blockAt[rir] = lhs(i, j).real();
540
541 if(Conjugate)
542 blockAt[rii] = -lhs(i, j).imag();
543 else
544 blockAt[rii] = lhs(i, j).imag();
545
546 rir += 1;
547 rii += 1;
548 }
549
550 rir += ((PanelMode) ? (2*stride - depth) : depth);
551 }
552 } else {
553 if (j < rows)
554 {
555 if(PanelMode) rir += (offset*(rows - j - vectorSize));
556 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
557
558 for(Index i = 0; i < depth; i++)
559 {
560 Index k = j;
561 for(; k < rows; k++)
562 {
563 blockAt[rir] = lhs(k, i).real();
564
565 if(Conjugate)
566 blockAt[rii] = -lhs(k, i).imag();
567 else
568 blockAt[rii] = lhs(k, i).imag();
569
570 rir += 1;
571 rii += 1;
572 }
573 }
574 }
575 }
576 }
577};
578
579// General template for lhs & rhs packing.
580template<typename Scalar, typename Index, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
581struct dhs_pack{
582 EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
583 {
585 Index ri = 0, j = 0;
586
587 for(; j + vectorSize <= rows; j+=vectorSize)
588 {
589 Index i = 0;
590
591 if(PanelMode) ri += vectorSize*offset;
592
593 for(; i + vectorSize <= depth; i+=vectorSize)
594 {
596
597 if (UseLhs) {
599 } else {
601 }
602 if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
603 {
604 ptranspose(block);
605 }
606
607 storeBlock<Scalar, Packet, Index, 4>(blockA + ri, block);
608
609 ri += 4*vectorSize;
610 }
611 for(; i < depth; i++)
612 {
613 if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
614 {
615 if (UseLhs) {
616 blockA[ri+0] = lhs(j+0, i);
617 blockA[ri+1] = lhs(j+1, i);
618 blockA[ri+2] = lhs(j+2, i);
619 blockA[ri+3] = lhs(j+3, i);
620 } else {
621 blockA[ri+0] = lhs(i, j+0);
622 blockA[ri+1] = lhs(i, j+1);
623 blockA[ri+2] = lhs(i, j+2);
624 blockA[ri+3] = lhs(i, j+3);
625 }
626 } else {
627 Packet lhsV;
628 if (UseLhs) {
629 lhsV = lhs.template loadPacket<Packet>(j, i);
630 } else {
631 lhsV = lhs.template loadPacket<Packet>(i, j);
632 }
633 pstore<Scalar>(blockA + ri, lhsV);
634 }
635
636 ri += vectorSize;
637 }
638
639 if(PanelMode) ri += vectorSize*(stride - offset - depth);
640 }
641
642 if (!UseLhs)
643 {
644 if(PanelMode) ri += offset;
645
646 for(; j < rows; j++)
647 {
648 for(Index i = 0; i < depth; i++)
649 {
650 blockA[ri] = lhs(i, j);
651 ri += 1;
652 }
653
654 if(PanelMode) ri += stride - depth;
655 }
656 } else {
657 if (j < rows)
658 {
659 if(PanelMode) ri += offset*(rows - j);
660
661 for(Index i = 0; i < depth; i++)
662 {
663 Index k = j;
664 for(; k < rows; k++)
665 {
666 blockA[ri] = lhs(k, i);
667 ri += 1;
668 }
669 }
670 }
671 }
672 }
673};
674
675// General template for lhs packing, float64 specialization.
676template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
678{
679 EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
680 {
682 Index ri = 0, j = 0;
683
684 for(; j + vectorSize <= rows; j+=vectorSize)
685 {
686 Index i = 0;
687
688 if(PanelMode) ri += vectorSize*offset;
689
690 for(; i + vectorSize <= depth; i+=vectorSize)
691 {
693 if(StorageOrder == RowMajor)
694 {
695 block.packet[0] = lhs.template loadPacket<Packet2d>(j + 0, i);
696 block.packet[1] = lhs.template loadPacket<Packet2d>(j + 1, i);
697
698 ptranspose(block);
699 } else {
700 block.packet[0] = lhs.template loadPacket<Packet2d>(j, i + 0);
701 block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
702 }
703
704 storeBlock<double, Packet2d, Index, 2>(blockA + ri, block);
705
706 ri += 2*vectorSize;
707 }
708 for(; i < depth; i++)
709 {
710 if(StorageOrder == RowMajor)
711 {
712 blockA[ri+0] = lhs(j+0, i);
713 blockA[ri+1] = lhs(j+1, i);
714 } else {
715 Packet2d lhsV = lhs.template loadPacket<Packet2d>(j, i);
716 pstore<double>(blockA + ri, lhsV);
717 }
718
719 ri += vectorSize;
720 }
721
722 if(PanelMode) ri += vectorSize*(stride - offset - depth);
723 }
724
725 if (j < rows)
726 {
727 if(PanelMode) ri += offset*(rows - j);
728
729 for(Index i = 0; i < depth; i++)
730 {
731 Index k = j;
732 for(; k < rows; k++)
733 {
734 blockA[ri] = lhs(k, i);
735 ri += 1;
736 }
737 }
738 }
739 }
740};
741
742// General template for rhs packing, float64 specialization.
743template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
745{
746 EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
747 {
749 Index ri = 0, j = 0;
750
751 for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
752 {
753 Index i = 0;
754
755 if(PanelMode) ri += offset*(2*vectorSize);
756
757 for(; i + vectorSize <= depth; i+=vectorSize)
758 {
760 if(StorageOrder == ColMajor)
761 {
763 block1.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 0);
764 block1.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 1);
765 block2.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 2);
766 block2.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 3);
767
768 ptranspose(block1);
769 ptranspose(block2);
770
771 pstore<double>(blockB + ri , block1.packet[0]);
772 pstore<double>(blockB + ri + 2, block2.packet[0]);
773 pstore<double>(blockB + ri + 4, block1.packet[1]);
774 pstore<double>(blockB + ri + 6, block2.packet[1]);
775 } else {
776 block.packet[0] = rhs.template loadPacket<Packet2d>(i + 0, j + 0); //[a1 a2]
777 block.packet[1] = rhs.template loadPacket<Packet2d>(i + 0, j + 2); //[a3 a4]
778 block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
779 block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
780
781 storeBlock<double, Packet2d, Index, 4>(blockB + ri, block);
782 }
783
784 ri += 4*vectorSize;
785 }
786 for(; i < depth; i++)
787 {
788 if(StorageOrder == ColMajor)
789 {
790 blockB[ri+0] = rhs(i, j+0);
791 blockB[ri+1] = rhs(i, j+1);
792
793 ri += vectorSize;
794
795 blockB[ri+0] = rhs(i, j+2);
796 blockB[ri+1] = rhs(i, j+3);
797 } else {
798 Packet2d rhsV = rhs.template loadPacket<Packet2d>(i, j);
799 pstore<double>(blockB + ri, rhsV);
800
801 ri += vectorSize;
802
803 rhsV = rhs.template loadPacket<Packet2d>(i, j + 2);
804 pstore<double>(blockB + ri, rhsV);
805 }
806 ri += vectorSize;
807 }
808
809 if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
810 }
811
812 if(PanelMode) ri += offset;
813
814 for(; j < cols; j++)
815 {
816 for(Index i = 0; i < depth; i++)
817 {
818 blockB[ri] = rhs(i, j);
819 ri += 1;
820 }
821
822 if(PanelMode) ri += stride - depth;
823 }
824 }
825};
826
827// General template for lhs complex packing, float64 specialization.
828template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
830{
831 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
832 {
834 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
835 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
836 double* blockAt = reinterpret_cast<double *>(blockA);
837 Index j = 0;
838
839 for(; j + vectorSize <= rows; j+=vectorSize)
840 {
841 Index i = 0;
842
843 rii = rir + vectorDelta;
844
845 for(; i + vectorSize <= depth; i+=vectorSize)
846 {
849
850 if(StorageOrder == ColMajor)
851 {
852 cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0); //[a1 a1i]
853 cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1); //[b1 b1i]
854
855 cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0); //[a2 a2i]
856 cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
857
858 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2]
859 blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
860
861 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64);
862 blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64);
863 } else {
864 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); //[a1 a1i]
865 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); //[a2 a2i]
866
867 cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1); //[b1 b1i]
868 cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i
869
870 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2]
871 blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
872
873 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
874 blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
875 }
876
877 if(Conjugate)
878 {
879 blocki.packet[0] = -blocki.packet[0];
880 blocki.packet[1] = -blocki.packet[1];
881 }
882
885
886 rir += 2*vectorSize;
887 rii += 2*vectorSize;
888 }
889 for(; i < depth; i++)
890 {
893
894 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
895 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
896
897 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
898 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
899
900 if(Conjugate)
901 {
902 blocki.packet[0] = -blocki.packet[0];
903 }
904
905 pstore<double>(blockAt + rir, blockr.packet[0]);
906 pstore<double>(blockAt + rii, blocki.packet[0]);
907
908 rir += vectorSize;
909 rii += vectorSize;
910 }
911
912 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
913 }
914
915 if (j < rows)
916 {
917 if(PanelMode) rir += (offset*(rows - j - vectorSize));
918 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
919
920 for(Index i = 0; i < depth; i++)
921 {
922 Index k = j;
923 for(; k < rows; k++)
924 {
925 blockAt[rir] = lhs(k, i).real();
926
927 if(Conjugate)
928 blockAt[rii] = -lhs(k, i).imag();
929 else
930 blockAt[rii] = lhs(k, i).imag();
931
932 rir += 1;
933 rii += 1;
934 }
935 }
936 }
937 }
938};
939
940// General template for rhs complex packing, float64 specialization.
941template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
943{
944 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
945 {
947 const Index vectorDelta = 2*vectorSize * ((PanelMode) ? stride : depth);
948 Index rir = ((PanelMode) ? (2*vectorSize*offset) : 0), rii;
949 double* blockBt = reinterpret_cast<double *>(blockB);
950 Index j = 0;
951
952 for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
953 {
954 Index i = 0;
955
956 rii = rir + vectorDelta;
957
958 for(; i < depth; i++)
959 {
962
964
965 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
966 blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
967
968 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
969 blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
970
971 if(Conjugate)
972 {
973 blocki.packet[0] = -blocki.packet[0];
974 blocki.packet[1] = -blocki.packet[1];
975 }
976
979
980 rir += 2*vectorSize;
981 rii += 2*vectorSize;
982 }
983
984 rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
985 }
986
987 if(PanelMode) rir -= (offset*(2*vectorSize - 1));
988
989 for(; j < cols; j++)
990 {
991 rii = rir + ((PanelMode) ? stride : depth);
992
993 for(Index i = 0; i < depth; i++)
994 {
995 blockBt[rir] = rhs(i, j).real();
996
997 if(Conjugate)
998 blockBt[rii] = -rhs(i, j).imag();
999 else
1000 blockBt[rii] = rhs(i, j).imag();
1001
1002 rir += 1;
1003 rii += 1;
1004 }
1005
1006 rir += ((PanelMode) ? (2*stride - depth) : depth);
1007 }
1008 }
1009};
1010
1011/**************
1012 * GEMM utils *
1013 **************/
1014
1015// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
1016template<typename Packet, bool NegativeAccumulate, int N>
1017EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,N>* acc, const Packet& lhsV, const Packet* rhsV)
1018{
1020 {
1021 acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
1022 if (N > 1) {
1023 acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
1024 }
1025 if (N > 2) {
1026 acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
1027 }
1028 if (N > 3) {
1029 acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
1030 }
1031 } else {
1032 acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
1033 if (N > 1) {
1034 acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
1035 }
1036 if (N > 2) {
1037 acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
1038 }
1039 if (N > 3) {
1040 acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
1041 }
1042 }
1043}
1044
1045template<int N, typename Scalar, typename Packet, bool NegativeAccumulate>
1046EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
1047{
1048 Packet lhsV = pload<Packet>(lhs);
1049
1050 pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1051}
1052
1053template<typename Scalar, typename Packet, typename Index, const Index remaining_rows>
1054EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV)
1055{
1056#ifdef _ARCH_PWR9
1057 lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
1058#else
1059 Index i = 0;
1060 do {
1061 lhsV[i] = lhs[i];
1062 } while (++i < remaining_rows);
1063#endif
1064}
1065
1066template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate, const Index remaining_rows>
1067EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
1068{
1069 Packet lhsV;
1070 loadPacketRemaining<Scalar, Packet, Index, remaining_rows>(lhs, lhsV);
1071
1072 pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1073}
1074
1075// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
1076template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1077EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
1078{
1079 pger_common<Packet, false, N>(accReal, lhsV, rhsV);
1080 if(LhsIsReal)
1081 {
1082 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1083 EIGEN_UNUSED_VARIABLE(lhsVi);
1084 } else {
1085 if (!RhsIsReal) {
1086 pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
1087 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1088 } else {
1089 EIGEN_UNUSED_VARIABLE(rhsVi);
1090 }
1091 pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
1092 }
1093}
1094
1095template<int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1096EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
1097{
1098 Packet lhsV = ploadLhs<Scalar, Packet>(lhs_ptr);
1099 Packet lhsVi;
1100 if(!LhsIsReal) lhsVi = ploadLhs<Scalar, Packet>(lhs_ptr_imag);
1101 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1102
1103 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1104}
1105
1106template<typename Scalar, typename Packet, typename Index, bool LhsIsReal, const Index remaining_rows>
1107EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi)
1108{
1109#ifdef _ARCH_PWR9
1110 lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar));
1111 if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar));
1112 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1113#else
1114 Index i = 0;
1115 do {
1116 lhsV[i] = lhs_ptr[i];
1117 if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i];
1118 } while (++i < remaining_rows);
1119 if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1120#endif
1121}
1122
1123template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
1124EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
1125{
1126 Packet lhsV, lhsVi;
1127 loadPacketRemaining<Scalar, Packet, Index, LhsIsReal, remaining_rows>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi);
1128
1129 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1130}
1131
1132template<typename Scalar, typename Packet>
1133EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs)
1134{
1135 return ploadu<Packet>(lhs);
1136}
1137
1138// Zero the accumulator on PacketBlock.
1139template<typename Scalar, typename Packet, int N>
1140EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,N>& acc)
1141{
1142 acc.packet[0] = pset1<Packet>((Scalar)0);
1143 if (N > 1) {
1144 acc.packet[1] = pset1<Packet>((Scalar)0);
1145 }
1146 if (N > 2) {
1147 acc.packet[2] = pset1<Packet>((Scalar)0);
1148 }
1149 if (N > 3) {
1150 acc.packet[3] = pset1<Packet>((Scalar)0);
1151 }
1152}
1153
1154// Scale the PacketBlock vectors by alpha.
1155template<typename Packet, int N>
1156EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
1157{
1158 acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
1159 if (N > 1) {
1160 acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
1161 }
1162 if (N > 2) {
1163 acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
1164 }
1165 if (N > 3) {
1166 acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
1167 }
1168}
1169
1170template<typename Packet, int N>
1171EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
1172{
1173 acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
1174 if (N > 1) {
1175 acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
1176 }
1177 if (N > 2) {
1178 acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
1179 }
1180 if (N > 3) {
1181 acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
1182 }
1183}
1184
1185// Complex version of PacketBlock scaling.
1186template<typename Packet, int N>
1187EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag)
1188{
1189 bscalec_common<Packet, N>(cReal, aReal, bReal);
1190
1191 bscalec_common<Packet, N>(cImag, aImag, bReal);
1192
1193 pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
1194
1195 pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
1196}
1197
1198template<typename Packet, int N>
1199EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,N>& acc, const Packet& pMask)
1200{
1201 acc.packet[0] = pand(acc.packet[0], pMask);
1202 if (N > 1) {
1203 acc.packet[1] = pand(acc.packet[1], pMask);
1204 }
1205 if (N > 2) {
1206 acc.packet[2] = pand(acc.packet[2], pMask);
1207 }
1208 if (N > 3) {
1209 acc.packet[3] = pand(acc.packet[3], pMask);
1210 }
1211}
1212
1213template<typename Packet, int N>
1214EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag, const Packet& pMask)
1215{
1216 band<Packet, N>(aReal, pMask);
1217 band<Packet, N>(aImag, pMask);
1218
1219 bscalec<Packet,N>(aReal, aImag, bReal, bImag, cReal, cImag);
1220}
1221
1222// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
1223template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N>
1224EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col)
1225{
1226 if (StorageOrder == RowMajor) {
1227 acc.packet[0] = res.template loadPacket<Packet>(row + 0, col);
1228 if (N > 1) {
1229 acc.packet[1] = res.template loadPacket<Packet>(row + 1, col);
1230 }
1231 if (N > 2) {
1232 acc.packet[2] = res.template loadPacket<Packet>(row + 2, col);
1233 }
1234 if (N > 3) {
1235 acc.packet[3] = res.template loadPacket<Packet>(row + 3, col);
1236 }
1237 if (Complex) {
1238 acc.packet[0+N] = res.template loadPacket<Packet>(row + 0, col + accCols);
1239 if (N > 1) {
1240 acc.packet[1+N] = res.template loadPacket<Packet>(row + 1, col + accCols);
1241 }
1242 if (N > 2) {
1243 acc.packet[2+N] = res.template loadPacket<Packet>(row + 2, col + accCols);
1244 }
1245 if (N > 3) {
1246 acc.packet[3+N] = res.template loadPacket<Packet>(row + 3, col + accCols);
1247 }
1248 }
1249 } else {
1250 acc.packet[0] = res.template loadPacket<Packet>(row, col + 0);
1251 if (N > 1) {
1252 acc.packet[1] = res.template loadPacket<Packet>(row, col + 1);
1253 }
1254 if (N > 2) {
1255 acc.packet[2] = res.template loadPacket<Packet>(row, col + 2);
1256 }
1257 if (N > 3) {
1258 acc.packet[3] = res.template loadPacket<Packet>(row, col + 3);
1259 }
1260 if (Complex) {
1261 acc.packet[0+N] = res.template loadPacket<Packet>(row + accCols, col + 0);
1262 if (N > 1) {
1263 acc.packet[1+N] = res.template loadPacket<Packet>(row + accCols, col + 1);
1264 }
1265 if (N > 2) {
1266 acc.packet[2+N] = res.template loadPacket<Packet>(row + accCols, col + 2);
1267 }
1268 if (N > 3) {
1269 acc.packet[3+N] = res.template loadPacket<Packet>(row + accCols, col + 3);
1270 }
1271 }
1272 }
1273}
1274
1275const static Packet4i mask41 = { -1, 0, 0, 0 };
1276const static Packet4i mask42 = { -1, -1, 0, 0 };
1277const static Packet4i mask43 = { -1, -1, -1, 0 };
1278
1279const static Packet2l mask21 = { -1, 0 };
1280
1281template<typename Packet>
1282EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows)
1283{
1284 if (remaining_rows == 0) {
1285 return pset1<Packet>(float(0.0)); // Not used
1286 } else {
1287 switch (remaining_rows) {
1288 case 1: return Packet(mask41);
1289 case 2: return Packet(mask42);
1290 default: return Packet(mask43);
1291 }
1292 }
1293}
1294
1295template<>
1296EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const int remaining_rows)
1297{
1298 if (remaining_rows == 0) {
1299 return pset1<Packet2d>(double(0.0)); // Not used
1300 } else {
1301 return Packet2d(mask21);
1302 }
1303}
1304
1305template<typename Packet, int N>
1306EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha, const Packet& pMask)
1307{
1308 band<Packet, N>(accZ, pMask);
1309
1310 bscale<Packet, N>(acc, accZ, pAlpha);
1311}
1312
1313template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
1314pbroadcastN_old(const __UNPACK_TYPE__(Packet) *a,
1315 Packet& a0, Packet& a1, Packet& a2, Packet& a3)
1316{
1317 a0 = pset1<Packet>(a[0]);
1318 if (N > 1) {
1319 a1 = pset1<Packet>(a[1]);
1320 } else {
1321 EIGEN_UNUSED_VARIABLE(a1);
1322 }
1323 if (N > 2) {
1324 a2 = pset1<Packet>(a[2]);
1325 } else {
1326 EIGEN_UNUSED_VARIABLE(a2);
1327 }
1328 if (N > 3) {
1329 a3 = pset1<Packet>(a[3]);
1330 } else {
1331 EIGEN_UNUSED_VARIABLE(a3);
1332 }
1333}
1334
1335template<>
1336EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet4f,4>(const float* a, Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
1337{
1338 pbroadcast4<Packet4f>(a, a0, a1, a2, a3);
1339}
1340
1341template<>
1342EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet2d,4>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
1343{
1344 a1 = pload<Packet2d>(a);
1345 a3 = pload<Packet2d>(a + 2);
1346 a0 = vec_splat(a1, 0);
1347 a1 = vec_splat(a1, 1);
1348 a2 = vec_splat(a3, 0);
1349 a3 = vec_splat(a3, 1);
1350}
1351
1352template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
1353pbroadcastN(const __UNPACK_TYPE__(Packet) *a,
1354 Packet& a0, Packet& a1, Packet& a2, Packet& a3)
1355{
1356 a0 = pset1<Packet>(a[0]);
1357 if (N > 1) {
1358 a1 = pset1<Packet>(a[1]);
1359 } else {
1360 EIGEN_UNUSED_VARIABLE(a1);
1361 }
1362 if (N > 2) {
1363 a2 = pset1<Packet>(a[2]);
1364 } else {
1365 EIGEN_UNUSED_VARIABLE(a2);
1366 }
1367 if (N > 3) {
1368 a3 = pset1<Packet>(a[3]);
1369 } else {
1370 EIGEN_UNUSED_VARIABLE(a3);
1371 }
1372}
1373
1374template<> EIGEN_ALWAYS_INLINE void
1375pbroadcastN<Packet4f,4>(const float *a,
1376 Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
1377{
1378 a3 = pload<Packet4f>(a);
1379 a0 = vec_splat(a3, 0);
1380 a1 = vec_splat(a3, 1);
1381 a2 = vec_splat(a3, 2);
1382 a3 = vec_splat(a3, 3);
1383}
1384
1385// PEEL loop factor.
1386#define PEEL 7
1387#define PEEL_ROW 7
1388
1389#define MICRO_UNROLL_PEEL(func) \
1390 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1391
1392#define MICRO_ZERO_PEEL(peel) \
1393 if ((PEEL_ROW > peel) && (peel != 0)) { \
1394 bsetzero<Scalar, Packet, accRows>(accZero##peel); \
1395 } else { \
1396 EIGEN_UNUSED_VARIABLE(accZero##peel); \
1397 }
1398
1399#define MICRO_ZERO_PEEL_ROW \
1400 MICRO_UNROLL_PEEL(MICRO_ZERO_PEEL);
1401
1402#define MICRO_WORK_PEEL(peel) \
1403 if (PEEL_ROW > peel) { \
1404 pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1405 pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
1406 } else { \
1407 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1408 }
1409
1410#define MICRO_WORK_PEEL_ROW \
1411 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
1412 MICRO_UNROLL_PEEL(MICRO_WORK_PEEL); \
1413 lhs_ptr += (remaining_rows * PEEL_ROW); \
1414 rhs_ptr += (accRows * PEEL_ROW);
1415
1416#define MICRO_ADD_PEEL(peel, sum) \
1417 if (PEEL_ROW > peel) { \
1418 for (Index i = 0; i < accRows; i++) { \
1419 accZero##sum.packet[i] += accZero##peel.packet[i]; \
1420 } \
1421 }
1422
1423#define MICRO_ADD_PEEL_ROW \
1424 MICRO_ADD_PEEL(4, 0) MICRO_ADD_PEEL(5, 1) MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) \
1425 MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
1426
1427template<typename Scalar, typename Packet, typename Index, const Index accRows, const Index remaining_rows>
1428EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
1429 const Scalar* &lhs_ptr,
1430 const Scalar* &rhs_ptr,
1431 PacketBlock<Packet,accRows> &accZero)
1432{
1433 Packet rhsV[4];
1434 pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1435 pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1436 lhs_ptr += remaining_rows;
1437 rhs_ptr += accRows;
1438}
1439
1440template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index remaining_rows>
1441EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(
1442 const DataMapper& res,
1443 const Scalar* lhs_base,
1444 const Scalar* rhs_base,
1445 Index depth,
1446 Index strideA,
1447 Index offsetA,
1448 Index row,
1449 Index col,
1450 Index rows,
1451 Index cols,
1452 const Packet& pAlpha,
1453 const Packet& pMask)
1454{
1455 const Scalar* rhs_ptr = rhs_base;
1456 const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
1457 PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
1458
1459 bsetzero<Scalar, Packet, accRows>(accZero0);
1460
1461 Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
1462 Index k = 0;
1463 if (remaining_depth >= PEEL_ROW) {
1464 MICRO_ZERO_PEEL_ROW
1465 do
1466 {
1467 EIGEN_POWER_PREFETCH(rhs_ptr);
1468 EIGEN_POWER_PREFETCH(lhs_ptr);
1469 MICRO_WORK_PEEL_ROW
1470 } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
1471 MICRO_ADD_PEEL_ROW
1472 }
1473 for(; k < remaining_depth; k++)
1474 {
1475 MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows, remaining_rows>(lhs_ptr, rhs_ptr, accZero0);
1476 }
1477
1478 if ((remaining_depth == depth) && (rows >= accCols))
1479 {
1480 bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row, 0);
1481 bscale<Packet,accRows>(acc, accZero0, pAlpha, pMask);
1482 res.template storePacketBlock<Packet,accRows>(row, 0, acc);
1483 } else {
1484 for(; k < depth; k++)
1485 {
1486 Packet rhsV[4];
1487 pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1488 pger<accRows, Scalar, Packet, Index, false, remaining_rows>(&accZero0, lhs_ptr, rhsV);
1489 lhs_ptr += remaining_rows;
1490 rhs_ptr += accRows;
1491 }
1492
1493 for(Index j = 0; j < accRows; j++) {
1494 accZero0.packet[j] = vec_mul(pAlpha, accZero0.packet[j]);
1495 for(Index i = 0; i < remaining_rows; i++) {
1496 res(row + i, j) += accZero0.packet[j][i];
1497 }
1498 }
1499 }
1500}
1501
1502template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
1503EIGEN_ALWAYS_INLINE void gemm_extra_row(
1504 const DataMapper& res,
1505 const Scalar* lhs_base,
1506 const Scalar* rhs_base,
1507 Index depth,
1508 Index strideA,
1509 Index offsetA,
1510 Index row,
1511 Index col,
1512 Index rows,
1513 Index cols,
1514 Index remaining_rows,
1515 const Packet& pAlpha,
1516 const Packet& pMask)
1517{
1518 switch(remaining_rows) {
1519 case 1:
1520 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
1521 break;
1522 case 2:
1523 if (sizeof(Scalar) == sizeof(float)) {
1524 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
1525 }
1526 break;
1527 default:
1528 if (sizeof(Scalar) == sizeof(float)) {
1529 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
1530 }
1531 break;
1532 }
1533}
1534
1535#define MICRO_UNROLL(func) \
1536 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1537
1538#define MICRO_UNROLL_WORK(func, func2, peel) \
1539 MICRO_UNROLL(func2); \
1540 func(0,peel) func(1,peel) func(2,peel) func(3,peel) \
1541 func(4,peel) func(5,peel) func(6,peel) func(7,peel)
1542
1543#define MICRO_LOAD_ONE(iter) \
1544 if (unroll_factor > iter) { \
1545 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
1546 lhs_ptr##iter += accCols; \
1547 } else { \
1548 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
1549 }
1550
1551#define MICRO_WORK_ONE(iter, peel) \
1552 if (unroll_factor > iter) { \
1553 pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
1554 }
1555
1556#define MICRO_TYPE_PEEL4(func, func2, peel) \
1557 if (PEEL > peel) { \
1558 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
1559 pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1560 MICRO_UNROLL_WORK(func, func2, peel) \
1561 } else { \
1562 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1563 }
1564
1565#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
1566 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
1567 func(func1,func2,0); func(func1,func2,1); \
1568 func(func1,func2,2); func(func1,func2,3); \
1569 func(func1,func2,4); func(func1,func2,5); \
1570 func(func1,func2,6); func(func1,func2,7);
1571
1572#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
1573 Packet rhsV0[M]; \
1574 func(func1,func2,0);
1575
1576#define MICRO_ONE_PEEL4 \
1577 MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
1578 rhs_ptr += (accRows * PEEL);
1579
1580#define MICRO_ONE4 \
1581 MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
1582 rhs_ptr += accRows;
1583
1584#define MICRO_DST_PTR_ONE(iter) \
1585 if (unroll_factor > iter) { \
1586 bsetzero<Scalar, Packet, accRows>(accZero##iter); \
1587 } else { \
1588 EIGEN_UNUSED_VARIABLE(accZero##iter); \
1589 }
1590
1591#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
1592
1593#define MICRO_SRC_PTR_ONE(iter) \
1594 if (unroll_factor > iter) { \
1595 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
1596 } else { \
1597 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
1598 }
1599
1600#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
1601
1602#define MICRO_PREFETCH_ONE(iter) \
1603 if (unroll_factor > iter) { \
1604 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
1605 }
1606
1607#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
1608
1609#define MICRO_STORE_ONE(iter) \
1610 if (unroll_factor > iter) { \
1611 bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
1612 bscale<Packet,accRows>(acc, accZero##iter, pAlpha); \
1613 res.template storePacketBlock<Packet,accRows>(row + iter*accCols, 0, acc); \
1614 }
1615
1616#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
1617
1618template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
1619EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
1620 const DataMapper& res,
1621 const Scalar* lhs_base,
1622 const Scalar* rhs_base,
1623 Index depth,
1624 Index strideA,
1625 Index& row,
1626 const Packet& pAlpha)
1627{
1628 const Scalar* rhs_ptr = rhs_base;
1629 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
1630 PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
1631 PacketBlock<Packet,accRows> acc;
1632
1633 MICRO_SRC_PTR
1634 MICRO_DST_PTR
1635
1636 Index k = 0;
1637 for(; k + PEEL <= depth; k+= PEEL)
1638 {
1639 EIGEN_POWER_PREFETCH(rhs_ptr);
1640 MICRO_PREFETCH
1641 MICRO_ONE_PEEL4
1642 }
1643 for(; k < depth; k++)
1644 {
1645 MICRO_ONE4
1646 }
1647 MICRO_STORE
1648
1649 row += unroll_factor*accCols;
1650}
1651
1652template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
1653EIGEN_ALWAYS_INLINE void gemm_cols(
1654 const DataMapper& res,
1655 const Scalar* blockA,
1656 const Scalar* blockB,
1657 Index depth,
1658 Index strideA,
1659 Index offsetA,
1660 Index strideB,
1661 Index offsetB,
1662 Index col,
1663 Index rows,
1664 Index cols,
1665 Index remaining_rows,
1666 const Packet& pAlpha,
1667 const Packet& pMask)
1668{
1669 const DataMapper res3 = res.getSubMapper(0, col);
1670
1671 const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
1672 const Scalar* lhs_base = blockA + accCols*offsetA;
1673 Index row = 0;
1674
1675#define MAX_UNROLL 6
1676 while(row + MAX_UNROLL*accCols <= rows) {
1677 gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1678 }
1679 switch( (rows-row)/accCols ) {
1680#if MAX_UNROLL > 7
1681 case 7:
1682 gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1683 break;
1684#endif
1685#if MAX_UNROLL > 6
1686 case 6:
1687 gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1688 break;
1689#endif
1690#if MAX_UNROLL > 5
1691 case 5:
1692 gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1693 break;
1694#endif
1695#if MAX_UNROLL > 4
1696 case 4:
1697 gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1698 break;
1699#endif
1700#if MAX_UNROLL > 3
1701 case 3:
1702 gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1703 break;
1704#endif
1705#if MAX_UNROLL > 2
1706 case 2:
1707 gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1708 break;
1709#endif
1710#if MAX_UNROLL > 1
1711 case 1:
1712 gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1713 break;
1714#endif
1715 default:
1716 break;
1717 }
1718#undef MAX_UNROLL
1719
1720 if(remaining_rows > 0)
1721 {
1722 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
1723 }
1724}
1725
1726template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
1727EIGEN_STRONG_INLINE void gemm_extra_cols(
1728 const DataMapper& res,
1729 const Scalar* blockA,
1730 const Scalar* blockB,
1731 Index depth,
1732 Index strideA,
1733 Index offsetA,
1734 Index strideB,
1735 Index offsetB,
1736 Index col,
1737 Index rows,
1738 Index cols,
1739 Index remaining_rows,
1740 const Packet& pAlpha,
1741 const Packet& pMask)
1742{
1743 for (; col < cols; col++) {
1744 gemm_cols<Scalar, Packet, DataMapper, Index, 1, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
1745 }
1746}
1747
1748/****************
1749 * GEMM kernels *
1750 * **************/
1751template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
1752EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
1753{
1754 const Index remaining_rows = rows % accCols;
1755
1756 if( strideA == -1 ) strideA = depth;
1757 if( strideB == -1 ) strideB = depth;
1758
1759 const Packet pAlpha = pset1<Packet>(alpha);
1760 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
1761
1762 Index col = 0;
1763 for(; col + accRows <= cols; col += accRows)
1764 {
1765 gemm_cols<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
1766 }
1767
1768 gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
1769}
1770
1771#define accColsC (accCols / 2)
1772#define advanceRows ((LhsIsReal) ? 1 : 2)
1773#define advanceCols ((RhsIsReal) ? 1 : 2)
1774
1775// PEEL_COMPLEX loop factor.
1776#define PEEL_COMPLEX 3
1777#define PEEL_COMPLEX_ROW 3
1778
1779#define MICRO_COMPLEX_UNROLL_PEEL(func) \
1780 func(0) func(1) func(2) func(3)
1781
1782#define MICRO_COMPLEX_ZERO_PEEL(peel) \
1783 if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
1784 bsetzero<Scalar, Packet, accRows>(accReal##peel); \
1785 bsetzero<Scalar, Packet, accRows>(accImag##peel); \
1786 } else { \
1787 EIGEN_UNUSED_VARIABLE(accReal##peel); \
1788 EIGEN_UNUSED_VARIABLE(accImag##peel); \
1789 }
1790
1791#define MICRO_COMPLEX_ZERO_PEEL_ROW \
1792 MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_ZERO_PEEL);
1793
1794#define MICRO_COMPLEX_WORK_PEEL(peel) \
1795 if (PEEL_COMPLEX_ROW > peel) { \
1796 pbroadcastN_old<Packet,accRows>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1797 if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
1798 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
1799 } else { \
1800 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1801 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
1802 }
1803
1804#define MICRO_COMPLEX_WORK_PEEL_ROW \
1805 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
1806 Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
1807 MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_WORK_PEEL); \
1808 lhs_ptr_real += (remaining_rows * PEEL_COMPLEX_ROW); \
1809 if(!LhsIsReal) lhs_ptr_imag += (remaining_rows * PEEL_COMPLEX_ROW); \
1810 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); \
1811 rhs_ptr_real += (accRows * PEEL_COMPLEX_ROW); \
1812 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_ROW); \
1813 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1814
1815#define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
1816 if (PEEL_COMPLEX_ROW > peel) { \
1817 for (Index i = 0; i < accRows; i++) { \
1818 accReal##sum.packet[i] += accReal##peel.packet[i]; \
1819 accImag##sum.packet[i] += accImag##peel.packet[i]; \
1820 } \
1821 }
1822
1823#define MICRO_COMPLEX_ADD_PEEL_ROW \
1824 MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) \
1825 MICRO_COMPLEX_ADD_PEEL(1, 0)
1826
1827template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
1828EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
1829 const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
1830 const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
1831 PacketBlock<Packet,accRows> &accReal, PacketBlock<Packet,accRows> &accImag)
1832{
1833 Packet rhsV[4], rhsVi[4];
1834 pbroadcastN_old<Packet,accRows>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1835 if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
1836 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
1837 lhs_ptr_real += remaining_rows;
1838 if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
1839 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1840 rhs_ptr_real += accRows;
1841 if(!RhsIsReal) rhs_ptr_imag += accRows;
1842 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1843}
1844
1845template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
1846EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(
1847 const DataMapper& res,
1848 const Scalar* lhs_base,
1849 const Scalar* rhs_base,
1850 Index depth,
1851 Index strideA,
1852 Index offsetA,
1853 Index strideB,
1854 Index row,
1855 Index col,
1856 Index rows,
1857 Index cols,
1858 const Packet& pAlphaReal,
1859 const Packet& pAlphaImag,
1860 const Packet& pMask)
1861{
1862 const Scalar* rhs_ptr_real = rhs_base;
1863 const Scalar* rhs_ptr_imag = NULL;
1864 if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB;
1865 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1866 const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
1867 const Scalar* lhs_ptr_imag = NULL;
1868 if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
1869 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1870 PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
1871 PacketBlock<Packet,accRows> taccReal, taccImag;
1872 PacketBlock<Packetc,accRows> acc0, acc1;
1873 PacketBlock<Packetc,accRows*2> tRes;
1874
1875 bsetzero<Scalar, Packet, accRows>(accReal0);
1876 bsetzero<Scalar, Packet, accRows>(accImag0);
1877
1878 Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
1879 Index k = 0;
1880 if (remaining_depth >= PEEL_COMPLEX_ROW) {
1881 MICRO_COMPLEX_ZERO_PEEL_ROW
1882 do
1883 {
1884 EIGEN_POWER_PREFETCH(rhs_ptr_real);
1885 if(!RhsIsReal) {
1886 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
1887 }
1888 EIGEN_POWER_PREFETCH(lhs_ptr_real);
1889 if(!LhsIsReal) {
1890 EIGEN_POWER_PREFETCH(lhs_ptr_imag);
1891 }
1892 MICRO_COMPLEX_WORK_PEEL_ROW
1893 } while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth);
1894 MICRO_COMPLEX_ADD_PEEL_ROW
1895 }
1896 for(; k < remaining_depth; k++)
1897 {
1898 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal0, accImag0);
1899 }
1900
1901 if ((remaining_depth == depth) && (rows >= accCols))
1902 {
1903 bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows>(tRes, res, row, 0);
1904 bscalec<Packet,accRows>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
1905 bcouple<Packet, Packetc, accRows>(taccReal, taccImag, tRes, acc0, acc1);
1906 res.template storePacketBlock<Packetc,accRows>(row + 0, 0, acc0);
1907 res.template storePacketBlock<Packetc,accRows>(row + accColsC, 0, acc1);
1908 } else {
1909 for(; k < depth; k++)
1910 {
1911 Packet rhsV[4], rhsVi[4];
1912 pbroadcastN_old<Packet,accRows>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1913 if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
1914 pgerc<accRows, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(&accReal0, &accImag0, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
1915 lhs_ptr_real += remaining_rows;
1916 if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
1917 rhs_ptr_real += accRows;
1918 if(!RhsIsReal) rhs_ptr_imag += accRows;
1919 }
1920
1921 bscalec<Packet,accRows>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag);
1922 bcouple_common<Packet, Packetc, accRows>(taccReal, taccImag, acc0, acc1);
1923
1924 if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
1925 {
1926 for(Index j = 0; j < accRows; j++) {
1927 res(row + 0, j) += pfirst<Packetc>(acc0.packet[j]);
1928 }
1929 } else {
1930 for(Index j = 0; j < accRows; j++) {
1931 PacketBlock<Packetc,1> acc2;
1932 acc2.packet[0] = res.template loadPacket<Packetc>(row + 0, j) + acc0.packet[j];
1933 res.template storePacketBlock<Packetc,1>(row + 0, j, acc2);
1934 if(remaining_rows > accColsC) {
1935 res(row + accColsC, j) += pfirst<Packetc>(acc1.packet[j]);
1936 }
1937 }
1938 }
1939 }
1940}
1941
1942template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1943EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
1944 const DataMapper& res,
1945 const Scalar* lhs_base,
1946 const Scalar* rhs_base,
1947 Index depth,
1948 Index strideA,
1949 Index offsetA,
1950 Index strideB,
1951 Index row,
1952 Index col,
1953 Index rows,
1954 Index cols,
1955 Index remaining_rows,
1956 const Packet& pAlphaReal,
1957 const Packet& pAlphaImag,
1958 const Packet& pMask)
1959{
1960 switch(remaining_rows) {
1961 case 1:
1962 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
1963 break;
1964 case 2:
1965 if (sizeof(Scalar) == sizeof(float)) {
1966 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
1967 }
1968 break;
1969 default:
1970 if (sizeof(Scalar) == sizeof(float)) {
1971 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
1972 }
1973 break;
1974 }
1975}
1976
1977#define MICRO_COMPLEX_UNROLL(func) \
1978 func(0) func(1) func(2) func(3)
1979
1980#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
1981 MICRO_COMPLEX_UNROLL(func2); \
1982 func(0,peel) func(1,peel) func(2,peel) func(3,peel)
1983
1984#define MICRO_COMPLEX_LOAD_ONE(iter) \
1985 if (unroll_factor > iter) { \
1986 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
1987 if(!LhsIsReal) { \
1988 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
1989 } else { \
1990 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
1991 } \
1992 lhs_ptr_real##iter += accCols; \
1993 } else { \
1994 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
1995 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
1996 }
1997
1998#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
1999 if (unroll_factor > iter) { \
2000 pgerc_common<accRows, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
2001 }
2002
2003#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
2004 if (PEEL_COMPLEX > peel) { \
2005 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
2006 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
2007 pbroadcastN_old<Packet,accRows>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
2008 if(!RhsIsReal) { \
2009 pbroadcastN_old<Packet,accRows>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
2010 } else { \
2011 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2012 } \
2013 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2014 } else { \
2015 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2016 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2017 }
2018
2019#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2020 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
2021 Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
2022 func(func1,func2,0); func(func1,func2,1); \
2023 func(func1,func2,2); func(func1,func2,3);
2024
2025#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
2026 Packet rhsV0[M], rhsVi0[M];\
2027 func(func1,func2,0);
2028
2029#define MICRO_COMPLEX_ONE_PEEL4 \
2030 MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
2031 rhs_ptr_real += (accRows * PEEL_COMPLEX); \
2032 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX);
2033
2034#define MICRO_COMPLEX_ONE4 \
2035 MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
2036 rhs_ptr_real += accRows; \
2037 if(!RhsIsReal) rhs_ptr_imag += accRows;
2038
2039#define MICRO_COMPLEX_DST_PTR_ONE(iter) \
2040 if (unroll_factor > iter) { \
2041 bsetzero<Scalar, Packet, accRows>(accReal##iter); \
2042 bsetzero<Scalar, Packet, accRows>(accImag##iter); \
2043 } else { \
2044 EIGEN_UNUSED_VARIABLE(accReal##iter); \
2045 EIGEN_UNUSED_VARIABLE(accImag##iter); \
2046 }
2047
2048#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
2049
2050#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \
2051 if (unroll_factor > iter) { \
2052 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
2053 } else { \
2054 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
2055 }
2056
2057#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
2058
2059#define MICRO_COMPLEX_PREFETCH_ONE(iter) \
2060 if (unroll_factor > iter) { \
2061 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
2062 }
2063
2064#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
2065
2066#define MICRO_COMPLEX_STORE_ONE(iter) \
2067 if (unroll_factor > iter) { \
2068 bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows>(tRes, res, row + iter*accCols, 0); \
2069 bscalec<Packet,accRows>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
2070 bcouple<Packet, Packetc, accRows>(taccReal, taccImag, tRes, acc0, acc1); \
2071 res.template storePacketBlock<Packetc,accRows>(row + iter*accCols + 0, 0, acc0); \
2072 res.template storePacketBlock<Packetc,accRows>(row + iter*accCols + accColsC, 0, acc1); \
2073 }
2074
2075#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
2076
2077template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2078EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
2079 const DataMapper& res,
2080 const Scalar* lhs_base,
2081 const Scalar* rhs_base,
2082 Index depth,
2083 Index strideA,
2084 Index strideB,
2085 Index& row,
2086 const Packet& pAlphaReal,
2087 const Packet& pAlphaImag)
2088{
2089 const Scalar* rhs_ptr_real = rhs_base;
2090 const Scalar* rhs_ptr_imag = NULL;
2091 const Index imag_delta = accCols*strideA;
2092 if(!RhsIsReal) {
2093 rhs_ptr_imag = rhs_base + accRows*strideB;
2094 } else {
2095 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
2096 }
2097 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
2098 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
2099 PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1;
2100 PacketBlock<Packet,accRows> accReal2, accImag2, accReal3, accImag3;
2101 PacketBlock<Packet,accRows> taccReal, taccImag;
2102 PacketBlock<Packetc,accRows> acc0, acc1;
2103 PacketBlock<Packetc,accRows*2> tRes;
2104
2105 MICRO_COMPLEX_SRC_PTR
2106 MICRO_COMPLEX_DST_PTR
2107
2108 Index k = 0;
2109 for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
2110 {
2111 EIGEN_POWER_PREFETCH(rhs_ptr_real);
2112 if(!RhsIsReal) {
2113 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
2114 }
2115 MICRO_COMPLEX_PREFETCH
2116 MICRO_COMPLEX_ONE_PEEL4
2117 }
2118 for(; k < depth; k++)
2119 {
2120 MICRO_COMPLEX_ONE4
2121 }
2122 MICRO_COMPLEX_STORE
2123
2124 row += unroll_factor*accCols;
2125}
2126
2127template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2128EIGEN_ALWAYS_INLINE void gemm_complex_cols(
2129 const DataMapper& res,
2130 const Scalar* blockA,
2131 const Scalar* blockB,
2132 Index depth,
2133 Index strideA,
2134 Index offsetA,
2135 Index strideB,
2136 Index offsetB,
2137 Index col,
2138 Index rows,
2139 Index cols,
2140 Index remaining_rows,
2141 const Packet& pAlphaReal,
2142 const Packet& pAlphaImag,
2143 const Packet& pMask)
2144{
2145 const DataMapper res3 = res.getSubMapper(0, col);
2146
2147 const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
2148 const Scalar* lhs_base = blockA + accCols*offsetA;
2149 Index row = 0;
2150
2151#define MAX_COMPLEX_UNROLL 3
2152 while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
2153 gemm_complex_unrolled_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2154 }
2155 switch( (rows-row)/accCols ) {
2156#if MAX_COMPLEX_UNROLL > 4
2157 case 4:
2158 gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2159 break;
2160#endif
2161#if MAX_COMPLEX_UNROLL > 3
2162 case 3:
2163 gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2164 break;
2165#endif
2166#if MAX_COMPLEX_UNROLL > 2
2167 case 2:
2168 gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2169 break;
2170#endif
2171#if MAX_COMPLEX_UNROLL > 1
2172 case 1:
2173 gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2174 break;
2175#endif
2176 default:
2177 break;
2178 }
2179#undef MAX_COMPLEX_UNROLL
2180
2181 if(remaining_rows > 0)
2182 {
2183 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2184 }
2185}
2186
2187template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2188EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
2189 const DataMapper& res,
2190 const Scalar* blockA,
2191 const Scalar* blockB,
2192 Index depth,
2193 Index strideA,
2194 Index offsetA,
2195 Index strideB,
2196 Index offsetB,
2197 Index col,
2198 Index rows,
2199 Index cols,
2200 Index remaining_rows,
2201 const Packet& pAlphaReal,
2202 const Packet& pAlphaImag,
2203 const Packet& pMask)
2204{
2205 for (; col < cols; col++) {
2206 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, 1, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2207 }
2208}
2209
2210template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2211EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
2212{
2213 const Index remaining_rows = rows % accCols;
2214
2215 if( strideA == -1 ) strideA = depth;
2216 if( strideB == -1 ) strideB = depth;
2217
2218 const Packet pAlphaReal = pset1<Packet>(alpha.real());
2219 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
2220 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
2221
2222 const Scalar* blockA = (Scalar *) blockAc;
2223 const Scalar* blockB = (Scalar *) blockBc;
2224
2225 Index col = 0;
2226 for(; col + accRows <= cols; col += accRows)
2227 {
2228 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2229 }
2230
2231 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2232}
2233
2234#undef accColsC
2235#undef advanceCols
2236#undef advanceRows
2237
2238/************************************
2239 * ppc64le template specializations *
2240 * **********************************/
2241template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2243{
2244 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2245};
2246
2247template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2249 ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2250{
2252 pack(blockA, lhs, depth, rows, stride, offset);
2253}
2254
2255template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2257{
2258 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2259};
2260
2261template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2263 ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2264{
2266 pack(blockA, lhs, depth, rows, stride, offset);
2267}
2268
2269#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
2270template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2272{
2273 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2274};
2275
2276template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2278 ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2279{
2281 pack(blockB, rhs, depth, cols, stride, offset);
2282}
2283
2284template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2286{
2287 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2288};
2289
2290template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2292 ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2293{
2295 pack(blockB, rhs, depth, cols, stride, offset);
2296}
2297#endif
2298
2299template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2301{
2302 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2303};
2304
2305template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2307 ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2308{
2310 pack(blockA, lhs, depth, rows, stride, offset);
2311}
2312
2313template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2315{
2316 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2317};
2318
2319template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2321 ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2322{
2324 pack(blockA, lhs, depth, rows, stride, offset);
2325}
2326
2327template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2329{
2330 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2331};
2332
2333template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2335 ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2336{
2338 pack(blockA, lhs, depth, rows, stride, offset);
2339}
2340
2341template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2343{
2344 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2345};
2346
2347template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2349 ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2350{
2352 pack(blockA, lhs, depth, rows, stride, offset);
2353}
2354
2355#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
2356template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2358{
2359 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2360};
2361
2362template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2364 ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2365{
2367 pack(blockB, rhs, depth, cols, stride, offset);
2368}
2369
2370template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2372{
2373 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2374};
2375
2376template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2378 ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2379{
2381 pack(blockB, rhs, depth, cols, stride, offset);
2382}
2383#endif
2384
2385template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2387{
2388 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2389};
2390
2391template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2393 ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2394{
2396 pack(blockB, rhs, depth, cols, stride, offset);
2397}
2398
2399template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2401{
2402 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2403};
2404
2405template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2407 ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2408{
2410 pack(blockB, rhs, depth, cols, stride, offset);
2411}
2412
2413template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2415{
2416 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2417};
2418
2419template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2421 ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2422{
2424 pack(blockA, lhs, depth, rows, stride, offset);
2425}
2426
2427template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2429{
2430 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2431};
2432
2433template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2435 ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2436{
2438 pack(blockA, lhs, depth, rows, stride, offset);
2439}
2440
2441template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2443{
2444 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2445};
2446
2447template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2449 ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2450{
2452 pack(blockB, rhs, depth, cols, stride, offset);
2453}
2454
2455template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2457{
2458 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2459};
2460
2461template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2463 ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2464{
2466 pack(blockB, rhs, depth, cols, stride, offset);
2467}
2468
2469// ********* gebp specializations *********
2470template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2471struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2472{
2473 typedef typename quad_traits<float>::vectortype Packet;
2474 typedef typename quad_traits<float>::rhstype RhsPacket;
2475
2476 void operator()(const DataMapper& res, const float* blockA, const float* blockB,
2477 Index rows, Index depth, Index cols, float alpha,
2479};
2480
2481template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2483 ::operator()(const DataMapper& res, const float* blockA, const float* blockB,
2484 Index rows, Index depth, Index cols, float alpha,
2486 {
2489 void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index);
2490
2491 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2492 //generate with MMA only
2493 gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2494 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
2495 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2496 gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2497 }
2498 else{
2499 gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2500 }
2501 #else
2502 gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2503 #endif
2504 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2505 }
2506
2507template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2508struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2509{
2510 typedef Packet4f Packet;
2511 typedef Packet2cf Packetc;
2512 typedef Packet4f RhsPacket;
2513
2514 void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
2515 Index rows, Index depth, Index cols, std::complex<float> alpha,
2517};
2518
2519template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2520void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2521 ::operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
2522 Index rows, Index depth, Index cols, std::complex<float> alpha,
2524 {
2527 void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*,
2528 Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
2529
2530 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2531 //generate with MMA only
2532 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2533 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
2534 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2535 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2536 }
2537 else{
2538 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2539 }
2540 #else
2541 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2542 #endif
2543 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2544 }
2545
2546template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2548{
2549 typedef Packet4f Packet;
2550 typedef Packet2cf Packetc;
2551 typedef Packet4f RhsPacket;
2552
2553 void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
2554 Index rows, Index depth, Index cols, std::complex<float> alpha,
2556};
2557
2558template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2560 ::operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
2561 Index rows, Index depth, Index cols, std::complex<float> alpha,
2563 {
2566 void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*,
2567 Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
2568 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2569 //generate with MMA only
2570 gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2571 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
2572 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2573 gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2574 }
2575 else{
2576 gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2577 }
2578 #else
2579 gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2580 #endif
2581 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2582 }
2583
2584template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2586{
2587 typedef Packet4f Packet;
2588 typedef Packet2cf Packetc;
2589 typedef Packet4f RhsPacket;
2590
2591 void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
2592 Index rows, Index depth, Index cols, std::complex<float> alpha,
2594};
2595
2596template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2598 ::operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
2599 Index rows, Index depth, Index cols, std::complex<float> alpha,
2601 {
2604 void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*,
2605 Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
2606 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2607 //generate with MMA only
2608 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2609 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
2610 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2611 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2612 }
2613 else{
2614 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2615 }
2616 #else
2617 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2618 #endif
2619 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2620 }
2621
2622template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2624{
2625 typedef typename quad_traits<double>::vectortype Packet;
2626 typedef typename quad_traits<double>::rhstype RhsPacket;
2627
2628 void operator()(const DataMapper& res, const double* blockA, const double* blockB,
2629 Index rows, Index depth, Index cols, double alpha,
2631};
2632
2633template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2635 ::operator()(const DataMapper& res, const double* blockA, const double* blockB,
2636 Index rows, Index depth, Index cols, double alpha,
2638 {
2641 void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index);
2642
2643 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2644 //generate with MMA only
2645 gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2646 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
2647 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2648 gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2649 }
2650 else{
2651 gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2652 }
2653 #else
2654 gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2655 #endif
2656 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2657 }
2658
2659template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2660struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2661{
2662 typedef quad_traits<double>::vectortype Packet;
2663 typedef Packet1cd Packetc;
2664 typedef quad_traits<double>::rhstype RhsPacket;
2665
2666 void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
2667 Index rows, Index depth, Index cols, std::complex<double> alpha,
2669};
2670
2671template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2672void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2673 ::operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
2674 Index rows, Index depth, Index cols, std::complex<double> alpha,
2676 {
2679 void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*,
2680 Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
2681 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2682 //generate with MMA only
2683 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2684 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
2685 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2686 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2687 }
2688 else{
2689 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2690 }
2691 #else
2692 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2693 #endif
2694 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2695 }
2696
2697template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2699{
2700 typedef quad_traits<double>::vectortype Packet;
2701 typedef Packet1cd Packetc;
2702 typedef quad_traits<double>::rhstype RhsPacket;
2703
2704 void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
2705 Index rows, Index depth, Index cols, std::complex<double> alpha,
2707};
2708
2709template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2711 ::operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
2712 Index rows, Index depth, Index cols, std::complex<double> alpha,
2714 {
2717 void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*,
2718 Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
2719 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2720 //generate with MMA only
2721 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2722 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
2723 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2724 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2725 }
2726 else{
2727 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2728 }
2729 #else
2730 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2731 #endif
2732 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2733 }
2734
2735template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2737{
2738 typedef quad_traits<double>::vectortype Packet;
2739 typedef Packet1cd Packetc;
2740 typedef quad_traits<double>::rhstype RhsPacket;
2741
2742 void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
2743 Index rows, Index depth, Index cols, std::complex<double> alpha,
2745};
2746
2747template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2749 ::operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
2750 Index rows, Index depth, Index cols, std::complex<double> alpha,
2752 {
2755 void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*,
2756 Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
2757 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2758 //generate with MMA only
2759 gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2760 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
2761 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2762 gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2763 }
2764 else{
2765 gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2766 }
2767 #else
2768 gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2769 #endif
2770 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2771 }
2772} // end namespace internal
2773
2774} // end namespace Eigen
2775
2776#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H
Definition ForwardDeclarations.h:87
Base class for all dense matrices, vectors, and expressions.
Definition MatrixBase.h:50
@ ColMajor
Storage order is column major (see TopicStorageOrders).
Definition Constants.h:319
@ RowMajor
Storage order is row major (see TopicStorageOrders).
Definition Constants.h:321
Namespace containing all symbols from the Eigen library.
Definition LDLT.h:16
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74
int N
Simulate some binary data with a single categorical and single continuous predictor.
Definition logistic_regression.py:26
Definition BFloat16.h:88
Definition Complex.h:338
Definition Complex.h:33
Definition MatrixProduct.h:431
Definition MatrixProduct.h:581
Definition GeneralBlockPanelKernel.h:1058
Definition BlasUtil.h:28
Definition BlasUtil.h:25
Definition GenericPacketMath.h:107
Definition MatrixProduct.h:66
Definition SelfadjointMatrixMatrix.h:20
Definition SelfadjointMatrixMatrix.h:102
Definition PacketMath.h:47