Fix scalar_logistic_function overflow for complex inputs.

This commit is contained in:
Antonio Sánchez
2023-12-05 18:21:04 +00:00
committed by Rasmus Munk Larsen
parent 9688081029
commit 3252ecc7a4
2 changed files with 28 additions and 8 deletions

View File

@@ -976,7 +976,14 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0 + exp(-m1))));
VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
if (m1.size() > 0) {
// Complex exponential overflow edge-case.
Scalar old_m1_val = m1(0, 0);
m1(0, 0) = std::complex<RealScalar>(1000.0, 1000.0);
VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
m1(0, 0) = old_m1_val; // Restore value for future tests.
}
for (Index i = 0; i < m.rows(); ++i)
for (Index j = 0; j < m.cols(); ++j)