|
|
|
|
@@ -52,7 +52,13 @@ class TensorLazyEvaluatorReadOnly
|
|
|
|
|
typedef TensorEvaluator<Expr, Device> EvalType;
|
|
|
|
|
|
|
|
|
|
TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
|
|
|
|
|
m_dims = m_impl.dimensions();
|
|
|
|
|
EIGEN_STATIC_ASSERT(
|
|
|
|
|
internal::array_size<Dimensions>::value == internal::array_size<typename EvalType::Dimensions>::value,
|
|
|
|
|
"Dimension sizes must match.");
|
|
|
|
|
const auto& other_dims = m_impl.dimensions();
|
|
|
|
|
for (std::size_t i = 0; i < m_dims.size(); ++i) {
|
|
|
|
|
m_dims[i] = other_dims[i];
|
|
|
|
|
}
|
|
|
|
|
m_impl.evalSubExprsIfNeeded(NULL);
|
|
|
|
|
}
|
|
|
|
|
virtual ~TensorLazyEvaluatorReadOnly() { m_impl.cleanup(); }
|
|
|
|
|
@@ -86,14 +92,12 @@ class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimension
|
|
|
|
|
EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) { return this->m_impl.coeffRef(index); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Dimensions, typename Expr, typename Device>
|
|
|
|
|
class TensorLazyEvaluator : public std::conditional_t<bool(internal::is_lvalue<Expr>::value),
|
|
|
|
|
TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
|
|
|
|
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> > {
|
|
|
|
|
template <typename Dimensions, typename Expr, typename Device, bool IsWritable>
|
|
|
|
|
class TensorLazyEvaluator : public std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
|
|
|
|
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>> {
|
|
|
|
|
public:
|
|
|
|
|
typedef std::conditional_t<bool(internal::is_lvalue<Expr>::value),
|
|
|
|
|
TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
|
|
|
|
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >
|
|
|
|
|
typedef std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
|
|
|
|
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>>
|
|
|
|
|
Base;
|
|
|
|
|
typedef typename Base::Scalar Scalar;
|
|
|
|
|
|
|
|
|
|
@@ -101,24 +105,15 @@ class TensorLazyEvaluator : public std::conditional_t<bool(internal::is_lvalue<E
|
|
|
|
|
virtual ~TensorLazyEvaluator() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace internal
|
|
|
|
|
|
|
|
|
|
/** \class TensorRef
|
|
|
|
|
* \ingroup CXX11_Tensor_Module
|
|
|
|
|
*
|
|
|
|
|
* \brief A reference to a tensor expression
|
|
|
|
|
* The expression will be evaluated lazily (as much as possible).
|
|
|
|
|
*
|
|
|
|
|
*/
|
|
|
|
|
template <typename PlainObjectType>
|
|
|
|
|
class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
|
|
|
|
template <typename Derived>
|
|
|
|
|
class TensorRefBase : public TensorBase<Derived> {
|
|
|
|
|
public:
|
|
|
|
|
typedef TensorRef<PlainObjectType> Self;
|
|
|
|
|
typedef typename traits<Derived>::PlainObjectType PlainObjectType;
|
|
|
|
|
typedef typename PlainObjectType::Base Base;
|
|
|
|
|
typedef typename Eigen::internal::nested<Self>::type Nested;
|
|
|
|
|
typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
|
|
|
|
|
typedef typename internal::traits<PlainObjectType>::Index Index;
|
|
|
|
|
typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
|
|
|
|
|
typedef typename Eigen::internal::nested<Derived>::type Nested;
|
|
|
|
|
typedef typename traits<PlainObjectType>::StorageKind StorageKind;
|
|
|
|
|
typedef typename traits<PlainObjectType>::Index Index;
|
|
|
|
|
typedef typename traits<PlainObjectType>::Scalar Scalar;
|
|
|
|
|
typedef typename NumTraits<Scalar>::Real RealScalar;
|
|
|
|
|
typedef typename Base::CoeffReturnType CoeffReturnType;
|
|
|
|
|
typedef Scalar* PointerType;
|
|
|
|
|
@@ -138,33 +133,17 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===- Tensor block evaluation strategy (see TensorBlock.h) -----------===//
|
|
|
|
|
typedef internal::TensorBlockNotImplemented TensorBlock;
|
|
|
|
|
typedef TensorBlockNotImplemented TensorBlock;
|
|
|
|
|
//===------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {}
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRefBase() : m_evaluator(NULL) {}
|
|
|
|
|
|
|
|
|
|
template <typename Expression>
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRef(const Expression& expr)
|
|
|
|
|
: m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
|
|
|
|
|
m_evaluator->incrRefCount();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Expression>
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
|
|
|
|
|
unrefEvaluator();
|
|
|
|
|
m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
|
|
|
|
|
m_evaluator->incrRefCount();
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~TensorRef() { unrefEvaluator(); }
|
|
|
|
|
|
|
|
|
|
TensorRef(const TensorRef& other) : TensorBase<TensorRef<PlainObjectType> >(other), m_evaluator(other.m_evaluator) {
|
|
|
|
|
TensorRefBase(const TensorRefBase& other) : TensorBase<Derived>(other), m_evaluator(other.m_evaluator) {
|
|
|
|
|
eigen_assert(m_evaluator->refCount() > 0);
|
|
|
|
|
m_evaluator->incrRefCount();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TensorRef& operator=(const TensorRef& other) {
|
|
|
|
|
TensorRefBase& operator=(const TensorRefBase& other) {
|
|
|
|
|
if (this != &other) {
|
|
|
|
|
unrefEvaluator();
|
|
|
|
|
m_evaluator = other.m_evaluator;
|
|
|
|
|
@@ -174,6 +153,28 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Expression,
|
|
|
|
|
typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRefBase(const Expression& expr)
|
|
|
|
|
: m_evaluator(new TensorLazyEvaluator<Dimensions, Expression, DefaultDevice,
|
|
|
|
|
/*IsWritable=*/!std::is_const<PlainObjectType>::value &&
|
|
|
|
|
bool(is_lvalue<Expression>::value)>(expr, DefaultDevice())) {
|
|
|
|
|
m_evaluator->incrRefCount();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Expression,
|
|
|
|
|
typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRefBase& operator=(const Expression& expr) {
|
|
|
|
|
unrefEvaluator();
|
|
|
|
|
m_evaluator = new TensorLazyEvaluator < Dimensions, Expression, DefaultDevice,
|
|
|
|
|
/*IsWritable=*/!std::is_const<PlainObjectType>::value&& bool(is_lvalue<Expression>::value) >
|
|
|
|
|
(expr, DefaultDevice());
|
|
|
|
|
m_evaluator->incrRefCount();
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~TensorRefBase() { unrefEvaluator(); }
|
|
|
|
|
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
|
|
|
|
|
@@ -188,12 +189,6 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
|
|
|
|
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
|
|
|
|
|
return coeff(indices);
|
|
|
|
|
}
|
|
|
|
|
template <typename... IndexTypes>
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) {
|
|
|
|
|
const std::size_t num_indices = (sizeof...(otherIndices) + 1);
|
|
|
|
|
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
|
|
|
|
|
return coeffRef(indices);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <std::size_t NumIndices>
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const {
|
|
|
|
|
@@ -212,6 +207,70 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
|
|
|
|
}
|
|
|
|
|
return m_evaluator->coeff(index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); }
|
|
|
|
|
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); }
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
TensorLazyBaseEvaluator<Dimensions, Scalar>* evaluator() { return m_evaluator; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
EIGEN_STRONG_INLINE void unrefEvaluator() {
|
|
|
|
|
if (m_evaluator) {
|
|
|
|
|
m_evaluator->decrRefCount();
|
|
|
|
|
if (m_evaluator->refCount() == 0) {
|
|
|
|
|
delete m_evaluator;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace internal
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* \ingroup CXX11_Tensor_Module
|
|
|
|
|
*
|
|
|
|
|
* \brief A reference to a tensor expression
|
|
|
|
|
* The expression will be evaluated lazily (as much as possible).
|
|
|
|
|
*
|
|
|
|
|
*/
|
|
|
|
|
template <typename PlainObjectType>
|
|
|
|
|
class TensorRef : public internal::TensorRefBase<TensorRef<PlainObjectType>> {
|
|
|
|
|
typedef internal::TensorRefBase<TensorRef<PlainObjectType>> Base;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
using Scalar = typename Base::Scalar;
|
|
|
|
|
using Dimensions = typename Base::Dimensions;
|
|
|
|
|
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRef() : Base() {}
|
|
|
|
|
|
|
|
|
|
template <typename Expression>
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {
|
|
|
|
|
EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
|
|
|
|
|
"Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
|
|
|
|
|
"TensorRef<const Expression>?)");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Expression>
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
|
|
|
|
|
EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
|
|
|
|
|
"Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
|
|
|
|
|
"TensorRef<const Expression>?)");
|
|
|
|
|
return Base::operator=(expr).derived();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
|
|
|
|
|
|
|
|
|
|
template <typename... IndexTypes>
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) {
|
|
|
|
|
const std::size_t num_indices = (sizeof...(otherIndices) + 1);
|
|
|
|
|
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
|
|
|
|
|
return coeffRef(indices);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <std::size_t NumIndices>
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices) {
|
|
|
|
|
const Dimensions& dims = this->dimensions();
|
|
|
|
|
@@ -227,24 +286,37 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
|
|
|
|
index = index * dims[i] + indices[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return m_evaluator->coeffRef(index);
|
|
|
|
|
return Base::evaluator()->coeffRef(index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); }
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return Base::evaluator()->coeffRef(index); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); }
|
|
|
|
|
/**
|
|
|
|
|
* \ingroup CXX11_Tensor_Module
|
|
|
|
|
*
|
|
|
|
|
* \brief A reference to a constant tensor expression
|
|
|
|
|
* The expression will be evaluated lazily (as much as possible).
|
|
|
|
|
*
|
|
|
|
|
*/
|
|
|
|
|
template <typename PlainObjectType>
|
|
|
|
|
class TensorRef<const PlainObjectType> : public internal::TensorRefBase<TensorRef<const PlainObjectType>> {
|
|
|
|
|
typedef internal::TensorRefBase<TensorRef<const PlainObjectType>> Base;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
EIGEN_STRONG_INLINE void unrefEvaluator() {
|
|
|
|
|
if (m_evaluator) {
|
|
|
|
|
m_evaluator->decrRefCount();
|
|
|
|
|
if (m_evaluator->refCount() == 0) {
|
|
|
|
|
delete m_evaluator;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
public:
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRef() : Base() {}
|
|
|
|
|
|
|
|
|
|
template <typename Expression>
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {}
|
|
|
|
|
|
|
|
|
|
template <typename Expression>
|
|
|
|
|
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
|
|
|
|
|
return Base::operator=(expr).derived();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
|
|
|
|
|
TensorRef(const TensorRef& other) : Base(other) {}
|
|
|
|
|
|
|
|
|
|
TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// evaluator for rvalues
|
|
|
|
|
|