Add TernaryFunctors and the betainc SpecialFunction.

TernaryFunctors and their executors allow operations on 3-tuples of inputs.
API fully implemented for Arrays and Tensors based on binary functors.

Ported the cephes betainc function (regularized incomplete beta
integral) to Eigen, with support for CPU and GPU, floats, doubles, and
half types.

Added unit tests in array.cpp and cxx11_tensor_cuda.cu


Collapsed revision
* Merged helper methods for betainc across floats and doubles.
* Added TensorGlobalFunctions with betainc().  Removed betainc() from TensorBase.
* Clean up CwiseTernaryOp checks, change igamma_helper to cephes_helper.
* betainc: merge incbcf and incbd into incbeta_cfe.  and more cleanup.
* Update TernaryOp and SpecialFunctions (betainc) based on review comments.
This commit is contained in:
Eugene Brevdo
2016-06-02 17:04:19 -07:00
parent 02db4e1a82
commit 39baff850c
21 changed files with 1389 additions and 21 deletions

View File

@@ -592,16 +592,123 @@ template<typename ArrayType> void array_special_functions()
ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927;
CALL_SUBTEST( verify_component_wise(ref, ref); );
if(sizeof(RealScalar)>=64) {
// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
if(sizeof(RealScalar)>=8) { // double
// Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res, ref); );
}
else {
// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
}
}
#endif
#if EIGEN_HAS_C99_MATH
{
// Inputs and ground truth generated with scipy via:
// a = np.logspace(-3, 3, 5) - 1e-3
// b = np.logspace(-3, 3, 5) - 1e-3
// x = np.linspace(-0.1, 1.1, 5)
// (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
// full_a = full_a.flatten().tolist() # same for full_b, full_x
// v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
//
// Note in Eigen, we call betainc with arguments in the order (x, a, b).
ArrayType a(125);
ArrayType b(125);
ArrayType x(125);
ArrayType v(125);
ArrayType res(125);
a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
999.999, 999.999, 999.999;
b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
31.62177660168379, 31.62177660168379, 31.62177660168379,
31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
999.999, 999.999;
x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
-0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
0.8, 1.1;
v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
0.0008598571564165444, nan, nan, 6.031987710123844e-08,
0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
CALL_SUBTEST(res = betainc(a, b, x);
verify_component_wise(res, v););
}
#endif
}
void test_array()