From ddce1d7d12076d13bac1c517609ca6b638d071f4 Mon Sep 17 00:00:00 2001 From: Artem Bishev Date: Thu, 7 Aug 2025 16:58:22 +0000 Subject: [PATCH] Fixes #2952 --- Eigen/src/Core/VectorwiseOp.h | 36 ++++++++++++++++++++++++++++++----- test/vectorwiseop.cpp | 26 +++++++++++++++++-------- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/Eigen/src/Core/VectorwiseOp.h b/Eigen/src/Core/VectorwiseOp.h index 9ccbf7d76..9e34d8c99 100644 --- a/Eigen/src/Core/VectorwiseOp.h +++ b/Eigen/src/Core/VectorwiseOp.h @@ -146,6 +146,22 @@ struct member_redux { const BinaryOp& binaryFunc() const { return m_functor; } const BinaryOp m_functor; }; + +template +struct scalar_replace_zero_with_one_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& x) const { + return numext::is_exactly_zero(x) ? Scalar(1) : x; + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { + return pselect(pcmp_eq(x, pzero(x)), pset1(Scalar(1)), x); + } +}; +template +struct functor_traits> { + enum { Cost = 1, PacketAccess = packet_traits::HasCmp }; +}; + } // namespace internal /** \class VectorwiseOp @@ -624,18 +640,28 @@ class VectorwiseOp { return m_matrix / extendedTo(other.derived()); } + using Normalized_NonzeroNormType = + CwiseUnaryOp, const NormReturnType>; + using NormalizedReturnType = CwiseBinaryOp, const ExpressionTypeNestedCleaned, + const typename OppositeExtendedType::Type>; + /** \returns an expression where each column (or row) of the referenced matrix are normalized. * The referenced matrix is \b not modified. + * + * \warning If the input columns (or rows) are too small (i.e., their norm equals to 0), they remain unchanged in the + * resulting expression. + * * \sa MatrixBase::normalized(), normalize() */ - EIGEN_DEVICE_FUNC CwiseBinaryOp, const ExpressionTypeNestedCleaned, - const typename OppositeExtendedType::Type> - normalized() const { - return m_matrix.cwiseQuotient(extendedToOpposite(this->norm())); + EIGEN_DEVICE_FUNC NormalizedReturnType normalized() const { + return m_matrix.cwiseQuotient(extendedToOpposite(Normalized_NonzeroNormType(this->norm()))); } /** Normalize in-place each row or columns of the referenced matrix. - * \sa MatrixBase::normalize(), normalized() + * + * \warning If the input columns (or rows) are too small (i.e., their norm equals to 0), they are left unchanged. + * + * \sa MatrixBase::normalized(), normalize() */ EIGEN_DEVICE_FUNC void normalize() { m_matrix = this->normalized(); } diff --git a/test/vectorwiseop.cpp b/test/vectorwiseop.cpp index 6d0e5cbf8..d037bb49b 100644 --- a/test/vectorwiseop.cpp +++ b/test/vectorwiseop.cpp @@ -114,6 +114,8 @@ void vectorwiseop_matrix(const MatrixType& m) { RealColVectorType rcres; RealRowVectorType rrres; + Scalar small_scalar = (std::numeric_limits::min)(); + // test broadcast assignment m2 = m1; m2.colwise() = colvec; @@ -171,18 +173,26 @@ void vectorwiseop_matrix(const MatrixType& m) { VERIFY_IS_APPROX(m1.cwiseAbs().colwise().sum().x(), m1.col(0).cwiseAbs().sum()); // test normalized - m2 = m1.colwise().normalized(); - VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized()); - m2 = m1.rowwise().normalized(); - VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized()); + m2 = m1; + m2.col(c).fill(small_scalar); + m3 = m2.colwise().normalized(); + for (Index k = 0; k < cols; ++k) VERIFY_IS_APPROX(m3.col(k), m2.col(k).normalized()); + m2 = m1; + m2.row(r).setZero(); + m3 = m2.rowwise().normalized(); + for (Index k = 0; k < rows; ++k) VERIFY_IS_APPROX(m3.row(k), m2.row(k).normalized()); // test normalize m2 = m1; - m2.colwise().normalize(); - VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized()); + m2.col(c).setZero(); + m3 = m2; + m3.colwise().normalize(); + for (Index k = 0; k < cols; ++k) VERIFY_IS_APPROX(m3.col(k), m2.col(k).normalized()); m2 = m1; - m2.rowwise().normalize(); - VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized()); + m2.row(r).fill(small_scalar); + m3 = m2; + m3.rowwise().normalize(); + for (Index k = 0; k < rows; ++k) VERIFY_IS_APPROX(m3.row(k), m2.row(k).normalized()); // test with partial reduction of products Matrix m1m1 = m1 * m1.transpose();