diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index c09fec4f6..7d42e0132 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -97,14 +97,72 @@ EIGEN_STRONG_INLINE void trsmKernelR::Vectorizable) { + using Packet = typename packet_traits::type; + constexpr Index PS = unpacket_traits::size; + // Unrolled k3 loop by 4 to reduce r load/store traffic. + Index k3 = 0; + for (; k3 + 3 < k; k3 += 4) { + Index col0 = IsLower ? j + 1 + k3 : k3; + Scalar b0 = conj(rhs(col0, j)); + Scalar b1 = conj(rhs(col0 + 1, j)); + Scalar b2 = conj(rhs(col0 + 2, j)); + Scalar b3 = conj(rhs(col0 + 3, j)); + Packet neg_pb0 = pset1(-b0); + Packet neg_pb1 = pset1(-b1); + Packet neg_pb2 = pset1(-b2); + Packet neg_pb3 = pset1(-b3); + typename LhsMapper::LinearMapper a0 = lhs.getLinearMapper(0, col0); + typename LhsMapper::LinearMapper a1 = lhs.getLinearMapper(0, col0 + 1); + typename LhsMapper::LinearMapper a2 = lhs.getLinearMapper(0, col0 + 2); + typename LhsMapper::LinearMapper a3 = lhs.getLinearMapper(0, col0 + 3); + Index i = 0; + for (; i + PS <= otherSize; i += PS) { + Packet pr = r.template loadPacket(i); + pr = pmadd(a0.template loadPacket(i), neg_pb0, pr); + pr = pmadd(a1.template loadPacket(i), neg_pb1, pr); + pr = pmadd(a2.template loadPacket(i), neg_pb2, pr); + pr = pmadd(a3.template loadPacket(i), neg_pb3, pr); + r.template storePacket(i, pr); + } + for (; i < otherSize; ++i) { + r(i) -= a0(i) * b0 + a1(i) * b1 + a2(i) * b2 + a3(i) * b3; + } + } + // Handle remaining k3 iterations with vectorized inner loop. + for (; k3 < k; ++k3) { + Scalar b = conj(rhs(IsLower ? j + 1 + k3 : k3, j)); + typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0, IsLower ? j + 1 + k3 : k3); + Packet neg_pb = pset1(-b); + Index i = 0; + for (; i + PS <= otherSize; i += PS) { + Packet pr = r.template loadPacket(i); + pr = pmadd(a.template loadPacket(i), neg_pb, pr); + r.template storePacket(i, pr); + } + for (; i < otherSize; ++i) r(i) -= a(i) * b; + } + // Vectorized diagonal scaling. + EIGEN_IF_CONSTEXPR((Mode & UnitDiag) == 0) { + Scalar inv_rjj = RealScalar(1) / conj(rhs(j, j)); + Packet pinv = pset1(inv_rjj); + Index i = 0; + for (; i + PS <= otherSize; i += PS) { + r.template storePacket(i, pmul(r.template loadPacket(i), pinv)); + } + for (; i < otherSize; ++i) r(i) *= inv_rjj; + } } - if ((Mode & UnitDiag) == 0) { - Scalar inv_rjj = RealScalar(1) / conj(rhs(j, j)); - for (Index i = 0; i < otherSize; ++i) r(i) *= inv_rjj; + else { + for (Index k3 = 0; k3 < k; ++k3) { + Scalar b = conj(rhs(IsLower ? j + 1 + k3 : k3, j)); + typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0, IsLower ? j + 1 + k3 : k3); + for (Index i = 0; i < otherSize; ++i) r(i) -= a(i) * b; + } + EIGEN_IF_CONSTEXPR((Mode & UnitDiag) == 0) { + Scalar inv_rjj = RealScalar(1) / conj(rhs(j, j)); + for (Index i = 0; i < otherSize; ++i) r(i) *= inv_rjj; + } } } }