Add CUDA complex sqrt.

This is to support scalar `sqrt` of complex numbers `std::complex<T>` on
device, requested by Tensorflow folks.

Technically `std::complex` is not supported by NVCC on device
(though it is by clang), so the default `sqrt(std::complex<T>)` function only
works on the host. Here we create an overload to add back the
functionality.

Also modified the CMake file to add `--relaxed-constexpr` (or
equivalent) flag for NVCC to allow calling constexpr functions from
device functions, and added support for specifying compute architecture for
NVCC (was already available for clang).
This commit is contained in:
Antonio Sanchez
2020-12-22 22:49:06 -08:00
parent fdf2ee62c5
commit 070d303d56
7 changed files with 217 additions and 28 deletions

View File

@@ -323,6 +323,27 @@ struct abs2_retval
typedef typename NumTraits<Scalar>::Real type;
};
/****************************************************************************
* Implementation of sqrt *
****************************************************************************/
template<typename Scalar>
struct sqrt_impl
{
EIGEN_DEVICE_FUNC
static EIGEN_ALWAYS_INLINE Scalar run(const Scalar& x)
{
EIGEN_USING_STD(sqrt);
return sqrt(x);
}
};
template<typename Scalar>
struct sqrt_retval
{
typedef Scalar type;
};
/****************************************************************************
* Implementation of norm1 *
****************************************************************************/
@@ -1368,12 +1389,11 @@ inline int log2(int x)
*
* It's usage is justified in performance critical functions, like norm/normalize.
*/
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T sqrt(const T &x)
template<typename Scalar>
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE EIGEN_MATHFUNC_RETVAL(sqrt, Scalar) sqrt(const Scalar& x)
{
EIGEN_USING_STD(sqrt);
return sqrt(x);
return EIGEN_MATHFUNC_IMPL(sqrt, Scalar)::run(x);
}
// Boolean specialization, avoids implicit float to bool conversion (-Wimplicit-conversion-floating-point-to-bool).

View File

@@ -12,12 +12,12 @@
// clang-format off
#if defined(EIGEN_CUDACC) && defined(EIGEN_GPU_COMPILE_PHASE)
namespace Eigen {
namespace internal {
#if defined(EIGEN_CUDACC) && defined(EIGEN_USE_GPU)
// Many std::complex methods such as operator+, operator-, operator* and
// operator/ are not constexpr. Due to this, clang does not treat them as device
// functions and thus Eigen functors making use of these operators fail to
@@ -94,10 +94,53 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std:
template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
template<typename T>
struct sqrt_impl<std::complex<T>> {
static EIGEN_DEVICE_FUNC std::complex<T> run(const std::complex<T>& z) {
// Computes the principal sqrt of the input.
//
// For a complex square root of the number x + i*y. We want to find real
// numbers u and v such that
// (u + i*v)^2 = x + i*y <=>
// u^2 - v^2 + i*2*u*v = x + i*v.
// By equating the real and imaginary parts we get:
// u^2 - v^2 = x
// 2*u*v = y.
//
// For x >= 0, this has the numerically stable solution
// u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
// v = y / (2 * u)
// and for x < 0,
// v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
// u = y / (2 * v)
//
// Letting w = sqrt(0.5 * (|x| + |z|)),
// if x == 0: u = w, v = sign(y) * w
// if x > 0: u = w, v = y / (2 * w)
// if x < 0: u = |y| / (2 * w), v = sign(y) * w
const T x = numext::real(z);
const T y = numext::imag(z);
const T zero = T(0);
const T cst_half = T(0.5);
// Special case of isinf(y)
if ((numext::isinf)(y)) {
const T inf = std::numeric_limits<T>::infinity();
return std::complex<T>(inf, y);
}
T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z)));
return
x == zero ? std::complex<T>(w, y < zero ? -w : w)
: x > zero ? std::complex<T>(w, y / (2 * w))
: std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
}
};
} // namespace internal
} // namespace Eigen
#endif
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_COMPLEX_CUDA_H
#endif // EIGEN_COMPLEX_CUDA_H

View File

@@ -703,8 +703,8 @@ Packet psqrt_complex(const Packet& a) {
// u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
// v = 0.5 * (y / u)
// and for x < 0,
// v = sign(y) * sqrt(0.5 * (x + sqrt(x^2 + y^2)))
// u = |0.5 * (y / v)|
// v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
// u = 0.5 * (y / v)
//
// To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as
// l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) ,