diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 075d18aa6..b5eb1cf99 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -216,12 +216,12 @@ template EIGEN_DEVICE_FUNC inline Packet pdiv(const Packet& a, const Packet& b) { return a/b; } /** \internal \returns the min of \a a and \a b (coeff-wise). -Equivalent to std::min(a, b), so if either a or b is NaN, a is returned. */ + If \a a or \b b is NaN, the return value is implementation defined. */ template EIGEN_DEVICE_FUNC inline Packet pmin(const Packet& a, const Packet& b) { return numext::mini(a, b); } /** \internal \returns the max of \a a and \a b (coeff-wise) -Equivalent to std::max(a, b), so if either a or b is NaN, a is returned.*/ + If \a a or \b b is NaN, the return value is implementation defined. */ template EIGEN_DEVICE_FUNC inline Packet pmax(const Packet& a, const Packet& b) { return numext::maxi(a, b); } @@ -635,23 +635,54 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); } template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); } -/** \internal \returns the min of \a a and \a b (coeff-wise) - Equivalent to std::fmin(a, b). Only if both a and b are NaN is NaN returned. -*/ + +/** \internal \returns the max of \a a and \a b (coeff-wise) + If both \a a and \a b are NaN, NaN is returned. + Equivalent to std::fmax(a, b). */ +template EIGEN_DEVICE_FUNC inline Packet +pfmax(const Packet& a, const Packet& b) { + Packet not_nan_mask_a = pcmp_eq(a, a); + Packet not_nan_mask_b = pcmp_eq(b, b); + return pselect(not_nan_mask_a, + pselect(not_nan_mask_b, pmax(a, b), a), + b); +} + +/** \internal \returns the min of \a a and \a b (coeff-wise) + If both \a a and \a b are NaN, NaN is returned. + Equivalent to std::fmin(a, b). */ template EIGEN_DEVICE_FUNC inline Packet pfmin(const Packet& a, const Packet& b) { - Packet not_nan_mask = pcmp_eq(a, a); - return pselect(not_nan_mask, pmin(a, b), b); + Packet not_nan_mask_a = pcmp_eq(a, a); + Packet not_nan_mask_b = pcmp_eq(b, b); + return pselect(not_nan_mask_a, + pselect(not_nan_mask_b, pmin(a, b), a), + b); } /** \internal \returns the max of \a a and \a b (coeff-wise) - Equivalent to std::fmax(a, b). Only if both a and b are NaN is NaN returned.*/ + If either \a a or \a b are NaN, NaN is returned. */ template EIGEN_DEVICE_FUNC inline Packet -pfmax(const Packet& a, const Packet& b) { - Packet not_nan_mask = pcmp_eq(a, a); - return pselect(not_nan_mask, pmax(a, b), b); +pfmax_nan(const Packet& a, const Packet& b) { + Packet not_nan_mask_a = pcmp_eq(a, a); + Packet not_nan_mask_b = pcmp_eq(b, b); + return pselect(not_nan_mask_a, + pselect(not_nan_mask_b, pmax(a, b), b), + a); } +/** \internal \returns the min of \a a and \a b (coeff-wise) + If either \a a or \a b are NaN, NaN is returned. */ +template EIGEN_DEVICE_FUNC inline Packet +pfmin_nan(const Packet& a, const Packet& b) { + Packet not_nan_mask_a = pcmp_eq(a, a); + Packet not_nan_mask_b = pcmp_eq(b, b); + return pselect(not_nan_mask_a, + pselect(not_nan_mask_b, pmin(a, b), b), + a); +} + + /*************************************************************************** * The following functions might not have to be overwritten for vectorized types ***************************************************************************/ diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index d8b7b1eba..55650bb8d 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -134,21 +134,39 @@ struct functor_traits > { * * \sa class CwiseBinaryOp, MatrixBase::cwiseMin, class VectorwiseOp, MatrixBase::minCoeff() */ -template +template struct scalar_min_op : binary_op_base { typedef typename ScalarBinaryOpTraits::ReturnType result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_min_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return numext::mini(a, b); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { + if (NaNPropagation == PropagateFast) { + return numext::mini(a, b); + } else if (NaNPropagation == PropagateNumbers) { + return internal::pfmin(a,b); + } else if (NaNPropagation == PropagateNaN) { + return internal::pfmin_nan(a,b); + } + } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const - { return internal::pmin(a,b); } + { + if (NaNPropagation == PropagateFast) { + return internal::pmin(a,b); + } else if (NaNPropagation == PropagateNumbers) { + return internal::pfmin(a,b); + } else if (NaNPropagation == PropagateNaN) { + return internal::pfmin_nan(a,b); + } + } + // TODO(rmlarsen): Handle all NaN propagation semantics reductions. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const { return internal::predux_min(a); } }; -template -struct functor_traits > { + +template +struct functor_traits > { enum { Cost = (NumTraits::AddCost+NumTraits::AddCost)/2, PacketAccess = internal::is_same::value && packet_traits::HasMin @@ -160,21 +178,39 @@ struct functor_traits > { * * \sa class CwiseBinaryOp, MatrixBase::cwiseMax, class VectorwiseOp, MatrixBase::maxCoeff() */ -template -struct scalar_max_op : binary_op_base +template +struct scalar_max_op : binary_op_base { typedef typename ScalarBinaryOpTraits::ReturnType result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_max_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return numext::maxi(a, b); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { + if (NaNPropagation == PropagateFast) { + return numext::maxi(a, b); + } else if (NaNPropagation == PropagateNumbers) { + return internal::pfmax(a,b); + } else if (NaNPropagation == PropagateNaN) { + return internal::pfmax_nan(a,b); + } + } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const - { return internal::pmax(a,b); } + { + if (NaNPropagation == PropagateFast) { + return internal::pmax(a,b); + } else if (NaNPropagation == PropagateNumbers) { + return internal::pfmax(a,b); + } else if (NaNPropagation == PropagateNaN) { + return internal::pfmax_nan(a,b); + } + } + // TODO(rmlarsen): Handle all NaN propagation semantics reductions. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const { return internal::predux_max(a); } }; -template -struct functor_traits > { + +template +struct functor_traits > { enum { Cost = (NumTraits::AddCost+NumTraits::AddCost)/2, PacketAccess = internal::is_same::value && packet_traits::HasMax diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index 7ada82195..ad9af5727 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h @@ -328,12 +328,21 @@ enum StorageOptions { * Enum for specifying whether to apply or solve on the left or right. */ enum SideType { /** Apply transformation on the left. */ - OnTheLeft = 1, + OnTheLeft = 1, /** Apply transformation on the right. */ - OnTheRight = 2 + OnTheRight = 2 }; - +/** \ingroup enums + * Enum for specifying NaN-propagation behavior, e.g. for coeff-wise min/max. */ +enum NaNPropagationOptions { + /** Implementation defined behavior if NaNs are present. */ + PropagateFast = 0, + /** Always propagate NaNs. */ + PropagateNaN, + /** Always propagate not-NaNs. */ + PropagateNumbers +}; /* the following used to be written as: * diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 208b96c9c..2f9cc4491 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -180,8 +180,8 @@ template struct scalar_sum_op; template struct scalar_difference_op; template struct scalar_conj_product_op; -template struct scalar_min_op; -template struct scalar_max_op; +template struct scalar_min_op; +template struct scalar_max_op; template struct scalar_opposite_op; template struct scalar_conjugate_op; template struct scalar_real_op; diff --git a/test/packetmath.cpp b/test/packetmath.cpp index dd3e5b41e..6cde7e87b 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -763,6 +763,20 @@ void packetmath_real::type> } +template +Scalar propagate_nan_max(const Scalar& a, const Scalar& b) { + if ((std::isnan)(a)) return a; + if ((std::isnan)(b)) return b; + return (std::max)(a,b); +} + +template +Scalar propagate_nan_min(const Scalar& a, const Scalar& b) { + if ((std::isnan)(a)) return a; + if ((std::isnan)(b)) return b; + return (std::min)(a,b); +} + template void packetmath_notcomplex() { typedef internal::packet_traits PacketTraits; @@ -829,12 +843,12 @@ void packetmath_notcomplex() { data1[i] = internal::random() ? std::numeric_limits::quiet_NaN() : Scalar(0); data1[i + PacketSize] = internal::random() ? std::numeric_limits::quiet_NaN() : Scalar(0); } - // Test NaN propagation for pmin and pmax. It should be equivalent to std::min. - CHECK_CWISE2_IF(PacketTraits::HasMin, (std::min), internal::pmin); - CHECK_CWISE2_IF(PacketTraits::HasMax, (std::max), internal::pmax); // Test NaN propagation for pfmin and pfmax. It should be equivalent to std::fmin. + // Note: NaN propagation is implementation defined for pmin/pmax, so we do not test it here. CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, internal::pfmin); CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pfmax); + CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_nan_min, internal::pfmin_nan); + CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_nan_max, internal::pfmax_nan); } template <> diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index bb0969f49..ef332dd19 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -395,16 +395,18 @@ class TensorBase return unaryExpr(internal::scalar_mod_op(rhs)); } + template EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > cwiseMax(Scalar threshold) const { - return cwiseMax(constant(threshold)); + return cwiseMax(constant(threshold)); } + template EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > cwiseMin(Scalar threshold) const { - return cwiseMin(constant(threshold)); + return cwiseMin(constant(threshold)); } template @@ -472,16 +474,16 @@ class TensorBase return binaryExpr(other.derived(), internal::scalar_quotient_op()); } - template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorCwiseBinaryOp, const Derived, const OtherDerived> + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> cwiseMax(const OtherDerived& other) const { - return binaryExpr(other.derived(), internal::scalar_max_op()); + return binaryExpr(other.derived(), internal::scalar_max_op()); } - template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorCwiseBinaryOp, const Derived, const OtherDerived> + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> cwiseMin(const OtherDerived& other) const { - return binaryExpr(other.derived(), internal::scalar_min_op()); + return binaryExpr(other.derived(), internal::scalar_min_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp index b49663fe9..7fac3b4ed 100644 --- a/unsupported/test/cxx11_tensor_expr.cpp +++ b/unsupported/test/cxx11_tensor_expr.cpp @@ -303,40 +303,79 @@ template void test_minmax_nan_propagation_templ() { for (int size = 1; size < 17; ++size) { const Scalar kNan = std::numeric_limits::quiet_NaN(); + const Scalar kZero(0); Tensor vec_nan(size); Tensor vec_zero(size); - Tensor vec_res(size); vec_nan.setConstant(kNan); vec_zero.setZero(); - vec_res.setZero(); - // Test that we propagate NaNs in the tensor when applying the - // cwiseMax(scalar) operator, which is used for the Relu operator. - vec_res = vec_nan.cwiseMax(Scalar(0)); - for (int i = 0; i < size; ++i) { - VERIFY((numext::isnan)(vec_res(i))); - } + auto verify_all_nan = [&](const Tensor& v) { + for (int i = 0; i < size; ++i) { + VERIFY((numext::isnan)(v(i))); + } + }; - // Test that NaNs do not propagate if we reverse the arguments. - vec_res = vec_zero.cwiseMax(kNan); - for (int i = 0; i < size; ++i) { - VERIFY_IS_EQUAL(vec_res(i), Scalar(0)); - } + auto verify_all_zero = [&](const Tensor& v) { + for (int i = 0; i < size; ++i) { + VERIFY_IS_EQUAL(v(i), Scalar(0)); + } + }; - // Test that we propagate NaNs in the tensor when applying the - // cwiseMin(scalar) operator. - vec_res.setZero(); - vec_res = vec_nan.cwiseMin(Scalar(0)); - for (int i = 0; i < size; ++i) { - VERIFY((numext::isnan)(vec_res(i))); - } + // Test NaN propagating max. + // max(nan, nan) = nan + // max(nan, 0) = nan + // max(0, nan) = nan + // max(0, 0) = 0 + verify_all_nan(vec_nan.template cwiseMax(kNan)); + verify_all_nan(vec_nan.template cwiseMax(vec_nan)); + verify_all_nan(vec_nan.template cwiseMax(kZero)); + verify_all_nan(vec_nan.template cwiseMax(vec_zero)); + verify_all_nan(vec_zero.template cwiseMax(kNan)); + verify_all_nan(vec_zero.template cwiseMax(vec_nan)); + verify_all_zero(vec_zero.template cwiseMax(kZero)); + verify_all_zero(vec_zero.template cwiseMax(vec_zero)); + // Test number propagating max. + // max(nan, nan) = nan + // max(nan, 0) = 0 + // max(0, nan) = 0 + // max(0, 0) = 0 + verify_all_nan(vec_nan.template cwiseMax(kNan)); + verify_all_nan(vec_nan.template cwiseMax(vec_nan)); + verify_all_zero(vec_nan.template cwiseMax(kZero)); + verify_all_zero(vec_nan.template cwiseMax(vec_zero)); + verify_all_zero(vec_zero.template cwiseMax(kNan)); + verify_all_zero(vec_zero.template cwiseMax(vec_nan)); + verify_all_zero(vec_zero.template cwiseMax(kZero)); + verify_all_zero(vec_zero.template cwiseMax(vec_zero)); - // Test that NaNs do not propagate if we reverse the arguments. - vec_res = vec_zero.cwiseMin(kNan); - for (int i = 0; i < size; ++i) { - VERIFY_IS_EQUAL(vec_res(i), Scalar(0)); - } + // Test NaN propagating min. + // min(nan, nan) = nan + // min(nan, 0) = nan + // min(0, nan) = nan + // min(0, 0) = 0 + verify_all_nan(vec_nan.template cwiseMin(kNan)); + verify_all_nan(vec_nan.template cwiseMin(vec_nan)); + verify_all_nan(vec_nan.template cwiseMin(kZero)); + verify_all_nan(vec_nan.template cwiseMin(vec_zero)); + verify_all_nan(vec_zero.template cwiseMin(kNan)); + verify_all_nan(vec_zero.template cwiseMin(vec_nan)); + verify_all_zero(vec_zero.template cwiseMin(kZero)); + verify_all_zero(vec_zero.template cwiseMin(vec_zero)); + + // Test number propagating min. + // min(nan, nan) = nan + // min(nan, 0) = 0 + // min(0, nan) = 0 + // min(0, 0) = 0 + verify_all_nan(vec_nan.template cwiseMin(kNan)); + verify_all_nan(vec_nan.template cwiseMin(vec_nan)); + verify_all_zero(vec_nan.template cwiseMin(kZero)); + verify_all_zero(vec_nan.template cwiseMin(vec_zero)); + verify_all_zero(vec_zero.template cwiseMin(kNan)); + verify_all_zero(vec_zero.template cwiseMin(vec_nan)); + verify_all_zero(vec_zero.template cwiseMin(kZero)); + verify_all_zero(vec_zero.template cwiseMin(vec_zero)); } }