mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Specialize std::complex operators for use on GPU device.
NVCC and older versions of clang do not fully support `std::complex` on device, leading to either compile errors (Cannot call `__host__` function) or worse, runtime errors (Illegal instruction). For most functions, we can implement specialized `numext` versions. Here we specialize the standard operators (with the exception of stream operators and member function operators with a scalar that are already specialized in `<complex>`) so they can be used in device code as well. To import these operators into the current scope, use `EIGEN_USING_STD_COMPLEX_OPERATORS`. By default, these are imported into the `Eigen`, `Eigen:internal`, and `Eigen::numext` namespaces. This allow us to remove specializations of the sum/difference/product/quotient ops, and allow us to treat complex numbers like most other scalars (e.g. in tests).
This commit is contained in:
committed by
Antonio Sánchez
parent
65e2169c45
commit
f19bcffee6
@@ -106,6 +106,116 @@ struct complex_sqrt {
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct complex_operators {
|
||||
EIGEN_DEVICE_FUNC
|
||||
void operator()(int i, const typename T::Scalar* in, typename T::Scalar* out) const
|
||||
{
|
||||
using namespace Eigen;
|
||||
typedef typename T::Scalar ComplexType;
|
||||
typedef typename T::Scalar::value_type ValueType;
|
||||
const int num_scalar_operators = 24;
|
||||
const int num_vector_operators = 23; // no unary + operator.
|
||||
int out_idx = i * (num_scalar_operators + num_vector_operators * T::MaxSizeAtCompileTime);
|
||||
|
||||
// Scalar operators.
|
||||
const ComplexType a = in[i];
|
||||
const ComplexType b = in[i + 1];
|
||||
|
||||
out[out_idx++] = +a;
|
||||
out[out_idx++] = -a;
|
||||
|
||||
out[out_idx++] = a + b;
|
||||
out[out_idx++] = a + numext::real(b);
|
||||
out[out_idx++] = numext::real(a) + b;
|
||||
out[out_idx++] = a - b;
|
||||
out[out_idx++] = a - numext::real(b);
|
||||
out[out_idx++] = numext::real(a) - b;
|
||||
out[out_idx++] = a * b;
|
||||
out[out_idx++] = a * numext::real(b);
|
||||
out[out_idx++] = numext::real(a) * b;
|
||||
out[out_idx++] = a / b;
|
||||
out[out_idx++] = a / numext::real(b);
|
||||
out[out_idx++] = numext::real(a) / b;
|
||||
|
||||
out[out_idx] = a; out[out_idx++] += b;
|
||||
out[out_idx] = a; out[out_idx++] -= b;
|
||||
out[out_idx] = a; out[out_idx++] *= b;
|
||||
out[out_idx] = a; out[out_idx++] /= b;
|
||||
|
||||
const ComplexType true_value = ComplexType(ValueType(1), ValueType(0));
|
||||
const ComplexType false_value = ComplexType(ValueType(0), ValueType(0));
|
||||
out[out_idx++] = (a == b ? true_value : false_value);
|
||||
out[out_idx++] = (a == numext::real(b) ? true_value : false_value);
|
||||
out[out_idx++] = (numext::real(a) == b ? true_value : false_value);
|
||||
out[out_idx++] = (a != b ? true_value : false_value);
|
||||
out[out_idx++] = (a != numext::real(b) ? true_value : false_value);
|
||||
out[out_idx++] = (numext::real(a) != b ? true_value : false_value);
|
||||
|
||||
// Vector versions.
|
||||
T x1(in + i);
|
||||
T x2(in + i + 1);
|
||||
const int res_size = T::MaxSizeAtCompileTime * num_scalar_operators;
|
||||
const int size = T::MaxSizeAtCompileTime;
|
||||
int block_idx = 0;
|
||||
|
||||
Map<VectorX<ComplexType>> res(out + out_idx, res_size);
|
||||
res.segment(block_idx, size) = -x1;
|
||||
block_idx += size;
|
||||
|
||||
res.segment(block_idx, size) = x1 + x2;
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1 + x2.real();
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1.real() + x2;
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1 - x2;
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1 - x2.real();
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1.real() - x2;
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1.array() * x2.array();
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1.array() * x2.real().array();
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1.real().array() * x2.array();
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1.array() / x2.array();
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1.array() / x2.real().array();
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1.real().array() / x2.array();
|
||||
block_idx += size;
|
||||
|
||||
res.segment(block_idx, size) = x1; res.segment(block_idx, size) += x2;
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1; res.segment(block_idx, size) -= x2;
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() *= x2.array();
|
||||
block_idx += size;
|
||||
res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() /= x2.array();
|
||||
block_idx += size;
|
||||
|
||||
// Equality comparisons currently not functional on device
|
||||
// (std::equal_to<T> is host-only).
|
||||
// const T true_vector = T::Constant(true_value);
|
||||
// const T false_vector = T::Constant(false_value);
|
||||
// res.segment(block_idx, size) = (x1 == x2 ? true_vector : false_vector);
|
||||
// block_idx += size;
|
||||
// res.segment(block_idx, size) = (x1 == x2.real() ? true_vector : false_vector);
|
||||
// block_idx += size;
|
||||
// res.segment(block_idx, size) = (x1.real() == x2 ? true_vector : false_vector);
|
||||
// block_idx += size;
|
||||
// res.segment(block_idx, size) = (x1 != x2 ? true_vector : false_vector);
|
||||
// block_idx += size;
|
||||
// res.segment(block_idx, size) = (x1 != x2.real() ? true_vector : false_vector);
|
||||
// block_idx += size;
|
||||
// res.segment(block_idx, size) = (x1.real() != x2 ? true_vector : false_vector);
|
||||
// block_idx += size;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct replicate {
|
||||
EIGEN_DEVICE_FUNC
|
||||
@@ -297,6 +407,8 @@ EIGEN_DECLARE_TEST(gpu_basic)
|
||||
CALL_SUBTEST( run_and_compare_to_gpu(eigenvalues_direct<Matrix3f>(), nthreads, in, out) );
|
||||
CALL_SUBTEST( run_and_compare_to_gpu(eigenvalues_direct<Matrix2f>(), nthreads, in, out) );
|
||||
|
||||
// Test std::complex.
|
||||
CALL_SUBTEST( run_and_compare_to_gpu(complex_operators<Vector3cf>(), nthreads, cfin, cfout) );
|
||||
CALL_SUBTEST( test_with_infs_nans(complex_sqrt<Vector3cf>(), nthreads, cfin, cfout) );
|
||||
|
||||
#if defined(__NVCC__)
|
||||
|
||||
Reference in New Issue
Block a user