Partially Vectorize Cast

This commit is contained in:
Charles Schlosser
2023-06-09 16:54:31 +00:00
committed by Rasmus Munk Larsen
parent 7d7576f326
commit 59b3ef5409
14 changed files with 1051 additions and 399 deletions

View File

@@ -9,6 +9,7 @@
#include <vector>
#include "main.h"
#include "random_without_cast_overflow.h"
// suppress annoying unsigned integer warnings
template <typename Scalar, bool IsSigned = NumTraits<Scalar>::IsSigned>
@@ -1213,6 +1214,109 @@ void typed_logicals_test(const ArrayType& m) {
typed_logicals_test_impl<ArrayType>::run(m);
}
template <typename SrcType, typename DstType, int RowsAtCompileTime, int ColsAtCompileTime>
struct cast_test_impl {
using SrcArray = Array<SrcType, RowsAtCompileTime, ColsAtCompileTime>;
using DstArray = Array<DstType, RowsAtCompileTime, ColsAtCompileTime>;
struct RandomOp {
inline SrcType operator()(const SrcType&) const {
return internal::random_without_cast_overflow<SrcType, DstType>::value();
}
};
static constexpr int SrcPacketSize = internal::packet_traits<SrcType>::size;
static constexpr int DstPacketSize = internal::packet_traits<DstType>::size;
static constexpr int MaxPacketSize = internal::plain_enum_max(SrcPacketSize, DstPacketSize);
// print non-mangled typenames
template <typename T>
static std::string printTypeInfo(const T&) {
if (internal::is_same<bool, T>::value)
return "bool";
else if (internal::is_same<int8_t, T>::value)
return "int8_t";
else if (internal::is_same<int16_t, T>::value)
return "int16_t";
else if (internal::is_same<int32_t, T>::value)
return "int32_t";
else if (internal::is_same<int64_t, T>::value)
return "int64_t";
else if (internal::is_same<uint8_t, T>::value)
return "uint8_t";
else if (internal::is_same<uint16_t, T>::value)
return "uint16_t";
else if (internal::is_same<uint32_t, T>::value)
return "uint32_t";
else if (internal::is_same<uint64_t, T>::value)
return "uint64_t";
else if (internal::is_same<float, T>::value)
return "float";
else if (internal::is_same<double, T>::value)
return "double";
//else if (internal::is_same<long double, T>::value)
// return "long double";
else if (internal::is_same<half, T>::value)
return "half";
else if (internal::is_same<bfloat16, T>::value)
return "bfloat16";
else
return typeid(T).name();
}
static void run() {
const Index testRows = RowsAtCompileTime == Dynamic ? ((10 * MaxPacketSize) + 1) : RowsAtCompileTime;
const Index testCols = ColsAtCompileTime == Dynamic ? ((10 * MaxPacketSize) + 1) : ColsAtCompileTime;
const Index testSize = testRows * testCols;
const Index minTestSize = 100;
const Index repeats = numext::div_ceil(minTestSize, testSize);
SrcArray src(testRows, testCols);
DstArray dst(testRows, testCols);
for (Index repeat = 0; repeat < repeats; repeat++) {
src = src.unaryExpr(RandomOp());
dst = src.template cast<DstType>();
for (Index i = 0; i < testRows; i++)
for (Index j = 0; j < testCols; j++) {
DstType ref = internal::cast_impl<SrcType, DstType>::run(src(i, j));
bool all_nan = ((numext::isnan)(src(i, j)) && (numext::isnan)(ref) && (numext::isnan)(dst(i, j)));
bool is_equal = ref == dst(i, j);
bool pass = all_nan || is_equal;
if (!pass) {
std::cout << printTypeInfo(SrcType()) << ": [" << +src(i, j) << "] to " << printTypeInfo(DstType()) << ": ["
<< +dst(i, j) << "] != [" << +ref << "]\n";
}
VERIFY(pass);
}
}
}
};
template <int RowsAtCompileTime, int ColsAtCompileTime, typename... ScalarTypes>
struct cast_tests_impl {
using ScalarTuple = std::tuple<ScalarTypes...>;
static constexpr size_t ScalarTupleSize = std::tuple_size<ScalarTuple>::value;
template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
static std::enable_if_t<Done> run() {}
template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
static std::enable_if_t<!Done> run() {
using Type1 = typename std::tuple_element<i, ScalarTuple>::type;
using Type2 = typename std::tuple_element<j, ScalarTuple>::type;
cast_test_impl<Type1, Type2, RowsAtCompileTime, ColsAtCompileTime>::run();
cast_test_impl<Type2, Type1, RowsAtCompileTime, ColsAtCompileTime>::run();
static constexpr size_t next_i = (j == ScalarTupleSize - 1) ? (i + 1) : (i + 0);
static constexpr size_t next_j = (j == ScalarTupleSize - 1) ? (i + 2) : (j + 1);
run<next_i, next_j>();
}
};
// for now, remove all references to 'long double' until test passes on all platforms
template <int RowsAtCompileTime, int ColsAtCompileTime>
void cast_test() {
cast_tests_impl<RowsAtCompileTime, ColsAtCompileTime, bool, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t,
uint32_t, uint64_t, float, double, /*long double, */half, bfloat16>::run();
}
EIGEN_DECLARE_TEST(array_cwise)
{
for(int i = 0; i < g_repeat; i++) {
@@ -1269,6 +1373,20 @@ EIGEN_DECLARE_TEST(array_cwise)
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<double>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
}
for (int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1((cast_test<1, 1>()));
CALL_SUBTEST_2((cast_test<3, 1>()));
CALL_SUBTEST_2((cast_test<3, 3>()));
CALL_SUBTEST_3((cast_test<5, 1>()));
CALL_SUBTEST_3((cast_test<5, 5>()));
CALL_SUBTEST_4((cast_test<9, 1>()));
CALL_SUBTEST_4((cast_test<9, 9>()));
CALL_SUBTEST_5((cast_test<17, 1>()));
CALL_SUBTEST_5((cast_test<17, 17>()));
CALL_SUBTEST_6((cast_test<Dynamic, 1>()));
CALL_SUBTEST_6((cast_test<Dynamic, Dynamic>()));
}
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<Array2i>::type, ArrayBase<Array2i> >::value));