diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index 9d3187422..684f7a55e 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -53,10 +53,11 @@ struct triangular_solver_selector { typedef typename Lhs::Scalar LhsScalar; typedef typename Rhs::Scalar RhsScalar; typedef blas_traits LhsProductTraits; - typedef typename LhsProductTraits::ExtractType ActualLhsType; + typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; + typedef remove_all_t ActualLhsTypeCleaned; typedef Map, Aligned> MappedRhs; static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs) { - ActualLhsType actualLhs = LhsProductTraits::extract(lhs); + add_const_on_value_type_t actualLhs = LhsProductTraits::extract(lhs); // FIXME find a way to allow an inner stride if packet_traits::size==1 @@ -67,10 +68,11 @@ struct triangular_solver_selector { if (!useRhsDirectly) MappedRhs(actualRhs, rhs.size()) = rhs; triangular_solve_vector::run(actualLhs.cols(), - actualLhs.data(), - actualLhs.outerStride(), - actualRhs); + (int(ActualLhsTypeCleaned::Flags) & RowMajorBit) ? RowMajor + : ColMajor>::run(actualLhs.cols(), + actualLhs.data(), + actualLhs.outerStride(), + actualRhs); if (!useRhsDirectly) rhs = MappedRhs(actualRhs, rhs.size()); } @@ -181,11 +183,15 @@ EIGEN_DEVICE_FUNC void TriangularViewImpl::solveInPlace if (derived().cols() == 0) return; enum { - copy = (internal::traits::Flags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime && - OtherDerived::SizeAtCompileTime != 1 + OtherFlags = internal::traits::Flags, + IsRowMajorVector = + (OtherFlags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime && OtherDerived::SizeAtCompileTime != 1, + copy = IsRowMajorVector || ((OtherFlags & DirectAccessBit) == 0) }; - typedef std::conditional_t::type, OtherDerived&> - OtherCopy; + typedef std::conditional_t::type, + typename internal::plain_matrix_type::type> + OtherPlainObject; + typedef std::conditional_t OtherCopy; OtherCopy otherCopy(other); internal::triangular_solver_selector, Side, Mode>::run( diff --git a/test/product_trsolve.cpp b/test/product_trsolve.cpp index c7dfb2581..998d422eb 100644 --- a/test/product_trsolve.cpp +++ b/test/product_trsolve.cpp @@ -221,6 +221,37 @@ void trsolve_strided_boundary() { } } +void trsolve_indexed_view() { + typedef Matrix MatrixX; + typedef Matrix VectorX; + + MatrixX lhs = MatrixX::Random(8, 8); + lhs *= 0.1; + lhs.diagonal().array() += 1.0; + + VectorX rhs = VectorX::Random(8); + std::vector indices{0, 1, 2, 7}; + + MatrixX lhs_slice = lhs(indices, indices); + VectorX rhs_slice = rhs(indices); + VectorX expected = lhs_slice.triangularView().solve(rhs_slice); + + VectorX actual = lhs(indices, indices).triangularView().solve(rhs(indices)); + VERIFY_IS_APPROX(actual, expected); + + VectorX assigned = VectorX::Random(8); + VectorX assigned_ref = assigned; + assigned(indices) = lhs_slice.triangularView().solve(rhs_slice); + assigned_ref(indices) = expected; + VERIFY_IS_APPROX(assigned, assigned_ref); + + VectorX inplace = rhs; + VectorX inplace_ref = rhs; + lhs_slice.triangularView().solveInPlace(inplace(indices)); + inplace_ref(indices) = expected; + VERIFY_IS_APPROX(inplace, inplace_ref); +} + EIGEN_DECLARE_TEST(product_trsolve) { for (int i = 0; i < g_repeat; i++) { // matrices @@ -250,4 +281,5 @@ EIGEN_DECLARE_TEST(product_trsolve) { // Strided solve at blocking boundaries (deterministic, outside g_repeat). CALL_SUBTEST_15(trsolve_strided_boundary<0>()); + CALL_SUBTEST_16(trsolve_indexed_view()); }