33#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
34#define EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
46template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder>
50#define EIGEN_MKL_TRMV_SPECIALIZE(Scalar) \
51template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
52struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
53 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
54 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
55 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
56 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
59template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
60struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
61 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
62 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
63 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
64 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
68EIGEN_MKL_TRMV_SPECIALIZE(
double)
69EIGEN_MKL_TRMV_SPECIALIZE(
float)
74#define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
75template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
76struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
78 IsLower = (Mode&Lower) == Lower, \
79 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
80 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
81 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
82 LowUp = IsLower ? Lower : Upper \
84 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
85 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
87 if (ConjLhs || IsZeroDiag) { \
88 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
89 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
92 Index size = (std::min)(_rows,_cols); \
93 Index rows = IsLower ? _rows : size; \
94 Index cols = IsLower ? size : _cols; \
96 typedef VectorX##EIGPREFIX VectorRhs; \
100 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
102 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
107 char trans, uplo, diag; \
108 MKL_INT m, n, lda, incx, incy; \
110 MKLTYPE alpha_, beta_; \
111 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
112 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
122 uplo = IsLower ? 'L' : 'U'; \
123 diag = IsUnitDiag ? 'U' : 'N'; \
126 MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
129 MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
131 if (size<(std::max)(rows,cols)) { \
132 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
135 y = _res + size*resIncr; \
143 a = _lhs + size*lda; \
147 MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
152EIGEN_MKL_TRMV_CM(
double,
double, d, d)
154EIGEN_MKL_TRMV_CM(
float,
float, f, s)
158#define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
159template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
160struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
162 IsLower = (Mode&Lower) == Lower, \
163 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
164 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
165 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
166 LowUp = IsLower ? Lower : Upper \
168 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
169 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
172 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
173 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
176 Index size = (std::min)(_rows,_cols); \
177 Index rows = IsLower ? _rows : size; \
178 Index cols = IsLower ? size : _cols; \
180 typedef VectorX##EIGPREFIX VectorRhs; \
184 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
186 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
191 char trans, uplo, diag; \
192 MKL_INT m, n, lda, incx, incy; \
194 MKLTYPE alpha_, beta_; \
195 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
196 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
205 trans = ConjLhs ? 'C' : 'T'; \
206 uplo = IsLower ? 'U' : 'L'; \
207 diag = IsUnitDiag ? 'U' : 'N'; \
210 MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
213 MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
215 if (size<(std::max)(rows,cols)) { \
216 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
219 y = _res + size*resIncr; \
220 a = _lhs + size*lda; \
231 MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
236EIGEN_MKL_TRMV_RM(
double,
double, d, d)
238EIGEN_MKL_TRMV_RM(
float,
float, f, s)
Pseudo expression representing a solving operation.
Definition Solve.h:63
Definition TriangularMatrixVector_MKL.h:48
Definition TriangularMatrixVector.h:18