Add digamma for CPU + CUDA. Includes tests.

This commit is contained in:
Eugene Brevdo
2015-12-24 21:15:38 -08:00
parent bdcbc66a5c
commit f7362772e3
9 changed files with 447 additions and 62 deletions

View File

@@ -219,6 +219,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
#ifdef EIGEN_HAS_C99_MATH
VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1));
VERIFY_IS_APPROX(m1.digamma(), digamma(m1));
VERIFY_IS_APPROX(m1.erf(), erf(m1));
VERIFY_IS_APPROX(m1.erfc(), erfc(m1));
#endif // EIGEN_HAS_C99_MATH
@@ -309,7 +310,20 @@ template<typename ArrayType> void array_real(const ArrayType& m)
s1 += Scalar(tiny);
m1 += ArrayType::Constant(rows,cols,Scalar(tiny));
VERIFY_IS_APPROX(s1/m1, s1 * m1.inverse());
// check special functions (comparing against numpy implementation)
if (!NumTraits<Scalar>::IsComplex) {
VERIFY_IS_APPROX(numext::digamma(Scalar(1)), RealScalar(-0.5772156649015329));
VERIFY_IS_APPROX(numext::digamma(Scalar(1.5)), RealScalar(0.03648997397857645));
VERIFY_IS_APPROX(numext::digamma(Scalar(4)), RealScalar(1.2561176684318));
VERIFY_IS_APPROX(numext::digamma(Scalar(-10.5)), RealScalar(2.398239129535781));
VERIFY_IS_APPROX(numext::digamma(Scalar(10000.5)), RealScalar(9.210340372392849));
VERIFY_IS_EQUAL(numext::digamma(Scalar(0)),
std::numeric_limits<RealScalar>::infinity());
VERIFY_IS_EQUAL(numext::digamma(Scalar(-1)),
std::numeric_limits<RealScalar>::infinity());
}
// check inplace transpose
m3 = m1;
m3.transposeInPlace();
@@ -336,8 +350,6 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
Array<RealScalar, -1, -1> m3(rows, cols);
Scalar s1 = internal::random<Scalar>();
for (Index i = 0; i < m.rows(); ++i)
for (Index j = 0; j < m.cols(); ++j)
m2(i,j) = sqrt(m1(i,j));
@@ -410,6 +422,7 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
VERIFY_IS_APPROX( m1.sign() * m1.abs(), m1);
// scalar by array division
Scalar s1 = internal::random<Scalar>();
const RealScalar tiny = sqrt(std::numeric_limits<RealScalar>::epsilon());
s1 += Scalar(tiny);
m1 += ArrayType::Constant(rows,cols,Scalar(tiny));