mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
add sparse sort inner vectors function
This commit is contained in:
committed by
Antonio Sánchez
parent
d194167149
commit
44fe539150
@@ -22,6 +22,9 @@ template<typename Derived>
|
||||
struct traits<SparseCompressedBase<Derived> > : traits<Derived>
|
||||
{};
|
||||
|
||||
template <typename Derived, class Comp, bool IsVector>
|
||||
struct inner_sort_impl;
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
/** \ingroup SparseCore_Module
|
||||
@@ -126,6 +129,40 @@ class SparseCompressedBase
|
||||
*
|
||||
* \sa valuePtr(), isCompressed() */
|
||||
Map<Array<Scalar,Dynamic,1> > coeffs() { eigen_assert(isCompressed()); return Array<Scalar,Dynamic,1>::Map(valuePtr(),nonZeros()); }
|
||||
|
||||
/** sorts the inner vectors in the range [begin,end) with respect to `Comp`
|
||||
* \sa innerIndicesAreSorted() */
|
||||
template <class Comp = std::less<>>
|
||||
inline void sortInnerIndices(Index begin, Index end) {
|
||||
eigen_assert(begin >= 0 && end <= derived().outerSize() && end >= begin);
|
||||
internal::inner_sort_impl<Derived, Comp, IsVectorAtCompileTime>::run(*this, begin, end);
|
||||
}
|
||||
|
||||
/** \returns the index of the first inner vector in the range [begin,end) that is not sorted with respect to `Comp`, or `end` if the range is fully sorted
|
||||
* \sa sortInnerIndices() */
|
||||
template <class Comp = std::less<>>
|
||||
inline Index innerIndicesAreSorted(Index begin, Index end) const {
|
||||
eigen_assert(begin >= 0 && end <= derived().outerSize() && end >= begin);
|
||||
return internal::inner_sort_impl<Derived, Comp, IsVectorAtCompileTime>::check(*this, begin, end);
|
||||
}
|
||||
|
||||
/** sorts the inner vectors in the range [0,outerSize) with respect to `Comp`
|
||||
* \sa innerIndicesAreSorted() */
|
||||
template <class Comp = std::less<>>
|
||||
inline void sortInnerIndices() {
|
||||
Index begin = 0;
|
||||
Index end = derived().outerSize();
|
||||
internal::inner_sort_impl<Derived, Comp, IsVectorAtCompileTime>::run(*this, begin, end);
|
||||
}
|
||||
|
||||
/** \returns the index of the first inner vector in the range [0,outerSize) that is not sorted with respect to `Comp`, or `outerSize` if the range is fully sorted
|
||||
* \sa sortInnerIndices() */
|
||||
template<class Comp = std::less<>>
|
||||
inline Index innerIndicesAreSorted() const {
|
||||
Index begin = 0;
|
||||
Index end = derived().outerSize();
|
||||
return internal::inner_sort_impl<Derived, Comp, IsVectorAtCompileTime>::check(*this, begin, end);
|
||||
}
|
||||
|
||||
protected:
|
||||
/** Default constructor. Do nothing. */
|
||||
@@ -306,6 +343,136 @@ class SparseCompressedBase<Derived>::ReverseInnerIterator
|
||||
|
||||
namespace internal {
|
||||
|
||||
// modified from https://artificial-mind.net/blog/2020/11/28/std-sort-multiple-ranges
|
||||
template <typename Scalar, typename StorageIndex>
|
||||
class CompressedStorageIterator;
|
||||
|
||||
// wrapper class analogous to std::pair<StorageIndex&, Scalar&>
|
||||
// used to define assignment, swap, and comparison operators for CompressedStorageIterator
|
||||
template <typename Scalar, typename StorageIndex>
|
||||
class StorageRef
|
||||
{
|
||||
public:
|
||||
using value_type = std::pair<StorageIndex, Scalar>;
|
||||
|
||||
inline StorageRef& operator=(const StorageRef& other) {
|
||||
*m_innerIndexIterator = *other.m_innerIndexIterator;
|
||||
*m_valueIterator = *other.m_valueIterator;
|
||||
return *this;
|
||||
}
|
||||
inline StorageRef& operator=(const value_type& other) {
|
||||
std::tie(*m_innerIndexIterator, *m_valueIterator) = other;
|
||||
return *this;
|
||||
}
|
||||
inline operator value_type() const { return std::make_pair(*m_innerIndexIterator, *m_valueIterator); }
|
||||
inline friend void swap(const StorageRef& a, const StorageRef& b) {
|
||||
std::iter_swap(a.m_innerIndexIterator, b.m_innerIndexIterator);
|
||||
std::iter_swap(a.m_valueIterator, b.m_valueIterator);
|
||||
}
|
||||
|
||||
inline static const StorageIndex& key(const StorageRef& a) { return *a.m_innerIndexIterator; }
|
||||
inline static const StorageIndex& key(const value_type& a) { return a.first; }
|
||||
#define REF_COMP_REF(OP) inline friend bool operator OP(const StorageRef& a, const StorageRef& b) { return key(a) OP key(b); };
|
||||
#define REF_COMP_VAL(OP) inline friend bool operator OP(const StorageRef& a, const value_type& b) { return key(a) OP key(b); };
|
||||
#define VAL_COMP_REF(OP) inline friend bool operator OP(const value_type& a, const StorageRef& b) { return key(a) OP key(b); };
|
||||
#define MAKE_COMPS(OP) REF_COMP_REF(OP) REF_COMP_VAL(OP) VAL_COMP_REF(OP)
|
||||
MAKE_COMPS(<) MAKE_COMPS(>) MAKE_COMPS(<=) MAKE_COMPS(>=) MAKE_COMPS(==) MAKE_COMPS(!=)
|
||||
|
||||
protected:
|
||||
StorageIndex* m_innerIndexIterator;
|
||||
Scalar* m_valueIterator;
|
||||
private:
|
||||
StorageRef() = delete;
|
||||
// these constructors are only called by the CompressedStorageIterator constructors for convenience only
|
||||
StorageRef(StorageIndex* innerIndexIterator, Scalar* valueIterator) : m_innerIndexIterator(innerIndexIterator), m_valueIterator(valueIterator) {}
|
||||
StorageRef(const StorageRef& other) : m_innerIndexIterator(other.m_innerIndexIterator), m_valueIterator(other.m_valueIterator) {}
|
||||
|
||||
friend class CompressedStorageIterator<Scalar, StorageIndex>;
|
||||
};
|
||||
|
||||
// STL-compatible iterator class that operates on inner indices and values
|
||||
template<typename Scalar, typename StorageIndex>
|
||||
class CompressedStorageIterator
|
||||
{
|
||||
public:
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using reference = StorageRef<Scalar, StorageIndex>;
|
||||
using difference_type = Index;
|
||||
using value_type = typename reference::value_type;
|
||||
using pointer = value_type*;
|
||||
|
||||
CompressedStorageIterator() = delete;
|
||||
CompressedStorageIterator(difference_type index, StorageIndex* innerIndexPtr, Scalar* valuePtr) : m_index(index), m_data(innerIndexPtr, valuePtr) {}
|
||||
CompressedStorageIterator(difference_type index, reference data) : m_index(index), m_data(data) {}
|
||||
CompressedStorageIterator(const CompressedStorageIterator& other) : m_index(other.m_index), m_data(other.m_data) {}
|
||||
inline CompressedStorageIterator& operator=(const CompressedStorageIterator& other) {
|
||||
m_index = other.m_index;
|
||||
m_data = other.m_data;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline bool operator==(const CompressedStorageIterator& other) const { return m_index == other.m_index; }
|
||||
inline bool operator!=(const CompressedStorageIterator& other) const { return m_index != other.m_index; }
|
||||
inline bool operator< (const CompressedStorageIterator& other) const { return m_index < other.m_index; }
|
||||
inline CompressedStorageIterator operator+(difference_type offset) const { return CompressedStorageIterator(m_index + offset, m_data); }
|
||||
inline CompressedStorageIterator operator-(difference_type offset) const { return CompressedStorageIterator(m_index - offset, m_data); }
|
||||
inline difference_type operator-(const CompressedStorageIterator& other) const { return m_index - other.m_index; }
|
||||
inline CompressedStorageIterator& operator++() { ++m_index; return *this; }
|
||||
inline CompressedStorageIterator& operator--() { --m_index; return *this; }
|
||||
inline reference operator*() const { return reference(m_data.m_innerIndexIterator + m_index, m_data.m_valueIterator + m_index); }
|
||||
|
||||
protected:
|
||||
difference_type m_index;
|
||||
reference m_data;
|
||||
};
|
||||
|
||||
template <typename Derived, class Comp, bool IsVector>
|
||||
struct inner_sort_impl {
|
||||
typedef typename Derived::Scalar Scalar;
|
||||
typedef typename Derived::StorageIndex StorageIndex;
|
||||
static inline void run(SparseCompressedBase<Derived>& obj, Index begin, Index end) {
|
||||
const bool is_compressed = obj.isCompressed();
|
||||
for (Index outer = begin; outer < end; outer++) {
|
||||
Index begin_offset = obj.outerIndexPtr()[outer];
|
||||
Index end_offset = is_compressed ? obj.outerIndexPtr()[outer + 1] : (begin_offset + obj.innerNonZeroPtr()[outer]);
|
||||
CompressedStorageIterator<Scalar, StorageIndex> begin_it(begin_offset, obj.innerIndexPtr(), obj.valuePtr());
|
||||
CompressedStorageIterator<Scalar, StorageIndex> end_it(end_offset, obj.innerIndexPtr(), obj.valuePtr());
|
||||
std::sort(begin_it, end_it, Comp());
|
||||
}
|
||||
}
|
||||
static inline Index check(const SparseCompressedBase<Derived>& obj, Index begin, Index end) {
|
||||
const bool is_compressed = obj.isCompressed();
|
||||
for (Index outer = begin; outer < end; outer++) {
|
||||
Index begin_offset = obj.outerIndexPtr()[outer];
|
||||
Index end_offset = is_compressed ? obj.outerIndexPtr()[outer + 1] : (begin_offset + obj.innerNonZeroPtr()[outer]);
|
||||
const StorageIndex* begin_it = obj.innerIndexPtr() + begin_offset;
|
||||
const StorageIndex* end_it = obj.innerIndexPtr() + end_offset;
|
||||
bool is_sorted = std::is_sorted(begin_it, end_it, Comp());
|
||||
if (!is_sorted) return outer;
|
||||
}
|
||||
return end;
|
||||
}
|
||||
};
|
||||
template <typename Derived, class Comp>
|
||||
struct inner_sort_impl<Derived, Comp, true> {
|
||||
typedef typename Derived::Scalar Scalar;
|
||||
typedef typename Derived::StorageIndex StorageIndex;
|
||||
static inline void run(SparseCompressedBase<Derived>& obj, Index, Index) {
|
||||
Index begin_offset = 0;
|
||||
Index end_offset = obj.nonZeros();
|
||||
CompressedStorageIterator<Scalar, StorageIndex> begin_it(begin_offset, obj.innerIndexPtr(), obj.valuePtr());
|
||||
CompressedStorageIterator<Scalar, StorageIndex> end_it(end_offset, obj.innerIndexPtr(), obj.valuePtr());
|
||||
std::sort(begin_it, end_it, Comp());
|
||||
}
|
||||
static inline Index check(const SparseCompressedBase<Derived>& obj, Index, Index) {
|
||||
Index begin_offset = 0;
|
||||
Index end_offset = obj.nonZeros();
|
||||
const StorageIndex* begin_it = obj.innerIndexPtr() + begin_offset;
|
||||
const StorageIndex* end_it = obj.innerIndexPtr() + end_offset;
|
||||
return std::is_sorted(begin_it, end_it, Comp()) ? 1 : 0;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Derived>
|
||||
struct evaluator<SparseCompressedBase<Derived> >
|
||||
: evaluator_base<Derived>
|
||||
|
||||
Reference in New Issue
Block a user