mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Add CUDA complex sqrt.
This is to support scalar `sqrt` of complex numbers `std::complex<T>` on device, requested by Tensorflow folks. Technically `std::complex` is not supported by NVCC on device (though it is by clang), so the default `sqrt(std::complex<T>)` function only works on the host. Here we create an overload to add back the functionality. Also modified the CMake file to add `--relaxed-constexpr` (or equivalent) flag for NVCC to allow calling constexpr functions from device functions, and added support for specifying compute architecture for NVCC (was already available for clang).
This commit is contained in:
@@ -323,6 +323,27 @@ struct abs2_retval
|
||||
typedef typename NumTraits<Scalar>::Real type;
|
||||
};
|
||||
|
||||
/****************************************************************************
|
||||
* Implementation of sqrt *
|
||||
****************************************************************************/
|
||||
|
||||
template<typename Scalar>
|
||||
struct sqrt_impl
|
||||
{
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_ALWAYS_INLINE Scalar run(const Scalar& x)
|
||||
{
|
||||
EIGEN_USING_STD(sqrt);
|
||||
return sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar>
|
||||
struct sqrt_retval
|
||||
{
|
||||
typedef Scalar type;
|
||||
};
|
||||
|
||||
/****************************************************************************
|
||||
* Implementation of norm1 *
|
||||
****************************************************************************/
|
||||
@@ -1368,12 +1389,11 @@ inline int log2(int x)
|
||||
*
|
||||
* It's usage is justified in performance critical functions, like norm/normalize.
|
||||
*/
|
||||
template<typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
T sqrt(const T &x)
|
||||
template<typename Scalar>
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_ALWAYS_INLINE EIGEN_MATHFUNC_RETVAL(sqrt, Scalar) sqrt(const Scalar& x)
|
||||
{
|
||||
EIGEN_USING_STD(sqrt);
|
||||
return sqrt(x);
|
||||
return EIGEN_MATHFUNC_IMPL(sqrt, Scalar)::run(x);
|
||||
}
|
||||
|
||||
// Boolean specialization, avoids implicit float to bool conversion (-Wimplicit-conversion-floating-point-to-bool).
|
||||
|
||||
@@ -12,12 +12,12 @@
|
||||
|
||||
// clang-format off
|
||||
|
||||
#if defined(EIGEN_CUDACC) && defined(EIGEN_GPU_COMPILE_PHASE)
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
#if defined(EIGEN_CUDACC) && defined(EIGEN_USE_GPU)
|
||||
|
||||
// Many std::complex methods such as operator+, operator-, operator* and
|
||||
// operator/ are not constexpr. Due to this, clang does not treat them as device
|
||||
// functions and thus Eigen functors making use of these operators fail to
|
||||
@@ -94,10 +94,53 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std:
|
||||
|
||||
template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
|
||||
|
||||
template<typename T>
|
||||
struct sqrt_impl<std::complex<T>> {
|
||||
static EIGEN_DEVICE_FUNC std::complex<T> run(const std::complex<T>& z) {
|
||||
// Computes the principal sqrt of the input.
|
||||
//
|
||||
// For a complex square root of the number x + i*y. We want to find real
|
||||
// numbers u and v such that
|
||||
// (u + i*v)^2 = x + i*y <=>
|
||||
// u^2 - v^2 + i*2*u*v = x + i*v.
|
||||
// By equating the real and imaginary parts we get:
|
||||
// u^2 - v^2 = x
|
||||
// 2*u*v = y.
|
||||
//
|
||||
// For x >= 0, this has the numerically stable solution
|
||||
// u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
|
||||
// v = y / (2 * u)
|
||||
// and for x < 0,
|
||||
// v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
|
||||
// u = y / (2 * v)
|
||||
//
|
||||
// Letting w = sqrt(0.5 * (|x| + |z|)),
|
||||
// if x == 0: u = w, v = sign(y) * w
|
||||
// if x > 0: u = w, v = y / (2 * w)
|
||||
// if x < 0: u = |y| / (2 * w), v = sign(y) * w
|
||||
|
||||
const T x = numext::real(z);
|
||||
const T y = numext::imag(z);
|
||||
const T zero = T(0);
|
||||
const T cst_half = T(0.5);
|
||||
|
||||
// Special case of isinf(y)
|
||||
if ((numext::isinf)(y)) {
|
||||
const T inf = std::numeric_limits<T>::infinity();
|
||||
return std::complex<T>(inf, y);
|
||||
}
|
||||
|
||||
T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z)));
|
||||
return
|
||||
x == zero ? std::complex<T>(w, y < zero ? -w : w)
|
||||
: x > zero ? std::complex<T>(w, y / (2 * w))
|
||||
: std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_COMPLEX_CUDA_H
|
||||
#endif // EIGEN_COMPLEX_CUDA_H
|
||||
|
||||
@@ -703,8 +703,8 @@ Packet psqrt_complex(const Packet& a) {
|
||||
// u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
|
||||
// v = 0.5 * (y / u)
|
||||
// and for x < 0,
|
||||
// v = sign(y) * sqrt(0.5 * (x + sqrt(x^2 + y^2)))
|
||||
// u = |0.5 * (y / v)|
|
||||
// v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
|
||||
// u = 0.5 * (y / v)
|
||||
//
|
||||
// To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as
|
||||
// l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) ,
|
||||
|
||||
Reference in New Issue
Block a user