diff --git a/unsupported/benchmarks/Tensor/bench_contraction.cpp b/unsupported/benchmarks/Tensor/bench_contraction.cpp index 0b71ebfe3..faf648f6d 100644 --- a/unsupported/benchmarks/Tensor/bench_contraction.cpp +++ b/unsupported/benchmarks/Tensor/bench_contraction.cpp @@ -69,6 +69,9 @@ static void BM_Contraction_ThreadPool(benchmark::State& state) { } // --- Rank-3 batch contraction --- +// Contracts A(batch, M, K) with B(batch, K, N) over batch dim (0<->0) +// and K dim (2<->1), producing C(M, N). This sums over both the batch +// and inner dimensions: C(m, n) = sum_b sum_k A(b, m, k) * B(b, k, n). static void BM_BatchContraction(benchmark::State& state) { const int batch = state.range(0); const int M = state.range(1); @@ -77,12 +80,12 @@ static void BM_BatchContraction(benchmark::State& state) { Tensor A(batch, M, K); Tensor B(batch, K, N); - Tensor C(batch, M, N); + Tensor C(M, N); A.setRandom(); B.setRandom(); using ContractDims = Tensor::DimensionPair; - Eigen::array contract_dims = {ContractDims(2, 1)}; + Eigen::array contract_dims = {ContractDims(0, 0), ContractDims(2, 1)}; for (auto _ : state) { C = A.contract(B, contract_dims);