diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index bdc1a17a7..97f90f638 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -517,9 +517,15 @@ class TensorBase typedef Eigen::IndexPair DimensionPair; template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorContractionOp + const TensorContractionOp contract(const OtherDerived& other, const Dimensions& dims) const { - return TensorContractionOp(derived(), other.derived(), dims); + return TensorContractionOp(derived(), other.derived(), dims); + } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorContractionOp + contract(const OtherDerived& other, const Dimensions& dims, const OutputKernel& output_kernel) const { + return TensorContractionOp(derived(), other.derived(), dims, output_kernel); } // Convolutions. diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 979fcf4d9..85126a127 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -85,8 +85,8 @@ template #endif -template -struct traits > +template +struct traits > { // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef typename gebp_traits::type, @@ -112,23 +112,24 @@ struct traits > }; }; -template -struct eval, Eigen::Dense> +template +struct eval, Eigen::Dense> { - typedef const TensorContractionOp& type; + typedef const TensorContractionOp& type; }; -template -struct nested, 1, typename eval >::type> +template +struct nested, 1, typename eval >::type> { - typedef TensorContractionOp type; + typedef TensorContractionOp type; }; -template -struct traits, Device_> > { +template +struct traits, Device_> > { typedef Indices_ Indices; typedef LeftArgType_ LeftArgType; typedef RightArgType_ RightArgType; + typedef OutputKernelType_ OutputKernelType; typedef Device_ Device; // From NumDims below. @@ -137,8 +138,52 @@ struct traits -class TensorContractionOp : public TensorBase, ReadOnlyAccessors> +// Tensor contraction params that should enable to get from output matrix +// 2-dimensional coordinates to the output tensor dimensions. +struct TensorContractionParams { + // TensorContraction evaluator assumes that both tensors are in ColMajor + // layout, if tensors are in RowMajor evaluator swap lhs with rhs. + bool swapped_arguments; +}; + +// Output kernel allows to fuse operations into the tensor contraction. +// +// Examples: +// 1. Elementwise Relu transformation following Conv2D. +// 2. AddBias to the Conv2D output channels dimension. +// +// See expected implementation in NoOpOutputKernel. +struct OutputKernel { + template + using OutputMapper = internal::blas_data_mapper; +}; + +// Output kernel that does absolutely nothing. +struct NoOpOutputKernel { + /** + * Tensor contraction evaluator calls this kernel after finishing each block + * of output matrix. Output blocks belong to the 2-dimensional output tensor. + * + * TensorContractionParams contains contraction dimensions information + * required to map output 2-d space into the expected output tensor space + * (potentially higher dimensional). + * + * \param[in] output_mapper Access to output tensor memory + * \param[in] params Tensor contraction parameters + * \param[in] i Index of a first row available through output_mapper + * \param[in] j Index of a first column available through output_mapper + * \param[in] num_rows Number of available rows + * \param[in] num_cols Number of available columns + */ + template + EIGEN_ALWAYS_INLINE void operator()( + const OutputKernel::OutputMapper& output_mapper, + const TensorContractionParams& params, Index i, Index j, Index num_rows, + Index num_cols) const {} +}; + +template +class TensorContractionOp : public TensorBase, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits::Scalar Scalar; @@ -149,8 +194,10 @@ class TensorContractionOp : public TensorBase::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( - const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) - : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} + const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims, + const OutputKernelType& output_kernel = OutputKernelType()) + : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims), + m_output_kernel(output_kernel) {} EIGEN_DEVICE_FUNC const Indices& indices() const { return m_indices; } @@ -164,10 +211,14 @@ class TensorContractionOp : public TensorBase::type& rhsExpression() const { return m_rhs_xpr; } + EIGEN_DEVICE_FUNC + const OutputKernelType& outputKernel() const { return m_output_kernel; } + protected: typename LhsXprType::Nested m_lhs_xpr; typename RhsXprType::Nested m_rhs_xpr; const Indices m_indices; + const OutputKernelType m_output_kernel; }; @@ -177,9 +228,10 @@ struct TensorContractionEvaluatorBase typedef typename internal::traits::Indices Indices; typedef typename internal::traits::LeftArgType LeftArgType; typedef typename internal::traits::RightArgType RightArgType; + typedef typename internal::traits::OutputKernelType OutputKernelType; typedef typename internal::traits::Device Device; - typedef TensorContractionOp XprType; + typedef TensorContractionOp XprType; typedef typename internal::remove_const::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; @@ -221,6 +273,7 @@ struct TensorContractionEvaluatorBase op.lhsExpression(), op.rhsExpression()), device), m_rightImpl(choose(Cond(Layout) == static_cast(ColMajor)>(), op.rhsExpression(), op.lhsExpression()), device), + m_output_kernel(op.outputKernel()), m_device(device), m_result(NULL) { EIGEN_STATIC_ASSERT((static_cast(TensorEvaluator::Layout) == @@ -391,6 +444,13 @@ struct TensorContractionEvaluatorBase numext::swap(m_dimensions[i], m_dimensions[j]); } } + + // A set of parameters that will allow output kernel to get from output + // tensor dimensions (i, j) into the original tensor dimensions. + // TODO(ezhulenev): Add parameters required to infer output tensor index for + // more complex contractions than 2x2 on internal dimension. + m_tensor_contraction_params = { + /**swapped_arguments=*/static_cast(Layout) == RowMajor}; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -585,7 +645,15 @@ struct TensorContractionEvaluatorBase // call gebp (matrix kernel) // The parameters here are copied from Eigen's GEMM implementation - gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0); + const auto output_mapper = output.getSubMapper(i2, j2); + gebp(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc, + Scalar(1), -1, -1, 0, 0); + + // We are done with this [i2, j2] output block. + if (k2 + kc >= k) { + m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2, + actual_mc, actual_nc); + } } } } @@ -848,23 +916,26 @@ protected: Index m_j_size; Index m_k_size; + TensorContractionParams m_tensor_contraction_params; + TensorEvaluator m_leftImpl; TensorEvaluator m_rightImpl; const Device& m_device; + OutputKernelType m_output_kernel; Scalar* m_result; bool m_can_use_xsmm; }; // evaluator for default device -template -struct TensorEvaluator, Device> : +template +struct TensorEvaluator, Device> : public TensorContractionEvaluatorBase< - TensorEvaluator, Device> > { - typedef TensorEvaluator, Device> Self; + TensorEvaluator, Device> > { + typedef TensorEvaluator, Device> Self; typedef TensorContractionEvaluatorBase Base; - typedef TensorContractionOp XprType; + typedef TensorContractionOp XprType; typedef typename internal::remove_const::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 3c007b183..d7536bd6a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -56,16 +56,16 @@ struct packRhsAndKernelArg { } // end namespace internal #endif // EIGEN_USE_SIMPLE_THREAD_POOL -template -struct TensorEvaluator, ThreadPoolDevice> : - public TensorContractionEvaluatorBase, ThreadPoolDevice> > { +template +struct TensorEvaluator, ThreadPoolDevice> : + public TensorContractionEvaluatorBase, ThreadPoolDevice> > { typedef ThreadPoolDevice Device; - typedef TensorEvaluator, Device> Self; + typedef TensorEvaluator, Device> Self; typedef TensorContractionEvaluatorBase Base; - typedef TensorContractionOp XprType; + typedef TensorContractionOp XprType; typedef typename internal::remove_const::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; @@ -308,7 +308,7 @@ struct TensorEvaluatorm_k_strides); Context(this->m_device, num_threads, lhs, rhs, buffer, m, n, + OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack) .run(); @@ -319,16 +319,18 @@ struct TensorEvaluator class Context { public: - Context(const Device& device, int num_threads, LhsMapper& lhs, + Context(const Self* self, int num_threads, LhsMapper& lhs, RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col, bool parallel_pack) - : device_(device), + : device_(self->m_device), lhs_(lhs), rhs_(rhs), buffer_(buffer), output_(buffer, tm), + output_kernel_(self->m_output_kernel), + tensor_contraction_params_(self->m_tensor_contraction_params), num_threads_(num_threads), shard_by_col_(shard_by_col), parallel_pack_(parallel_pack), @@ -420,6 +422,8 @@ struct TensorEvaluator::value, + "SimpleThreadPool does not support contraction output kernels."); template void evalProduct(Scalar* buffer) const { @@ -1065,6 +1086,10 @@ struct TensorEvaluator::value, + "XSMM does not support contraction output kernels."); + template class ContextXsmm { public: diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index 6c237bac3..19e456e19 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -65,7 +65,7 @@ template class Ma template class TensorIndexTupleOp; template class TensorTupleReducerOp; template class TensorConcatenationOp; -template class TensorContractionOp; +template class TensorContractionOp; template class TensorConversionOp; template class TensorConvolutionOp; template class TensorFFTOp; @@ -97,6 +97,8 @@ template class TensorForcedEvalOp; template class TensorDevice; template struct TensorEvaluator; +class NoOpOutputKernel; + struct DefaultDevice; struct ThreadPoolDevice; struct GpuDevice; diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index ace97057f..918c96277 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -510,6 +510,55 @@ static void test_const_inputs() VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1)); } +// Apply Sqrt to all output elements. +struct SqrtOutputKernel { + template + EIGEN_ALWAYS_INLINE void operator()( + const OutputKernel::OutputMapper& output_mapper, + const TensorContractionParams&, Index, Index, Index num_rows, + Index num_cols) const { + for (int i = 0; i < num_rows; ++i) { + for (int j = 0; j < num_cols; ++j) { + output_mapper(i, j) = std::sqrt(output_mapper(i, j)); + } + } + } +}; + +template +static void test_large_contraction_with_output_kernel() { + Tensor t_left(30, 50, 8, 31); + Tensor t_right(8, 31, 7, 20, 10); + Tensor t_result(30, 50, 7, 20, 10); + + t_left.setRandom(); + t_right.setRandom(); + // Put trash in mat4 to verify contraction clears output memory. + t_result.setRandom(); + + // Add a little offset so that the results won't be close to zero. + t_left += t_left.constant(1.0f); + t_right += t_right.constant(1.0f); + + typedef Map> MapXf; + MapXf m_left(t_left.data(), 1500, 248); + MapXf m_right(t_right.data(), 248, 1400); + Eigen::Matrix m_result(1500, 1400); + + // this contraction should be equivalent to a single matrix multiplication + Eigen::array dims({{DimPair(2, 0), DimPair(3, 1)}}); + + // compute results by separate methods + t_result = t_left.contract(t_right, dims, SqrtOutputKernel()); + + m_result = m_left * m_right; + + for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) { + VERIFY(&t_result.data()[i] != &m_result.data()[i]); + VERIFY_IS_APPROX(t_result.data()[i], std::sqrt(m_result.data()[i])); + } +} + void test_cxx11_tensor_contraction() { CALL_SUBTEST(test_evals()); @@ -542,4 +591,6 @@ void test_cxx11_tensor_contraction() CALL_SUBTEST(test_tensor_product()); CALL_SUBTEST(test_const_inputs()); CALL_SUBTEST(test_const_inputs()); + CALL_SUBTEST(test_large_contraction_with_output_kernel()); + CALL_SUBTEST(test_large_contraction_with_output_kernel()); } diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index 2ef665f30..ea9d8afdc 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -232,6 +232,60 @@ void test_multithread_contraction_agrees_with_singlethread() { } } +// Apply Sqrt to all output elements. +struct SqrtOutputKernel { + template + EIGEN_ALWAYS_INLINE void operator()( + const OutputKernel::OutputMapper& output_mapper, + const TensorContractionParams&, Index, Index, Index num_rows, + Index num_cols) const { + for (int i = 0; i < num_rows; ++i) { + for (int j = 0; j < num_cols; ++j) { + output_mapper(i, j) = std::sqrt(output_mapper(i, j)); + } + } + } +}; + +template +static void test_multithread_contraction_with_output_kernel() { + typedef Tensor::DimensionPair DimPair; + + const int num_threads = internal::random(2, 11); + ThreadPool threads(num_threads); + Eigen::ThreadPoolDevice device(&threads, num_threads); + + Tensor t_left(30, 50, 8, 31); + Tensor t_right(8, 31, 7, 20, 10); + Tensor t_result(30, 50, 7, 20, 10); + + t_left.setRandom(); + t_right.setRandom(); + // Put trash in mat4 to verify contraction clears output memory. + t_result.setRandom(); + + // Add a little offset so that the results won't be close to zero. + t_left += t_left.constant(1.0f); + t_right += t_right.constant(1.0f); + + typedef Map> MapXf; + MapXf m_left(t_left.data(), 1500, 248); + MapXf m_right(t_right.data(), 248, 1400); + Eigen::Matrix m_result(1500, 1400); + + // this contraction should be equivalent to a single matrix multiplication + Eigen::array dims({{DimPair(2, 0), DimPair(3, 1)}}); + + // compute results by separate methods + t_result.device(device) = t_left.contract(t_right, dims, SqrtOutputKernel()); + + m_result = m_left * m_right; + + for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) { + VERIFY(&t_result.data()[i] != &m_result.data()[i]); + VERIFY_IS_APPROX(t_result.data()[i], std::sqrt(m_result.data()[i])); + } +} template void test_full_contraction() { @@ -355,6 +409,8 @@ void test_cxx11_tensor_thread_pool() CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread()); CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread()); + CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel()); + CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel()); // Exercise various cases that have been problematic in the past. CALL_SUBTEST_4(test_contraction_corner_cases());