Enable equality comparisons on GPU.

Since `std::equal_to::operator()` is not a device function, it
fails on GPU.  On my device, I seem to get a silent crash in the
kernel (no reported error, but the kernel does not complete).

Replacing this with a portable version enables comparisons on device.

Addresses #2292 - would need to be cherry-picked.  The 3.3 branch
also requires adding `EIGEN_DEVICE_FUNC` in `BooleanRedux.h` to get
fully working.


(cherry picked from commit 7880f10526)
This commit is contained in:
Antonio Sanchez
2021-07-20 13:53:41 -07:00
committed by Rasmus Munk Larsen
parent 7adc1545b4
commit 3dc42eeaec
3 changed files with 41 additions and 12 deletions

View File

@@ -39,10 +39,10 @@ cwiseProduct(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
inline const CwiseBinaryOp<numext::equal_to<Scalar>, const Derived, const OtherDerived>
cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryOp<numext::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise != operator of *this and \a other
@@ -59,10 +59,10 @@ cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
inline const CwiseBinaryOp<numext::not_equal_to<Scalar>, const Derived, const OtherDerived>
cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryOp<numext::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise min of *this and \a other