11#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
15#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
16#pragma GCC push_options
17#pragma GCC target("cpu=power10,htm")
21#if !__has_builtin(__builtin_vsx_assemble_pair)
22#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
30template<
typename Scalar,
typename Packet>
31EIGEN_ALWAYS_INLINE
void bsetzeroMMA(__vector_quad* acc)
33 __builtin_mma_xxsetaccz(acc);
36template<
typename DataMapper,
typename Index,
typename Packet, const Index accCols>
37EIGEN_ALWAYS_INLINE
void storeAccumulator(
Index i,
const DataMapper& data,
const Packet& alpha, __vector_quad* acc)
39 PacketBlock<Packet, 4> result;
40 __builtin_mma_disassemble_acc(&result.packet, acc);
42 PacketBlock<Packet, 4> tRes;
43 bload<DataMapper, Packet, Index, accCols, ColMajor, false, 4>(tRes, data, i, 0);
45 bscale<Packet, 4>(tRes, result, alpha);
47 data.template storePacketBlock<Packet, 4>(i, 0, tRes);
50template<
typename DataMapper,
typename Index,
typename Packet,
typename Packetc, const Index accColsC>
51EIGEN_ALWAYS_INLINE
void storeComplexAccumulator(
Index i,
const DataMapper& data,
const Packet& alphaReal,
const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
53 PacketBlock<Packet, 4> resultReal, resultImag;
54 __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
55 __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
57 PacketBlock<Packetc, 8> tRes;
58 bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4>(tRes, data, i, 0);
60 PacketBlock<Packet,4> taccReal, taccImag;
61 bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
63 PacketBlock<Packetc, 4> acc1, acc2;
64 bcouple<Packet, Packetc, 4>(taccReal, taccImag, tRes, acc1, acc2);
66 data.template storePacketBlock<Packetc, 4>(i, 0, acc1);
67 data.template storePacketBlock<Packetc, 4>(i + accColsC, 0, acc2);
71template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
72EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad* acc,
const RhsPacket& a,
const LhsPacket& b)
74 if(NegativeAccumulate)
76 __builtin_mma_xvf32gernp(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
78 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
82template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
83EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad* acc,
const PacketBlock<Packet2d,2>& a,
const Packet2d& b)
85 __vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
86 if(NegativeAccumulate)
88 __builtin_mma_xvf64gernp(acc, *a0, (__vector
unsigned char)b);
90 __builtin_mma_xvf64gerpp(acc, *a0, (__vector
unsigned char)b);
94template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
95EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad* acc,
const __vector_pair& a,
const Packet2d& b)
97 if(NegativeAccumulate)
99 __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector
unsigned char)b);
101 __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector
unsigned char)b);
105template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
106EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad*,
const __vector_pair&,
const Packet4f&)
111template<
typename Scalar,
typename Packet,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
112EIGEN_ALWAYS_INLINE
void pgercMMA(__vector_quad* accReal, __vector_quad* accImag,
const Packet& lhsV,
const Packet& lhsVi,
const RhsPacket& rhsV,
const RhsPacket& rhsVi)
114 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
116 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
119 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
120 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
122 EIGEN_UNUSED_VARIABLE(rhsVi);
124 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
129template<
typename Scalar,
typename Packet>
130EIGEN_ALWAYS_INLINE
void ploadRhsMMA(
const Scalar* rhs, Packet& rhsV)
132 rhsV = ploadRhs<Scalar, Packet>(rhs);
136EIGEN_ALWAYS_INLINE
void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(
const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
138 rhsV.packet[0] = ploadRhs<double, Packet2d>((
const double *)((Packet2d *)rhs ));
139 rhsV.packet[1] = ploadRhs<double, Packet2d>((
const double *)(((Packet2d *)rhs) + 1));
143EIGEN_ALWAYS_INLINE
void ploadRhsMMA<double, __vector_pair>(
const double* rhs, __vector_pair& rhsV)
146 __builtin_vsx_assemble_pair(&rhsV,
147 (__vector
unsigned char)(ploadRhs<double, Packet2d>((
const double *)(((Packet2d *)rhs) + 1))),
148 (__vector
unsigned char)(ploadRhs<double, Packet2d>((
const double *)((Packet2d *)rhs ))));
150 __asm__ (
"lxvp %x0,%1" :
"=wa" (rhsV) :
"Y" (*rhs));
155EIGEN_ALWAYS_INLINE
void ploadRhsMMA(
const float*, __vector_pair&)
163#define MICRO_MMA_UNROLL(func) \
164 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
166#define MICRO_MMA_LOAD_ONE(iter) \
167 if (unroll_factor > iter) { \
168 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
169 lhs_ptr##iter += accCols; \
171 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
174#define MICRO_MMA_WORK_ONE(iter, type, peel) \
175 if (unroll_factor > iter) { \
176 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
179#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
180 if (PEEL_MMA > peel) { \
181 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
182 ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
183 MICRO_MMA_UNROLL(func2); \
184 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
185 func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
187 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
190#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
191 type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
192 MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
193 MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
194 MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
195 MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7);
197#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
199 MICRO_MMA_TYPE_PEEL(func,func2,type,0);
201#define MICRO_MMA_ONE_PEEL \
202 if (sizeof(Scalar) == sizeof(float)) { \
203 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
205 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
207 rhs_ptr += (accRows * PEEL_MMA);
209#define MICRO_MMA_ONE \
210 if (sizeof(Scalar) == sizeof(float)) { \
211 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
213 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
217#define MICRO_MMA_DST_PTR_ONE(iter) \
218 if (unroll_factor > iter) { \
219 bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
221 EIGEN_UNUSED_VARIABLE(accZero##iter); \
224#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
226#define MICRO_MMA_SRC_PTR_ONE(iter) \
227 if (unroll_factor > iter) { \
228 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
230 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
233#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
235#define MICRO_MMA_PREFETCH_ONE(iter) \
236 if (unroll_factor > iter) { \
237 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
240#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
242#define MICRO_MMA_STORE_ONE(iter) \
243 if (unroll_factor > iter) { \
244 storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, res, pAlpha, &accZero##iter); \
247#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
249template<
int unroll_factor,
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols>
250EIGEN_ALWAYS_INLINE
void gemm_unrolled_MMA_iteration(
251 const DataMapper& res,
252 const Scalar* lhs_base,
253 const Scalar* rhs_base,
257 const Packet& pAlpha)
259 const Scalar* rhs_ptr = rhs_base;
260 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
261 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
267 for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
269 EIGEN_POWER_PREFETCH(rhs_ptr);
273 for(; k < depth; k++)
279 row += unroll_factor*accCols;
282template<
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols>
283EIGEN_ALWAYS_INLINE
void gemmMMA_cols(
284 const DataMapper& res,
285 const Scalar* blockA,
286 const Scalar* blockB,
295 Index remaining_rows,
296 const Packet& pAlpha,
299 const DataMapper res3 = res.getSubMapper(0, col);
301 const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
302 const Scalar* lhs_base = blockA + accCols*offsetA;
305#define MAX_MMA_UNROLL 7
306 while(row + MAX_MMA_UNROLL*accCols <= rows) {
307 gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
309 switch( (rows-row)/accCols ) {
310#if MAX_MMA_UNROLL > 7
312 gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
315#if MAX_MMA_UNROLL > 6
317 gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
320#if MAX_MMA_UNROLL > 5
322 gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
325#if MAX_MMA_UNROLL > 4
327 gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
330#if MAX_MMA_UNROLL > 3
332 gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
335#if MAX_MMA_UNROLL > 2
337 gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
340#if MAX_MMA_UNROLL > 1
342 gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
350 if(remaining_rows > 0)
352 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
356template<
typename Scalar,
typename Index,
typename Packet,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols>
357void gemmMMA(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB,
Index rows,
Index depth,
Index cols, Scalar alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
359 const Index remaining_rows = rows % accCols;
361 if( strideA == -1 ) strideA = depth;
362 if( strideB == -1 ) strideB = depth;
364 const Packet pAlpha = pset1<Packet>(alpha);
365 const Packet pMask = bmask<Packet>((
const int)(remaining_rows));
368 for(; col + accRows <= cols; col += accRows)
370 gemmMMA_cols<Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
373 gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
376#define accColsC (accCols / 2)
377#define advanceRows ((LhsIsReal) ? 1 : 2)
378#define advanceCols ((RhsIsReal) ? 1 : 2)
381#define PEEL_COMPLEX_MMA 3
383#define MICRO_COMPLEX_MMA_UNROLL(func) \
384 func(0) func(1) func(2) func(3)
386#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
387 if (unroll_factor > iter) { \
388 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
390 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
392 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
394 lhs_ptr_real##iter += accCols; \
396 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
397 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
400#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
401 if (unroll_factor > iter) { \
402 pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
405#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
406 if (PEEL_COMPLEX_MMA > peel) { \
407 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
408 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
409 ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
411 ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
413 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
415 MICRO_COMPLEX_MMA_UNROLL(func2); \
416 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
418 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
419 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
422#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
423 type rhsV0, rhsV1, rhsV2, rhsV3; \
424 type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
425 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
426 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3);
428#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
429 type rhsV0, rhsVi0; \
430 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
432#define MICRO_COMPLEX_MMA_ONE_PEEL \
433 if (sizeof(Scalar) == sizeof(float)) { \
434 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
436 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
438 rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
439 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
441#define MICRO_COMPLEX_MMA_ONE \
442 if (sizeof(Scalar) == sizeof(float)) { \
443 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
445 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
447 rhs_ptr_real += accRows; \
448 if(!RhsIsReal) rhs_ptr_imag += accRows;
450#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
451 if (unroll_factor > iter) { \
452 bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
453 bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
455 EIGEN_UNUSED_VARIABLE(accReal##iter); \
456 EIGEN_UNUSED_VARIABLE(accImag##iter); \
459#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
461#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
462 if (unroll_factor > iter) { \
463 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
465 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
468#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
470#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
471 if (unroll_factor > iter) { \
472 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
475#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
477#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
478 if (unroll_factor > iter) { \
479 storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC>(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
482#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
484template<
int unroll_factor,
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
485EIGEN_ALWAYS_INLINE
void gemm_complex_unrolled_MMA_iteration(
486 const DataMapper& res,
487 const Scalar* lhs_base,
488 const Scalar* rhs_base,
493 const Packet& pAlphaReal,
494 const Packet& pAlphaImag)
496 const Scalar* rhs_ptr_real = rhs_base;
497 const Scalar* rhs_ptr_imag = NULL;
498 const Index imag_delta = accCols*strideA;
500 rhs_ptr_imag = rhs_base + accRows*strideB;
502 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
504 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
505 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
506 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
508 MICRO_COMPLEX_MMA_SRC_PTR
509 MICRO_COMPLEX_MMA_DST_PTR
512 for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
514 EIGEN_POWER_PREFETCH(rhs_ptr_real);
516 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
518 MICRO_COMPLEX_MMA_PREFETCH
519 MICRO_COMPLEX_MMA_ONE_PEEL
521 for(; k < depth; k++)
523 MICRO_COMPLEX_MMA_ONE
525 MICRO_COMPLEX_MMA_STORE
527 row += unroll_factor*accCols;
530template<
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
531EIGEN_ALWAYS_INLINE
void gemmMMA_complex_cols(
532 const DataMapper& res,
533 const Scalar* blockA,
534 const Scalar* blockB,
543 Index remaining_rows,
544 const Packet& pAlphaReal,
545 const Packet& pAlphaImag,
548 const DataMapper res3 = res.getSubMapper(0, col);
550 const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
551 const Scalar* lhs_base = blockA + accCols*offsetA;
554#define MAX_COMPLEX_MMA_UNROLL 4
555 while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
556 gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
558 switch( (rows-row)/accCols ) {
559#if MAX_COMPLEX_MMA_UNROLL > 4
561 gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
564#if MAX_COMPLEX_MMA_UNROLL > 3
566 gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
569#if MAX_COMPLEX_MMA_UNROLL > 2
571 gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
574#if MAX_COMPLEX_MMA_UNROLL > 1
576 gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
582#undef MAX_COMPLEX_MMA_UNROLL
584 if(remaining_rows > 0)
586 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
590template<
typename LhsScalar,
typename RhsScalar,
typename Scalarc,
typename Scalar,
typename Index,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
591void gemm_complexMMA(
const DataMapper& res,
const LhsScalar* blockAc,
const RhsScalar* blockBc,
Index rows,
Index depth,
Index cols, Scalarc alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
593 const Index remaining_rows = rows % accCols;
595 if( strideA == -1 ) strideA = depth;
596 if( strideB == -1 ) strideB = depth;
598 const Packet pAlphaReal = pset1<Packet>(alpha.real());
599 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
600 const Packet pMask = bmask<Packet>((
const int)(remaining_rows));
602 const Scalar* blockA = (Scalar *) blockAc;
603 const Scalar* blockB = (Scalar *) blockBc;
606 for(; col + accRows <= cols; col += accRows)
608 gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
611 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
622#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
623#pragma GCC pop_options
Namespace containing all symbols from the Eigen library.
Definition LDLT.h:16
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74