Improved std::complex sqrt and rsqrt.

Replaces `std::sqrt` with `complex_sqrt` for all platforms (previously
`complex_sqrt` was only used for CUDA and MSVC), and implements
custom `complex_rsqrt`.

Also introduces `numext::rsqrt` to simplify implementation, and modified
`numext::hypot` to adhere to IEEE IEC 6059 for special cases.

The `complex_sqrt` and `complex_rsqrt` implementations were found to be
significantly faster than `std::sqrt<std::complex<T>>` and
`1/numext::sqrt<std::complex<T>>`.

Benchmark file attached.
```
GCC 10, Intel Xeon, x86_64:
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
BM_Sqrt<std::complex<float>>           9.21 ns         9.21 ns     73225448
BM_StdSqrt<std::complex<float>>        17.1 ns         17.1 ns     40966545
BM_Sqrt<std::complex<double>>          8.53 ns         8.53 ns     81111062
BM_StdSqrt<std::complex<double>>       21.5 ns         21.5 ns     32757248
BM_Rsqrt<std::complex<float>>          10.3 ns         10.3 ns     68047474
BM_DivSqrt<std::complex<float>>        16.3 ns         16.3 ns     42770127
BM_Rsqrt<std::complex<double>>         11.3 ns         11.3 ns     61322028
BM_DivSqrt<std::complex<double>>       16.5 ns         16.5 ns     42200711

Clang 11, Intel Xeon, x86_64:
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
BM_Sqrt<std::complex<float>>           7.46 ns         7.45 ns     90742042
BM_StdSqrt<std::complex<float>>        16.6 ns         16.6 ns     42369878
BM_Sqrt<std::complex<double>>          8.49 ns         8.49 ns     81629030
BM_StdSqrt<std::complex<double>>       21.8 ns         21.7 ns     31809588
BM_Rsqrt<std::complex<float>>          8.39 ns         8.39 ns     82933666
BM_DivSqrt<std::complex<float>>        14.4 ns         14.4 ns     48638676
BM_Rsqrt<std::complex<double>>         9.83 ns         9.82 ns     70068956
BM_DivSqrt<std::complex<double>>       15.7 ns         15.7 ns     44487798

Clang 9, Pixel 2, aarch64:
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
BM_Sqrt<std::complex<float>>           24.2 ns         24.1 ns     28616031
BM_StdSqrt<std::complex<float>>         104 ns          103 ns      6826926
BM_Sqrt<std::complex<double>>          31.8 ns         31.8 ns     22157591
BM_StdSqrt<std::complex<double>>        128 ns          128 ns      5437375
BM_Rsqrt<std::complex<float>>          31.9 ns         31.8 ns     22384383
BM_DivSqrt<std::complex<float>>        99.2 ns         98.9 ns      7250438
BM_Rsqrt<std::complex<double>>         46.0 ns         45.8 ns     15338689
BM_DivSqrt<std::complex<double>>        119 ns          119 ns      5898944
```
This commit is contained in:
Antonio Sanchez
2021-01-16 10:22:07 -08:00
parent 21a8a2487c
commit bde6741641
8 changed files with 357 additions and 82 deletions

View File

@@ -161,8 +161,12 @@ template<typename MatrixType> void stable_norm(const MatrixType& m)
// mix
{
Index i2 = internal::random<Index>(0,rows-1);
Index j2 = internal::random<Index>(0,cols-1);
// Ensure unique indices otherwise inf may be overwritten by NaN.
Index i2, j2;
do {
i2 = internal::random<Index>(0,rows-1);
j2 = internal::random<Index>(0,cols-1);
} while (i2 == i && j2 == j);
v = vrand;
v(i,j) = -std::numeric_limits<RealScalar>::infinity();
v(i2,j2) = std::numeric_limits<RealScalar>::quiet_NaN();
@@ -170,7 +174,8 @@ template<typename MatrixType> void stable_norm(const MatrixType& m)
VERIFY(!(numext::isfinite)(v.norm())); VERIFY((numext::isnan)(v.norm()));
VERIFY(!(numext::isfinite)(v.stableNorm())); VERIFY((numext::isnan)(v.stableNorm()));
VERIFY(!(numext::isfinite)(v.blueNorm())); VERIFY((numext::isnan)(v.blueNorm()));
VERIFY(!(numext::isfinite)(v.hypotNorm())); VERIFY((numext::isnan)(v.hypotNorm()));
// hypot propagates inf over NaN.
VERIFY(!(numext::isfinite)(v.hypotNorm())); VERIFY((numext::isinf)(v.hypotNorm()));
}
// stableNormalize[d]