mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Replace empirical product test tolerances with principled Higham-Mary bounds
libeigen/eigen!2292 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
66
test/main.h
66
test/main.h
@@ -605,6 +605,72 @@ typename NumTraits<T>::Real get_test_precision(
|
||||
return test_precision<typename NumTraits<T>::Real>();
|
||||
}
|
||||
|
||||
// Rounding error bounds for matrix products, based on:
|
||||
//
|
||||
// Deterministic: Higham, "Accuracy and Stability of Numerical Algorithms",
|
||||
// Thm 3.5: |fl(A*B) - A*B| <= gamma_k * |A| * |B|, gamma_k ~ k * epsilon.
|
||||
//
|
||||
// Probabilistic: Higham & Mary, "A New Approach to Probabilistic Rounding
|
||||
// Error Analysis", SISC 2019, Thm 3.4: under the assumption that rounding
|
||||
// errors are independent with mean zero:
|
||||
// |fl(A*B) - A*B| <= gamma_tilde_k * |A| * |B|,
|
||||
// gamma_tilde_k ~ lambda * sqrt(k) * epsilon,
|
||||
// holding with probability >= 1 - 2*exp(-lambda^2/2) per inner product.
|
||||
//
|
||||
// Two overloads are provided:
|
||||
//
|
||||
// 1. product_tolerance<Scalar>(inner_dim, ...) — RELATIVE tolerance for use
|
||||
// with isApprox(). Assumes random matrices in [-1,1], where sign
|
||||
// cancellation gives || |A|*|B| ||_F / ||A*B||_F ~ (3/4)*sqrt(k).
|
||||
// Combined: tol ~ lambda * num_products * k * epsilon.
|
||||
//
|
||||
// 2. product_error_bound(A, B, ...) — ABSOLUTE error bound for arbitrary
|
||||
// matrices. Computes || |A|*|B| ||_F directly.
|
||||
// Bound: lambda * sqrt(k) * epsilon * num_products * || |A|*|B| ||_F.
|
||||
//
|
||||
// Parameters common to both:
|
||||
// num_products: number of independent products contributing error (default 1).
|
||||
// Use 2 when comparing two different evaluations of A*B.
|
||||
// lambda: probability parameter; P(lambda) = 1 - 2*exp(-lambda^2/2).
|
||||
// lambda=5 gives P > 0.9999 per inner product.
|
||||
|
||||
// Overload 1: Relative tolerance for random [-1,1] matrices.
|
||||
template <typename Scalar>
|
||||
typename NumTraits<Scalar>::Real product_tolerance(Index inner_dim, int num_products = 1, double lambda = 5) {
|
||||
using Real = typename NumTraits<Scalar>::Real;
|
||||
const Real lambda_real(lambda);
|
||||
return lambda_real * Real(num_products) * Real(inner_dim) * NumTraits<Scalar>::epsilon();
|
||||
}
|
||||
|
||||
// Overload 2: Absolute error bound for arbitrary matrices.
|
||||
// Returns lambda * sqrt(k) * epsilon * num_products * || |A|*|B| ||_F.
|
||||
template <typename DerivedA, typename DerivedB>
|
||||
typename NumTraits<typename DerivedA::Scalar>::Real product_error_bound(const MatrixBase<DerivedA>& A,
|
||||
const MatrixBase<DerivedB>& B,
|
||||
int num_products = 1, double lambda = 5) {
|
||||
using Scalar = typename DerivedA::Scalar;
|
||||
using Real = typename NumTraits<Scalar>::Real;
|
||||
Index k = A.cols();
|
||||
Real abs_prod_norm = (A.cwiseAbs() * B.cwiseAbs()).norm();
|
||||
const Real lambda_real(lambda);
|
||||
return lambda_real * numext::sqrt(Real(k)) * NumTraits<Scalar>::epsilon() * Real(num_products) * abs_prod_norm;
|
||||
}
|
||||
|
||||
// Verify that two computations of A*B agree within the Higham-Mary bound.
|
||||
// Returns true if ||actual - expected||_F <= product_error_bound(A, B, ...).
|
||||
template <typename D1, typename D2, typename DA, typename DB>
|
||||
inline bool verifyProduct(const MatrixBase<D1>& actual, const MatrixBase<D2>& expected, const MatrixBase<DA>& A,
|
||||
const MatrixBase<DB>& B, int num_products = 2, double lambda = 5) {
|
||||
using Real = typename NumTraits<typename DA::Scalar>::Real;
|
||||
Real bound = product_error_bound(A, B, num_products, lambda);
|
||||
Real error = (actual - expected).norm();
|
||||
if (error > bound) {
|
||||
std::cerr << "Product verification failed: error " << error << " exceeds bound " << bound << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// verifyIsApprox is a wrapper to test_isApprox that outputs the relative difference magnitude if the test fails.
|
||||
template <typename Type1, typename Type2>
|
||||
inline bool verifyIsApprox(const Type1& a, const Type2& b) {
|
||||
|
||||
@@ -97,8 +97,12 @@ void product(const MatrixType& m) {
|
||||
// begin testing Product.h: only associativity for now
|
||||
// (we use Transpose.h but this doesn't count as a test for it)
|
||||
{
|
||||
// Increase tolerance, since coefficients here can get relatively large.
|
||||
RealScalar tol = RealScalar(2) * get_test_precision(m1);
|
||||
// Associativity: (m1 * m1^T) * m2 vs m1 * (m1^T * m2).
|
||||
// Two chained products with inner dims cols and rows. Intermediate entries
|
||||
// of m1*m1^T are O(sqrt(cols)), amplifying the second product's error.
|
||||
// Probabilistic bound (Higham & Mary 2019): ~lambda * sqrt(k) * epsilon
|
||||
// per inner product, times sqrt(cols) amplification from chained product.
|
||||
RealScalar tol = product_tolerance<Scalar>((std::max)(rows, cols), 3);
|
||||
VERIFY(verifyIsApprox((m1 * m1.transpose()) * m2, m1 * (m1.transpose() * m2), tol));
|
||||
}
|
||||
m3 = m1;
|
||||
@@ -289,8 +293,10 @@ void product(const MatrixType& m) {
|
||||
|
||||
// regression for blas_trais
|
||||
{
|
||||
// Increase test tolerance, since coefficients can get relatively large.
|
||||
RealScalar tol = RealScalar(2) * get_test_precision(square);
|
||||
// Triple products of rows x rows matrices. Each side computes 2-3
|
||||
// products with inner dim = rows. Probabilistic bound with amplification
|
||||
// from chained products with O(sqrt(rows)) intermediate entries.
|
||||
RealScalar tol = product_tolerance<Scalar>(rows, 4);
|
||||
VERIFY(
|
||||
verifyIsApprox(square * (square * square).transpose(), square * square.transpose() * square.transpose(), tol));
|
||||
VERIFY(verifyIsApprox(square * (-(square * square)), -square * square * square, tol));
|
||||
|
||||
@@ -15,8 +15,6 @@ void trmv(const MatrixType& m) {
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> VectorType;
|
||||
|
||||
RealScalar largerEps = 10 * test_precision<RealScalar>();
|
||||
|
||||
Index rows = m.rows();
|
||||
Index cols = m.cols();
|
||||
|
||||
@@ -29,46 +27,53 @@ void trmv(const MatrixType& m) {
|
||||
|
||||
// check with a column-major matrix
|
||||
m3 = m1.template triangularView<Eigen::Lower>();
|
||||
VERIFY((m3 * v1).isApprox(m1.template triangularView<Eigen::Lower>() * v1, largerEps));
|
||||
VERIFY(verifyProduct(m3 * v1, m1.template triangularView<Eigen::Lower>() * v1, m3, v1));
|
||||
m3 = m1.template triangularView<Eigen::Upper>();
|
||||
VERIFY((m3 * v1).isApprox(m1.template triangularView<Eigen::Upper>() * v1, largerEps));
|
||||
VERIFY(verifyProduct(m3 * v1, m1.template triangularView<Eigen::Upper>() * v1, m3, v1));
|
||||
m3 = m1.template triangularView<Eigen::UnitLower>();
|
||||
VERIFY((m3 * v1).isApprox(m1.template triangularView<Eigen::UnitLower>() * v1, largerEps));
|
||||
VERIFY(verifyProduct(m3 * v1, m1.template triangularView<Eigen::UnitLower>() * v1, m3, v1));
|
||||
m3 = m1.template triangularView<Eigen::UnitUpper>();
|
||||
VERIFY((m3 * v1).isApprox(m1.template triangularView<Eigen::UnitUpper>() * v1, largerEps));
|
||||
VERIFY(verifyProduct(m3 * v1, m1.template triangularView<Eigen::UnitUpper>() * v1, m3, v1));
|
||||
|
||||
// check conjugated and scalar multiple expressions (col-major)
|
||||
m3 = m1.template triangularView<Eigen::Lower>();
|
||||
VERIFY(((s1 * m3).conjugate() * v1)
|
||||
.isApprox((s1 * m1).conjugate().template triangularView<Eigen::Lower>() * v1, largerEps));
|
||||
VERIFY(verifyProduct((s1 * m3).conjugate() * v1, (s1 * m1).conjugate().template triangularView<Eigen::Lower>() * v1,
|
||||
(s1 * m3).conjugate(), v1));
|
||||
m3 = m1.template triangularView<Eigen::Upper>();
|
||||
VERIFY((m3.conjugate() * v1.conjugate())
|
||||
.isApprox(m1.conjugate().template triangularView<Eigen::Upper>() * v1.conjugate(), largerEps));
|
||||
VERIFY(verifyProduct(m3.conjugate() * v1.conjugate(),
|
||||
m1.conjugate().template triangularView<Eigen::Upper>() * v1.conjugate(), m3.conjugate(),
|
||||
v1.conjugate()));
|
||||
|
||||
// check with a row-major matrix
|
||||
m3 = m1.template triangularView<Eigen::Upper>();
|
||||
VERIFY((m3.transpose() * v1).isApprox(m1.transpose().template triangularView<Eigen::Lower>() * v1, largerEps));
|
||||
VERIFY(verifyProduct(m3.transpose() * v1, m1.transpose().template triangularView<Eigen::Lower>() * v1, m3.transpose(),
|
||||
v1));
|
||||
m3 = m1.template triangularView<Eigen::Lower>();
|
||||
VERIFY((m3.transpose() * v1).isApprox(m1.transpose().template triangularView<Eigen::Upper>() * v1, largerEps));
|
||||
VERIFY(verifyProduct(m3.transpose() * v1, m1.transpose().template triangularView<Eigen::Upper>() * v1, m3.transpose(),
|
||||
v1));
|
||||
m3 = m1.template triangularView<Eigen::UnitUpper>();
|
||||
VERIFY((m3.transpose() * v1).isApprox(m1.transpose().template triangularView<Eigen::UnitLower>() * v1, largerEps));
|
||||
VERIFY(verifyProduct(m3.transpose() * v1, m1.transpose().template triangularView<Eigen::UnitLower>() * v1,
|
||||
m3.transpose(), v1));
|
||||
m3 = m1.template triangularView<Eigen::UnitLower>();
|
||||
VERIFY((m3.transpose() * v1).isApprox(m1.transpose().template triangularView<Eigen::UnitUpper>() * v1, largerEps));
|
||||
VERIFY(verifyProduct(m3.transpose() * v1, m1.transpose().template triangularView<Eigen::UnitUpper>() * v1,
|
||||
m3.transpose(), v1));
|
||||
|
||||
// check conjugated and scalar multiple expressions (row-major)
|
||||
m3 = m1.template triangularView<Eigen::Upper>();
|
||||
VERIFY((m3.adjoint() * v1).isApprox(m1.adjoint().template triangularView<Eigen::Lower>() * v1, largerEps));
|
||||
VERIFY(verifyProduct(m3.adjoint() * v1, m1.adjoint().template triangularView<Eigen::Lower>() * v1, m3.adjoint(), v1));
|
||||
m3 = m1.template triangularView<Eigen::Lower>();
|
||||
VERIFY((m3.adjoint() * (s1 * v1.conjugate()))
|
||||
.isApprox(m1.adjoint().template triangularView<Eigen::Upper>() * (s1 * v1.conjugate()), largerEps));
|
||||
VERIFY(verifyProduct(m3.adjoint() * (s1 * v1.conjugate()),
|
||||
m1.adjoint().template triangularView<Eigen::Upper>() * (s1 * v1.conjugate()), m3.adjoint(),
|
||||
(s1 * v1.conjugate()).eval()));
|
||||
m3 = m1.template triangularView<Eigen::UnitUpper>();
|
||||
|
||||
// check transposed cases:
|
||||
m3 = m1.template triangularView<Eigen::Lower>();
|
||||
VERIFY((v1.transpose() * m3).isApprox(v1.transpose() * m1.template triangularView<Eigen::Lower>(), largerEps));
|
||||
VERIFY((v1.adjoint() * m3).isApprox(v1.adjoint() * m1.template triangularView<Eigen::Lower>(), largerEps));
|
||||
VERIFY((v1.adjoint() * m3.adjoint())
|
||||
.isApprox(v1.adjoint() * m1.template triangularView<Eigen::Lower>().adjoint(), largerEps));
|
||||
VERIFY(verifyProduct(v1.transpose() * m3, v1.transpose() * m1.template triangularView<Eigen::Lower>(), v1.transpose(),
|
||||
m3));
|
||||
VERIFY(verifyProduct(v1.adjoint() * m3, v1.adjoint() * m1.template triangularView<Eigen::Lower>(), v1.adjoint(), m3));
|
||||
VERIFY(verifyProduct(v1.adjoint() * m3.adjoint(), v1.adjoint() * m1.template triangularView<Eigen::Lower>().adjoint(),
|
||||
v1.adjoint(), m3.adjoint()));
|
||||
|
||||
// TODO check with sub-matrices
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user