diff --git a/Eigen/src/Core/Dot.h b/Eigen/src/Core/Dot.h index a173306b4..059527c85 100644 --- a/Eigen/src/Core/Dot.h +++ b/Eigen/src/Core/Dot.h @@ -20,7 +20,10 @@ namespace internal { template ::Scalar> struct squared_norm_impl { using Real = typename NumTraits::Real; - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Real run(const Derived& a) { return a.realView().cwiseAbs2().sum(); } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Real run(const Derived& a) { + Scalar result = a.unaryExpr(squared_norm_functor()).sum(); + return numext::real(result) + numext::imag(result); + } }; template diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index d7fc7bb4d..202995ff0 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -106,6 +106,26 @@ struct functor_traits> { }; }; +template ::IsComplex> +struct squared_norm_functor { + typedef Scalar result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { + return Scalar(numext::real(a) * numext::real(a), numext::imag(a) * numext::imag(a)); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const { + return Packet(pmul(a.v, a.v)); + } +}; +template +struct squared_norm_functor : scalar_abs2_op {}; + +template +struct functor_traits> { + using Real = typename NumTraits::Real; + enum { Cost = NumTraits::MulCost, PacketAccess = packet_traits::HasMul }; +}; + /** \internal * \brief Template functor to compute the conjugate of a complex value *