// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Gael Guennebaud // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_SOLVE_H #define EIGEN_SOLVE_H // IWYU pragma: private #include "./InternalHeaderCheck.h" namespace Eigen { template class SolveImpl; /** \class Solve * \ingroup Core_Module * * \brief Pseudo expression representing a solving operation * * \tparam Decomposition the type of the matrix or decomposition object * \tparam Rhstype the type of the right-hand side * * This class represents an expression of A.solve(B) * and most of the time this is the only way it is used. * */ namespace internal { // this solve_traits class permits to determine the evaluation type with respect to storage kind (Dense vs Sparse) template struct solve_traits; template struct solve_traits { typedef typename make_proper_matrix_type::type PlainObject; }; template struct traits > : traits< typename solve_traits::StorageKind>::PlainObject> { typedef typename solve_traits::StorageKind>::PlainObject PlainObject; typedef typename promote_index_type::type StorageIndex; typedef traits BaseTraits; enum { Flags = BaseTraits::Flags & RowMajorBit, CoeffReadCost = HugeCost }; }; } // namespace internal template class Solve : public SolveImpl::StorageKind> { public: typedef typename internal::traits::PlainObject PlainObject; typedef typename internal::traits::StorageIndex StorageIndex; Solve(const Decomposition &dec, const RhsType &rhs) : m_dec(dec), m_rhs(rhs) {} EIGEN_DEVICE_FUNC constexpr Index rows() const noexcept { return m_dec.cols(); } EIGEN_DEVICE_FUNC constexpr Index cols() const noexcept { return m_rhs.cols(); } EIGEN_DEVICE_FUNC constexpr const Decomposition &dec() const { return m_dec; } EIGEN_DEVICE_FUNC constexpr const RhsType &rhs() const { return m_rhs; } protected: const Decomposition &m_dec; const typename internal::ref_selector::type m_rhs; }; // Specialization of the Solve expression for dense results template class SolveImpl : public MatrixBase > { typedef Solve Derived; public: typedef MatrixBase > Base; EIGEN_DENSE_PUBLIC_INTERFACE(Derived) private: Scalar coeff(Index row, Index col) const; Scalar coeff(Index i) const; }; // Generic API dispatcher template class SolveImpl : public internal::generic_xpr_base, MatrixXpr, StorageKind>::type { public: typedef typename internal::generic_xpr_base, MatrixXpr, StorageKind>::type Base; }; namespace internal { // Evaluator of Solve -> eval into a temporary template struct evaluator > : public evaluator::PlainObject> { typedef Solve SolveType; typedef typename SolveType::PlainObject PlainObject; typedef evaluator Base; enum { Flags = Base::Flags | EvalBeforeNestingBit }; EIGEN_DEVICE_FUNC explicit evaluator(const SolveType &solve) : m_result(solve.rows(), solve.cols()) { internal::construct_at(this, m_result); solve.dec()._solve_impl(solve.rhs(), m_result); } protected: PlainObject m_result; }; // Specialization for "dst = dec.solve(rhs)" // NOTE we need to specialize it for Dense2Dense to avoid ambiguous specialization error and a Sparse2Sparse // specialization must exist somewhere template struct Assignment, internal::assign_op, Dense2Dense> { typedef Solve SrcXprType; static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op &) { Index dstRows = src.rows(); Index dstCols = src.cols(); if ((dst.rows() != dstRows) || (dst.cols() != dstCols)) dst.resize(dstRows, dstCols); src.dec()._solve_impl(src.rhs(), dst); } }; // Specialization for "dst = dec.transpose().solve(rhs)" template struct Assignment, RhsType>, internal::assign_op, Dense2Dense> { typedef Solve, RhsType> SrcXprType; static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op &) { Index dstRows = src.rows(); Index dstCols = src.cols(); if ((dst.rows() != dstRows) || (dst.cols() != dstCols)) dst.resize(dstRows, dstCols); src.dec().nestedExpression().template _solve_impl_transposed(src.rhs(), dst); } }; // Specialization for "dst = dec.adjoint().solve(rhs)" template struct Assignment< DstXprType, Solve, const Transpose >, RhsType>, internal::assign_op, Dense2Dense> { typedef Solve, const Transpose >, RhsType> SrcXprType; static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op &) { Index dstRows = src.rows(); Index dstCols = src.cols(); if ((dst.rows() != dstRows) || (dst.cols() != dstCols)) dst.resize(dstRows, dstCols); src.dec().nestedExpression().nestedExpression().template _solve_impl_transposed(src.rhs(), dst); } }; } // end namespace internal } // end namespace Eigen #endif // EIGEN_SOLVE_H