Add serialization for sparse matrix and sparse vector.

This commit is contained in:
Antonio Sánchez
2022-11-21 19:43:07 +00:00
parent 044f3f6234
commit e7b1ad0315
4 changed files with 267 additions and 5 deletions

View File

@@ -1496,6 +1496,125 @@ struct evaluator<SparseMatrix<Scalar_,Options_,StorageIndex_> >
}
// Specialization for SparseMatrix.
// Serializes [rows, cols, isCompressed, outerSize, numNonZeros, innerNonZeros,
// outerIndices, innerIndices, values].
template <typename Scalar, int Options, typename StorageIndex>
class Serializer<SparseMatrix<Scalar, Options, StorageIndex>, void> {
public:
typedef SparseMatrix<Scalar, Options, StorageIndex> SparseMat;
struct Header {
typename SparseMat::Index rows;
typename SparseMat::Index cols;
bool compressed;
Index outer_size;
Index num_non_zeros;
};
EIGEN_DEVICE_FUNC size_t size(const SparseMat& value) const {
// innerNonZeros.
std::size_t num_storage_indices =
value.isCompressed() ? 0 : value.outerSize();
// Outer indices.
num_storage_indices += value.outerSize() + 1;
// Inner indices.
num_storage_indices += value.nonZeros();
// Values.
std::size_t num_values = value.nonZeros();
return sizeof(Header) + sizeof(Scalar) * num_values +
sizeof(StorageIndex) * num_storage_indices;
}
EIGEN_DEVICE_FUNC uint8_t* serialize(uint8_t* dest, uint8_t* end,
const SparseMat& value) {
if (EIGEN_PREDICT_FALSE(dest == nullptr)) return nullptr;
if (EIGEN_PREDICT_FALSE(dest + size(value) > end)) return nullptr;
const size_t header_bytes = sizeof(Header);
Header header = {value.rows(), value.cols(), value.isCompressed(),
value.outerSize(), value.nonZeros()};
EIGEN_USING_STD(memcpy)
memcpy(dest, &header, header_bytes);
dest += header_bytes;
// innerNonZeros.
size_t data_bytes = sizeof(StorageIndex) * header.outer_size;
if (!header.compressed) {
memcpy(dest, value.innerNonZeroPtr(), data_bytes);
dest += data_bytes;
}
// Outer indices.
data_bytes = sizeof(StorageIndex) * (header.outer_size + 1);
memcpy(dest, value.outerIndexPtr(), data_bytes);
dest += data_bytes;
// Inner indices.
data_bytes = sizeof(StorageIndex) * header.num_non_zeros;
memcpy(dest, value.innerIndexPtr(), data_bytes);
dest += data_bytes;
// Values.
data_bytes = sizeof(Scalar) * header.num_non_zeros;
memcpy(dest, value.valuePtr(), data_bytes);
dest += data_bytes;
return dest;
}
EIGEN_DEVICE_FUNC const uint8_t* deserialize(const uint8_t* src,
const uint8_t* end,
SparseMat& value) const {
if (EIGEN_PREDICT_FALSE(src == nullptr)) return nullptr;
if (EIGEN_PREDICT_FALSE(src + sizeof(Header) > end)) return nullptr;
const size_t header_bytes = sizeof(Header);
Header header;
EIGEN_USING_STD(memcpy)
memcpy(&header, src, header_bytes);
src += header_bytes;
value.setZero();
value.resize(header.rows, header.cols);
// Initialize compressed state and inner non-zeros.
size_t data_bytes = sizeof(StorageIndex) * header.outer_size;
if (EIGEN_PREDICT_FALSE(src + data_bytes > end)) return nullptr;
if (header.compressed) {
value.makeCompressed();
value.resizeNonZeros(header.num_non_zeros);
} else {
// Temporarily load inner sizes, then reserve.
std::vector<StorageIndex> inner_sizes(header.outer_size);
memcpy(inner_sizes.data(), src, data_bytes);
src += data_bytes;
value.uncompress();
value.reserve(inner_sizes);
memcpy(value.innerNonZeroPtr(), inner_sizes.data(), data_bytes);
}
// Outer indices.
data_bytes = sizeof(StorageIndex) * (header.outer_size + 1);
memcpy(value.outerIndexPtr(), src, data_bytes);
src += data_bytes;
// Inner indices.
data_bytes = sizeof(StorageIndex) * header.num_non_zeros;
if (EIGEN_PREDICT_FALSE(src + data_bytes > end)) return nullptr;
memcpy(value.innerIndexPtr(), src, data_bytes);
src += data_bytes;
// Values.
data_bytes = sizeof(Scalar) * header.num_non_zeros;
if (EIGEN_PREDICT_FALSE(src + data_bytes > end)) return nullptr;
memcpy(value.valuePtr(), src, data_bytes);
src += data_bytes;
return src;
}
};
} // end namespace Eigen
#endif // EIGEN_SPARSEMATRIX_H