mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Fix mixed-type compilation error in row-major GEMV small-cols path
libeigen/eigen!2160 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
@@ -512,16 +512,17 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
|
||||
// of compiler unrolling heuristics.
|
||||
template <int K, int N>
|
||||
struct gemv_small_cols_unroller {
|
||||
template <typename Packet, int Alignment, typename ConjHelper, typename LhsMapper, typename Index>
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j,
|
||||
const Packet& b0, ConjHelper& pcj) {
|
||||
gemv_small_cols_unroller<K - 1, N>::template madd<Packet, Alignment>(acc, lhs, i, j, b0, pcj);
|
||||
acc[K] = pcj.pmadd(lhs.template load<Packet, Alignment>(i + K, j), b0, acc[K]);
|
||||
template <typename LhsPacket, typename AccPacket, int Alignment, typename RhsType, typename ConjHelper,
|
||||
typename LhsMapper, typename Index>
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* acc, const LhsMapper& lhs, Index i, Index j,
|
||||
const RhsType& b0, ConjHelper& pcj) {
|
||||
gemv_small_cols_unroller<K - 1, N>::template madd<LhsPacket, AccPacket, Alignment>(acc, lhs, i, j, b0, pcj);
|
||||
acc[K] = pcj.pmadd(lhs.template load<LhsPacket, Alignment>(i + K, j), b0, acc[K]);
|
||||
}
|
||||
|
||||
template <typename Scalar, typename ConjHelper, typename LhsMapper, typename Index>
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j,
|
||||
const Scalar& b0, ConjHelper& cj) {
|
||||
template <typename ResScalar, typename RhsScalar, typename ConjHelper, typename LhsMapper, typename Index>
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(ResScalar* cc, const LhsMapper& lhs, Index i, Index j,
|
||||
const RhsScalar& b0, ConjHelper& cj) {
|
||||
gemv_small_cols_unroller<K - 1, N>::scalar_madd(cc, lhs, i, j, b0, cj);
|
||||
cc[K] += cj.pmul(lhs(i + K, j), b0);
|
||||
}
|
||||
@@ -548,15 +549,16 @@ struct gemv_small_cols_unroller {
|
||||
|
||||
template <int N>
|
||||
struct gemv_small_cols_unroller<0, N> {
|
||||
template <typename Packet, int Alignment, typename ConjHelper, typename LhsMapper, typename Index>
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(Packet* acc, const LhsMapper& lhs, Index i, Index j,
|
||||
const Packet& b0, ConjHelper& pcj) {
|
||||
acc[0] = pcj.pmadd(lhs.template load<Packet, Alignment>(i, j), b0, acc[0]);
|
||||
template <typename LhsPacket, typename AccPacket, int Alignment, typename RhsType, typename ConjHelper,
|
||||
typename LhsMapper, typename Index>
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void madd(AccPacket* acc, const LhsMapper& lhs, Index i, Index j,
|
||||
const RhsType& b0, ConjHelper& pcj) {
|
||||
acc[0] = pcj.pmadd(lhs.template load<LhsPacket, Alignment>(i, j), b0, acc[0]);
|
||||
}
|
||||
|
||||
template <typename Scalar, typename ConjHelper, typename LhsMapper, typename Index>
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(Scalar* cc, const LhsMapper& lhs, Index i, Index j,
|
||||
const Scalar& b0, ConjHelper& cj) {
|
||||
template <typename ResScalar, typename RhsScalar, typename ConjHelper, typename LhsMapper, typename Index>
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void scalar_madd(ResScalar* cc, const LhsMapper& lhs, Index i, Index j,
|
||||
const RhsScalar& b0, ConjHelper& cj) {
|
||||
cc[0] += cj.pmul(lhs(i, j), b0);
|
||||
}
|
||||
|
||||
@@ -608,7 +610,7 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
|
||||
Unroll::init_zero(h);
|
||||
for (Index j = 0; j < halfColBlockEnd; j += LhsPacketSizeHalf) {
|
||||
RhsPacketHalf b0 = rhs.template load<RhsPacketHalf, Unaligned>(j, 0);
|
||||
Unroll::template madd<ResPacketHalf, LhsAlignment>(h, lhs, i, j, b0, pcj_half);
|
||||
Unroll::template madd<LhsPacketHalf, ResPacketHalf, LhsAlignment>(h, lhs, i, j, b0, pcj_half);
|
||||
}
|
||||
Unroll::predux_accum(cc, h);
|
||||
}
|
||||
@@ -617,7 +619,7 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLh
|
||||
Unroll::init_zero(q);
|
||||
for (Index j = halfColBlockEnd; j < quarterColBlockEnd; j += LhsPacketSizeQuarter) {
|
||||
RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter, Unaligned>(j, 0);
|
||||
Unroll::template madd<ResPacketQuarter, LhsAlignment>(q, lhs, i, j, b0, pcj_quarter);
|
||||
Unroll::template madd<LhsPacketQuarter, ResPacketQuarter, LhsAlignment>(q, lhs, i, j, b0, pcj_quarter);
|
||||
}
|
||||
Unroll::predux_accum(cc, q);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user