mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
@@ -367,11 +367,11 @@ class DenseBase
|
||||
EIGEN_DEVICE_FUNC inline bool allFinite() const;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const Scalar& other);
|
||||
template <bool Enable = !internal::is_same<Scalar, RealScalar>::value, typename = std::enable_if_t<Enable>>
|
||||
template <bool Enable = internal::complex_array_access<Scalar>::value, typename = std::enable_if_t<Enable>>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const RealScalar& other);
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const Scalar& other);
|
||||
template <bool Enable = !internal::is_same<Scalar, RealScalar>::value, typename = std::enable_if_t<Enable>>
|
||||
template <bool Enable = internal::complex_array_access<Scalar>::value, typename = std::enable_if_t<Enable>>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const RealScalar& other);
|
||||
|
||||
typedef internal::add_const_on_value_type_t<typename internal::eval<Derived>::type> EvalReturnType;
|
||||
|
||||
@@ -20,10 +20,7 @@ namespace internal {
|
||||
template <typename Derived, typename Scalar = typename traits<Derived>::Scalar>
|
||||
struct squared_norm_impl {
|
||||
using Real = typename NumTraits<Scalar>::Real;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Real run(const Derived& a) {
|
||||
Scalar result = a.unaryExpr(squared_norm_functor<Scalar>()).sum();
|
||||
return numext::real(result) + numext::imag(result);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Real run(const Derived& a) { return a.realView().cwiseAbs2().sum(); }
|
||||
};
|
||||
|
||||
template <typename Derived>
|
||||
|
||||
@@ -17,20 +17,16 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Vectorized assignment to RealView requires array-oriented access to the real and imaginary components.
|
||||
// Write access and vectorization requires array-oriented access to the real and imaginary components.
|
||||
// From https://en.cppreference.com/w/cpp/numeric/complex.html:
|
||||
// For any pointer to an element of an array of std::complex<T> named p and any valid array index i,
|
||||
// reinterpret_cast<T*>(p)[2 * i] is the real part of the complex number p[i], and
|
||||
// reinterpret_cast<T*>(p)[2 * i + 1] is the imaginary part of the complex number p[i].
|
||||
|
||||
template <typename ComplexScalar>
|
||||
template <typename T>
|
||||
struct complex_array_access : std::false_type {};
|
||||
template <>
|
||||
struct complex_array_access<std::complex<float>> : std::true_type {};
|
||||
template <>
|
||||
struct complex_array_access<std::complex<double>> : std::true_type {};
|
||||
template <>
|
||||
struct complex_array_access<std::complex<long double>> : std::true_type {};
|
||||
template <typename T>
|
||||
struct complex_array_access<std::complex<T>> : std::true_type {};
|
||||
|
||||
template <typename Xpr>
|
||||
struct traits<RealView<Xpr>> : public traits<Xpr> {
|
||||
@@ -40,13 +36,17 @@ struct traits<RealView<Xpr>> : public traits<Xpr> {
|
||||
if (size_as_int == Dynamic) return Dynamic;
|
||||
return times_two ? (2 * size_as_int) : size_as_int;
|
||||
}
|
||||
|
||||
using Base = traits<Xpr>;
|
||||
using ComplexScalar = typename Base::Scalar;
|
||||
using Scalar = typename NumTraits<ComplexScalar>::Real;
|
||||
static constexpr int ActualDirectAccessBit = complex_array_access<ComplexScalar>::value ? DirectAccessBit : 0;
|
||||
|
||||
static constexpr bool ArrayAccess = complex_array_access<ComplexScalar>::value;
|
||||
static constexpr int ActualDirectAccessBit = ArrayAccess ? DirectAccessBit : 0;
|
||||
static constexpr int ActualLvaluebit = !std::is_const<Xpr>::value && ArrayAccess ? LvalueBit : 0;
|
||||
static constexpr int ActualPacketAccessBit = packet_traits<Scalar>::Vectorizable ? PacketAccessBit : 0;
|
||||
static constexpr int FlagMask =
|
||||
ActualDirectAccessBit | ActualPacketAccessBit | HereditaryBits | LinearAccessBit | LvalueBit;
|
||||
ActualDirectAccessBit | ActualLvaluebit | ActualPacketAccessBit | HereditaryBits | LinearAccessBit;
|
||||
static constexpr int BaseFlags = int(evaluator<Xpr>::Flags) | int(Base::Flags);
|
||||
static constexpr int Flags = BaseFlags & FlagMask;
|
||||
static constexpr bool IsRowMajor = Flags & RowMajorBit;
|
||||
@@ -66,68 +66,84 @@ struct evaluator<RealView<Xpr>> : private evaluator<Xpr> {
|
||||
using XprType = RealView<Xpr>;
|
||||
using ExpressionTraits = traits<XprType>;
|
||||
using ComplexScalar = typename ExpressionTraits::ComplexScalar;
|
||||
using ComplexCoeffReturnType = typename BaseEvaluator::CoeffReturnType;
|
||||
using Scalar = typename ExpressionTraits::Scalar;
|
||||
|
||||
static constexpr bool IsRowMajor = ExpressionTraits::IsRowMajor;
|
||||
static constexpr int Flags = ExpressionTraits::Flags;
|
||||
static constexpr int CoeffReadCost = BaseEvaluator::CoeffReadCost;
|
||||
static constexpr int Alignment = BaseEvaluator::Alignment;
|
||||
static constexpr bool IsRowMajor = ExpressionTraits::IsRowMajor;
|
||||
static constexpr bool DirectAccess = Flags & DirectAccessBit;
|
||||
|
||||
using ComplexCoeffReturnType = std::conditional_t<DirectAccess, const ComplexScalar&, ComplexScalar>;
|
||||
using CoeffReturnType = std::conditional_t<DirectAccess, const Scalar&, Scalar>;
|
||||
|
||||
EIGEN_DEVICE_FUNC explicit evaluator(XprType realView) : BaseEvaluator(realView.m_xpr) {}
|
||||
|
||||
template <bool Enable = std::is_reference<ComplexCoeffReturnType>::value, typename = std::enable_if_t<!Enable>>
|
||||
template <bool Enable = DirectAccess, std::enable_if_t<!Enable, bool> = true>
|
||||
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index row, Index col) const {
|
||||
ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col);
|
||||
Index p = (IsRowMajor ? col : row) & 1;
|
||||
return p ? numext::real(cscalar) : numext::imag(cscalar);
|
||||
Index r = IsRowMajor ? row : row / 2;
|
||||
Index c = IsRowMajor ? col / 2 : col;
|
||||
bool p = (IsRowMajor ? col : row) & 1;
|
||||
ComplexScalar ccoeff = BaseEvaluator::coeff(r, c);
|
||||
return p ? numext::imag(ccoeff) : numext::real(ccoeff);
|
||||
}
|
||||
|
||||
template <bool Enable = std::is_reference<ComplexCoeffReturnType>::value, typename = std::enable_if_t<Enable>>
|
||||
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index row, Index col) const {
|
||||
ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col);
|
||||
template <bool Enable = DirectAccess, std::enable_if_t<Enable, bool> = true>
|
||||
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const {
|
||||
Index r = IsRowMajor ? row : row / 2;
|
||||
Index c = IsRowMajor ? col / 2 : col;
|
||||
Index p = (IsRowMajor ? col : row) & 1;
|
||||
return reinterpret_cast<const Scalar(&)[2]>(cscalar)[p];
|
||||
ComplexCoeffReturnType ccoeff = BaseEvaluator::coeff(r, c);
|
||||
return reinterpret_cast<const Scalar(&)[2]>(ccoeff)[p];
|
||||
}
|
||||
|
||||
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
|
||||
ComplexScalar& cscalar = BaseEvaluator::coeffRef(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col);
|
||||
Index p = (IsRowMajor ? col : row) & 1;
|
||||
return reinterpret_cast<Scalar(&)[2]>(cscalar)[p];
|
||||
}
|
||||
|
||||
template <bool Enable = std::is_reference<ComplexCoeffReturnType>::value, typename = std::enable_if_t<!Enable>>
|
||||
template <bool Enable = DirectAccess, std::enable_if_t<!Enable, bool> = true>
|
||||
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const {
|
||||
ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(index / 2);
|
||||
Index p = index & 1;
|
||||
return p ? numext::real(cscalar) : numext::imag(cscalar);
|
||||
ComplexScalar ccoeff = BaseEvaluator::coeff(index / 2);
|
||||
bool p = index & 1;
|
||||
return p ? numext::imag(ccoeff) : numext::real(ccoeff);
|
||||
}
|
||||
|
||||
template <bool Enable = std::is_reference<ComplexCoeffReturnType>::value, typename = std::enable_if_t<Enable>>
|
||||
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const {
|
||||
ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(index / 2);
|
||||
template <bool Enable = DirectAccess, std::enable_if_t<Enable, bool> = true>
|
||||
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
|
||||
ComplexCoeffReturnType ccoeff = BaseEvaluator::coeff(index / 2);
|
||||
Index p = index & 1;
|
||||
return reinterpret_cast<const Scalar(&)[2]>(cscalar)[p];
|
||||
return reinterpret_cast<const Scalar(&)[2]>(ccoeff)[p];
|
||||
}
|
||||
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
|
||||
Index r = IsRowMajor ? row : row / 2;
|
||||
Index c = IsRowMajor ? col / 2 : col;
|
||||
Index p = (IsRowMajor ? col : row) & 1;
|
||||
ComplexScalar& ccoeffRef = BaseEvaluator::coeffRef(r, c);
|
||||
return reinterpret_cast<Scalar(&)[2]>(ccoeffRef)[p];
|
||||
}
|
||||
|
||||
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
|
||||
ComplexScalar& cscalar = BaseEvaluator::coeffRef(index / 2);
|
||||
ComplexScalar& ccoeffRef = BaseEvaluator::coeffRef(index / 2);
|
||||
Index p = index & 1;
|
||||
return reinterpret_cast<Scalar(&)[2]>(cscalar)[p];
|
||||
return reinterpret_cast<Scalar(&)[2]>(ccoeffRef)[p];
|
||||
}
|
||||
|
||||
// If the first index is odd (imaginary), discard the first scalar
|
||||
// in 'result' and assign the missing scalar.
|
||||
// This operation is safe as the real component of the first scalar must exist.
|
||||
|
||||
template <int LoadMode, typename PacketType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const {
|
||||
constexpr int RealPacketSize = unpacket_traits<PacketType>::size;
|
||||
using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type;
|
||||
EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value),
|
||||
MISSING COMPATIBLE COMPLEX PACKET TYPE)
|
||||
eigen_assert(((IsRowMajor ? col : row) % 2 == 0) && "the inner index must be even");
|
||||
|
||||
Index crow = IsRowMajor ? row : row / 2;
|
||||
Index ccol = IsRowMajor ? col / 2 : col;
|
||||
ComplexPacket cpacket = BaseEvaluator::template packet<LoadMode, ComplexPacket>(crow, ccol);
|
||||
return preinterpret<PacketType, ComplexPacket>(cpacket);
|
||||
Index r = IsRowMajor ? row : row / 2;
|
||||
Index c = IsRowMajor ? col / 2 : col;
|
||||
bool p = (IsRowMajor ? col : row) & 1;
|
||||
ComplexPacket cresult = BaseEvaluator::template packet<LoadMode, ComplexPacket>(r, c);
|
||||
PacketType result = preinterpret<PacketType>(cresult);
|
||||
if (p) {
|
||||
Scalar aux[RealPacketSize + 1];
|
||||
pstoreu(aux, result);
|
||||
Index lastr = IsRowMajor ? row : row + RealPacketSize - 1;
|
||||
Index lastc = IsRowMajor ? col + RealPacketSize - 1 : col;
|
||||
aux[RealPacketSize] = coeff(lastr, lastc);
|
||||
result = ploadu<PacketType>(aux + 1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <int LoadMode, typename PacketType>
|
||||
@@ -136,28 +152,48 @@ struct evaluator<RealView<Xpr>> : private evaluator<Xpr> {
|
||||
using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type;
|
||||
EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value),
|
||||
MISSING COMPATIBLE COMPLEX PACKET TYPE)
|
||||
eigen_assert((index % 2 == 0) && "the index must be even");
|
||||
|
||||
Index cindex = index / 2;
|
||||
ComplexPacket cpacket = BaseEvaluator::template packet<LoadMode, ComplexPacket>(cindex);
|
||||
return preinterpret<PacketType, ComplexPacket>(cpacket);
|
||||
ComplexPacket cresult = BaseEvaluator::template packet<LoadMode, ComplexPacket>(index / 2);
|
||||
PacketType result = preinterpret<PacketType>(cresult);
|
||||
bool p = index & 1;
|
||||
if (p) {
|
||||
Scalar aux[RealPacketSize + 1];
|
||||
pstoreu(aux, result);
|
||||
aux[RealPacketSize] = coeff(index + RealPacketSize - 1);
|
||||
result = ploadu<PacketType>(aux + 1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// The requested real packet segment forms the half-open interval [begin, end), where 'end' = 'begin' + 'count'.
|
||||
// In order to access the underlying complex array, even indices must be aligned with the real components
|
||||
// of the complex scalars. 'begin' and 'count' must be modified as follows:
|
||||
// a) 'begin' must be rounded down to the nearest even number; and
|
||||
// b) 'end' must be rounded up to the nearest even number.
|
||||
|
||||
template <int LoadMode, typename PacketType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index row, Index col, Index begin, Index count) const {
|
||||
constexpr int RealPacketSize = unpacket_traits<PacketType>::size;
|
||||
using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type;
|
||||
EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value),
|
||||
MISSING COMPATIBLE COMPLEX PACKET TYPE)
|
||||
eigen_assert(((IsRowMajor ? col : row) % 2 == 0) && "the inner index must be even");
|
||||
eigen_assert((begin % 2 == 0) && (count % 2 == 0) && "begin and count must be even");
|
||||
|
||||
Index crow = IsRowMajor ? row : row / 2;
|
||||
Index ccol = IsRowMajor ? col / 2 : col;
|
||||
Index cbegin = begin / 2;
|
||||
Index ccount = count / 2;
|
||||
ComplexPacket cpacket = BaseEvaluator::template packetSegment<LoadMode, ComplexPacket>(crow, ccol, cbegin, ccount);
|
||||
return preinterpret<PacketType, ComplexPacket>(cpacket);
|
||||
Index actualBegin = numext::round_down(begin, 2);
|
||||
Index actualEnd = numext::round_down(begin + count + 1, 2);
|
||||
Index actualCount = actualEnd - actualBegin;
|
||||
Index r = IsRowMajor ? row : row / 2;
|
||||
Index c = IsRowMajor ? col / 2 : col;
|
||||
ComplexPacket cresult =
|
||||
BaseEvaluator::template packetSegment<LoadMode, ComplexPacket>(r, c, actualBegin / 2, actualCount / 2);
|
||||
PacketType result = preinterpret<PacketType>(cresult);
|
||||
bool p = (IsRowMajor ? col : row) & 1;
|
||||
if (p) {
|
||||
Scalar aux[RealPacketSize + 1] = {};
|
||||
pstoreu(aux, result);
|
||||
Index lastr = IsRowMajor ? row : row + actualEnd - 1;
|
||||
Index lastc = IsRowMajor ? col + actualEnd - 1 : col;
|
||||
aux[actualEnd] = coeff(lastr, lastc);
|
||||
result = ploadu<PacketType>(aux + 1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <int LoadMode, typename PacketType>
|
||||
@@ -166,14 +202,20 @@ struct evaluator<RealView<Xpr>> : private evaluator<Xpr> {
|
||||
using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type;
|
||||
EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value),
|
||||
MISSING COMPATIBLE COMPLEX PACKET TYPE)
|
||||
eigen_assert((index % 2 == 0) && "the index must be even");
|
||||
eigen_assert((begin % 2 == 0) && (count % 2 == 0) && "begin and count must be even");
|
||||
|
||||
Index cindex = index / 2;
|
||||
Index cbegin = begin / 2;
|
||||
Index ccount = count / 2;
|
||||
ComplexPacket cpacket = BaseEvaluator::template packetSegment<LoadMode, ComplexPacket>(cindex, cbegin, ccount);
|
||||
return preinterpret<PacketType, ComplexPacket>(cpacket);
|
||||
Index actualBegin = numext::round_down(begin, 2);
|
||||
Index actualEnd = numext::round_down(begin + count + 1, 2);
|
||||
Index actualCount = actualEnd - actualBegin;
|
||||
ComplexPacket cresult =
|
||||
BaseEvaluator::template packetSegment<LoadMode, ComplexPacket>(index / 2, actualBegin / 2, actualCount / 2);
|
||||
PacketType result = preinterpret<PacketType>(cresult);
|
||||
bool p = index & 1;
|
||||
if (p) {
|
||||
Scalar aux[RealPacketSize + 1] = {};
|
||||
pstoreu(aux, result);
|
||||
aux[actualEnd] = coeff(index + actualEnd - 1);
|
||||
result = ploadu<PacketType>(aux + 1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -211,7 +253,7 @@ class RealView : public internal::dense_xpr_base<RealView<Xpr>>::type {
|
||||
EIGEN_DEVICE_FUNC RealView& operator=(const DenseBase<OtherDerived>& other);
|
||||
|
||||
protected:
|
||||
friend struct internal::evaluator<RealView<Xpr>>;
|
||||
friend struct internal::evaluator<RealView>;
|
||||
Xpr& m_xpr;
|
||||
};
|
||||
|
||||
|
||||
@@ -106,26 +106,6 @@ struct functor_traits<scalar_abs2_op<Scalar>> {
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
|
||||
struct squared_norm_functor {
|
||||
typedef Scalar result_type;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
|
||||
return Scalar(numext::real(a) * numext::real(a), numext::imag(a) * numext::imag(a));
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const {
|
||||
return Packet(pmul(a.v, a.v));
|
||||
}
|
||||
};
|
||||
template <typename Scalar>
|
||||
struct squared_norm_functor<Scalar, false> : scalar_abs2_op<Scalar> {};
|
||||
|
||||
template <typename Scalar>
|
||||
struct functor_traits<squared_norm_functor<Scalar>> {
|
||||
using Real = typename NumTraits<Scalar>::Real;
|
||||
enum { Cost = NumTraits<Real>::MulCost, PacketAccess = packet_traits<Real>::HasMul };
|
||||
};
|
||||
|
||||
/** \internal
|
||||
* \brief Template functor to compute the conjugate of a complex value
|
||||
*
|
||||
|
||||
@@ -517,6 +517,9 @@ struct eigen_zero_impl;
|
||||
|
||||
template <typename Packet>
|
||||
struct has_packet_segment : std::false_type {};
|
||||
|
||||
template <typename T>
|
||||
struct complex_array_access;
|
||||
} // namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
@@ -9,8 +9,20 @@
|
||||
|
||||
#include "main.h"
|
||||
|
||||
// wrapper that disables array-oriented access to real and imaginary components
|
||||
struct TestComplex : public std::complex<float> {
|
||||
TestComplex() = default;
|
||||
TestComplex(const TestComplex&) = default;
|
||||
TestComplex(std::complex<float> x) : std::complex<float>(x){};
|
||||
TestComplex(float x) : std::complex<float>(x){};
|
||||
};
|
||||
template <>
|
||||
struct NumTraits<TestComplex> : NumTraits<std::complex<float>> {};
|
||||
template <>
|
||||
struct internal::random_impl<TestComplex> : internal::random_impl<std::complex<float>> {};
|
||||
|
||||
template <typename T>
|
||||
void test_realview(const T&) {
|
||||
void test_realview_readonly(const T&) {
|
||||
using Scalar = typename T::Scalar;
|
||||
using RealScalar = typename NumTraits<Scalar>::Real;
|
||||
|
||||
@@ -26,21 +38,16 @@ void test_realview(const T&) {
|
||||
Index rows = internal::random<Index>(minRows, maxRows);
|
||||
Index cols = internal::random<Index>(minCols, maxCols);
|
||||
|
||||
T A(rows, cols), B, C;
|
||||
T A(rows, cols), B(rows, cols);
|
||||
|
||||
VERIFY(A.realView().rows() == rowFactor * A.rows());
|
||||
VERIFY(A.realView().cols() == colFactor * A.cols());
|
||||
VERIFY(A.realView().size() == sizeFactor * A.size());
|
||||
|
||||
RealScalar alpha = internal::random(RealScalar(1), RealScalar(2));
|
||||
A.setRandom();
|
||||
VERIFY_IS_APPROX(A.matrix().cwiseAbs2().sum(), A.realView().matrix().cwiseAbs2().sum());
|
||||
|
||||
VERIFY_IS_APPROX(A.matrix().squaredNorm(), A.realView().matrix().squaredNorm());
|
||||
|
||||
// test re-sizing realView during assignment
|
||||
B.realView() = A.realView();
|
||||
VERIFY_IS_APPROX(A, B);
|
||||
VERIFY_IS_APPROX(A.realView(), B.realView());
|
||||
RealScalar alpha = internal::random(RealScalar(1), RealScalar(2));
|
||||
|
||||
// B = A * alpha
|
||||
for (Index r = 0; r < rows; r++) {
|
||||
@@ -48,14 +55,7 @@ void test_realview(const T&) {
|
||||
B.coeffRef(r, c) = A.coeff(r, c) * Scalar(alpha);
|
||||
}
|
||||
}
|
||||
|
||||
VERIFY_IS_APPROX(B.realView(), A.realView() * alpha);
|
||||
C = A;
|
||||
C.realView() *= alpha;
|
||||
VERIFY_IS_APPROX(B, C);
|
||||
|
||||
alpha = internal::random(RealScalar(1), RealScalar(2));
|
||||
A.setRandom();
|
||||
VERIFY_IS_CWISE_APPROX(B.realView(), A.realView() * alpha);
|
||||
|
||||
// B = A / alpha
|
||||
for (Index r = 0; r < rows; r++) {
|
||||
@@ -63,15 +63,155 @@ void test_realview(const T&) {
|
||||
B.coeffRef(r, c) = A.coeff(r, c) / Scalar(alpha);
|
||||
}
|
||||
}
|
||||
VERIFY_IS_CWISE_APPROX(B.realView(), A.realView() / alpha);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void test_realview(const T&) {
|
||||
using Scalar = typename T::Scalar;
|
||||
using RealScalar = typename NumTraits<Scalar>::Real;
|
||||
|
||||
constexpr Index minRows = T::RowsAtCompileTime == Dynamic ? 1 : T::RowsAtCompileTime;
|
||||
constexpr Index maxRows = T::MaxRowsAtCompileTime == Dynamic ? (EIGEN_TEST_MAX_SIZE / 2) : T::MaxRowsAtCompileTime;
|
||||
constexpr Index minCols = T::ColsAtCompileTime == Dynamic ? 1 : T::ColsAtCompileTime;
|
||||
constexpr Index maxCols = T::MaxColsAtCompileTime == Dynamic ? (EIGEN_TEST_MAX_SIZE / 2) : T::MaxColsAtCompileTime;
|
||||
|
||||
constexpr Index rowFactor = (NumTraits<Scalar>::IsComplex && !T::IsRowMajor) ? 2 : 1;
|
||||
constexpr Index colFactor = (NumTraits<Scalar>::IsComplex && T::IsRowMajor) ? 2 : 1;
|
||||
constexpr Index sizeFactor = NumTraits<Scalar>::IsComplex ? 2 : 1;
|
||||
|
||||
const Index rows = internal::random<Index>(minRows, maxRows);
|
||||
const Index cols = internal::random<Index>(minCols, maxCols);
|
||||
const Index realViewRows = rowFactor * rows;
|
||||
const Index realViewCols = colFactor * cols;
|
||||
|
||||
const T A = T::Random(rows, cols);
|
||||
T B;
|
||||
|
||||
VERIFY_IS_EQUAL(A.realView().rows(), rowFactor * A.rows());
|
||||
VERIFY_IS_EQUAL(A.realView().cols(), colFactor * A.cols());
|
||||
VERIFY_IS_EQUAL(A.realView().size(), sizeFactor * A.size());
|
||||
|
||||
VERIFY_IS_APPROX(A.matrix().cwiseAbs2().sum(), A.realView().matrix().cwiseAbs2().sum());
|
||||
|
||||
// test re-sizing realView during assignment
|
||||
B.realView() = A.realView();
|
||||
VERIFY_IS_APPROX(A, B);
|
||||
VERIFY_IS_APPROX(A.realView(), B.realView());
|
||||
|
||||
const RealScalar alpha = internal::random(RealScalar(1), RealScalar(2));
|
||||
|
||||
// B = A * alpha
|
||||
for (Index r = 0; r < rows; r++) {
|
||||
for (Index c = 0; c < cols; c++) {
|
||||
B.coeffRef(r, c) = A.coeff(r, c) * Scalar(alpha);
|
||||
}
|
||||
}
|
||||
VERIFY_IS_APPROX(B.realView(), A.realView() * alpha);
|
||||
|
||||
B = A;
|
||||
B.realView() *= alpha;
|
||||
VERIFY_IS_APPROX(B.realView(), A.realView() * alpha);
|
||||
|
||||
// B = A / alpha
|
||||
for (Index r = 0; r < rows; r++) {
|
||||
for (Index c = 0; c < cols; c++) {
|
||||
B.coeffRef(r, c) = A.coeff(r, c) / Scalar(alpha);
|
||||
}
|
||||
}
|
||||
VERIFY_IS_APPROX(B.realView(), A.realView() / alpha);
|
||||
|
||||
B = A;
|
||||
B.realView() /= alpha;
|
||||
VERIFY_IS_APPROX(B.realView(), A.realView() / alpha);
|
||||
|
||||
// force some usual access patterns
|
||||
Index malloc_size = (rows * cols * sizeof(Scalar)) + sizeof(RealScalar);
|
||||
void* data1 = internal::aligned_malloc(malloc_size);
|
||||
void* data2 = internal::aligned_malloc(malloc_size);
|
||||
Scalar* ptr1 = reinterpret_cast<Scalar*>(reinterpret_cast<uint8_t*>(data1) + sizeof(RealScalar));
|
||||
Scalar* ptr2 = reinterpret_cast<Scalar*>(reinterpret_cast<uint8_t*>(data2) + sizeof(RealScalar));
|
||||
Map<T> C(ptr1, rows, cols), D(ptr2, rows, cols);
|
||||
|
||||
C.setRandom();
|
||||
D.setRandom();
|
||||
for (Index r = 0; r < realViewRows; r++) {
|
||||
for (Index c = 0; c < realViewCols; c++) {
|
||||
C.realView().coeffRef(r, c) = D.realView().coeff(r, c);
|
||||
}
|
||||
}
|
||||
VERIFY_IS_CWISE_EQUAL(C, D);
|
||||
|
||||
C = A;
|
||||
C.realView() /= alpha;
|
||||
VERIFY_IS_APPROX(B, C);
|
||||
|
||||
for (Index c = 0; c < realViewCols - 1; c++) {
|
||||
B.realView().row(0).coeffRef(realViewCols - 1 - c) = C.realView().row(0).coeff(c + 1);
|
||||
}
|
||||
D.realView().row(0).tail(realViewCols - 1) = C.realView().row(0).tail(realViewCols - 1).reverse();
|
||||
VERIFY_IS_CWISE_EQUAL(B.realView().row(0).tail(realViewCols - 1), D.realView().row(0).tail(realViewCols - 1));
|
||||
|
||||
for (Index r = 0; r < realViewRows - 1; r++) {
|
||||
B.realView().col(0).coeffRef(realViewRows - 1 - r) = C.realView().col(0).coeff(r + 1);
|
||||
}
|
||||
D.realView().col(0).tail(realViewRows - 1) = C.realView().col(0).tail(realViewRows - 1).reverse();
|
||||
VERIFY_IS_CWISE_EQUAL(B.realView().col(0).tail(realViewRows - 1), D.realView().col(0).tail(realViewRows - 1));
|
||||
}
|
||||
|
||||
template <typename ComplexScalar, bool Enable = internal::packet_traits<ComplexScalar>::Vectorizable>
|
||||
struct test_edge_cases_impl {
|
||||
static void run() {
|
||||
using namespace internal;
|
||||
using RealScalar = typename NumTraits<ComplexScalar>::Real;
|
||||
using ComplexPacket = typename packet_traits<ComplexScalar>::type;
|
||||
using RealPacket = typename unpacket_traits<ComplexPacket>::as_real;
|
||||
constexpr int ComplexSize = unpacket_traits<ComplexPacket>::size;
|
||||
constexpr int RealSize = 2 * ComplexSize;
|
||||
VectorX<ComplexScalar> a_data(2 * ComplexSize);
|
||||
Map<const VectorX<RealScalar>> a_data_asreal(reinterpret_cast<const RealScalar*>(a_data.data()), 2 * a_data.size());
|
||||
VectorX<RealScalar> b_data(RealSize);
|
||||
|
||||
a_data.setRandom();
|
||||
evaluator<RealView<VectorX<ComplexScalar>>> eval(a_data.realView());
|
||||
|
||||
for (Index offset = 0; offset < RealSize; offset++) {
|
||||
for (Index begin = 0; offset + begin < RealSize; begin++) {
|
||||
for (Index count = 0; begin + count < RealSize; count++) {
|
||||
b_data.setRandom();
|
||||
RealPacket res = eval.packetSegment<Unaligned, RealPacket>(offset, begin, count);
|
||||
pstoreSegment(b_data.data(), res, begin, count);
|
||||
VERIFY_IS_CWISE_EQUAL(a_data_asreal.segment(offset + begin, count), b_data.segment(begin, count));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ComplexScalar>
|
||||
struct test_edge_cases_impl<ComplexScalar, false> {
|
||||
static void run() {}
|
||||
};
|
||||
|
||||
template <typename ComplexScalar>
|
||||
void test_edge_cases(const ComplexScalar&) {
|
||||
test_edge_cases_impl<ComplexScalar>::run();
|
||||
}
|
||||
|
||||
template <typename Scalar, int Rows, int Cols, int MaxRows = Rows, int MaxCols = Cols>
|
||||
void test_realview_driver() {
|
||||
void test_realview_readonly() {
|
||||
// if Rows == 1, don't test ColMajor as it is not a valid array
|
||||
using ColMajorMatrixType = Matrix<Scalar, Rows, Cols, Rows == 1 ? RowMajor : ColMajor, MaxRows, MaxCols>;
|
||||
using ColMajorArrayType = Array<Scalar, Rows, Cols, Rows == 1 ? RowMajor : ColMajor, MaxRows, MaxCols>;
|
||||
// if Cols == 1, don't test RowMajor as it is not a valid array
|
||||
using RowMajorMatrixType = Matrix<Scalar, Rows, Cols, Cols == 1 ? ColMajor : RowMajor, MaxRows, MaxCols>;
|
||||
using RowMajorArrayType = Array<Scalar, Rows, Cols, Cols == 1 ? ColMajor : RowMajor, MaxRows, MaxCols>;
|
||||
test_realview_readonly(ColMajorMatrixType());
|
||||
test_realview_readonly(ColMajorArrayType());
|
||||
test_realview_readonly(RowMajorMatrixType());
|
||||
test_realview_readonly(RowMajorArrayType());
|
||||
}
|
||||
|
||||
template <typename Scalar, int Rows, int Cols, int MaxRows = Rows, int MaxCols = Cols>
|
||||
void test_realview_readwrite() {
|
||||
// if Rows == 1, don't test ColMajor as it is not a valid array
|
||||
using ColMajorMatrixType = Matrix<Scalar, Rows, Cols, Rows == 1 ? RowMajor : ColMajor, MaxRows, MaxCols>;
|
||||
using ColMajorArrayType = Array<Scalar, Rows, Cols, Rows == 1 ? RowMajor : ColMajor, MaxRows, MaxCols>;
|
||||
@@ -85,26 +225,29 @@ void test_realview_driver() {
|
||||
}
|
||||
|
||||
template <int Rows, int Cols, int MaxRows = Rows, int MaxCols = Cols>
|
||||
void test_realview_driver_complex() {
|
||||
test_realview_driver<float, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_driver<std::complex<float>, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_driver<double, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_driver<std::complex<double>, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_driver<long double, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_driver<std::complex<long double>, Rows, Cols, MaxRows, MaxCols>();
|
||||
void test_realview() {
|
||||
test_realview_readwrite<float, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_readwrite<std::complex<float>, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_readwrite<double, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_readwrite<std::complex<double>, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_readwrite<long double, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_readwrite<std::complex<long double>, Rows, Cols, MaxRows, MaxCols>();
|
||||
test_realview_readonly<TestComplex, Rows, Cols, MaxRows, MaxCols>();
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(realview) {
|
||||
for (int i = 0; i < g_repeat; i++) {
|
||||
CALL_SUBTEST_1((test_realview_driver_complex<Dynamic, Dynamic, Dynamic, Dynamic>()));
|
||||
CALL_SUBTEST_2((test_realview_driver_complex<Dynamic, Dynamic, 17, Dynamic>()));
|
||||
CALL_SUBTEST_3((test_realview_driver_complex<Dynamic, Dynamic, Dynamic, 19>()));
|
||||
CALL_SUBTEST_4((test_realview_driver_complex<Dynamic, Dynamic, 17, 19>()));
|
||||
CALL_SUBTEST_5((test_realview_driver_complex<17, Dynamic, 17, Dynamic>()));
|
||||
CALL_SUBTEST_6((test_realview_driver_complex<Dynamic, 19, Dynamic, 19>()));
|
||||
CALL_SUBTEST_7((test_realview_driver_complex<17, 19, 17, 19>()));
|
||||
CALL_SUBTEST_8((test_realview_driver_complex<Dynamic, 1>()));
|
||||
CALL_SUBTEST_9((test_realview_driver_complex<1, Dynamic>()));
|
||||
CALL_SUBTEST_10((test_realview_driver_complex<1, 1>()));
|
||||
CALL_SUBTEST_1((test_realview<Dynamic, Dynamic, Dynamic, Dynamic>()));
|
||||
CALL_SUBTEST_2((test_realview<Dynamic, Dynamic, 17, Dynamic>()));
|
||||
CALL_SUBTEST_3((test_realview<Dynamic, Dynamic, Dynamic, 19>()));
|
||||
CALL_SUBTEST_4((test_realview<Dynamic, Dynamic, 17, 19>()));
|
||||
CALL_SUBTEST_5((test_realview<17, Dynamic, 17, Dynamic>()));
|
||||
CALL_SUBTEST_6((test_realview<Dynamic, 19, Dynamic, 19>()));
|
||||
CALL_SUBTEST_7((test_realview<17, 19, 17, 19>()));
|
||||
CALL_SUBTEST_8((test_realview<Dynamic, 1>()));
|
||||
CALL_SUBTEST_9((test_realview<1, Dynamic>()));
|
||||
CALL_SUBTEST_10((test_realview<1, 1>()));
|
||||
CALL_SUBTEST_11(test_edge_cases(std::complex<float>()));
|
||||
CALL_SUBTEST_12(test_edge_cases(std::complex<double>()));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user