Add missing x86 pcasts

This commit is contained in:
Charles Schlosser
2023-07-28 23:41:38 +00:00
committed by Rasmus Munk Larsen
parent 24d15e086f
commit 5527e78a64
5 changed files with 177 additions and 238 deletions

View File

@@ -1211,6 +1211,22 @@ void typed_logicals_test(const ArrayType& m) {
typed_logicals_test_impl<ArrayType>::run(m);
}
// print non-mangled typenames
template<typename T> std::string printTypeInfo(const T&) { return typeid(T).name(); }
template<> std::string printTypeInfo(const int8_t&) { return "int8_t"; }
template<> std::string printTypeInfo(const int16_t&) { return "int16_t"; }
template<> std::string printTypeInfo(const int32_t&) { return "int32_t"; }
template<> std::string printTypeInfo(const int64_t&) { return "int64_t"; }
template<> std::string printTypeInfo(const uint8_t&) { return "uint8_t"; }
template<> std::string printTypeInfo(const uint16_t&) { return "uint16_t"; }
template<> std::string printTypeInfo(const uint32_t&) { return "uint32_t"; }
template<> std::string printTypeInfo(const uint64_t&) { return "uint64_t"; }
template<> std::string printTypeInfo(const float&) { return "float"; }
template<> std::string printTypeInfo(const double&) { return "double"; }
//template<> std::string printTypeInfo(const long double&) { return "long double"; }
template<> std::string printTypeInfo(const half&) { return "half"; }
template<> std::string printTypeInfo(const bfloat16&) { return "bfloat16"; }
template <typename SrcType, typename DstType, int RowsAtCompileTime, int ColsAtCompileTime>
struct cast_test_impl {
using SrcArray = Array<SrcType, RowsAtCompileTime, ColsAtCompileTime>;
@@ -1225,63 +1241,30 @@ struct cast_test_impl {
static constexpr int DstPacketSize = internal::packet_traits<DstType>::size;
static constexpr int MaxPacketSize = internal::plain_enum_max(SrcPacketSize, DstPacketSize);
// print non-mangled typenames
template <typename T>
static std::string printTypeInfo(const T&) {
if (internal::is_same<bool, T>::value)
return "bool";
else if (internal::is_same<int8_t, T>::value)
return "int8_t";
else if (internal::is_same<int16_t, T>::value)
return "int16_t";
else if (internal::is_same<int32_t, T>::value)
return "int32_t";
else if (internal::is_same<int64_t, T>::value)
return "int64_t";
else if (internal::is_same<uint8_t, T>::value)
return "uint8_t";
else if (internal::is_same<uint16_t, T>::value)
return "uint16_t";
else if (internal::is_same<uint32_t, T>::value)
return "uint32_t";
else if (internal::is_same<uint64_t, T>::value)
return "uint64_t";
else if (internal::is_same<float, T>::value)
return "float";
else if (internal::is_same<double, T>::value)
return "double";
//else if (internal::is_same<long double, T>::value)
// return "long double";
else if (internal::is_same<half, T>::value)
return "half";
else if (internal::is_same<bfloat16, T>::value)
return "bfloat16";
else
return typeid(T).name();
}
static void run() {
const Index testRows = RowsAtCompileTime == Dynamic ? ((10 * MaxPacketSize) + 1) : RowsAtCompileTime;
const Index testCols = ColsAtCompileTime == Dynamic ? ((10 * MaxPacketSize) + 1) : ColsAtCompileTime;
const Index testSize = testRows * testCols;
const Index minTestSize = 100;
const Index repeats = numext::div_ceil(minTestSize, testSize);
SrcArray src(testRows, testCols);
DstArray dst(testRows, testCols);
for (Index repeat = 0; repeat < repeats; repeat++) {
src = src.unaryExpr(RandomOp());
dst = src.template cast<DstType>();
for (Index i = 0; i < testRows; i++)
for (Index j = 0; j < testCols; j++) {
DstType ref = internal::cast_impl<SrcType, DstType>::run(src(i, j));
bool all_nan = ((numext::isnan)(src(i, j)) && (numext::isnan)(ref) && (numext::isnan)(dst(i, j)));
bool is_equal = ref == dst(i, j);
bool pass = all_nan || is_equal;
if (!pass) {
std::cout << printTypeInfo(SrcType()) << ": [" << +src(i, j) << "] to " << printTypeInfo(DstType()) << ": ["
<< +dst(i, j) << "] != [" << +ref << "]\n";
}
VERIFY(pass);
for (Index j = 0; j < testCols; j++)
for (Index i = 0; i < testRows; i++) {
SrcType srcVal = src(i, j);
DstType refVal = internal::cast_impl<SrcType, DstType>::run(srcVal);
DstType dstVal = dst(i, j);
bool isApprox = verifyIsApprox(dstVal, refVal);
if (!isApprox)
std::cout << printTypeInfo(srcVal) << ": [" << +srcVal << "] to " << printTypeInfo(dstVal) << ": ["
<< +dstVal << "] != [" << +refVal << "]\n";
VERIFY(isApprox);
}
}
}