Medial Code Documentation
Loading...
Searching...
No Matches
MatrixProductMMA.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
5// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
13
14// If using dynamic dispatch, set the CPU target.
15#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
16#pragma GCC push_options
17#pragma GCC target("cpu=power10,htm")
18#endif
19
20#ifdef __has_builtin
21#if !__has_builtin(__builtin_vsx_assemble_pair)
22#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
23#endif
24#endif
25
26namespace Eigen {
27
28namespace internal {
29
30template<typename Scalar, typename Packet>
31EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
32{
33 __builtin_mma_xxsetaccz(acc);
34}
35
36template<typename DataMapper, typename Index, typename Packet, const Index accCols>
37EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
38{
39 PacketBlock<Packet, 4> result;
40 __builtin_mma_disassemble_acc(&result.packet, acc);
41
42 PacketBlock<Packet, 4> tRes;
43 bload<DataMapper, Packet, Index, accCols, ColMajor, false, 4>(tRes, data, i, 0);
44
45 bscale<Packet, 4>(tRes, result, alpha);
46
47 data.template storePacketBlock<Packet, 4>(i, 0, tRes);
48}
49
50template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC>
51EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
52{
53 PacketBlock<Packet, 4> resultReal, resultImag;
54 __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
55 __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
56
57 PacketBlock<Packetc, 8> tRes;
58 bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4>(tRes, data, i, 0);
59
60 PacketBlock<Packet,4> taccReal, taccImag;
61 bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
62
63 PacketBlock<Packetc, 4> acc1, acc2;
64 bcouple<Packet, Packetc, 4>(taccReal, taccImag, tRes, acc1, acc2);
65
66 data.template storePacketBlock<Packetc, 4>(i, 0, acc1);
67 data.template storePacketBlock<Packetc, 4>(i + accColsC, 0, acc2);
68}
69
70// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
71template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
72EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b)
73{
74 if(NegativeAccumulate)
75 {
76 __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
77 } else {
78 __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
79 }
80}
81
82template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
83EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
84{
85 __vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
86 if(NegativeAccumulate)
87 {
88 __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b);
89 } else {
90 __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b);
91 }
92}
93
94template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
95EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b)
96{
97 if(NegativeAccumulate)
98 {
99 __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
100 } else {
101 __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
102 }
103}
104
105template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
106EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&)
107{
108 // Just for compilation
109}
110
111template<typename Scalar, typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
112EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi)
113{
114 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
115 if(LhsIsReal) {
116 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
117 } else {
118 if(!RhsIsReal) {
119 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
120 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
121 } else {
122 EIGEN_UNUSED_VARIABLE(rhsVi);
123 }
124 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
125 }
126}
127
128// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
129template<typename Scalar, typename Packet>
130EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
131{
132 rhsV = ploadRhs<Scalar, Packet>(rhs);
133}
134
135template<>
136EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
137{
138 rhsV.packet[0] = ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ));
139 rhsV.packet[1] = ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1));
140}
141
142template<>
143EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, __vector_pair>(const double* rhs, __vector_pair& rhsV)
144{
145#if EIGEN_COMP_LLVM
146 __builtin_vsx_assemble_pair(&rhsV,
147 (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1))),
148 (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ))));
149#else
150 __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs));
151#endif
152}
153
154template<>
155EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
156{
157 // Just for compilation
158}
159
160// PEEL_MMA loop factor.
161#define PEEL_MMA 7
162
163#define MICRO_MMA_UNROLL(func) \
164 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
165
166#define MICRO_MMA_LOAD_ONE(iter) \
167 if (unroll_factor > iter) { \
168 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
169 lhs_ptr##iter += accCols; \
170 } else { \
171 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
172 }
173
174#define MICRO_MMA_WORK_ONE(iter, type, peel) \
175 if (unroll_factor > iter) { \
176 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
177 }
178
179#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
180 if (PEEL_MMA > peel) { \
181 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
182 ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
183 MICRO_MMA_UNROLL(func2); \
184 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
185 func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
186 } else { \
187 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
188 }
189
190#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
191 type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
192 MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
193 MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
194 MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
195 MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7);
196
197#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
198 type rhsV0; \
199 MICRO_MMA_TYPE_PEEL(func,func2,type,0);
200
201#define MICRO_MMA_ONE_PEEL \
202 if (sizeof(Scalar) == sizeof(float)) { \
203 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
204 } else { \
205 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
206 } \
207 rhs_ptr += (accRows * PEEL_MMA);
208
209#define MICRO_MMA_ONE \
210 if (sizeof(Scalar) == sizeof(float)) { \
211 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
212 } else { \
213 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
214 } \
215 rhs_ptr += accRows;
216
217#define MICRO_MMA_DST_PTR_ONE(iter) \
218 if (unroll_factor > iter) { \
219 bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
220 } else { \
221 EIGEN_UNUSED_VARIABLE(accZero##iter); \
222 }
223
224#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
225
226#define MICRO_MMA_SRC_PTR_ONE(iter) \
227 if (unroll_factor > iter) { \
228 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
229 } else { \
230 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
231 }
232
233#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
234
235#define MICRO_MMA_PREFETCH_ONE(iter) \
236 if (unroll_factor > iter) { \
237 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
238 }
239
240#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
241
242#define MICRO_MMA_STORE_ONE(iter) \
243 if (unroll_factor > iter) { \
244 storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, res, pAlpha, &accZero##iter); \
245 }
246
247#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
248
249template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
250EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
251 const DataMapper& res,
252 const Scalar* lhs_base,
253 const Scalar* rhs_base,
254 Index depth,
255 Index strideA,
256 Index& row,
257 const Packet& pAlpha)
258{
259 const Scalar* rhs_ptr = rhs_base;
260 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
261 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
262
263 MICRO_MMA_SRC_PTR
264 MICRO_MMA_DST_PTR
265
266 Index k = 0;
267 for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
268 {
269 EIGEN_POWER_PREFETCH(rhs_ptr);
270 MICRO_MMA_PREFETCH
271 MICRO_MMA_ONE_PEEL
272 }
273 for(; k < depth; k++)
274 {
275 MICRO_MMA_ONE
276 }
277 MICRO_MMA_STORE
278
279 row += unroll_factor*accCols;
280}
281
282template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
283EIGEN_ALWAYS_INLINE void gemmMMA_cols(
284 const DataMapper& res,
285 const Scalar* blockA,
286 const Scalar* blockB,
287 Index depth,
288 Index strideA,
289 Index offsetA,
290 Index strideB,
291 Index offsetB,
292 Index col,
293 Index rows,
294 Index cols,
295 Index remaining_rows,
296 const Packet& pAlpha,
297 const Packet& pMask)
298{
299 const DataMapper res3 = res.getSubMapper(0, col);
300
301 const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
302 const Scalar* lhs_base = blockA + accCols*offsetA;
303 Index row = 0;
304
305#define MAX_MMA_UNROLL 7
306 while(row + MAX_MMA_UNROLL*accCols <= rows) {
307 gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
308 }
309 switch( (rows-row)/accCols ) {
310#if MAX_MMA_UNROLL > 7
311 case 7:
312 gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
313 break;
314#endif
315#if MAX_MMA_UNROLL > 6
316 case 6:
317 gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
318 break;
319#endif
320#if MAX_MMA_UNROLL > 5
321 case 5:
322 gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
323 break;
324#endif
325#if MAX_MMA_UNROLL > 4
326 case 4:
327 gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
328 break;
329#endif
330#if MAX_MMA_UNROLL > 3
331 case 3:
332 gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
333 break;
334#endif
335#if MAX_MMA_UNROLL > 2
336 case 2:
337 gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
338 break;
339#endif
340#if MAX_MMA_UNROLL > 1
341 case 1:
342 gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
343 break;
344#endif
345 default:
346 break;
347 }
348#undef MAX_MMA_UNROLL
349
350 if(remaining_rows > 0)
351 {
352 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
353 }
354}
355
356template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
357void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
358{
359 const Index remaining_rows = rows % accCols;
360
361 if( strideA == -1 ) strideA = depth;
362 if( strideB == -1 ) strideB = depth;
363
364 const Packet pAlpha = pset1<Packet>(alpha);
365 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
366
367 Index col = 0;
368 for(; col + accRows <= cols; col += accRows)
369 {
370 gemmMMA_cols<Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
371 }
372
373 gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
374}
375
376#define accColsC (accCols / 2)
377#define advanceRows ((LhsIsReal) ? 1 : 2)
378#define advanceCols ((RhsIsReal) ? 1 : 2)
379
380// PEEL_COMPLEX_MMA loop factor.
381#define PEEL_COMPLEX_MMA 3
382
383#define MICRO_COMPLEX_MMA_UNROLL(func) \
384 func(0) func(1) func(2) func(3)
385
386#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
387 if (unroll_factor > iter) { \
388 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
389 if(!LhsIsReal) { \
390 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
391 } else { \
392 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
393 } \
394 lhs_ptr_real##iter += accCols; \
395 } else { \
396 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
397 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
398 }
399
400#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
401 if (unroll_factor > iter) { \
402 pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
403 }
404
405#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
406 if (PEEL_COMPLEX_MMA > peel) { \
407 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
408 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
409 ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
410 if(!RhsIsReal) { \
411 ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
412 } else { \
413 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
414 } \
415 MICRO_COMPLEX_MMA_UNROLL(func2); \
416 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
417 } else { \
418 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
419 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
420 }
421
422#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
423 type rhsV0, rhsV1, rhsV2, rhsV3; \
424 type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
425 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
426 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3);
427
428#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
429 type rhsV0, rhsVi0; \
430 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
431
432#define MICRO_COMPLEX_MMA_ONE_PEEL \
433 if (sizeof(Scalar) == sizeof(float)) { \
434 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
435 } else { \
436 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
437 } \
438 rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
439 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
440
441#define MICRO_COMPLEX_MMA_ONE \
442 if (sizeof(Scalar) == sizeof(float)) { \
443 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
444 } else { \
445 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
446 } \
447 rhs_ptr_real += accRows; \
448 if(!RhsIsReal) rhs_ptr_imag += accRows;
449
450#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
451 if (unroll_factor > iter) { \
452 bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
453 bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
454 } else { \
455 EIGEN_UNUSED_VARIABLE(accReal##iter); \
456 EIGEN_UNUSED_VARIABLE(accImag##iter); \
457 }
458
459#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
460
461#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
462 if (unroll_factor > iter) { \
463 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
464 } else { \
465 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
466 }
467
468#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
469
470#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
471 if (unroll_factor > iter) { \
472 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
473 }
474
475#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
476
477#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
478 if (unroll_factor > iter) { \
479 storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC>(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
480 }
481
482#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
483
484template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
485EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
486 const DataMapper& res,
487 const Scalar* lhs_base,
488 const Scalar* rhs_base,
489 Index depth,
490 Index strideA,
491 Index strideB,
492 Index& row,
493 const Packet& pAlphaReal,
494 const Packet& pAlphaImag)
495{
496 const Scalar* rhs_ptr_real = rhs_base;
497 const Scalar* rhs_ptr_imag = NULL;
498 const Index imag_delta = accCols*strideA;
499 if(!RhsIsReal) {
500 rhs_ptr_imag = rhs_base + accRows*strideB;
501 } else {
502 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
503 }
504 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
505 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
506 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
507
508 MICRO_COMPLEX_MMA_SRC_PTR
509 MICRO_COMPLEX_MMA_DST_PTR
510
511 Index k = 0;
512 for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
513 {
514 EIGEN_POWER_PREFETCH(rhs_ptr_real);
515 if(!RhsIsReal) {
516 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
517 }
518 MICRO_COMPLEX_MMA_PREFETCH
519 MICRO_COMPLEX_MMA_ONE_PEEL
520 }
521 for(; k < depth; k++)
522 {
523 MICRO_COMPLEX_MMA_ONE
524 }
525 MICRO_COMPLEX_MMA_STORE
526
527 row += unroll_factor*accCols;
528}
529
530template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
531EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
532 const DataMapper& res,
533 const Scalar* blockA,
534 const Scalar* blockB,
535 Index depth,
536 Index strideA,
537 Index offsetA,
538 Index strideB,
539 Index offsetB,
540 Index col,
541 Index rows,
542 Index cols,
543 Index remaining_rows,
544 const Packet& pAlphaReal,
545 const Packet& pAlphaImag,
546 const Packet& pMask)
547{
548 const DataMapper res3 = res.getSubMapper(0, col);
549
550 const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
551 const Scalar* lhs_base = blockA + accCols*offsetA;
552 Index row = 0;
553
554#define MAX_COMPLEX_MMA_UNROLL 4
555 while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
556 gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
557 }
558 switch( (rows-row)/accCols ) {
559#if MAX_COMPLEX_MMA_UNROLL > 4
560 case 4:
561 gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
562 break;
563#endif
564#if MAX_COMPLEX_MMA_UNROLL > 3
565 case 3:
566 gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
567 break;
568#endif
569#if MAX_COMPLEX_MMA_UNROLL > 2
570 case 2:
571 gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
572 break;
573#endif
574#if MAX_COMPLEX_MMA_UNROLL > 1
575 case 1:
576 gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
577 break;
578#endif
579 default:
580 break;
581 }
582#undef MAX_COMPLEX_MMA_UNROLL
583
584 if(remaining_rows > 0)
585 {
586 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
587 }
588}
589
590template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
591void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
592{
593 const Index remaining_rows = rows % accCols;
594
595 if( strideA == -1 ) strideA = depth;
596 if( strideB == -1 ) strideB = depth;
597
598 const Packet pAlphaReal = pset1<Packet>(alpha.real());
599 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
600 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
601
602 const Scalar* blockA = (Scalar *) blockAc;
603 const Scalar* blockB = (Scalar *) blockBc;
604
605 Index col = 0;
606 for(; col + accRows <= cols; col += accRows)
607 {
608 gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
609 }
610
611 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
612}
613
614#undef accColsC
615#undef advanceRows
616#undef advanceCols
617
618} // end namespace internal
619
620} // end namespace Eigen
621
622#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
623#pragma GCC pop_options
624#endif
625
626#endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
627
Namespace containing all symbols from the Eigen library.
Definition LDLT.h:16
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74