diff --git a/test/main.h b/test/main.h index 7ba0dc830..59a777312 100644 --- a/test/main.h +++ b/test/main.h @@ -605,6 +605,72 @@ typename NumTraits::Real get_test_precision( return test_precision::Real>(); } +// Rounding error bounds for matrix products, based on: +// +// Deterministic: Higham, "Accuracy and Stability of Numerical Algorithms", +// Thm 3.5: |fl(A*B) - A*B| <= gamma_k * |A| * |B|, gamma_k ~ k * epsilon. +// +// Probabilistic: Higham & Mary, "A New Approach to Probabilistic Rounding +// Error Analysis", SISC 2019, Thm 3.4: under the assumption that rounding +// errors are independent with mean zero: +// |fl(A*B) - A*B| <= gamma_tilde_k * |A| * |B|, +// gamma_tilde_k ~ lambda * sqrt(k) * epsilon, +// holding with probability >= 1 - 2*exp(-lambda^2/2) per inner product. +// +// Two overloads are provided: +// +// 1. product_tolerance(inner_dim, ...) — RELATIVE tolerance for use +// with isApprox(). Assumes random matrices in [-1,1], where sign +// cancellation gives || |A|*|B| ||_F / ||A*B||_F ~ (3/4)*sqrt(k). +// Combined: tol ~ lambda * num_products * k * epsilon. +// +// 2. product_error_bound(A, B, ...) — ABSOLUTE error bound for arbitrary +// matrices. Computes || |A|*|B| ||_F directly. +// Bound: lambda * sqrt(k) * epsilon * num_products * || |A|*|B| ||_F. +// +// Parameters common to both: +// num_products: number of independent products contributing error (default 1). +// Use 2 when comparing two different evaluations of A*B. +// lambda: probability parameter; P(lambda) = 1 - 2*exp(-lambda^2/2). +// lambda=5 gives P > 0.9999 per inner product. + +// Overload 1: Relative tolerance for random [-1,1] matrices. +template +typename NumTraits::Real product_tolerance(Index inner_dim, int num_products = 1, double lambda = 5) { + using Real = typename NumTraits::Real; + const Real lambda_real(lambda); + return lambda_real * Real(num_products) * Real(inner_dim) * NumTraits::epsilon(); +} + +// Overload 2: Absolute error bound for arbitrary matrices. +// Returns lambda * sqrt(k) * epsilon * num_products * || |A|*|B| ||_F. +template +typename NumTraits::Real product_error_bound(const MatrixBase& A, + const MatrixBase& B, + int num_products = 1, double lambda = 5) { + using Scalar = typename DerivedA::Scalar; + using Real = typename NumTraits::Real; + Index k = A.cols(); + Real abs_prod_norm = (A.cwiseAbs() * B.cwiseAbs()).norm(); + const Real lambda_real(lambda); + return lambda_real * numext::sqrt(Real(k)) * NumTraits::epsilon() * Real(num_products) * abs_prod_norm; +} + +// Verify that two computations of A*B agree within the Higham-Mary bound. +// Returns true if ||actual - expected||_F <= product_error_bound(A, B, ...). +template +inline bool verifyProduct(const MatrixBase& actual, const MatrixBase& expected, const MatrixBase& A, + const MatrixBase& B, int num_products = 2, double lambda = 5) { + using Real = typename NumTraits::Real; + Real bound = product_error_bound(A, B, num_products, lambda); + Real error = (actual - expected).norm(); + if (error > bound) { + std::cerr << "Product verification failed: error " << error << " exceeds bound " << bound << std::endl; + return false; + } + return true; +} + // verifyIsApprox is a wrapper to test_isApprox that outputs the relative difference magnitude if the test fails. template inline bool verifyIsApprox(const Type1& a, const Type2& b) { diff --git a/test/product.h b/test/product.h index 21b470119..645d40aa9 100644 --- a/test/product.h +++ b/test/product.h @@ -97,8 +97,12 @@ void product(const MatrixType& m) { // begin testing Product.h: only associativity for now // (we use Transpose.h but this doesn't count as a test for it) { - // Increase tolerance, since coefficients here can get relatively large. - RealScalar tol = RealScalar(2) * get_test_precision(m1); + // Associativity: (m1 * m1^T) * m2 vs m1 * (m1^T * m2). + // Two chained products with inner dims cols and rows. Intermediate entries + // of m1*m1^T are O(sqrt(cols)), amplifying the second product's error. + // Probabilistic bound (Higham & Mary 2019): ~lambda * sqrt(k) * epsilon + // per inner product, times sqrt(cols) amplification from chained product. + RealScalar tol = product_tolerance((std::max)(rows, cols), 3); VERIFY(verifyIsApprox((m1 * m1.transpose()) * m2, m1 * (m1.transpose() * m2), tol)); } m3 = m1; @@ -289,8 +293,10 @@ void product(const MatrixType& m) { // regression for blas_trais { - // Increase test tolerance, since coefficients can get relatively large. - RealScalar tol = RealScalar(2) * get_test_precision(square); + // Triple products of rows x rows matrices. Each side computes 2-3 + // products with inner dim = rows. Probabilistic bound with amplification + // from chained products with O(sqrt(rows)) intermediate entries. + RealScalar tol = product_tolerance(rows, 4); VERIFY( verifyIsApprox(square * (square * square).transpose(), square * square.transpose() * square.transpose(), tol)); VERIFY(verifyIsApprox(square * (-(square * square)), -square * square * square, tol)); diff --git a/test/product_trmv.cpp b/test/product_trmv.cpp index 447243f31..634bb2a7e 100644 --- a/test/product_trmv.cpp +++ b/test/product_trmv.cpp @@ -15,8 +15,6 @@ void trmv(const MatrixType& m) { typedef typename NumTraits::Real RealScalar; typedef Matrix VectorType; - RealScalar largerEps = 10 * test_precision(); - Index rows = m.rows(); Index cols = m.cols(); @@ -29,46 +27,53 @@ void trmv(const MatrixType& m) { // check with a column-major matrix m3 = m1.template triangularView(); - VERIFY((m3 * v1).isApprox(m1.template triangularView() * v1, largerEps)); + VERIFY(verifyProduct(m3 * v1, m1.template triangularView() * v1, m3, v1)); m3 = m1.template triangularView(); - VERIFY((m3 * v1).isApprox(m1.template triangularView() * v1, largerEps)); + VERIFY(verifyProduct(m3 * v1, m1.template triangularView() * v1, m3, v1)); m3 = m1.template triangularView(); - VERIFY((m3 * v1).isApprox(m1.template triangularView() * v1, largerEps)); + VERIFY(verifyProduct(m3 * v1, m1.template triangularView() * v1, m3, v1)); m3 = m1.template triangularView(); - VERIFY((m3 * v1).isApprox(m1.template triangularView() * v1, largerEps)); + VERIFY(verifyProduct(m3 * v1, m1.template triangularView() * v1, m3, v1)); // check conjugated and scalar multiple expressions (col-major) m3 = m1.template triangularView(); - VERIFY(((s1 * m3).conjugate() * v1) - .isApprox((s1 * m1).conjugate().template triangularView() * v1, largerEps)); + VERIFY(verifyProduct((s1 * m3).conjugate() * v1, (s1 * m1).conjugate().template triangularView() * v1, + (s1 * m3).conjugate(), v1)); m3 = m1.template triangularView(); - VERIFY((m3.conjugate() * v1.conjugate()) - .isApprox(m1.conjugate().template triangularView() * v1.conjugate(), largerEps)); + VERIFY(verifyProduct(m3.conjugate() * v1.conjugate(), + m1.conjugate().template triangularView() * v1.conjugate(), m3.conjugate(), + v1.conjugate())); // check with a row-major matrix m3 = m1.template triangularView(); - VERIFY((m3.transpose() * v1).isApprox(m1.transpose().template triangularView() * v1, largerEps)); + VERIFY(verifyProduct(m3.transpose() * v1, m1.transpose().template triangularView() * v1, m3.transpose(), + v1)); m3 = m1.template triangularView(); - VERIFY((m3.transpose() * v1).isApprox(m1.transpose().template triangularView() * v1, largerEps)); + VERIFY(verifyProduct(m3.transpose() * v1, m1.transpose().template triangularView() * v1, m3.transpose(), + v1)); m3 = m1.template triangularView(); - VERIFY((m3.transpose() * v1).isApprox(m1.transpose().template triangularView() * v1, largerEps)); + VERIFY(verifyProduct(m3.transpose() * v1, m1.transpose().template triangularView() * v1, + m3.transpose(), v1)); m3 = m1.template triangularView(); - VERIFY((m3.transpose() * v1).isApprox(m1.transpose().template triangularView() * v1, largerEps)); + VERIFY(verifyProduct(m3.transpose() * v1, m1.transpose().template triangularView() * v1, + m3.transpose(), v1)); // check conjugated and scalar multiple expressions (row-major) m3 = m1.template triangularView(); - VERIFY((m3.adjoint() * v1).isApprox(m1.adjoint().template triangularView() * v1, largerEps)); + VERIFY(verifyProduct(m3.adjoint() * v1, m1.adjoint().template triangularView() * v1, m3.adjoint(), v1)); m3 = m1.template triangularView(); - VERIFY((m3.adjoint() * (s1 * v1.conjugate())) - .isApprox(m1.adjoint().template triangularView() * (s1 * v1.conjugate()), largerEps)); + VERIFY(verifyProduct(m3.adjoint() * (s1 * v1.conjugate()), + m1.adjoint().template triangularView() * (s1 * v1.conjugate()), m3.adjoint(), + (s1 * v1.conjugate()).eval())); m3 = m1.template triangularView(); // check transposed cases: m3 = m1.template triangularView(); - VERIFY((v1.transpose() * m3).isApprox(v1.transpose() * m1.template triangularView(), largerEps)); - VERIFY((v1.adjoint() * m3).isApprox(v1.adjoint() * m1.template triangularView(), largerEps)); - VERIFY((v1.adjoint() * m3.adjoint()) - .isApprox(v1.adjoint() * m1.template triangularView().adjoint(), largerEps)); + VERIFY(verifyProduct(v1.transpose() * m3, v1.transpose() * m1.template triangularView(), v1.transpose(), + m3)); + VERIFY(verifyProduct(v1.adjoint() * m3, v1.adjoint() * m1.template triangularView(), v1.adjoint(), m3)); + VERIFY(verifyProduct(v1.adjoint() * m3.adjoint(), v1.adjoint() * m1.template triangularView().adjoint(), + v1.adjoint(), m3.adjoint())); // TODO check with sub-matrices }