Eliminate boolean product warnings by factoring out a

`combine_scalar_factors` helper function.
This commit is contained in:
Christoph Hertzberg
2021-01-01 20:54:45 +01:00
committed by Rasmus Munk Larsen
parent 070d303d56
commit 12dda34b15
5 changed files with 51 additions and 14 deletions

View File

@@ -228,8 +228,7 @@ template<> struct gemv_dense_selector<OnTheRight,ColMajor,true>
ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
* RhsBlasTraits::extractScalarFactor(rhs);
ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
// make sure Dest is a compile-time vector type (bug 1166)
typedef typename conditional<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr>::type ActualDest;
@@ -320,8 +319,7 @@ template<> struct gemv_dense_selector<OnTheRight,RowMajor,true>
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
* RhsBlasTraits::extractScalarFactor(rhs);
ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
enum {
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1

View File

@@ -441,8 +441,8 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
};
// FIXME: in c++11 this should be auto, and extractScalarFactor should also return auto
// this is important for real*complex_mat
Scalar actualAlpha = blas_traits<Lhs>::extractScalarFactor(lhs)
* blas_traits<Rhs>::extractScalarFactor(rhs);
Scalar actualAlpha = combine_scalar_factors<Scalar>(lhs, rhs);
eval_dynamic_impl(dst,
blas_traits<Lhs>::extract(lhs).template conjugateIf<ConjLhs>(),
blas_traits<Rhs>::extract(rhs).template conjugateIf<ConjRhs>(),

View File

@@ -489,8 +489,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
* RhsBlasTraits::extractScalarFactor(a_rhs);
Scalar actualAlpha = combine_scalar_factors(alpha, a_lhs, a_rhs);
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;

View File

@@ -618,6 +618,47 @@ template<typename T> const typename T::Scalar* extract_data(const T& m)
return extract_data_selector<T>::run(m);
}
/**
* \c combine_scalar_factors extracts and multiplies factors from GEMM and GEMV products.
* There is a specialization for booleans
*/
template<typename ResScalar, typename Lhs, typename Rhs>
struct combine_scalar_factors_impl
{
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const Lhs& lhs, const Rhs& rhs)
{
return blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs)
{
return alpha * blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
}
};
template<typename Lhs, typename Rhs>
struct combine_scalar_factors_impl<bool, Lhs, Rhs>
{
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const Lhs& lhs, const Rhs& rhs)
{
return blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const bool& alpha, const Lhs& lhs, const Rhs& rhs)
{
return alpha && blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
}
};
template<typename ResScalar, typename Lhs, typename Rhs>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs)
{
return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(alpha, lhs, rhs);
}
template<typename ResScalar, typename Lhs, typename Rhs>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const Lhs& lhs, const Rhs& rhs)
{
return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(lhs, rhs);
}
} // end namespace internal
} // end namespace Eigen