diff --git a/Eigen/src/Core/arch/NEON/Kernels.h b/Eigen/src/Core/arch/NEON/Kernels.h index b01476a6a..4411389e5 100644 --- a/Eigen/src/Core/arch/NEON/Kernels.h +++ b/Eigen/src/Core/arch/NEON/Kernels.h @@ -17,30 +17,28 @@ namespace internal { #ifdef __ENABLE_VECTOR_KERNELS__ #define MICRO_12x1x4() \ - pRhs = pload(rhsPackMap.pCur); \ - rhsPackMap.advance(1*4); \ - pRhs0 = pset1(pRhs[0]); \ - pRhs1 = pset1(pRhs[1]); \ - pRhs2 = pset1(pRhs[2]); \ - pRhs3 = pset1(pRhs[3]); \ pLhs = pload(lhsPackMap.pCur); \ - lhsPackMap.advance(4*1); \ + pLhs2 = pload(lhsPackMap.pCur + 4); \ + pLhs3 = pload(lhsPackMap.pCur + 8); \ + pRhs = pload(lhsPackMap.pCur);\ + pRhs0 = pset1(pRhs[0]); \ acc._acc1.packet[0] += pLhs*pRhs0; \ - acc._acc1.packet[1] += pLhs*pRhs1; \ - acc._acc1.packet[2] += pLhs*pRhs2; \ - acc._acc1.packet[3] += pLhs*pRhs3; \ - pLhs2 = pload(lhsPackMap.pCur); \ - lhsPackMap.advance(4*1); \ acc._acc2.packet[0] += pLhs2*pRhs0; \ - acc._acc2.packet[1] += pLhs2*pRhs1; \ - acc._acc2.packet[2] += pLhs2*pRhs2; \ - acc._acc2.packet[3] += pLhs2*pRhs3; \ - pLhs3 = pload(lhsPackMap.pCur); \ acc._acc3.packet[0] += pLhs3*pRhs0; \ + pRhs1 = pset1(pRhs[1]); \ + acc._acc1.packet[1] += pLhs*pRhs1; \ + acc._acc2.packet[1] += pLhs2*pRhs1; \ acc._acc3.packet[1] += pLhs3*pRhs1; \ + pRhs2 = pset1(pRhs[2]); \ + acc._acc1.packet[2] += pLhs*pRhs2; \ + acc._acc2.packet[2] += pLhs2*pRhs2; \ acc._acc3.packet[2] += pLhs3*pRhs2; \ + pRhs3 = pset1(pRhs[3]); \ + acc._acc1.packet[3] += pLhs*pRhs3; \ + acc._acc2.packet[3] += pLhs2*pRhs3; \ acc._acc3.packet[3] += pLhs3*pRhs3; \ - lhsPackMap.advance(4*1); + rhsPackMap.advance(4); \ + lhsPackMap.advance(12); #define MICRO_8x1x4() \ pLhs = pload(lhsPackMap.pCur); \ diff --git a/Eigen/src/Core/arch/NEON/MatrixProduct.h b/Eigen/src/Core/arch/NEON/MatrixProduct.h index f478ea15e..177b7b25d 100644 --- a/Eigen/src/Core/arch/NEON/MatrixProduct.h +++ b/Eigen/src/Core/arch/NEON/MatrixProduct.h @@ -215,6 +215,7 @@ struct PackMap PackMap(const Scalar *base, Index d2Size, Index stride, Index offset) : pBase(base), pCur(base), d2Size(d2Size), stride(stride), offset(offset) {} EIGEN_STRONG_INLINE void resetCur() { pCur = pBase; } + EIGEN_STRONG_INLINE void updateBase() { pBase = pCur; } EIGEN_STRONG_INLINE void moveTo(Index p1) { pCur = pBase + pmc.getPosition(p1, d2Size); } EIGEN_STRONG_INLINE void advance(int progress) { pCur += progress; } }; @@ -362,12 +363,12 @@ struct LhsLoopStruct constexpr auto lhsProgress = SHAPES[IDX][SHAPES_LHS_DIMENSION]; constexpr auto rhsProgress = SHAPES[IDX][SHAPES_RHS_DIMENSION]; DepthLoopStruct depthLS; + rhsPackMap.resetCur(); for(;rowIdx + lhsProgress <= rows; rowIdx+=lhsProgress) { - lhsPackMap.moveTo(rowIdx); - rhsPackMap.moveTo(colIdx); - //prefetch(lhsPackMap.pCur + 2*lhsProgress); - //prefetch(rhsPackMap.pCur + 2*rhsProgress); + //lhsPackMap.moveTo(rowIdx); + //rhsPackMap.moveTo(colIdx); + depthLS(rowIdx, colIdx, 0, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); } lhsLS(rowIdx, colIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); @@ -395,7 +396,9 @@ struct RhsLoopStruct for(;colIdx + rhsProgress <= cols; colIdx+=rhsProgress) { LhsLoopStruct lhsLS; + lhsPackMap.resetCur(); lhsLS(0, colIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); + rhsPackMap.updateBase(); } rhsLS(colIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); } diff --git a/new_gemm_test.cpp b/new_gemm_test.cpp index dabb33d02..1524b36d9 100644 --- a/new_gemm_test.cpp +++ b/new_gemm_test.cpp @@ -15,7 +15,7 @@ void set(MatrixXf& A, int m, int n, int id, int digits) int main(int argc, char* argv[]) { #ifdef __DEBUG__ - int m = 32, k = 32, n = 32, max = std::max(std::max(m,k),n); + int m = 9, k = 9, n = 9, max = std::max(std::max(m,k),n); MatrixXf A = MatrixXf::Zero(m, k); MatrixXf B = MatrixXf::Zero(k, n); MatrixXf C = MatrixXf::Zero(m, n); @@ -28,6 +28,7 @@ int main(int argc, char* argv[]) std::cout << A << std::endl; std::cout << B << std::endl; + std::cout << C << std::endl; std::cout << std::endl; @@ -47,6 +48,8 @@ int main(int argc, char* argv[]) } } } + + std::cout << D << std::endl; #else if(argc < 3) {