diff --git a/Eigen/src/SparseCore/CompressedStorage.h b/Eigen/src/SparseCore/CompressedStorage.h index 696f29d02..1f7aeb59c 100644 --- a/Eigen/src/SparseCore/CompressedStorage.h +++ b/Eigen/src/SparseCore/CompressedStorage.h @@ -225,22 +225,6 @@ class CompressedStorage } } - void prune(const Scalar& reference, const RealScalar& epsilon = NumTraits::dummy_precision()) - { - Index k = 0; - Index n = size(); - for (Index i=0; i::dummy_precision()) + Index prune(const Scalar& reference, const RealScalar& epsilon = NumTraits::dummy_precision()) { + return prune([&](const Scalar& val){ return !internal::isMuchSmallerThan(val, reference, epsilon); }); + } + + /** + * \brief Prunes the entries of the vector based on a `predicate` + * \tparam F Type of the predicate. + * \param keep_predicate The predicate that is used to test whether a value should be kept. A callable that + * gets passed om a `Scalar` value and returns a boolean. If the predicate returns true, the value is kept. + * \return The new number of structural non-zeros. + */ + template + Index prune(F&& keep_predicate) { - m_data.prune(reference,epsilon); + Index k = 0; + Index n = m_data.size(); + for (Index i = 0; i < n; ++i) + { + if (keep_predicate(m_data.value(i))) + { + m_data.value(k) = std::move(m_data.value(i)); + m_data.index(k) = m_data.index(i); + ++k; + } + } + m_data.resize(k); + return k; } /** Resizes the sparse vector to \a rows x \a cols diff --git a/test/sparse_vector.cpp b/test/sparse_vector.cpp index 35129278b..7bd57cdbe 100644 --- a/test/sparse_vector.cpp +++ b/test/sparse_vector.cpp @@ -111,7 +111,7 @@ template void sparse_vector(int rows, int // check copy to dense vector with transpose refV3.resize(0); VERIFY_IS_APPROX(refV3 = v1.transpose(),v1.toDense()); - VERIFY_IS_APPROX(DenseVector(v1),v1.toDense()); + VERIFY_IS_APPROX(DenseVector(v1),v1.toDense()); // test conservative resize { @@ -144,6 +144,31 @@ template void sparse_vector(int rows, int } } +void test_pruning() { + using SparseVectorType = SparseVector; + + SparseVectorType vec; + auto init_vec = [&](){; + vec.resize(10); + vec.insert(3) = 0.1; + vec.insert(5) = 1.0; + vec.insert(8) = -0.1; + vec.insert(9) = -0.2; + }; + init_vec(); + + VERIFY_IS_EQUAL(vec.nonZeros(), 4); + VERIFY_IS_EQUAL(vec.prune(0.1, 1.0), 2); + VERIFY_IS_EQUAL(vec.nonZeros(), 2); + VERIFY_IS_EQUAL(vec.coeff(5), 1.0); + VERIFY_IS_EQUAL(vec.coeff(9), -0.2); + + init_vec(); + VERIFY_IS_EQUAL(vec.prune([](double v) { return v >= 0; }), 2); + VERIFY_IS_EQUAL(vec.nonZeros(), 2); + VERIFY_IS_EQUAL(vec.coeff(3), 0.1); + VERIFY_IS_EQUAL(vec.coeff(5), 1.0); +} EIGEN_DECLARE_TEST(sparse_vector) { @@ -159,5 +184,7 @@ EIGEN_DECLARE_TEST(sparse_vector) CALL_SUBTEST_1(( sparse_vector(r, c) )); CALL_SUBTEST_1(( sparse_vector(r, c) )); } + + CALL_SUBTEST_1(test_pruning()); }