mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
GEMM: catch all scalar-multiple variants when falling-back to a coeff-based product.
Before only s*A*B was caught which was both inconsistent with GEMM, sub-optimal, and could even lead to compilation-errors (https://stackoverflow.com/questions/54738495).
This commit is contained in:
@@ -11,6 +11,35 @@
|
||||
|
||||
#include "main.h"
|
||||
|
||||
template<typename Dst, typename Lhs, typename Rhs>
|
||||
void check_scalar_multiple3(Dst &dst, const Lhs& A, const Rhs& B)
|
||||
{
|
||||
VERIFY_EVALUATION_COUNT( (dst.noalias() = A * B), 0);
|
||||
VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() );
|
||||
VERIFY_EVALUATION_COUNT( (dst.noalias() += A * B), 0);
|
||||
VERIFY_IS_APPROX( dst, 2*(A.eval() * B.eval()).eval() );
|
||||
VERIFY_EVALUATION_COUNT( (dst.noalias() -= A * B), 0);
|
||||
VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() );
|
||||
}
|
||||
|
||||
template<typename Dst, typename Lhs, typename Rhs, typename S2>
|
||||
void check_scalar_multiple2(Dst &dst, const Lhs& A, const Rhs& B, S2 s2)
|
||||
{
|
||||
CALL_SUBTEST( check_scalar_multiple3(dst, A, B) );
|
||||
CALL_SUBTEST( check_scalar_multiple3(dst, A, -B) );
|
||||
CALL_SUBTEST( check_scalar_multiple3(dst, A, s2*B) );
|
||||
CALL_SUBTEST( check_scalar_multiple3(dst, A, B*s2) );
|
||||
}
|
||||
|
||||
template<typename Dst, typename Lhs, typename Rhs, typename S1, typename S2>
|
||||
void check_scalar_multiple1(Dst &dst, const Lhs& A, const Rhs& B, S1 s1, S2 s2)
|
||||
{
|
||||
CALL_SUBTEST( check_scalar_multiple2(dst, A, B, s2) );
|
||||
CALL_SUBTEST( check_scalar_multiple2(dst, -A, B, s2) );
|
||||
CALL_SUBTEST( check_scalar_multiple2(dst, s1*A, B, s2) );
|
||||
CALL_SUBTEST( check_scalar_multiple2(dst, A*s1, B, s2) );
|
||||
}
|
||||
|
||||
template<typename MatrixType> void product_notemporary(const MatrixType& m)
|
||||
{
|
||||
/* This test checks the number of temporaries created
|
||||
@@ -148,6 +177,15 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
|
||||
// Check nested products
|
||||
VERIFY_EVALUATION_COUNT( cvres.noalias() = m1.adjoint() * m1 * cv1, 1 );
|
||||
VERIFY_EVALUATION_COUNT( rvres.noalias() = rv1 * (m1 * m2.adjoint()), 1 );
|
||||
|
||||
// exhaustively check all scalar multiple combinations:
|
||||
{
|
||||
// Generic path:
|
||||
check_scalar_multiple1(m3, m1, m2, s1, s2);
|
||||
// Force fall back to coeff-based:
|
||||
typename ColMajorMatrixType::BlockXpr m3_blck = m3.block(r0,r0,1,1);
|
||||
check_scalar_multiple1(m3_blck, m1.block(r0,c0,1,1), m2.block(c0,r0,1,1), s1, s2);
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(product_notemporary)
|
||||
|
||||
Reference in New Issue
Block a user