Medial Code Documentation
Loading...
Searching...
No Matches
TriangularMatrixVector.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11#define EIGEN_TRIANGULARMATRIXVECTOR_H
12
13namespace Eigen {
14
15namespace internal {
16
17template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
19
20template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
21struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
22{
24 enum {
25 IsLower = ((Mode&Lower)==Lower),
26 HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
27 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
28 };
29 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
30 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
31};
32
33template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
35 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
36 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
37 {
38 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
39 Index size = (std::min)(_rows,_cols);
40 Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
41 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
42
44 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
45 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
46
48 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
49 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
50
52 ResMap res(_res,rows);
53
56
57 for (Index pi=0; pi<size; pi+=PanelWidth)
58 {
59 Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
60 for (Index k=0; k<actualPanelWidth; ++k)
61 {
62 Index i = pi + k;
63 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
64 Index r = IsLower ? actualPanelWidth-k : k+1;
65 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
66 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
67 if (HasUnitDiag)
68 res.coeffRef(i) += alpha * cjRhs.coeff(i);
69 }
70 Index r = IsLower ? rows - pi - actualPanelWidth : pi;
71 if (r>0)
72 {
73 Index s = IsLower ? pi+actualPanelWidth : 0;
74 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
76 LhsMapper(&lhs.coeffRef(s,pi), lhsStride),
77 RhsMapper(&rhs.coeffRef(pi), rhsIncr),
78 &res.coeffRef(s), resIncr, alpha);
79 }
80 }
81 if((!IsLower) && cols>size)
82 {
83 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
84 rows, cols-size,
85 LhsMapper(&lhs.coeffRef(0,size), lhsStride),
86 RhsMapper(&rhs.coeffRef(size), rhsIncr),
87 _res, resIncr, alpha);
88 }
89 }
90
91template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
92struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
93{
95 enum {
96 IsLower = ((Mode&Lower)==Lower),
97 HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
98 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
99 };
100 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
101 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
102};
103
104template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
106 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
107 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
108 {
109 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
110 Index diagSize = (std::min)(_rows,_cols);
111 Index rows = IsLower ? _rows : diagSize;
112 Index cols = IsLower ? diagSize : _cols;
113
115 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
116 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
117
119 const RhsMap rhs(_rhs,cols);
120 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
121
123 ResMap res(_res,rows,InnerStride<>(resIncr));
124
127
128 for (Index pi=0; pi<diagSize; pi+=PanelWidth)
129 {
130 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
131 for (Index k=0; k<actualPanelWidth; ++k)
132 {
133 Index i = pi + k;
134 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
135 Index r = IsLower ? k+1 : actualPanelWidth-k;
136 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
137 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
138 if (HasUnitDiag)
139 res.coeffRef(i) += alpha * cjRhs.coeff(i);
140 }
141 Index r = IsLower ? pi : cols - pi - actualPanelWidth;
142 if (r>0)
143 {
144 Index s = IsLower ? 0 : pi + actualPanelWidth;
145 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
147 LhsMapper(&lhs.coeffRef(pi,s), lhsStride),
148 RhsMapper(&rhs.coeffRef(s), rhsIncr),
149 &res.coeffRef(pi), resIncr, alpha);
150 }
151 }
152 if(IsLower && rows>diagSize)
153 {
154 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
155 rows-diagSize, cols,
156 LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride),
157 RhsMapper(&rhs.coeffRef(0), rhsIncr),
158 &res.coeffRef(diagSize), resIncr, alpha);
159 }
160 }
161
162/***************************************************************************
163* Wrapper to product_triangular_vector
164***************************************************************************/
165
166template<int Mode,int StorageOrder>
168
169} // end namespace internal
170
171namespace internal {
172
173template<int Mode, typename Lhs, typename Rhs>
175{
176 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
177 {
178 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
179
181 }
182};
183
184template<int Mode, typename Lhs, typename Rhs>
186{
187 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
188 {
189 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
190
192 internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
194 ::run(rhs.transpose(),lhs.transpose(), dstT, alpha);
195 }
196};
197
198} // end namespace internal
199
200namespace internal {
201
202// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
203
204template<int Mode> struct trmv_selector<Mode,ColMajor>
205{
206 template<typename Lhs, typename Rhs, typename Dest>
207 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
208 {
209 typedef typename Lhs::Scalar LhsScalar;
210 typedef typename Rhs::Scalar RhsScalar;
211 typedef typename Dest::Scalar ResScalar;
212 typedef typename Dest::RealScalar RealScalar;
213
214 typedef internal::blas_traits<Lhs> LhsBlasTraits;
215 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
216 typedef internal::blas_traits<Rhs> RhsBlasTraits;
217 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
218
220
221 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
222 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
223
224 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
225 * RhsBlasTraits::extractScalarFactor(rhs);
226
227 enum {
228 // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
229 // on, the other hand it is good for the cache to pack the vector anyways...
230 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
232 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
233 };
234
236
237 bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
239
241
242 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
243 evalToDest ? dest.data() : static_dest.data());
244
245 if(!evalToDest)
246 {
247 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
248 Index size = dest.size();
250 #endif
252 {
253 MappedDest(actualDestPtr, dest.size()).setZero();
254 compatibleAlpha = RhsScalar(1);
255 }
256 else
257 MappedDest(actualDestPtr, dest.size()) = dest;
258 }
259
261 <Index,Mode,
262 LhsScalar, LhsBlasTraits::NeedToConjugate,
263 RhsScalar, RhsBlasTraits::NeedToConjugate,
264 ColMajor>
265 ::run(actualLhs.rows(),actualLhs.cols(),
266 actualLhs.data(),actualLhs.outerStride(),
267 actualRhs.data(),actualRhs.innerStride(),
269
270 if (!evalToDest)
271 {
273 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
274 else
275 dest = MappedDest(actualDestPtr, dest.size());
276 }
277 }
278};
279
280template<int Mode> struct trmv_selector<Mode,RowMajor>
281{
282 template<typename Lhs, typename Rhs, typename Dest>
283 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
284 {
285 typedef typename Lhs::Scalar LhsScalar;
286 typedef typename Rhs::Scalar RhsScalar;
287 typedef typename Dest::Scalar ResScalar;
288
289 typedef internal::blas_traits<Lhs> LhsBlasTraits;
290 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
291 typedef internal::blas_traits<Rhs> RhsBlasTraits;
292 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
293 typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
294
295 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
296 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
297
298 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
299 * RhsBlasTraits::extractScalarFactor(rhs);
300
301 enum {
302 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
303 };
304
306
307 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
308 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
309
310 if(!DirectlyUseRhs)
311 {
312 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
313 Index size = actualRhs.size();
315 #endif
317 }
318
320 <Index,Mode,
321 LhsScalar, LhsBlasTraits::NeedToConjugate,
322 RhsScalar, RhsBlasTraits::NeedToConjugate,
323 RowMajor>
324 ::run(actualLhs.rows(),actualLhs.cols(),
325 actualLhs.data(),actualLhs.outerStride(),
326 actualRhsPtr,1,
327 dest.data(),dest.innerStride(),
329 }
330};
331
332} // end namespace internal
333
334} // end namespace Eigen
335
336#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
Pseudo expression representing a solving operation.
Definition Solve.h:63
@ UnitDiag
Matrix has ones on the diagonal; to be used in combination with #Lower or #Upper.
Definition Constants.h:208
@ ZeroDiag
Matrix has zeros on the diagonal; to be used in combination with #Lower or #Upper.
Definition Constants.h:210
@ Lower
View matrix as a lower triangular matrix.
Definition Constants.h:204
@ Upper
View matrix as an upper triangular matrix.
Definition Constants.h:206
@ Aligned
Definition Constants.h:235
@ ColMajor
Storage order is column major (see TopicStorageOrders).
Definition Constants.h:320
@ RowMajor
Storage order is row major (see TopicStorageOrders).
Definition Constants.h:322
const unsigned int RowMajorBit
for a matrix, this means that the storage order is row-major.
Definition Constants.h:61
Holds information about the various numeric (i.e.
Definition NumTraits.h:108
Definition BlasUtil.h:257
Definition BlasUtil.h:113
Definition ForwardDeclarations.h:17
Definition TriangularMatrixVector.h:18
Definition ProductEvaluators.h:656
Definition TriangularMatrixVector.h:167