Fixes triangular solves on indexed/sliced dense expressions

libeigen/eigen!2340

Closes #2814
This commit is contained in:
Florian Maurin
2026-03-22 18:12:21 +00:00
committed by Rasmus Munk Larsen
parent ac6aedc60a
commit 71ef987edb
2 changed files with 48 additions and 10 deletions

View File

@@ -53,10 +53,11 @@ struct triangular_solver_selector<Lhs, Rhs, Side, Mode, NoUnrolling, 1> {
typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar;
typedef blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
typedef remove_all_t<ActualLhsType> ActualLhsTypeCleaned;
typedef Map<Matrix<RhsScalar, Dynamic, 1>, Aligned> MappedRhs;
static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs) {
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
add_const_on_value_type_t<ActualLhsType> actualLhs = LhsProductTraits::extract(lhs);
// FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1
@@ -67,10 +68,11 @@ struct triangular_solver_selector<Lhs, Rhs, Side, Mode, NoUnrolling, 1> {
if (!useRhsDirectly) MappedRhs(actualRhs, rhs.size()) = rhs;
triangular_solve_vector<LhsScalar, RhsScalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate,
(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor>::run(actualLhs.cols(),
actualLhs.data(),
actualLhs.outerStride(),
actualRhs);
(int(ActualLhsTypeCleaned::Flags) & RowMajorBit) ? RowMajor
: ColMajor>::run(actualLhs.cols(),
actualLhs.data(),
actualLhs.outerStride(),
actualRhs);
if (!useRhsDirectly) rhs = MappedRhs(actualRhs, rhs.size());
}
@@ -181,11 +183,15 @@ EIGEN_DEVICE_FUNC void TriangularViewImpl<MatrixType, Mode, Dense>::solveInPlace
if (derived().cols() == 0) return;
enum {
copy = (internal::traits<OtherDerived>::Flags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime &&
OtherDerived::SizeAtCompileTime != 1
OtherFlags = internal::traits<OtherDerived>::Flags,
IsRowMajorVector =
(OtherFlags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime && OtherDerived::SizeAtCompileTime != 1,
copy = IsRowMajorVector || ((OtherFlags & DirectAccessBit) == 0)
};
typedef std::conditional_t<copy, typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>
OtherCopy;
typedef std::conditional_t<IsRowMajorVector, typename internal::plain_matrix_type_column_major<OtherDerived>::type,
typename internal::plain_matrix_type<OtherDerived>::type>
OtherPlainObject;
typedef std::conditional_t<copy, OtherPlainObject, OtherDerived&> OtherCopy;
OtherCopy otherCopy(other);
internal::triangular_solver_selector<MatrixType, std::remove_reference_t<OtherCopy>, Side, Mode>::run(

View File

@@ -221,6 +221,37 @@ void trsolve_strided_boundary() {
}
}
void trsolve_indexed_view() {
typedef Matrix<double, Dynamic, Dynamic> MatrixX;
typedef Matrix<double, Dynamic, 1> VectorX;
MatrixX lhs = MatrixX::Random(8, 8);
lhs *= 0.1;
lhs.diagonal().array() += 1.0;
VectorX rhs = VectorX::Random(8);
std::vector<int> indices{0, 1, 2, 7};
MatrixX lhs_slice = lhs(indices, indices);
VectorX rhs_slice = rhs(indices);
VectorX expected = lhs_slice.triangularView<Upper>().solve(rhs_slice);
VectorX actual = lhs(indices, indices).triangularView<Upper>().solve(rhs(indices));
VERIFY_IS_APPROX(actual, expected);
VectorX assigned = VectorX::Random(8);
VectorX assigned_ref = assigned;
assigned(indices) = lhs_slice.triangularView<Upper>().solve(rhs_slice);
assigned_ref(indices) = expected;
VERIFY_IS_APPROX(assigned, assigned_ref);
VectorX inplace = rhs;
VectorX inplace_ref = rhs;
lhs_slice.triangularView<Upper>().solveInPlace(inplace(indices));
inplace_ref(indices) = expected;
VERIFY_IS_APPROX(inplace, inplace_ref);
}
EIGEN_DECLARE_TEST(product_trsolve) {
for (int i = 0; i < g_repeat; i++) {
// matrices
@@ -250,4 +281,5 @@ EIGEN_DECLARE_TEST(product_trsolve) {
// Strided solve at blocking boundaries (deterministic, outside g_repeat).
CALL_SUBTEST_15(trsolve_strided_boundary<0>());
CALL_SUBTEST_16(trsolve_indexed_view());
}