mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Fixes triangular solves on indexed/sliced dense expressions
libeigen/eigen!2340 Closes #2814
This commit is contained in:
committed by
Rasmus Munk Larsen
parent
ac6aedc60a
commit
71ef987edb
@@ -53,10 +53,11 @@ struct triangular_solver_selector<Lhs, Rhs, Side, Mode, NoUnrolling, 1> {
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
typedef blas_traits<Lhs> LhsProductTraits;
|
||||
typedef typename LhsProductTraits::ExtractType ActualLhsType;
|
||||
typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef remove_all_t<ActualLhsType> ActualLhsTypeCleaned;
|
||||
typedef Map<Matrix<RhsScalar, Dynamic, 1>, Aligned> MappedRhs;
|
||||
static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs) {
|
||||
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
|
||||
add_const_on_value_type_t<ActualLhsType> actualLhs = LhsProductTraits::extract(lhs);
|
||||
|
||||
// FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1
|
||||
|
||||
@@ -67,10 +68,11 @@ struct triangular_solver_selector<Lhs, Rhs, Side, Mode, NoUnrolling, 1> {
|
||||
if (!useRhsDirectly) MappedRhs(actualRhs, rhs.size()) = rhs;
|
||||
|
||||
triangular_solve_vector<LhsScalar, RhsScalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate,
|
||||
(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor>::run(actualLhs.cols(),
|
||||
actualLhs.data(),
|
||||
actualLhs.outerStride(),
|
||||
actualRhs);
|
||||
(int(ActualLhsTypeCleaned::Flags) & RowMajorBit) ? RowMajor
|
||||
: ColMajor>::run(actualLhs.cols(),
|
||||
actualLhs.data(),
|
||||
actualLhs.outerStride(),
|
||||
actualRhs);
|
||||
|
||||
if (!useRhsDirectly) rhs = MappedRhs(actualRhs, rhs.size());
|
||||
}
|
||||
@@ -181,11 +183,15 @@ EIGEN_DEVICE_FUNC void TriangularViewImpl<MatrixType, Mode, Dense>::solveInPlace
|
||||
if (derived().cols() == 0) return;
|
||||
|
||||
enum {
|
||||
copy = (internal::traits<OtherDerived>::Flags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime &&
|
||||
OtherDerived::SizeAtCompileTime != 1
|
||||
OtherFlags = internal::traits<OtherDerived>::Flags,
|
||||
IsRowMajorVector =
|
||||
(OtherFlags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime && OtherDerived::SizeAtCompileTime != 1,
|
||||
copy = IsRowMajorVector || ((OtherFlags & DirectAccessBit) == 0)
|
||||
};
|
||||
typedef std::conditional_t<copy, typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>
|
||||
OtherCopy;
|
||||
typedef std::conditional_t<IsRowMajorVector, typename internal::plain_matrix_type_column_major<OtherDerived>::type,
|
||||
typename internal::plain_matrix_type<OtherDerived>::type>
|
||||
OtherPlainObject;
|
||||
typedef std::conditional_t<copy, OtherPlainObject, OtherDerived&> OtherCopy;
|
||||
OtherCopy otherCopy(other);
|
||||
|
||||
internal::triangular_solver_selector<MatrixType, std::remove_reference_t<OtherCopy>, Side, Mode>::run(
|
||||
|
||||
@@ -221,6 +221,37 @@ void trsolve_strided_boundary() {
|
||||
}
|
||||
}
|
||||
|
||||
void trsolve_indexed_view() {
|
||||
typedef Matrix<double, Dynamic, Dynamic> MatrixX;
|
||||
typedef Matrix<double, Dynamic, 1> VectorX;
|
||||
|
||||
MatrixX lhs = MatrixX::Random(8, 8);
|
||||
lhs *= 0.1;
|
||||
lhs.diagonal().array() += 1.0;
|
||||
|
||||
VectorX rhs = VectorX::Random(8);
|
||||
std::vector<int> indices{0, 1, 2, 7};
|
||||
|
||||
MatrixX lhs_slice = lhs(indices, indices);
|
||||
VectorX rhs_slice = rhs(indices);
|
||||
VectorX expected = lhs_slice.triangularView<Upper>().solve(rhs_slice);
|
||||
|
||||
VectorX actual = lhs(indices, indices).triangularView<Upper>().solve(rhs(indices));
|
||||
VERIFY_IS_APPROX(actual, expected);
|
||||
|
||||
VectorX assigned = VectorX::Random(8);
|
||||
VectorX assigned_ref = assigned;
|
||||
assigned(indices) = lhs_slice.triangularView<Upper>().solve(rhs_slice);
|
||||
assigned_ref(indices) = expected;
|
||||
VERIFY_IS_APPROX(assigned, assigned_ref);
|
||||
|
||||
VectorX inplace = rhs;
|
||||
VectorX inplace_ref = rhs;
|
||||
lhs_slice.triangularView<Upper>().solveInPlace(inplace(indices));
|
||||
inplace_ref(indices) = expected;
|
||||
VERIFY_IS_APPROX(inplace, inplace_ref);
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(product_trsolve) {
|
||||
for (int i = 0; i < g_repeat; i++) {
|
||||
// matrices
|
||||
@@ -250,4 +281,5 @@ EIGEN_DECLARE_TEST(product_trsolve) {
|
||||
|
||||
// Strided solve at blocking boundaries (deterministic, outside g_repeat).
|
||||
CALL_SUBTEST_15(trsolve_strided_boundary<0>());
|
||||
CALL_SUBTEST_16(trsolve_indexed_view());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user