mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Vectorize generic trsmKernelR for non-AVX512 targets
libeigen/eigen!2135 Closes #3027 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
@@ -97,14 +97,72 @@ EIGEN_STRONG_INLINE void trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageO
|
||||
Index j = IsLower ? size - k - 1 : k;
|
||||
|
||||
typename LhsMapper::LinearMapper r = lhs.getLinearMapper(0, j);
|
||||
for (Index k3 = 0; k3 < k; ++k3) {
|
||||
Scalar b = conj(rhs(IsLower ? j + 1 + k3 : k3, j));
|
||||
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0, IsLower ? j + 1 + k3 : k3);
|
||||
for (Index i = 0; i < otherSize; ++i) r(i) -= a(i) * b;
|
||||
EIGEN_IF_CONSTEXPR(OtherInnerStride == 1 && packet_traits<Scalar>::Vectorizable) {
|
||||
using Packet = typename packet_traits<Scalar>::type;
|
||||
constexpr Index PS = unpacket_traits<Packet>::size;
|
||||
// Unrolled k3 loop by 4 to reduce r load/store traffic.
|
||||
Index k3 = 0;
|
||||
for (; k3 + 3 < k; k3 += 4) {
|
||||
Index col0 = IsLower ? j + 1 + k3 : k3;
|
||||
Scalar b0 = conj(rhs(col0, j));
|
||||
Scalar b1 = conj(rhs(col0 + 1, j));
|
||||
Scalar b2 = conj(rhs(col0 + 2, j));
|
||||
Scalar b3 = conj(rhs(col0 + 3, j));
|
||||
Packet neg_pb0 = pset1<Packet>(-b0);
|
||||
Packet neg_pb1 = pset1<Packet>(-b1);
|
||||
Packet neg_pb2 = pset1<Packet>(-b2);
|
||||
Packet neg_pb3 = pset1<Packet>(-b3);
|
||||
typename LhsMapper::LinearMapper a0 = lhs.getLinearMapper(0, col0);
|
||||
typename LhsMapper::LinearMapper a1 = lhs.getLinearMapper(0, col0 + 1);
|
||||
typename LhsMapper::LinearMapper a2 = lhs.getLinearMapper(0, col0 + 2);
|
||||
typename LhsMapper::LinearMapper a3 = lhs.getLinearMapper(0, col0 + 3);
|
||||
Index i = 0;
|
||||
for (; i + PS <= otherSize; i += PS) {
|
||||
Packet pr = r.template loadPacket<Packet>(i);
|
||||
pr = pmadd(a0.template loadPacket<Packet>(i), neg_pb0, pr);
|
||||
pr = pmadd(a1.template loadPacket<Packet>(i), neg_pb1, pr);
|
||||
pr = pmadd(a2.template loadPacket<Packet>(i), neg_pb2, pr);
|
||||
pr = pmadd(a3.template loadPacket<Packet>(i), neg_pb3, pr);
|
||||
r.template storePacket(i, pr);
|
||||
}
|
||||
for (; i < otherSize; ++i) {
|
||||
r(i) -= a0(i) * b0 + a1(i) * b1 + a2(i) * b2 + a3(i) * b3;
|
||||
}
|
||||
}
|
||||
// Handle remaining k3 iterations with vectorized inner loop.
|
||||
for (; k3 < k; ++k3) {
|
||||
Scalar b = conj(rhs(IsLower ? j + 1 + k3 : k3, j));
|
||||
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0, IsLower ? j + 1 + k3 : k3);
|
||||
Packet neg_pb = pset1<Packet>(-b);
|
||||
Index i = 0;
|
||||
for (; i + PS <= otherSize; i += PS) {
|
||||
Packet pr = r.template loadPacket<Packet>(i);
|
||||
pr = pmadd(a.template loadPacket<Packet>(i), neg_pb, pr);
|
||||
r.template storePacket(i, pr);
|
||||
}
|
||||
for (; i < otherSize; ++i) r(i) -= a(i) * b;
|
||||
}
|
||||
// Vectorized diagonal scaling.
|
||||
EIGEN_IF_CONSTEXPR((Mode & UnitDiag) == 0) {
|
||||
Scalar inv_rjj = RealScalar(1) / conj(rhs(j, j));
|
||||
Packet pinv = pset1<Packet>(inv_rjj);
|
||||
Index i = 0;
|
||||
for (; i + PS <= otherSize; i += PS) {
|
||||
r.template storePacket(i, pmul(r.template loadPacket<Packet>(i), pinv));
|
||||
}
|
||||
for (; i < otherSize; ++i) r(i) *= inv_rjj;
|
||||
}
|
||||
}
|
||||
if ((Mode & UnitDiag) == 0) {
|
||||
Scalar inv_rjj = RealScalar(1) / conj(rhs(j, j));
|
||||
for (Index i = 0; i < otherSize; ++i) r(i) *= inv_rjj;
|
||||
else {
|
||||
for (Index k3 = 0; k3 < k; ++k3) {
|
||||
Scalar b = conj(rhs(IsLower ? j + 1 + k3 : k3, j));
|
||||
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0, IsLower ? j + 1 + k3 : k3);
|
||||
for (Index i = 0; i < otherSize; ++i) r(i) -= a(i) * b;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR((Mode & UnitDiag) == 0) {
|
||||
Scalar inv_rjj = RealScalar(1) / conj(rhs(j, j));
|
||||
for (Index i = 0; i < otherSize; ++i) r(i) *= inv_rjj;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user