mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
vectorize comparisons and select by enabling typed comparisons
This commit is contained in:
committed by
Rasmus Munk Larsen
parent
2e9b945baf
commit
826627f653
@@ -590,6 +590,21 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
|
||||
typedef typename ArrayType::Scalar Scalar;
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
|
||||
// explicitly test both typed and boolean comparison ops
|
||||
using typed_eq = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>;
|
||||
using typed_ne = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>;
|
||||
using typed_lt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>;
|
||||
using typed_le = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>;
|
||||
using typed_gt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>;
|
||||
using typed_ge = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>;
|
||||
|
||||
using bool_eq = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, false>;
|
||||
using bool_ne = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, false>;
|
||||
using bool_lt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, false>;
|
||||
using bool_le = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, false>;
|
||||
using bool_gt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, false>;
|
||||
using bool_ge = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, false>;
|
||||
|
||||
Index rows = m.rows();
|
||||
Index cols = m.cols();
|
||||
|
||||
@@ -603,6 +618,8 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
|
||||
|
||||
m4 = (m4.abs()==Scalar(0)).select(1,m4);
|
||||
|
||||
// use operator overloads with default return type
|
||||
|
||||
VERIFY(((m1 + Scalar(1)) > m1).all());
|
||||
VERIFY(((m1 - Scalar(1)) < m1).all());
|
||||
if (rows*cols>1)
|
||||
@@ -627,6 +644,34 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
|
||||
VERIFY( ( (m1(r,c)+1) > m1).any() );
|
||||
VERIFY( ( m1(r,c) == m1).any() );
|
||||
|
||||
// currently, any() / all() are not vectorized, so use VERIFY_IS_CWISE_EQUAL to test vectorized path
|
||||
|
||||
// use typed comparisons, regardless of operator overload behavior
|
||||
typename ArrayType::ConstantReturnType typed_true = ArrayType::Constant(rows, cols, Scalar(1));
|
||||
// (m1 + Scalar(1)) > m1).all()
|
||||
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, typed_gt()), typed_true);
|
||||
// (m1 - Scalar(1)) < m1).all()
|
||||
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_lt()), typed_true);
|
||||
// (m1 + Scalar(1)) == (m1 + Scalar(1))).all()
|
||||
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), typed_eq()), typed_true);
|
||||
// (m1 - Scalar(1)) != m1).all()
|
||||
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_ne()), typed_true);
|
||||
// (m1 <= m2 || m1 >= m2).all()
|
||||
VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, typed_le()) || m1.binaryExpr(m2, typed_ge()), typed_true);
|
||||
|
||||
// use boolean comparisons, regardless of operator overload behavior
|
||||
ArrayXX<bool>::ConstantReturnType bool_true = ArrayXX<bool>::Constant(rows, cols, true);
|
||||
// (m1 + Scalar(1)) > m1).all()
|
||||
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, bool_gt()), bool_true);
|
||||
// (m1 - Scalar(1)) < m1).all()
|
||||
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_lt()), bool_true);
|
||||
// (m1 + Scalar(1)) == (m1 + Scalar(1))).all()
|
||||
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), bool_eq()), bool_true);
|
||||
// (m1 - Scalar(1)) != m1).all()
|
||||
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_ne()), bool_true);
|
||||
// (m1 <= m2 || m1 >= m2).all()
|
||||
VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, bool_le()) || m1.binaryExpr(m2, bool_ge()), bool_true);
|
||||
|
||||
// test Select
|
||||
VERIFY_IS_APPROX( (m1<m2).select(m1,m2), m1.cwiseMin(m2) );
|
||||
VERIFY_IS_APPROX( (m1>m2).select(m1,m2), m1.cwiseMax(m2) );
|
||||
@@ -642,7 +687,7 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
|
||||
VERIFY_IS_APPROX( (m1.abs()>=ArrayType::Constant(rows,cols,mid))
|
||||
.select(m1,0), m3);
|
||||
// even shorter version:
|
||||
VERIFY_IS_APPROX( (m1.abs()<mid).select(0,m1), m3);
|
||||
VERIFY_IS_APPROX( (m1.abs()<mid).select(0,m1), m3);
|
||||
|
||||
// count
|
||||
VERIFY(((m1.abs()+1)>RealScalar(0.1)).count() == rows*cols);
|
||||
@@ -1039,7 +1084,7 @@ struct typed_logicals_test_impl {
|
||||
using Scalar = typename ArrayType::Scalar;
|
||||
|
||||
static bool scalar_to_bool(const Scalar& x) { return x != Scalar(0); }
|
||||
static Scalar bool_to_scalar(const bool& x) { return x ? Scalar(1) : Scalar(0); }
|
||||
static Scalar bool_to_scalar(bool x) { return x ? Scalar(1) : Scalar(0); }
|
||||
|
||||
static Scalar eval_bool_and(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) && scalar_to_bool(y)); }
|
||||
static Scalar eval_bool_or(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) || scalar_to_bool(y)); }
|
||||
@@ -1091,40 +1136,45 @@ struct typed_logicals_test_impl {
|
||||
m4 = (!m1).binaryExpr((!m2), internal::scalar_boolean_xor_op<Scalar>());
|
||||
VERIFY_IS_CWISE_EQUAL(m3, m4);
|
||||
|
||||
const Index bytes = rows * cols * sizeof(Scalar);
|
||||
const uint8_t* m1_data = reinterpret_cast<const uint8_t*>(m1.data());
|
||||
const uint8_t* m2_data = reinterpret_cast<const uint8_t*>(m2.data());
|
||||
uint8_t* m3_data = reinterpret_cast<uint8_t*>(m3.data());
|
||||
uint8_t* m4_data = reinterpret_cast<uint8_t*>(m4.data());
|
||||
const size_t bytes = size_t(rows) * size_t(cols) * sizeof(Scalar);
|
||||
|
||||
std::vector<uint8_t> m1_buffer(bytes), m2_buffer(bytes), m3_buffer(bytes), m4_buffer(bytes);
|
||||
|
||||
std::memcpy(m1_buffer.data(), m1.data(), bytes);
|
||||
std::memcpy(m2_buffer.data(), m2.data(), bytes);
|
||||
|
||||
// test bitwise and
|
||||
m3 = m1 & m2;
|
||||
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] & m2_data[i];
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
std::memcpy(m3_buffer.data(), m3.data(), bytes);
|
||||
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(m1_buffer[i] & m2_buffer[i]));
|
||||
|
||||
// test bitwise or
|
||||
m3 = m1 | m2;
|
||||
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] | m2_data[i];
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
std::memcpy(m3_buffer.data(), m3.data(), bytes);
|
||||
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(m1_buffer[i] | m2_buffer[i]));
|
||||
|
||||
// test bitwise xor
|
||||
m3 = m1 ^ m2;
|
||||
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] ^ m2_data[i];
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
std::memcpy(m3_buffer.data(), m3.data(), bytes);
|
||||
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(m1_buffer[i] ^ m2_buffer[i]));
|
||||
|
||||
// test bitwise not
|
||||
m3 = ~m1;
|
||||
for (Index i = 0; i < bytes; i++) m4_data[i] = ~m1_data[i];
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
std::memcpy(m3_buffer.data(), m3.data(), bytes);
|
||||
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(~m1_buffer[i]));
|
||||
|
||||
// test something more complicated
|
||||
m3 = m1 & m2;
|
||||
m4 = ~(~m1 | ~m2);
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
std::memcpy(m3_buffer.data(), m3.data(), bytes);
|
||||
std::memcpy(m4_buffer.data(), m4.data(), bytes);
|
||||
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], m4_buffer[i]);
|
||||
|
||||
m3 = m1 ^ m2;
|
||||
m4 = (~m1) ^ (~m2);
|
||||
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
|
||||
std::memcpy(m3_buffer.data(), m3.data(), bytes);
|
||||
std::memcpy(m4_buffer.data(), m4.data(), bytes);
|
||||
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], m4_buffer[i]);
|
||||
}
|
||||
};
|
||||
template <typename ArrayType>
|
||||
@@ -1181,7 +1231,6 @@ EIGEN_DECLARE_TEST(array_cwise)
|
||||
CALL_SUBTEST_8( signbit_tests() );
|
||||
}
|
||||
for (int i = 0; i < g_repeat; i++) {
|
||||
CALL_SUBTEST_1( typed_logicals_test(ArrayX<bool>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_2( typed_logicals_test(ArrayX<int>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_2( typed_logicals_test(ArrayX<float>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_3( typed_logicals_test(ArrayX<double>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||
|
||||
Reference in New Issue
Block a user