Vectorize cast

This commit is contained in:
Charles Schlosser
2023-04-26 02:50:13 +00:00
committed by Rasmus Munk Larsen
parent 3918768be1
commit eb5ff1861a
8 changed files with 254 additions and 128 deletions

View File

@@ -9,6 +9,7 @@
#include <vector>
#include "main.h"
#include "random_without_cast_overflow.h"
template <typename Scalar, std::enable_if_t<NumTraits<Scalar>::IsInteger,int> = 0>
std::vector<Scalar> special_values() {
@@ -1183,6 +1184,59 @@ void typed_logicals_test(const ArrayType& m) {
typed_logicals_test_impl<ArrayType>::run(m);
}
template <typename SrcType, typename DstType>
struct cast_test_impl {
using SrcArray = ArrayX<SrcType>;
using DstArray = ArrayX<DstType>;
static constexpr int SrcPacketSize = internal::packet_traits<SrcType>::size;
static constexpr int DstPacketSize = internal::packet_traits<DstType>::size;
static constexpr int MaxPacketSize = internal::plain_enum_max(SrcPacketSize, DstPacketSize);
static void run() {
const Index testSize = 100 * MaxPacketSize;
SrcArray src(testSize);
for (Index i = 0; i < testSize; i++) src(i) = internal::random_without_cast_overflow<SrcType, DstType>::value();
DstArray dst = src.template cast<DstType>();
for (Index i = 0; i < testSize; i++) {
DstType ref = static_cast<DstType>(src(i));
bool all_nan = ((numext::isnan)(src(i)) && (numext::isnan)(ref) && (numext::isnan)(dst(i)));
bool is_equal = ref == dst(i);
bool pass = all_nan || is_equal;
if (!pass) {
std::cout << typeid(SrcType).name() << ": [" << +src(i) << "] to " << typeid(DstType).name() << ": [" << +dst(i)
<< "] != [" << +ref << "]\n";
}
VERIFY(pass);
}
}
};
template <typename... ScalarTypes>
struct cast_tests_impl {
using ScalarTuple = std::tuple<ScalarTypes...>;
static constexpr size_t ScalarTupleSize = std::tuple_size<ScalarTuple>::value;
template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
static std::enable_if_t<Done> run() {}
template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
static std::enable_if_t<!Done> run() {
using Type1 = typename std::tuple_element<i, ScalarTuple>::type;
using Type2 = typename std::tuple_element<j, ScalarTuple>::type;
cast_test_impl<Type1, Type2>::run();
cast_test_impl<Type2, Type1>::run();
static constexpr size_t next_i = (j == ScalarTupleSize - 1) ? (i + 1) : (i + 0);
static constexpr size_t next_j = (j == ScalarTupleSize - 1) ? (i + 2) : (j + 1);
run<next_i, next_j>();
}
};
void cast_test() {
cast_tests_impl<bool, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t, float, double,
long double, half, bfloat16>::run();
}
EIGEN_DECLARE_TEST(array_cwise)
{
for(int i = 0; i < g_repeat; i++) {
@@ -1238,6 +1292,9 @@ EIGEN_DECLARE_TEST(array_cwise)
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<float>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<double>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
}
for (int i = 0; i < g_repeat; i++) {
cast_test();
}
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));