vectorize comparisons and select by enabling typed comparisons

This commit is contained in:
Charles Schlosser
2023-02-25 20:52:11 +00:00
committed by Rasmus Munk Larsen
parent 2e9b945baf
commit 826627f653
11 changed files with 463 additions and 185 deletions

View File

@@ -590,6 +590,21 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
typedef typename ArrayType::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
// explicitly test both typed and boolean comparison ops
using typed_eq = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>;
using typed_ne = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>;
using typed_lt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>;
using typed_le = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>;
using typed_gt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>;
using typed_ge = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>;
using bool_eq = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, false>;
using bool_ne = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, false>;
using bool_lt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, false>;
using bool_le = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, false>;
using bool_gt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, false>;
using bool_ge = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, false>;
Index rows = m.rows();
Index cols = m.cols();
@@ -603,6 +618,8 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
m4 = (m4.abs()==Scalar(0)).select(1,m4);
// use operator overloads with default return type
VERIFY(((m1 + Scalar(1)) > m1).all());
VERIFY(((m1 - Scalar(1)) < m1).all());
if (rows*cols>1)
@@ -627,6 +644,34 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
VERIFY( ( (m1(r,c)+1) > m1).any() );
VERIFY( ( m1(r,c) == m1).any() );
// currently, any() / all() are not vectorized, so use VERIFY_IS_CWISE_EQUAL to test vectorized path
// use typed comparisons, regardless of operator overload behavior
typename ArrayType::ConstantReturnType typed_true = ArrayType::Constant(rows, cols, Scalar(1));
// (m1 + Scalar(1)) > m1).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, typed_gt()), typed_true);
// (m1 - Scalar(1)) < m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_lt()), typed_true);
// (m1 + Scalar(1)) == (m1 + Scalar(1))).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), typed_eq()), typed_true);
// (m1 - Scalar(1)) != m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_ne()), typed_true);
// (m1 <= m2 || m1 >= m2).all()
VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, typed_le()) || m1.binaryExpr(m2, typed_ge()), typed_true);
// use boolean comparisons, regardless of operator overload behavior
ArrayXX<bool>::ConstantReturnType bool_true = ArrayXX<bool>::Constant(rows, cols, true);
// (m1 + Scalar(1)) > m1).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, bool_gt()), bool_true);
// (m1 - Scalar(1)) < m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_lt()), bool_true);
// (m1 + Scalar(1)) == (m1 + Scalar(1))).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), bool_eq()), bool_true);
// (m1 - Scalar(1)) != m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_ne()), bool_true);
// (m1 <= m2 || m1 >= m2).all()
VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, bool_le()) || m1.binaryExpr(m2, bool_ge()), bool_true);
// test Select
VERIFY_IS_APPROX( (m1<m2).select(m1,m2), m1.cwiseMin(m2) );
VERIFY_IS_APPROX( (m1>m2).select(m1,m2), m1.cwiseMax(m2) );
@@ -642,7 +687,7 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
VERIFY_IS_APPROX( (m1.abs()>=ArrayType::Constant(rows,cols,mid))
.select(m1,0), m3);
// even shorter version:
VERIFY_IS_APPROX( (m1.abs()<mid).select(0,m1), m3);
VERIFY_IS_APPROX( (m1.abs()<mid).select(0,m1), m3);
// count
VERIFY(((m1.abs()+1)>RealScalar(0.1)).count() == rows*cols);
@@ -1039,7 +1084,7 @@ struct typed_logicals_test_impl {
using Scalar = typename ArrayType::Scalar;
static bool scalar_to_bool(const Scalar& x) { return x != Scalar(0); }
static Scalar bool_to_scalar(const bool& x) { return x ? Scalar(1) : Scalar(0); }
static Scalar bool_to_scalar(bool x) { return x ? Scalar(1) : Scalar(0); }
static Scalar eval_bool_and(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) && scalar_to_bool(y)); }
static Scalar eval_bool_or(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) || scalar_to_bool(y)); }
@@ -1091,40 +1136,45 @@ struct typed_logicals_test_impl {
m4 = (!m1).binaryExpr((!m2), internal::scalar_boolean_xor_op<Scalar>());
VERIFY_IS_CWISE_EQUAL(m3, m4);
const Index bytes = rows * cols * sizeof(Scalar);
const uint8_t* m1_data = reinterpret_cast<const uint8_t*>(m1.data());
const uint8_t* m2_data = reinterpret_cast<const uint8_t*>(m2.data());
uint8_t* m3_data = reinterpret_cast<uint8_t*>(m3.data());
uint8_t* m4_data = reinterpret_cast<uint8_t*>(m4.data());
const size_t bytes = size_t(rows) * size_t(cols) * sizeof(Scalar);
std::vector<uint8_t> m1_buffer(bytes), m2_buffer(bytes), m3_buffer(bytes), m4_buffer(bytes);
std::memcpy(m1_buffer.data(), m1.data(), bytes);
std::memcpy(m2_buffer.data(), m2.data(), bytes);
// test bitwise and
m3 = m1 & m2;
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] & m2_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(m1_buffer[i] & m2_buffer[i]));
// test bitwise or
m3 = m1 | m2;
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] | m2_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(m1_buffer[i] | m2_buffer[i]));
// test bitwise xor
m3 = m1 ^ m2;
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] ^ m2_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(m1_buffer[i] ^ m2_buffer[i]));
// test bitwise not
m3 = ~m1;
for (Index i = 0; i < bytes; i++) m4_data[i] = ~m1_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(~m1_buffer[i]));
// test something more complicated
m3 = m1 & m2;
m4 = ~(~m1 | ~m2);
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
std::memcpy(m4_buffer.data(), m4.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], m4_buffer[i]);
m3 = m1 ^ m2;
m4 = (~m1) ^ (~m2);
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
std::memcpy(m4_buffer.data(), m4.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], m4_buffer[i]);
}
};
template <typename ArrayType>
@@ -1181,7 +1231,6 @@ EIGEN_DECLARE_TEST(array_cwise)
CALL_SUBTEST_8( signbit_tests() );
}
for (int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1( typed_logicals_test(ArrayX<bool>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_2( typed_logicals_test(ArrayX<int>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_2( typed_logicals_test(ArrayX<float>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_3( typed_logicals_test(ArrayX<double>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));