mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Add typed logicals
This commit is contained in:
committed by
Rasmus Munk Larsen
parent
e797974689
commit
049a144798
@@ -29,7 +29,7 @@ std::vector<Scalar> special_values() {
|
||||
const Scalar two = Scalar(2);
|
||||
const Scalar three = Scalar(3);
|
||||
const Scalar sqrt_half = Scalar(std::sqrt(0.5));
|
||||
const Scalar sqrt2 = Scalar(std::sqrt(2));
|
||||
const Scalar sqrt2 = Scalar(std::sqrt(2));
|
||||
const Scalar inf = Eigen::NumTraits<Scalar>::infinity();
|
||||
const Scalar nan = Eigen::NumTraits<Scalar>::quiet_NaN();
|
||||
const Scalar denorm_min = std::numeric_limits<Scalar>::denorm_min();
|
||||
@@ -968,6 +968,104 @@ void signed_shift_test(const ArrayType& m) {
|
||||
signed_shift_test_impl<ArrayType>::run(m);
|
||||
}
|
||||
|
||||
template <typename ArrayType>
|
||||
struct typed_logicals_test_impl {
|
||||
using Scalar = typename ArrayType::Scalar;
|
||||
|
||||
static bool scalar_to_bool(const Scalar& x) { return x != Scalar(0); }
|
||||
static Scalar bool_to_scalar(const bool& x) { return x ? Scalar(1) : Scalar(0); }
|
||||
|
||||
static Scalar eval_bool_and(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) && scalar_to_bool(y)); }
|
||||
static Scalar eval_bool_or(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) || scalar_to_bool(y)); }
|
||||
static Scalar eval_bool_xor(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) != scalar_to_bool(y)); }
|
||||
static Scalar eval_bool_not(const Scalar& x) { return bool_to_scalar(!scalar_to_bool(x)); }
|
||||
|
||||
static void run(const ArrayType& m) {
|
||||
|
||||
Index rows = m.rows();
|
||||
Index cols = m.cols();
|
||||
|
||||
ArrayType m1(rows, cols), m2(rows, cols), m3(rows, cols), m4(rows, cols);
|
||||
|
||||
m1.setRandom();
|
||||
m2.setRandom();
|
||||
m1 *= ArrayX<bool>::Random(rows, cols).cast<Scalar>();
|
||||
m2 *= ArrayX<bool>::Random(rows, cols).cast<Scalar>();
|
||||
|
||||
// test boolean and
|
||||
m3 = m1 && m2;
|
||||
m4 = m1.binaryExpr(m2, [](const Scalar& x, const Scalar& y) { return eval_bool_and(x, y); });
|
||||
VERIFY_IS_CWISE_EQUAL(m3, m4);
|
||||
for (const Scalar& val : m3) VERIFY(val == Scalar(0) || val == Scalar(1));
|
||||
|
||||
// test boolean or
|
||||
m3 = m1 || m2;
|
||||
m4 = m1.binaryExpr(m2, [](const Scalar& x, const Scalar& y) { return eval_bool_or(x, y); });
|
||||
VERIFY_IS_CWISE_EQUAL(m3, m4);
|
||||
for (const Scalar& val : m3) VERIFY(val == Scalar(0) || val == Scalar(1));
|
||||
|
||||
// test boolean xor
|
||||
m3 = m1.binaryExpr(m2, internal::scalar_boolean_xor_op<Scalar>());
|
||||
m4 = m1.binaryExpr(m2, [](const Scalar& x, const Scalar& y) { return eval_bool_xor(x, y); });
|
||||
VERIFY_IS_CWISE_EQUAL(m3, m4);
|
||||
for (const Scalar& val : m3) VERIFY(val == Scalar(0) || val == Scalar(1));
|
||||
|
||||
// test boolean not
|
||||
m3 = !m1;
|
||||
m4 = m1.unaryExpr([](const Scalar& x) { return eval_bool_not(x); });
|
||||
VERIFY_IS_CWISE_EQUAL(m3, m4);
|
||||
for (const Scalar& val : m3) VERIFY(val == Scalar(0) || val == Scalar(1));
|
||||
|
||||
// test something more complicated
|
||||
m3 = m1 && m2;
|
||||
m4 = !(!m1 || !m2);
|
||||
VERIFY_IS_CWISE_EQUAL(m3, m4);
|
||||
|
||||
m3 = m1.binaryExpr(m2, internal::scalar_boolean_xor_op<Scalar>());
|
||||
m4 = (!m1).binaryExpr((!m2), internal::scalar_boolean_xor_op<Scalar>());
|
||||
VERIFY_IS_CWISE_EQUAL(m3, m4);
|
||||
|
||||
const Index bytes = rows * cols * sizeof(Scalar);
|
||||
const uint8_t* m1_data = reinterpret_cast<const uint8_t*>(m1.data());
|
||||
const uint8_t* m2_data = reinterpret_cast<const uint8_t*>(m2.data());
|
||||
uint8_t* m3_data = reinterpret_cast<uint8_t*>(m3.data());
|
||||
uint8_t* m4_data = reinterpret_cast<uint8_t*>(m4.data());
|
||||
|
||||
// test bitwise and
|
||||
m3 = m1 & m2;
|
||||
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] & m2_data[i];
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
|
||||
// test bitwise or
|
||||
m3 = m1 | m2;
|
||||
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] | m2_data[i];
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
|
||||
// test bitwise xor
|
||||
m3 = m1 ^ m2;
|
||||
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] ^ m2_data[i];
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
|
||||
// test bitwise not
|
||||
m3 = ~m1;
|
||||
for (Index i = 0; i < bytes; i++) m4_data[i] = ~m1_data[i];
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
|
||||
// test something more complicated
|
||||
m3 = m1 & m2;
|
||||
m4 = ~(~m1 | ~m2);
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
|
||||
m3 = m1 ^ m2;
|
||||
m4 = (~m1) ^ (~m2);
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
}
|
||||
};
|
||||
template <typename ArrayType>
|
||||
void typed_logicals_test(const ArrayType& m) {
|
||||
typed_logicals_test_impl<ArrayType>::run(m);
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(array_cwise)
|
||||
{
|
||||
for(int i = 0; i < g_repeat; i++) {
|
||||
@@ -1016,6 +1114,14 @@ EIGEN_DECLARE_TEST(array_cwise)
|
||||
CALL_SUBTEST_7( mixed_pow_test() );
|
||||
CALL_SUBTEST_8( signbit_tests() );
|
||||
}
|
||||
for (int i = 0; i < g_repeat; i++) {
|
||||
CALL_SUBTEST_1( typed_logicals_test(ArrayX<bool>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_2( typed_logicals_test(ArrayX<int>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_2( typed_logicals_test(ArrayX<float>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_3( typed_logicals_test(ArrayX<double>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<float>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<double>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||
}
|
||||
|
||||
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
|
||||
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));
|
||||
|
||||
Reference in New Issue
Block a user