Vectorize generic trsmKernelR for non-AVX512 targets

libeigen/eigen!2135

Closes #3027

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-02-18 17:34:31 -08:00
parent 43a01f06ad
commit 3c86a013b1

View File

@@ -97,14 +97,72 @@ EIGEN_STRONG_INLINE void trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageO
Index j = IsLower ? size - k - 1 : k;
typename LhsMapper::LinearMapper r = lhs.getLinearMapper(0, j);
for (Index k3 = 0; k3 < k; ++k3) {
Scalar b = conj(rhs(IsLower ? j + 1 + k3 : k3, j));
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0, IsLower ? j + 1 + k3 : k3);
for (Index i = 0; i < otherSize; ++i) r(i) -= a(i) * b;
EIGEN_IF_CONSTEXPR(OtherInnerStride == 1 && packet_traits<Scalar>::Vectorizable) {
using Packet = typename packet_traits<Scalar>::type;
constexpr Index PS = unpacket_traits<Packet>::size;
// Unrolled k3 loop by 4 to reduce r load/store traffic.
Index k3 = 0;
for (; k3 + 3 < k; k3 += 4) {
Index col0 = IsLower ? j + 1 + k3 : k3;
Scalar b0 = conj(rhs(col0, j));
Scalar b1 = conj(rhs(col0 + 1, j));
Scalar b2 = conj(rhs(col0 + 2, j));
Scalar b3 = conj(rhs(col0 + 3, j));
Packet neg_pb0 = pset1<Packet>(-b0);
Packet neg_pb1 = pset1<Packet>(-b1);
Packet neg_pb2 = pset1<Packet>(-b2);
Packet neg_pb3 = pset1<Packet>(-b3);
typename LhsMapper::LinearMapper a0 = lhs.getLinearMapper(0, col0);
typename LhsMapper::LinearMapper a1 = lhs.getLinearMapper(0, col0 + 1);
typename LhsMapper::LinearMapper a2 = lhs.getLinearMapper(0, col0 + 2);
typename LhsMapper::LinearMapper a3 = lhs.getLinearMapper(0, col0 + 3);
Index i = 0;
for (; i + PS <= otherSize; i += PS) {
Packet pr = r.template loadPacket<Packet>(i);
pr = pmadd(a0.template loadPacket<Packet>(i), neg_pb0, pr);
pr = pmadd(a1.template loadPacket<Packet>(i), neg_pb1, pr);
pr = pmadd(a2.template loadPacket<Packet>(i), neg_pb2, pr);
pr = pmadd(a3.template loadPacket<Packet>(i), neg_pb3, pr);
r.template storePacket(i, pr);
}
for (; i < otherSize; ++i) {
r(i) -= a0(i) * b0 + a1(i) * b1 + a2(i) * b2 + a3(i) * b3;
}
}
// Handle remaining k3 iterations with vectorized inner loop.
for (; k3 < k; ++k3) {
Scalar b = conj(rhs(IsLower ? j + 1 + k3 : k3, j));
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0, IsLower ? j + 1 + k3 : k3);
Packet neg_pb = pset1<Packet>(-b);
Index i = 0;
for (; i + PS <= otherSize; i += PS) {
Packet pr = r.template loadPacket<Packet>(i);
pr = pmadd(a.template loadPacket<Packet>(i), neg_pb, pr);
r.template storePacket(i, pr);
}
for (; i < otherSize; ++i) r(i) -= a(i) * b;
}
// Vectorized diagonal scaling.
EIGEN_IF_CONSTEXPR((Mode & UnitDiag) == 0) {
Scalar inv_rjj = RealScalar(1) / conj(rhs(j, j));
Packet pinv = pset1<Packet>(inv_rjj);
Index i = 0;
for (; i + PS <= otherSize; i += PS) {
r.template storePacket(i, pmul(r.template loadPacket<Packet>(i), pinv));
}
for (; i < otherSize; ++i) r(i) *= inv_rjj;
}
}
if ((Mode & UnitDiag) == 0) {
Scalar inv_rjj = RealScalar(1) / conj(rhs(j, j));
for (Index i = 0; i < otherSize; ++i) r(i) *= inv_rjj;
else {
for (Index k3 = 0; k3 < k; ++k3) {
Scalar b = conj(rhs(IsLower ? j + 1 + k3 : k3, j));
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0, IsLower ? j + 1 + k3 : k3);
for (Index i = 0; i < otherSize; ++i) r(i) -= a(i) * b;
}
EIGEN_IF_CONSTEXPR((Mode & UnitDiag) == 0) {
Scalar inv_rjj = RealScalar(1) / conj(rhs(j, j));
for (Index i = 0; i < otherSize; ++i) r(i) *= inv_rjj;
}
}
}
}