diff --git a/Eigen/src/Core/DenseBase.h b/Eigen/src/Core/DenseBase.h index c81e1d109..dcadcbfb5 100644 --- a/Eigen/src/Core/DenseBase.h +++ b/Eigen/src/Core/DenseBase.h @@ -367,11 +367,11 @@ class DenseBase EIGEN_DEVICE_FUNC inline bool allFinite() const; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const Scalar& other); - template ::value, typename = std::enable_if_t> + template ::value, typename = std::enable_if_t> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const RealScalar& other); EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const Scalar& other); - template ::value, typename = std::enable_if_t> + template ::value, typename = std::enable_if_t> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const RealScalar& other); typedef internal::add_const_on_value_type_t::type> EvalReturnType; diff --git a/Eigen/src/Core/Dot.h b/Eigen/src/Core/Dot.h index 059527c85..a173306b4 100644 --- a/Eigen/src/Core/Dot.h +++ b/Eigen/src/Core/Dot.h @@ -20,10 +20,7 @@ namespace internal { template ::Scalar> struct squared_norm_impl { using Real = typename NumTraits::Real; - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Real run(const Derived& a) { - Scalar result = a.unaryExpr(squared_norm_functor()).sum(); - return numext::real(result) + numext::imag(result); - } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Real run(const Derived& a) { return a.realView().cwiseAbs2().sum(); } }; template diff --git a/Eigen/src/Core/RealView.h b/Eigen/src/Core/RealView.h index 7ba42f9a1..3be5556a8 100644 --- a/Eigen/src/Core/RealView.h +++ b/Eigen/src/Core/RealView.h @@ -17,20 +17,16 @@ namespace Eigen { namespace internal { -// Vectorized assignment to RealView requires array-oriented access to the real and imaginary components. +// Write access and vectorization requires array-oriented access to the real and imaginary components. // From https://en.cppreference.com/w/cpp/numeric/complex.html: // For any pointer to an element of an array of std::complex named p and any valid array index i, // reinterpret_cast(p)[2 * i] is the real part of the complex number p[i], and // reinterpret_cast(p)[2 * i + 1] is the imaginary part of the complex number p[i]. -template +template struct complex_array_access : std::false_type {}; -template <> -struct complex_array_access> : std::true_type {}; -template <> -struct complex_array_access> : std::true_type {}; -template <> -struct complex_array_access> : std::true_type {}; +template +struct complex_array_access> : std::true_type {}; template struct traits> : public traits { @@ -40,13 +36,17 @@ struct traits> : public traits { if (size_as_int == Dynamic) return Dynamic; return times_two ? (2 * size_as_int) : size_as_int; } + using Base = traits; using ComplexScalar = typename Base::Scalar; using Scalar = typename NumTraits::Real; - static constexpr int ActualDirectAccessBit = complex_array_access::value ? DirectAccessBit : 0; + + static constexpr bool ArrayAccess = complex_array_access::value; + static constexpr int ActualDirectAccessBit = ArrayAccess ? DirectAccessBit : 0; + static constexpr int ActualLvaluebit = !std::is_const::value && ArrayAccess ? LvalueBit : 0; static constexpr int ActualPacketAccessBit = packet_traits::Vectorizable ? PacketAccessBit : 0; static constexpr int FlagMask = - ActualDirectAccessBit | ActualPacketAccessBit | HereditaryBits | LinearAccessBit | LvalueBit; + ActualDirectAccessBit | ActualLvaluebit | ActualPacketAccessBit | HereditaryBits | LinearAccessBit; static constexpr int BaseFlags = int(evaluator::Flags) | int(Base::Flags); static constexpr int Flags = BaseFlags & FlagMask; static constexpr bool IsRowMajor = Flags & RowMajorBit; @@ -66,68 +66,84 @@ struct evaluator> : private evaluator { using XprType = RealView; using ExpressionTraits = traits; using ComplexScalar = typename ExpressionTraits::ComplexScalar; - using ComplexCoeffReturnType = typename BaseEvaluator::CoeffReturnType; using Scalar = typename ExpressionTraits::Scalar; - static constexpr bool IsRowMajor = ExpressionTraits::IsRowMajor; static constexpr int Flags = ExpressionTraits::Flags; static constexpr int CoeffReadCost = BaseEvaluator::CoeffReadCost; static constexpr int Alignment = BaseEvaluator::Alignment; + static constexpr bool IsRowMajor = ExpressionTraits::IsRowMajor; + static constexpr bool DirectAccess = Flags & DirectAccessBit; + + using ComplexCoeffReturnType = std::conditional_t; + using CoeffReturnType = std::conditional_t; EIGEN_DEVICE_FUNC explicit evaluator(XprType realView) : BaseEvaluator(realView.m_xpr) {} - template ::value, typename = std::enable_if_t> + template = true> constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index row, Index col) const { - ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col); - Index p = (IsRowMajor ? col : row) & 1; - return p ? numext::real(cscalar) : numext::imag(cscalar); + Index r = IsRowMajor ? row : row / 2; + Index c = IsRowMajor ? col / 2 : col; + bool p = (IsRowMajor ? col : row) & 1; + ComplexScalar ccoeff = BaseEvaluator::coeff(r, c); + return p ? numext::imag(ccoeff) : numext::real(ccoeff); } - - template ::value, typename = std::enable_if_t> - constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index row, Index col) const { - ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col); + template = true> + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { + Index r = IsRowMajor ? row : row / 2; + Index c = IsRowMajor ? col / 2 : col; Index p = (IsRowMajor ? col : row) & 1; - return reinterpret_cast(cscalar)[p]; + ComplexCoeffReturnType ccoeff = BaseEvaluator::coeff(r, c); + return reinterpret_cast(ccoeff)[p]; } - - constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) { - ComplexScalar& cscalar = BaseEvaluator::coeffRef(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col); - Index p = (IsRowMajor ? col : row) & 1; - return reinterpret_cast(cscalar)[p]; - } - - template ::value, typename = std::enable_if_t> + template = true> constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const { - ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(index / 2); - Index p = index & 1; - return p ? numext::real(cscalar) : numext::imag(cscalar); + ComplexScalar ccoeff = BaseEvaluator::coeff(index / 2); + bool p = index & 1; + return p ? numext::imag(ccoeff) : numext::real(ccoeff); } - - template ::value, typename = std::enable_if_t> - constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const { - ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(index / 2); + template = true> + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { + ComplexCoeffReturnType ccoeff = BaseEvaluator::coeff(index / 2); Index p = index & 1; - return reinterpret_cast(cscalar)[p]; + return reinterpret_cast(ccoeff)[p]; + } + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) { + Index r = IsRowMajor ? row : row / 2; + Index c = IsRowMajor ? col / 2 : col; + Index p = (IsRowMajor ? col : row) & 1; + ComplexScalar& ccoeffRef = BaseEvaluator::coeffRef(r, c); + return reinterpret_cast(ccoeffRef)[p]; } - constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { - ComplexScalar& cscalar = BaseEvaluator::coeffRef(index / 2); + ComplexScalar& ccoeffRef = BaseEvaluator::coeffRef(index / 2); Index p = index & 1; - return reinterpret_cast(cscalar)[p]; + return reinterpret_cast(ccoeffRef)[p]; } + // If the first index is odd (imaginary), discard the first scalar + // in 'result' and assign the missing scalar. + // This operation is safe as the real component of the first scalar must exist. + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const { constexpr int RealPacketSize = unpacket_traits::size; using ComplexPacket = typename find_packet_by_size::type; EIGEN_STATIC_ASSERT((find_packet_by_size::value), MISSING COMPATIBLE COMPLEX PACKET TYPE) - eigen_assert(((IsRowMajor ? col : row) % 2 == 0) && "the inner index must be even"); - - Index crow = IsRowMajor ? row : row / 2; - Index ccol = IsRowMajor ? col / 2 : col; - ComplexPacket cpacket = BaseEvaluator::template packet(crow, ccol); - return preinterpret(cpacket); + Index r = IsRowMajor ? row : row / 2; + Index c = IsRowMajor ? col / 2 : col; + bool p = (IsRowMajor ? col : row) & 1; + ComplexPacket cresult = BaseEvaluator::template packet(r, c); + PacketType result = preinterpret(cresult); + if (p) { + Scalar aux[RealPacketSize + 1]; + pstoreu(aux, result); + Index lastr = IsRowMajor ? row : row + RealPacketSize - 1; + Index lastc = IsRowMajor ? col + RealPacketSize - 1 : col; + aux[RealPacketSize] = coeff(lastr, lastc); + result = ploadu(aux + 1); + } + return result; } template @@ -136,28 +152,48 @@ struct evaluator> : private evaluator { using ComplexPacket = typename find_packet_by_size::type; EIGEN_STATIC_ASSERT((find_packet_by_size::value), MISSING COMPATIBLE COMPLEX PACKET TYPE) - eigen_assert((index % 2 == 0) && "the index must be even"); - - Index cindex = index / 2; - ComplexPacket cpacket = BaseEvaluator::template packet(cindex); - return preinterpret(cpacket); + ComplexPacket cresult = BaseEvaluator::template packet(index / 2); + PacketType result = preinterpret(cresult); + bool p = index & 1; + if (p) { + Scalar aux[RealPacketSize + 1]; + pstoreu(aux, result); + aux[RealPacketSize] = coeff(index + RealPacketSize - 1); + result = ploadu(aux + 1); + } + return result; } + // The requested real packet segment forms the half-open interval [begin, end), where 'end' = 'begin' + 'count'. + // In order to access the underlying complex array, even indices must be aligned with the real components + // of the complex scalars. 'begin' and 'count' must be modified as follows: + // a) 'begin' must be rounded down to the nearest even number; and + // b) 'end' must be rounded up to the nearest even number. + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index row, Index col, Index begin, Index count) const { constexpr int RealPacketSize = unpacket_traits::size; using ComplexPacket = typename find_packet_by_size::type; EIGEN_STATIC_ASSERT((find_packet_by_size::value), MISSING COMPATIBLE COMPLEX PACKET TYPE) - eigen_assert(((IsRowMajor ? col : row) % 2 == 0) && "the inner index must be even"); - eigen_assert((begin % 2 == 0) && (count % 2 == 0) && "begin and count must be even"); - - Index crow = IsRowMajor ? row : row / 2; - Index ccol = IsRowMajor ? col / 2 : col; - Index cbegin = begin / 2; - Index ccount = count / 2; - ComplexPacket cpacket = BaseEvaluator::template packetSegment(crow, ccol, cbegin, ccount); - return preinterpret(cpacket); + Index actualBegin = numext::round_down(begin, 2); + Index actualEnd = numext::round_down(begin + count + 1, 2); + Index actualCount = actualEnd - actualBegin; + Index r = IsRowMajor ? row : row / 2; + Index c = IsRowMajor ? col / 2 : col; + ComplexPacket cresult = + BaseEvaluator::template packetSegment(r, c, actualBegin / 2, actualCount / 2); + PacketType result = preinterpret(cresult); + bool p = (IsRowMajor ? col : row) & 1; + if (p) { + Scalar aux[RealPacketSize + 1] = {}; + pstoreu(aux, result); + Index lastr = IsRowMajor ? row : row + actualEnd - 1; + Index lastc = IsRowMajor ? col + actualEnd - 1 : col; + aux[actualEnd] = coeff(lastr, lastc); + result = ploadu(aux + 1); + } + return result; } template @@ -166,14 +202,20 @@ struct evaluator> : private evaluator { using ComplexPacket = typename find_packet_by_size::type; EIGEN_STATIC_ASSERT((find_packet_by_size::value), MISSING COMPATIBLE COMPLEX PACKET TYPE) - eigen_assert((index % 2 == 0) && "the index must be even"); - eigen_assert((begin % 2 == 0) && (count % 2 == 0) && "begin and count must be even"); - - Index cindex = index / 2; - Index cbegin = begin / 2; - Index ccount = count / 2; - ComplexPacket cpacket = BaseEvaluator::template packetSegment(cindex, cbegin, ccount); - return preinterpret(cpacket); + Index actualBegin = numext::round_down(begin, 2); + Index actualEnd = numext::round_down(begin + count + 1, 2); + Index actualCount = actualEnd - actualBegin; + ComplexPacket cresult = + BaseEvaluator::template packetSegment(index / 2, actualBegin / 2, actualCount / 2); + PacketType result = preinterpret(cresult); + bool p = index & 1; + if (p) { + Scalar aux[RealPacketSize + 1] = {}; + pstoreu(aux, result); + aux[actualEnd] = coeff(index + actualEnd - 1); + result = ploadu(aux + 1); + } + return result; } }; @@ -211,7 +253,7 @@ class RealView : public internal::dense_xpr_base>::type { EIGEN_DEVICE_FUNC RealView& operator=(const DenseBase& other); protected: - friend struct internal::evaluator>; + friend struct internal::evaluator; Xpr& m_xpr; }; diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 202995ff0..d7fc7bb4d 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -106,26 +106,6 @@ struct functor_traits> { }; }; -template ::IsComplex> -struct squared_norm_functor { - typedef Scalar result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { - return Scalar(numext::real(a) * numext::real(a), numext::imag(a) * numext::imag(a)); - } - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const { - return Packet(pmul(a.v, a.v)); - } -}; -template -struct squared_norm_functor : scalar_abs2_op {}; - -template -struct functor_traits> { - using Real = typename NumTraits::Real; - enum { Cost = NumTraits::MulCost, PacketAccess = packet_traits::HasMul }; -}; - /** \internal * \brief Template functor to compute the conjugate of a complex value * diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index e0bc57eab..e2e71555e 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -517,6 +517,9 @@ struct eigen_zero_impl; template struct has_packet_segment : std::false_type {}; + +template +struct complex_array_access; } // namespace internal } // end namespace Eigen diff --git a/test/realview.cpp b/test/realview.cpp index 8658a3f94..48b45fb17 100644 --- a/test/realview.cpp +++ b/test/realview.cpp @@ -9,8 +9,20 @@ #include "main.h" +// wrapper that disables array-oriented access to real and imaginary components +struct TestComplex : public std::complex { + TestComplex() = default; + TestComplex(const TestComplex&) = default; + TestComplex(std::complex x) : std::complex(x){}; + TestComplex(float x) : std::complex(x){}; +}; +template <> +struct NumTraits : NumTraits> {}; +template <> +struct internal::random_impl : internal::random_impl> {}; + template -void test_realview(const T&) { +void test_realview_readonly(const T&) { using Scalar = typename T::Scalar; using RealScalar = typename NumTraits::Real; @@ -26,21 +38,16 @@ void test_realview(const T&) { Index rows = internal::random(minRows, maxRows); Index cols = internal::random(minCols, maxCols); - T A(rows, cols), B, C; + T A(rows, cols), B(rows, cols); VERIFY(A.realView().rows() == rowFactor * A.rows()); VERIFY(A.realView().cols() == colFactor * A.cols()); VERIFY(A.realView().size() == sizeFactor * A.size()); - RealScalar alpha = internal::random(RealScalar(1), RealScalar(2)); A.setRandom(); + VERIFY_IS_APPROX(A.matrix().cwiseAbs2().sum(), A.realView().matrix().cwiseAbs2().sum()); - VERIFY_IS_APPROX(A.matrix().squaredNorm(), A.realView().matrix().squaredNorm()); - - // test re-sizing realView during assignment - B.realView() = A.realView(); - VERIFY_IS_APPROX(A, B); - VERIFY_IS_APPROX(A.realView(), B.realView()); + RealScalar alpha = internal::random(RealScalar(1), RealScalar(2)); // B = A * alpha for (Index r = 0; r < rows; r++) { @@ -48,14 +55,7 @@ void test_realview(const T&) { B.coeffRef(r, c) = A.coeff(r, c) * Scalar(alpha); } } - - VERIFY_IS_APPROX(B.realView(), A.realView() * alpha); - C = A; - C.realView() *= alpha; - VERIFY_IS_APPROX(B, C); - - alpha = internal::random(RealScalar(1), RealScalar(2)); - A.setRandom(); + VERIFY_IS_CWISE_APPROX(B.realView(), A.realView() * alpha); // B = A / alpha for (Index r = 0; r < rows; r++) { @@ -63,15 +63,155 @@ void test_realview(const T&) { B.coeffRef(r, c) = A.coeff(r, c) / Scalar(alpha); } } + VERIFY_IS_CWISE_APPROX(B.realView(), A.realView() / alpha); +} +template +void test_realview(const T&) { + using Scalar = typename T::Scalar; + using RealScalar = typename NumTraits::Real; + + constexpr Index minRows = T::RowsAtCompileTime == Dynamic ? 1 : T::RowsAtCompileTime; + constexpr Index maxRows = T::MaxRowsAtCompileTime == Dynamic ? (EIGEN_TEST_MAX_SIZE / 2) : T::MaxRowsAtCompileTime; + constexpr Index minCols = T::ColsAtCompileTime == Dynamic ? 1 : T::ColsAtCompileTime; + constexpr Index maxCols = T::MaxColsAtCompileTime == Dynamic ? (EIGEN_TEST_MAX_SIZE / 2) : T::MaxColsAtCompileTime; + + constexpr Index rowFactor = (NumTraits::IsComplex && !T::IsRowMajor) ? 2 : 1; + constexpr Index colFactor = (NumTraits::IsComplex && T::IsRowMajor) ? 2 : 1; + constexpr Index sizeFactor = NumTraits::IsComplex ? 2 : 1; + + const Index rows = internal::random(minRows, maxRows); + const Index cols = internal::random(minCols, maxCols); + const Index realViewRows = rowFactor * rows; + const Index realViewCols = colFactor * cols; + + const T A = T::Random(rows, cols); + T B; + + VERIFY_IS_EQUAL(A.realView().rows(), rowFactor * A.rows()); + VERIFY_IS_EQUAL(A.realView().cols(), colFactor * A.cols()); + VERIFY_IS_EQUAL(A.realView().size(), sizeFactor * A.size()); + + VERIFY_IS_APPROX(A.matrix().cwiseAbs2().sum(), A.realView().matrix().cwiseAbs2().sum()); + + // test re-sizing realView during assignment + B.realView() = A.realView(); + VERIFY_IS_APPROX(A, B); + VERIFY_IS_APPROX(A.realView(), B.realView()); + + const RealScalar alpha = internal::random(RealScalar(1), RealScalar(2)); + + // B = A * alpha + for (Index r = 0; r < rows; r++) { + for (Index c = 0; c < cols; c++) { + B.coeffRef(r, c) = A.coeff(r, c) * Scalar(alpha); + } + } + VERIFY_IS_APPROX(B.realView(), A.realView() * alpha); + + B = A; + B.realView() *= alpha; + VERIFY_IS_APPROX(B.realView(), A.realView() * alpha); + + // B = A / alpha + for (Index r = 0; r < rows; r++) { + for (Index c = 0; c < cols; c++) { + B.coeffRef(r, c) = A.coeff(r, c) / Scalar(alpha); + } + } VERIFY_IS_APPROX(B.realView(), A.realView() / alpha); + + B = A; + B.realView() /= alpha; + VERIFY_IS_APPROX(B.realView(), A.realView() / alpha); + + // force some usual access patterns + Index malloc_size = (rows * cols * sizeof(Scalar)) + sizeof(RealScalar); + void* data1 = internal::aligned_malloc(malloc_size); + void* data2 = internal::aligned_malloc(malloc_size); + Scalar* ptr1 = reinterpret_cast(reinterpret_cast(data1) + sizeof(RealScalar)); + Scalar* ptr2 = reinterpret_cast(reinterpret_cast(data2) + sizeof(RealScalar)); + Map C(ptr1, rows, cols), D(ptr2, rows, cols); + + C.setRandom(); + D.setRandom(); + for (Index r = 0; r < realViewRows; r++) { + for (Index c = 0; c < realViewCols; c++) { + C.realView().coeffRef(r, c) = D.realView().coeff(r, c); + } + } + VERIFY_IS_CWISE_EQUAL(C, D); + C = A; - C.realView() /= alpha; - VERIFY_IS_APPROX(B, C); + + for (Index c = 0; c < realViewCols - 1; c++) { + B.realView().row(0).coeffRef(realViewCols - 1 - c) = C.realView().row(0).coeff(c + 1); + } + D.realView().row(0).tail(realViewCols - 1) = C.realView().row(0).tail(realViewCols - 1).reverse(); + VERIFY_IS_CWISE_EQUAL(B.realView().row(0).tail(realViewCols - 1), D.realView().row(0).tail(realViewCols - 1)); + + for (Index r = 0; r < realViewRows - 1; r++) { + B.realView().col(0).coeffRef(realViewRows - 1 - r) = C.realView().col(0).coeff(r + 1); + } + D.realView().col(0).tail(realViewRows - 1) = C.realView().col(0).tail(realViewRows - 1).reverse(); + VERIFY_IS_CWISE_EQUAL(B.realView().col(0).tail(realViewRows - 1), D.realView().col(0).tail(realViewRows - 1)); +} + +template ::Vectorizable> +struct test_edge_cases_impl { + static void run() { + using namespace internal; + using RealScalar = typename NumTraits::Real; + using ComplexPacket = typename packet_traits::type; + using RealPacket = typename unpacket_traits::as_real; + constexpr int ComplexSize = unpacket_traits::size; + constexpr int RealSize = 2 * ComplexSize; + VectorX a_data(2 * ComplexSize); + Map> a_data_asreal(reinterpret_cast(a_data.data()), 2 * a_data.size()); + VectorX b_data(RealSize); + + a_data.setRandom(); + evaluator>> eval(a_data.realView()); + + for (Index offset = 0; offset < RealSize; offset++) { + for (Index begin = 0; offset + begin < RealSize; begin++) { + for (Index count = 0; begin + count < RealSize; count++) { + b_data.setRandom(); + RealPacket res = eval.packetSegment(offset, begin, count); + pstoreSegment(b_data.data(), res, begin, count); + VERIFY_IS_CWISE_EQUAL(a_data_asreal.segment(offset + begin, count), b_data.segment(begin, count)); + } + } + } + } +}; + +template +struct test_edge_cases_impl { + static void run() {} +}; + +template +void test_edge_cases(const ComplexScalar&) { + test_edge_cases_impl::run(); } template -void test_realview_driver() { +void test_realview_readonly() { + // if Rows == 1, don't test ColMajor as it is not a valid array + using ColMajorMatrixType = Matrix; + using ColMajorArrayType = Array; + // if Cols == 1, don't test RowMajor as it is not a valid array + using RowMajorMatrixType = Matrix; + using RowMajorArrayType = Array; + test_realview_readonly(ColMajorMatrixType()); + test_realview_readonly(ColMajorArrayType()); + test_realview_readonly(RowMajorMatrixType()); + test_realview_readonly(RowMajorArrayType()); +} + +template +void test_realview_readwrite() { // if Rows == 1, don't test ColMajor as it is not a valid array using ColMajorMatrixType = Matrix; using ColMajorArrayType = Array; @@ -85,26 +225,29 @@ void test_realview_driver() { } template -void test_realview_driver_complex() { - test_realview_driver(); - test_realview_driver, Rows, Cols, MaxRows, MaxCols>(); - test_realview_driver(); - test_realview_driver, Rows, Cols, MaxRows, MaxCols>(); - test_realview_driver(); - test_realview_driver, Rows, Cols, MaxRows, MaxCols>(); +void test_realview() { + test_realview_readwrite(); + test_realview_readwrite, Rows, Cols, MaxRows, MaxCols>(); + test_realview_readwrite(); + test_realview_readwrite, Rows, Cols, MaxRows, MaxCols>(); + test_realview_readwrite(); + test_realview_readwrite, Rows, Cols, MaxRows, MaxCols>(); + test_realview_readonly(); } EIGEN_DECLARE_TEST(realview) { for (int i = 0; i < g_repeat; i++) { - CALL_SUBTEST_1((test_realview_driver_complex())); - CALL_SUBTEST_2((test_realview_driver_complex())); - CALL_SUBTEST_3((test_realview_driver_complex())); - CALL_SUBTEST_4((test_realview_driver_complex())); - CALL_SUBTEST_5((test_realview_driver_complex<17, Dynamic, 17, Dynamic>())); - CALL_SUBTEST_6((test_realview_driver_complex())); - CALL_SUBTEST_7((test_realview_driver_complex<17, 19, 17, 19>())); - CALL_SUBTEST_8((test_realview_driver_complex())); - CALL_SUBTEST_9((test_realview_driver_complex<1, Dynamic>())); - CALL_SUBTEST_10((test_realview_driver_complex<1, 1>())); + CALL_SUBTEST_1((test_realview())); + CALL_SUBTEST_2((test_realview())); + CALL_SUBTEST_3((test_realview())); + CALL_SUBTEST_4((test_realview())); + CALL_SUBTEST_5((test_realview<17, Dynamic, 17, Dynamic>())); + CALL_SUBTEST_6((test_realview())); + CALL_SUBTEST_7((test_realview<17, 19, 17, 19>())); + CALL_SUBTEST_8((test_realview())); + CALL_SUBTEST_9((test_realview<1, Dynamic>())); + CALL_SUBTEST_10((test_realview<1, 1>())); + CALL_SUBTEST_11(test_edge_cases(std::complex())); + CALL_SUBTEST_12(test_edge_cases(std::complex())); } }