diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index 0b4957f76..5f74e37de 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h @@ -512,16 +512,17 @@ general_matrix_vector_product struct gemv_small_cols_unroller { - template - EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j, - const Packet& b0, ConjHelper& pcj) { - gemv_small_cols_unroller::template madd(acc, lhs, i, j, b0, pcj); - acc[K] = pcj.pmadd(lhs.template load(i + K, j), b0, acc[K]); + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* acc, const LhsMapper& lhs, Index i, Index j, + const RhsType& b0, ConjHelper& pcj) { + gemv_small_cols_unroller::template madd(acc, lhs, i, j, b0, pcj); + acc[K] = pcj.pmadd(lhs.template load(i + K, j), b0, acc[K]); } - template - EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j, - const Scalar& b0, ConjHelper& cj) { + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(ResScalar* cc, const LhsMapper& lhs, Index i, Index j, + const RhsScalar& b0, ConjHelper& cj) { gemv_small_cols_unroller::scalar_madd(cc, lhs, i, j, b0, cj); cc[K] += cj.pmul(lhs(i + K, j), b0); } @@ -548,15 +549,16 @@ struct gemv_small_cols_unroller { template struct gemv_small_cols_unroller<0, N> { - template - EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j, - const Packet& b0, ConjHelper& pcj) { - acc[0] = pcj.pmadd(lhs.template load(i, j), b0, acc[0]); + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* acc, const LhsMapper& lhs, Index i, Index j, + const RhsType& b0, ConjHelper& pcj) { + acc[0] = pcj.pmadd(lhs.template load(i, j), b0, acc[0]); } - template - EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j, - const Scalar& b0, ConjHelper& cj) { + template + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(ResScalar* cc, const LhsMapper& lhs, Index i, Index j, + const RhsScalar& b0, ConjHelper& cj) { cc[0] += cj.pmul(lhs(i, j), b0); } @@ -608,7 +610,7 @@ general_matrix_vector_product(j, 0); - Unroll::template madd(h, lhs, i, j, b0, pcj_half); + Unroll::template madd(h, lhs, i, j, b0, pcj_half); } Unroll::predux_accum(cc, h); } @@ -617,7 +619,7 @@ general_matrix_vector_product(j, 0); - Unroll::template madd(q, lhs, i, j, b0, pcj_quarter); + Unroll::template madd(q, lhs, i, j, b0, pcj_quarter); } Unroll::predux_accum(cc, q); }