Specialize std::complex operators for use on GPU device.

NVCC and older versions of clang do not fully support `std::complex` on device,
leading to either compile errors (Cannot call `__host__` function) or worse,
runtime errors (Illegal instruction).  For most functions, we can
implement specialized `numext` versions. Here we specialize the standard
operators (with the exception of stream operators and member function operators
with a scalar that are already specialized in `<complex>`) so they can be used
in device code as well.

To import these operators into the current scope, use
`EIGEN_USING_STD_COMPLEX_OPERATORS`. By default, these are imported into
the `Eigen`, `Eigen:internal`, and `Eigen::numext` namespaces.

This allow us to remove specializations of the
sum/difference/product/quotient ops, and allow us to treat complex
numbers like most other scalars (e.g. in tests).
This commit is contained in:
Antonio Sanchez
2021-01-06 09:41:15 -08:00
committed by Antonio Sánchez
parent 65e2169c45
commit f19bcffee6
3 changed files with 336 additions and 76 deletions

View File

@@ -106,6 +106,116 @@ struct complex_sqrt {
}
};
template<typename T>
struct complex_operators {
EIGEN_DEVICE_FUNC
void operator()(int i, const typename T::Scalar* in, typename T::Scalar* out) const
{
using namespace Eigen;
typedef typename T::Scalar ComplexType;
typedef typename T::Scalar::value_type ValueType;
const int num_scalar_operators = 24;
const int num_vector_operators = 23; // no unary + operator.
int out_idx = i * (num_scalar_operators + num_vector_operators * T::MaxSizeAtCompileTime);
// Scalar operators.
const ComplexType a = in[i];
const ComplexType b = in[i + 1];
out[out_idx++] = +a;
out[out_idx++] = -a;
out[out_idx++] = a + b;
out[out_idx++] = a + numext::real(b);
out[out_idx++] = numext::real(a) + b;
out[out_idx++] = a - b;
out[out_idx++] = a - numext::real(b);
out[out_idx++] = numext::real(a) - b;
out[out_idx++] = a * b;
out[out_idx++] = a * numext::real(b);
out[out_idx++] = numext::real(a) * b;
out[out_idx++] = a / b;
out[out_idx++] = a / numext::real(b);
out[out_idx++] = numext::real(a) / b;
out[out_idx] = a; out[out_idx++] += b;
out[out_idx] = a; out[out_idx++] -= b;
out[out_idx] = a; out[out_idx++] *= b;
out[out_idx] = a; out[out_idx++] /= b;
const ComplexType true_value = ComplexType(ValueType(1), ValueType(0));
const ComplexType false_value = ComplexType(ValueType(0), ValueType(0));
out[out_idx++] = (a == b ? true_value : false_value);
out[out_idx++] = (a == numext::real(b) ? true_value : false_value);
out[out_idx++] = (numext::real(a) == b ? true_value : false_value);
out[out_idx++] = (a != b ? true_value : false_value);
out[out_idx++] = (a != numext::real(b) ? true_value : false_value);
out[out_idx++] = (numext::real(a) != b ? true_value : false_value);
// Vector versions.
T x1(in + i);
T x2(in + i + 1);
const int res_size = T::MaxSizeAtCompileTime * num_scalar_operators;
const int size = T::MaxSizeAtCompileTime;
int block_idx = 0;
Map<VectorX<ComplexType>> res(out + out_idx, res_size);
res.segment(block_idx, size) = -x1;
block_idx += size;
res.segment(block_idx, size) = x1 + x2;
block_idx += size;
res.segment(block_idx, size) = x1 + x2.real();
block_idx += size;
res.segment(block_idx, size) = x1.real() + x2;
block_idx += size;
res.segment(block_idx, size) = x1 - x2;
block_idx += size;
res.segment(block_idx, size) = x1 - x2.real();
block_idx += size;
res.segment(block_idx, size) = x1.real() - x2;
block_idx += size;
res.segment(block_idx, size) = x1.array() * x2.array();
block_idx += size;
res.segment(block_idx, size) = x1.array() * x2.real().array();
block_idx += size;
res.segment(block_idx, size) = x1.real().array() * x2.array();
block_idx += size;
res.segment(block_idx, size) = x1.array() / x2.array();
block_idx += size;
res.segment(block_idx, size) = x1.array() / x2.real().array();
block_idx += size;
res.segment(block_idx, size) = x1.real().array() / x2.array();
block_idx += size;
res.segment(block_idx, size) = x1; res.segment(block_idx, size) += x2;
block_idx += size;
res.segment(block_idx, size) = x1; res.segment(block_idx, size) -= x2;
block_idx += size;
res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() *= x2.array();
block_idx += size;
res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() /= x2.array();
block_idx += size;
// Equality comparisons currently not functional on device
// (std::equal_to<T> is host-only).
// const T true_vector = T::Constant(true_value);
// const T false_vector = T::Constant(false_value);
// res.segment(block_idx, size) = (x1 == x2 ? true_vector : false_vector);
// block_idx += size;
// res.segment(block_idx, size) = (x1 == x2.real() ? true_vector : false_vector);
// block_idx += size;
// res.segment(block_idx, size) = (x1.real() == x2 ? true_vector : false_vector);
// block_idx += size;
// res.segment(block_idx, size) = (x1 != x2 ? true_vector : false_vector);
// block_idx += size;
// res.segment(block_idx, size) = (x1 != x2.real() ? true_vector : false_vector);
// block_idx += size;
// res.segment(block_idx, size) = (x1.real() != x2 ? true_vector : false_vector);
// block_idx += size;
}
};
template<typename T>
struct replicate {
EIGEN_DEVICE_FUNC
@@ -297,6 +407,8 @@ EIGEN_DECLARE_TEST(gpu_basic)
CALL_SUBTEST( run_and_compare_to_gpu(eigenvalues_direct<Matrix3f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_gpu(eigenvalues_direct<Matrix2f>(), nthreads, in, out) );
// Test std::complex.
CALL_SUBTEST( run_and_compare_to_gpu(complex_operators<Vector3cf>(), nthreads, cfin, cfout) );
CALL_SUBTEST( test_with_infs_nans(complex_sqrt<Vector3cf>(), nthreads, cfin, cfout) );
#if defined(__NVCC__)