mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Fix tensor casts for large packets and casts to/from std::complex
The original tensor casts were only defined for `SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the missing 1:N and 8:1. We also add casting `Eigen::half` to/from `std::complex<T>`, which was missing to make it consistent with `Eigen:bfloat16`, and generalize the overload to work for any complex type. Tests were added to `basicstuff`, `packetmath`, and `cxx11_tensor_casts` to test all cast configurations.
This commit is contained in:
committed by
Antonio Sánchez
parent
145e51516f
commit
9cb8771e9c
@@ -8,8 +8,8 @@
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include <limits>
|
||||
#include "packetmath_test_shared.h"
|
||||
#include "random_without_cast_overflow.h"
|
||||
|
||||
template <typename T>
|
||||
inline T REF_ADD(const T& a, const T& b) {
|
||||
@@ -126,129 +126,6 @@ struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, fals
|
||||
static void run() {}
|
||||
};
|
||||
|
||||
// Generates random values that fit in both SrcScalar and TgtScalar without
|
||||
// overflowing when cast.
|
||||
template <typename SrcScalar, typename TgtScalar, typename EnableIf = void>
|
||||
struct random_without_cast_overflow {
|
||||
static SrcScalar value() { return internal::random<SrcScalar>(); }
|
||||
};
|
||||
|
||||
// Widening integer cast signed to unsigned.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger &&
|
||||
!NumTraits<TgtScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits < std::numeric_limits<TgtScalar>::digits ||
|
||||
(std::numeric_limits<SrcScalar>::digits == std::numeric_limits<TgtScalar>::digits &&
|
||||
NumTraits<SrcScalar>::IsSigned))>::type> {
|
||||
static SrcScalar value() {
|
||||
SrcScalar a = internal::random<SrcScalar>();
|
||||
return a < SrcScalar(0) ? -(a + 1) : a;
|
||||
}
|
||||
};
|
||||
|
||||
// Narrowing integer cast to unsigned.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<
|
||||
NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && !NumTraits<SrcScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() {
|
||||
TgtScalar b = internal::random<TgtScalar>();
|
||||
return static_cast<SrcScalar>(b < TgtScalar(0) ? -(b + 1) : b);
|
||||
}
|
||||
};
|
||||
|
||||
// Narrowing integer cast to signed.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<
|
||||
NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && NumTraits<SrcScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() {
|
||||
TgtScalar b = internal::random<TgtScalar>();
|
||||
return static_cast<SrcScalar>(b);
|
||||
}
|
||||
};
|
||||
|
||||
// Unsigned to signed narrowing cast.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger &&
|
||||
!NumTraits<SrcScalar>::IsSigned && NumTraits<TgtScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits ==
|
||||
std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return internal::random<SrcScalar>() / 2; }
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct is_floating_point {
|
||||
enum { value = 0 };
|
||||
};
|
||||
template <>
|
||||
struct is_floating_point<float> {
|
||||
enum { value = 1 };
|
||||
};
|
||||
template <>
|
||||
struct is_floating_point<double> {
|
||||
enum { value = 1 };
|
||||
};
|
||||
template <>
|
||||
struct is_floating_point<half> {
|
||||
enum { value = 1 };
|
||||
};
|
||||
template <>
|
||||
struct is_floating_point<bfloat16> {
|
||||
enum { value = 1 };
|
||||
};
|
||||
|
||||
// Floating-point to integer, full precision.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<is_floating_point<SrcScalar>::value && NumTraits<TgtScalar>::IsInteger &&
|
||||
(std::numeric_limits<TgtScalar>::digits <=
|
||||
std::numeric_limits<SrcScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); }
|
||||
};
|
||||
|
||||
// Floating-point to integer, narrowing precision.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<is_floating_point<SrcScalar>::value && NumTraits<TgtScalar>::IsInteger &&
|
||||
(std::numeric_limits<TgtScalar>::digits >
|
||||
std::numeric_limits<SrcScalar>::digits)>::type> {
|
||||
static SrcScalar value() {
|
||||
static const int BitShift = std::numeric_limits<TgtScalar>::digits - std::numeric_limits<SrcScalar>::digits;
|
||||
return static_cast<SrcScalar>(internal::random<TgtScalar>() >> BitShift);
|
||||
}
|
||||
};
|
||||
|
||||
// Floating-point target from integer, re-use above logic.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && is_floating_point<TgtScalar>::value>::type> {
|
||||
static SrcScalar value() {
|
||||
return static_cast<SrcScalar>(random_without_cast_overflow<TgtScalar, SrcScalar>::value());
|
||||
}
|
||||
};
|
||||
|
||||
// Floating-point narrowing conversion.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<is_floating_point<SrcScalar>::value && is_floating_point<TgtScalar>::value &&
|
||||
(std::numeric_limits<SrcScalar>::digits >
|
||||
std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); }
|
||||
};
|
||||
|
||||
template <typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio>
|
||||
struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, true> {
|
||||
static void run() {
|
||||
@@ -266,10 +143,12 @@ struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, true
|
||||
|
||||
// Construct a packet of scalars that will not overflow when casting
|
||||
for (int i = 0; i < DataSize; ++i) {
|
||||
data1[i] = random_without_cast_overflow<SrcScalar, TgtScalar>::value();
|
||||
data1[i] = internal::random_without_cast_overflow<SrcScalar, TgtScalar>::value();
|
||||
}
|
||||
|
||||
for (int i = 0; i < DataSize; ++i) ref[i] = static_cast<const TgtScalar>(data1[i]);
|
||||
for (int i = 0; i < DataSize; ++i) {
|
||||
ref[i] = static_cast<const TgtScalar>(data1[i]);
|
||||
}
|
||||
|
||||
pcast_array<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio>::cast(data1, DataSize, data2);
|
||||
|
||||
@@ -318,21 +197,37 @@ struct test_cast_runner<SrcPacket, TgtScalar, TgtPacket, false, false> {
|
||||
static void run() {}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Packet, typename EnableIf = void>
|
||||
struct packetmath_pcast_ops_runner {
|
||||
static void run() {
|
||||
test_cast_runner<Packet, float>::run();
|
||||
test_cast_runner<Packet, double>::run();
|
||||
test_cast_runner<Packet, int8_t>::run();
|
||||
test_cast_runner<Packet, uint8_t>::run();
|
||||
test_cast_runner<Packet, int16_t>::run();
|
||||
test_cast_runner<Packet, uint16_t>::run();
|
||||
test_cast_runner<Packet, int32_t>::run();
|
||||
test_cast_runner<Packet, uint32_t>::run();
|
||||
test_cast_runner<Packet, int64_t>::run();
|
||||
test_cast_runner<Packet, uint64_t>::run();
|
||||
test_cast_runner<Packet, bool>::run();
|
||||
test_cast_runner<Packet, std::complex<float>>::run();
|
||||
test_cast_runner<Packet, std::complex<double>>::run();
|
||||
test_cast_runner<Packet, half>::run();
|
||||
test_cast_runner<Packet, bfloat16>::run();
|
||||
}
|
||||
};
|
||||
|
||||
// Only some types support cast from std::complex<>.
|
||||
template <typename Scalar, typename Packet>
|
||||
void packetmath_pcast_ops() {
|
||||
test_cast_runner<Packet, float>::run();
|
||||
test_cast_runner<Packet, double>::run();
|
||||
test_cast_runner<Packet, int8_t>::run();
|
||||
test_cast_runner<Packet, uint8_t>::run();
|
||||
test_cast_runner<Packet, int16_t>::run();
|
||||
test_cast_runner<Packet, uint16_t>::run();
|
||||
test_cast_runner<Packet, int32_t>::run();
|
||||
test_cast_runner<Packet, uint32_t>::run();
|
||||
test_cast_runner<Packet, int64_t>::run();
|
||||
test_cast_runner<Packet, uint64_t>::run();
|
||||
test_cast_runner<Packet, bool>::run();
|
||||
test_cast_runner<Packet, half>::run();
|
||||
}
|
||||
struct packetmath_pcast_ops_runner<Scalar, Packet, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> {
|
||||
static void run() {
|
||||
test_cast_runner<Packet, std::complex<float>>::run();
|
||||
test_cast_runner<Packet, std::complex<double>>::run();
|
||||
test_cast_runner<Packet, half>::run();
|
||||
test_cast_runner<Packet, bfloat16>::run();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
void packetmath_boolean_mask_ops() {
|
||||
@@ -356,10 +251,8 @@ void packetmath_boolean_mask_ops() {
|
||||
|
||||
// Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path
|
||||
// (for some compilers) compute the bitwise and with 0x1 of the results to keep the value in [0,1].
|
||||
#ifdef EIGEN_PACKET_MATH_SSE_H
|
||||
template <>
|
||||
void packetmath_boolean_mask_ops<bool, internal::Packet16b>() {}
|
||||
#endif
|
||||
template<>
|
||||
void packetmath_boolean_mask_ops<bool, typename internal::packet_traits<bool>::type>() {}
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
void packetmath() {
|
||||
@@ -560,7 +453,7 @@ void packetmath() {
|
||||
CHECK_CWISE2_IF(true, internal::pand, internal::pand);
|
||||
|
||||
packetmath_boolean_mask_ops<Scalar, Packet>();
|
||||
packetmath_pcast_ops<Scalar, Packet>();
|
||||
packetmath_pcast_ops_runner<Scalar, Packet>::run();
|
||||
}
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
@@ -975,9 +868,7 @@ EIGEN_DECLARE_TEST(packetmath) {
|
||||
CALL_SUBTEST_11(test::runner<std::complex<float> >::run());
|
||||
CALL_SUBTEST_12(test::runner<std::complex<double> >::run());
|
||||
CALL_SUBTEST_13((packetmath<half, internal::packet_traits<half>::type>()));
|
||||
#ifdef EIGEN_PACKET_MATH_SSE_H
|
||||
CALL_SUBTEST_14((packetmath<bool, internal::packet_traits<bool>::type>()));
|
||||
#endif
|
||||
CALL_SUBTEST_15((packetmath<bfloat16, internal::packet_traits<bfloat16>::type>()));
|
||||
g_first_pass = false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user