Files
eigen/Eigen/src/Core/SolveTriangular.h
Gael Guennebaud 8885d56928 commit woking versions of triangular solvers naturally
handling conjuagted expression. still have to bench whether it
is faster (runtime and compile time) to directly call the
cache friendly functions, whence all the commented piece of code...
2009-07-09 23:59:18 +02:00

235 lines
9.8 KiB
C++

// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2009 Gael Guennebaud <g.gael@free.fr>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 3 of the License, or (at your option) any later version.
//
// Alternatively, you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation; either version 2 of
// the License, or (at your option) any later version.
//
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License and a copy of the GNU General Public License along with
// Eigen. If not, see <http://www.gnu.org/licenses/>.
#ifndef EIGEN_SOLVETRIANGULAR_H
#define EIGEN_SOLVETRIANGULAR_H
template<typename Lhs, typename Rhs,
int Mode, // Upper/Lower | UnitDiag
// bool ConjugateLhs, bool ConjugateRhs,
int UpLo = (Mode & LowerTriangularBit)
? LowerTriangular
: (Mode & UpperTriangularBit)
? UpperTriangular
: -1,
int StorageOrder = int(Lhs::Flags) & RowMajorBit
>
struct ei_triangular_solver_selector;
// forward substitution, row-major
template<typename Lhs, typename Rhs, int Mode, /*bool ConjugateLhs, bool ConjugateRhs,*/ int UpLo>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/UpLo,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
static void run(const Lhs& lhs, Rhs& other)
{//std::cerr << "row maj " << ConjugateLhs << " , " << ConjugateRhs
// << " " << typeid(Lhs).name() << "\n";
static const int PanelWidth = 40; // TODO make this a user definable constant
static const bool IsLowerTriangular = (UpLo==LowerTriangular);
const int size = lhs.cols();
for(int c=0 ; c<other.cols() ; ++c)
{
for(int pi=IsLowerTriangular ? 0 : size;
IsLowerTriangular ? pi<size : pi>0;
IsLowerTriangular ? pi+=PanelWidth : pi-=PanelWidth)
{
int actualPanelWidth = std::min(IsLowerTriangular ? size - pi : pi, PanelWidth);
int r = IsLowerTriangular ? pi : size - pi; // remaining size
if (r > 0)
{
int startRow = IsLowerTriangular ? pi : pi-actualPanelWidth;
int startCol = IsLowerTriangular ? 0 : pi;
// Block<Rhs,Dynamic,1> target(other,startRow,c,actualPanelWidth,1);
// ei_cache_friendly_product_rowmajor_times_vector<ConjugateLhs,ConjugateRhs>(
// &(lhs.const_cast_derived().coeffRef(startRow,startCol)), lhs.stride(),
// &(other.coeffRef(startCol, c)), r,
// target, Scalar(-1));
other.col(c).segment(startRow,actualPanelWidth) -=
lhs.block(startRow,startCol,actualPanelWidth,r)
* other.col(c).segment(startCol,r);
}
for(int k=0; k<actualPanelWidth; ++k)
{
int i = IsLowerTriangular ? pi+k : pi-k-1;
int s = IsLowerTriangular ? pi : i+1;
if (k>0)
other.coeffRef(i,c) -= ((lhs.row(i).segment(s,k).transpose())
.cwise()*(other.col(c).segment(s,k))).sum();
if(!(Mode & UnitDiagBit))
other.coeffRef(i,c) /= lhs.coeff(i,i);
}
}
}
}
};
// Implements the following configurations:
// - inv(LowerTriangular, ColMajor) * Column vector
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector
// - inv(UpperTriangular, ColMajor) * Column vector
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector
template<typename Lhs, typename Rhs, int Mode, /*bool ConjugateLhs, bool ConjugateRhs,*/ int UpLo>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/UpLo,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type Packet;
enum { PacketSize = ei_packet_traits<Scalar>::size };
static void run(const Lhs& lhs, Rhs& other)
{//std::cerr << "col maj " << ConjugateLhs << " , " << ConjugateRhs << "\n";
static const int PanelWidth = 4; // TODO make this a user definable constant
static const bool IsLowerTriangular = (UpLo==LowerTriangular);
const int size = lhs.cols();
for(int c=0 ; c<other.cols() ; ++c)
{
for(int pi=IsLowerTriangular ? 0 : size;
IsLowerTriangular ? pi<size : pi>0;
IsLowerTriangular ? pi+=PanelWidth : pi-=PanelWidth)
{
int actualPanelWidth = std::min(IsLowerTriangular ? size - pi : pi, PanelWidth);
int startBlock = IsLowerTriangular ? pi : pi-actualPanelWidth;
int endBlock = IsLowerTriangular ? pi + actualPanelWidth : 0;
for(int k=0; k<actualPanelWidth; ++k)
{
int i = IsLowerTriangular ? pi+k : pi-k-1;
if(!(Mode & UnitDiagBit))
other.coeffRef(i,c) /= lhs.coeff(i,i);
int r = actualPanelWidth - k - 1; // remaining size
if (r>0)
{
other.col(c).segment((IsLowerTriangular ? i+1 : i-r), r) -=
other.coeffRef(i,c)
* Block<Lhs,Dynamic,1>(lhs, (IsLowerTriangular ? i+1 : i-r), i, r, 1);
}
}
int r = IsLowerTriangular ? size - endBlock : startBlock; // remaining size
if (r > 0)
{
// ei_cache_friendly_product_colmajor_times_vector<ConjugateLhs,ConjugateRhs>(
// r,
// &(lhs.const_cast_derived().coeffRef(endBlock,startBlock)), lhs.stride(),
// other.col(c).segment(startBlock, actualPanelWidth),
// &(other.coeffRef(endBlock, c)),
// Scalar(-1));
other.col(c).segment(endBlock,r) -=
lhs.block(endBlock,startBlock,r,actualPanelWidth)
* other.col(c).segment(startBlock,actualPanelWidth);
}
}
}
}
};
/** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other
*
* \nonstableyet
*
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
* This function will const_cast it, so constness isn't honored here.
*
* See MatrixBase:solveTriangular() for the details.
*/
template<typename MatrixType, unsigned int Mode>
template<typename RhsDerived>
void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>& _rhs) const
{
RhsDerived& rhs = _rhs.const_cast_derived();
ei_assert(cols() == rows());
ei_assert(cols() == rhs.rows());
ei_assert(!(Mode & ZeroDiagBit));
ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit));
// typedef ei_product_factor_traits<MatrixType> LhsProductTraits;
// typedef ei_product_factor_traits<RhsDerived> RhsProductTraits;
// typedef typename LhsProductTraits::ActualXprType ActualLhsType;
// typedef typename RhsProductTraits::ActualXprType ActualRhsType;
// const ActualLhsType& actualLhs = LhsProductTraits::extract(_expression());
// ActualRhsType& actualRhs = const_cast<ActualRhsType&>(RhsProductTraits::extract(rhs));
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit };
// std::cerr << typeid(MatrixType).name() << "\n";
typedef typename ei_meta_if<copy,
typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy;
RhsCopy rhsCopy(rhs);
ei_triangular_solver_selector<MatrixType, typename ei_unref<RhsCopy>::type,
Mode/*, LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate*/>::run(_expression(), rhsCopy);
if (copy)
rhs = rhsCopy;
}
/** \returns the product of the inverse of \c *this with \a other, \a *this being triangular.
*
* \nonstableyet
*
* This function computes the inverse-matrix matrix product inverse(\c *this) * \a other.
* The matrix \c *this must be triangular and invertible (i.e., all the coefficients of the
* diagonal must be non zero). It works as a forward (resp. backward) substitution if \c *this
* is an upper (resp. lower) triangular matrix.
*
* It is required that \c *this be marked as either an upper or a lower triangular matrix, which
* can be done by marked(), and that is automatically the case with expressions such as those returned
* by extract().
*
* \addexample SolveTriangular \label How to solve a triangular system (aka. how to multiply the inverse of a triangular matrix by another one)
*
* Example: \include MatrixBase_marked.cpp
* Output: \verbinclude MatrixBase_marked.out
*
* This function is essentially a wrapper to the faster solveTriangularInPlace() function creating
* a temporary copy of \a other, calling solveTriangularInPlace() on the copy and returning it.
* Therefore, if \a other is not needed anymore, it is quite faster to call solveTriangularInPlace()
* instead of solveTriangular().
*
* For users coming from BLAS, this function (and more specifically solveTriangularInPlace()) offer
* all the operations supported by the \c *TRSV and \c *TRSM BLAS routines.
*
* \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.:
* \code
* M * T^1 <=> T.transpose().solveTriangularInPlace(M.transpose());
* \endcode
*
* \sa solveTriangularInPlace()
*/
template<typename Derived, unsigned int Mode>
template<typename RhsDerived>
typename ei_plain_matrix_type_column_major<RhsDerived>::type
TriangularView<Derived,Mode>::solve(const MatrixBase<RhsDerived>& rhs) const
{
typename ei_plain_matrix_type_column_major<RhsDerived>::type res(rhs);
solveInPlace(res);
return res;
}
#endif // EIGEN_SOLVETRIANGULAR_H