Add typed logicals

This commit is contained in:
Charles Schlosser
2023-02-18 01:23:47 +00:00
committed by Rasmus Munk Larsen
parent e797974689
commit 049a144798
13 changed files with 415 additions and 124 deletions

View File

@@ -29,7 +29,7 @@ std::vector<Scalar> special_values() {
const Scalar two = Scalar(2);
const Scalar three = Scalar(3);
const Scalar sqrt_half = Scalar(std::sqrt(0.5));
const Scalar sqrt2 = Scalar(std::sqrt(2));
const Scalar sqrt2 = Scalar(std::sqrt(2));
const Scalar inf = Eigen::NumTraits<Scalar>::infinity();
const Scalar nan = Eigen::NumTraits<Scalar>::quiet_NaN();
const Scalar denorm_min = std::numeric_limits<Scalar>::denorm_min();
@@ -968,6 +968,104 @@ void signed_shift_test(const ArrayType& m) {
signed_shift_test_impl<ArrayType>::run(m);
}
template <typename ArrayType>
struct typed_logicals_test_impl {
using Scalar = typename ArrayType::Scalar;
static bool scalar_to_bool(const Scalar& x) { return x != Scalar(0); }
static Scalar bool_to_scalar(const bool& x) { return x ? Scalar(1) : Scalar(0); }
static Scalar eval_bool_and(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) && scalar_to_bool(y)); }
static Scalar eval_bool_or(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) || scalar_to_bool(y)); }
static Scalar eval_bool_xor(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) != scalar_to_bool(y)); }
static Scalar eval_bool_not(const Scalar& x) { return bool_to_scalar(!scalar_to_bool(x)); }
static void run(const ArrayType& m) {
Index rows = m.rows();
Index cols = m.cols();
ArrayType m1(rows, cols), m2(rows, cols), m3(rows, cols), m4(rows, cols);
m1.setRandom();
m2.setRandom();
m1 *= ArrayX<bool>::Random(rows, cols).cast<Scalar>();
m2 *= ArrayX<bool>::Random(rows, cols).cast<Scalar>();
// test boolean and
m3 = m1 && m2;
m4 = m1.binaryExpr(m2, [](const Scalar& x, const Scalar& y) { return eval_bool_and(x, y); });
VERIFY_IS_CWISE_EQUAL(m3, m4);
for (const Scalar& val : m3) VERIFY(val == Scalar(0) || val == Scalar(1));
// test boolean or
m3 = m1 || m2;
m4 = m1.binaryExpr(m2, [](const Scalar& x, const Scalar& y) { return eval_bool_or(x, y); });
VERIFY_IS_CWISE_EQUAL(m3, m4);
for (const Scalar& val : m3) VERIFY(val == Scalar(0) || val == Scalar(1));
// test boolean xor
m3 = m1.binaryExpr(m2, internal::scalar_boolean_xor_op<Scalar>());
m4 = m1.binaryExpr(m2, [](const Scalar& x, const Scalar& y) { return eval_bool_xor(x, y); });
VERIFY_IS_CWISE_EQUAL(m3, m4);
for (const Scalar& val : m3) VERIFY(val == Scalar(0) || val == Scalar(1));
// test boolean not
m3 = !m1;
m4 = m1.unaryExpr([](const Scalar& x) { return eval_bool_not(x); });
VERIFY_IS_CWISE_EQUAL(m3, m4);
for (const Scalar& val : m3) VERIFY(val == Scalar(0) || val == Scalar(1));
// test something more complicated
m3 = m1 && m2;
m4 = !(!m1 || !m2);
VERIFY_IS_CWISE_EQUAL(m3, m4);
m3 = m1.binaryExpr(m2, internal::scalar_boolean_xor_op<Scalar>());
m4 = (!m1).binaryExpr((!m2), internal::scalar_boolean_xor_op<Scalar>());
VERIFY_IS_CWISE_EQUAL(m3, m4);
const Index bytes = rows * cols * sizeof(Scalar);
const uint8_t* m1_data = reinterpret_cast<const uint8_t*>(m1.data());
const uint8_t* m2_data = reinterpret_cast<const uint8_t*>(m2.data());
uint8_t* m3_data = reinterpret_cast<uint8_t*>(m3.data());
uint8_t* m4_data = reinterpret_cast<uint8_t*>(m4.data());
// test bitwise and
m3 = m1 & m2;
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] & m2_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
// test bitwise or
m3 = m1 | m2;
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] | m2_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
// test bitwise xor
m3 = m1 ^ m2;
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] ^ m2_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
// test bitwise not
m3 = ~m1;
for (Index i = 0; i < bytes; i++) m4_data[i] = ~m1_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
// test something more complicated
m3 = m1 & m2;
m4 = ~(~m1 | ~m2);
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
m3 = m1 ^ m2;
m4 = (~m1) ^ (~m2);
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
}
};
template <typename ArrayType>
void typed_logicals_test(const ArrayType& m) {
typed_logicals_test_impl<ArrayType>::run(m);
}
EIGEN_DECLARE_TEST(array_cwise)
{
for(int i = 0; i < g_repeat; i++) {
@@ -1016,6 +1114,14 @@ EIGEN_DECLARE_TEST(array_cwise)
CALL_SUBTEST_7( mixed_pow_test() );
CALL_SUBTEST_8( signbit_tests() );
}
for (int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1( typed_logicals_test(ArrayX<bool>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_2( typed_logicals_test(ArrayX<int>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_2( typed_logicals_test(ArrayX<float>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_3( typed_logicals_test(ArrayX<double>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<float>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<double>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
}
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));