diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 72b09d998..564eb97dc 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -283,9 +283,27 @@ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { } #endif +template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) { + __m256 lo = pcmp_le(extract256<0>(a), extract256<0>(b)); + __m256 hi = pcmp_le(extract256<1>(a), extract256<1>(b)); + return cat256(lo, hi); +} + +template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) { + __m256 lo = pcmp_lt(extract256<0>(a), extract256<0>(b)); + __m256 hi = pcmp_lt(extract256<1>(a), extract256<1>(b)); + return cat256(lo, hi); +} + +template<> EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) { + __m256 lo = pcmp_eq(extract256<0>(a), extract256<0>(b)); + __m256 hi = pcmp_eq(extract256<1>(a), extract256<1>(b)); + return cat256(lo, hi); +} + template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) { - __m256 lo = _mm256_cmp_ps(extract256<0>(a), extract256<0>(b), _CMP_NGE_UQ); - __m256 hi = _mm256_cmp_ps(extract256<1>(a), extract256<1>(b), _CMP_NGE_UQ); + __m256 lo = pcmp_lt_or_nan(extract256<0>(a), extract256<0>(b)); + __m256 hi = pcmp_lt_or_nan(extract256<1>(a), extract256<1>(b)); return cat256(lo, hi); }