Consolidate BF16/F16 wrapper macros and simplify arch math functions

libeigen/eigen!2195

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-02-22 20:17:43 -08:00
parent d5e67adbe7
commit 112c2324bd
14 changed files with 149 additions and 428 deletions

View File

@@ -116,34 +116,10 @@ EIGEN_STRONG_INLINE Packet8bf pldexp(const Packet8bf& a, const Packet8bf& expone
return F32ToBf16(pldexp<Packet8f>(Bf16ToF32(a), Bf16ToF32(exponent)));
}
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp2)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexpm1)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog2)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, preciprocal)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcbrt)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_BF16(Packet8f, Packet8bf)
#ifndef EIGEN_VECTORIZE_AVX512FP16
F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp2)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexpm1)
F16_PACKET_FUNCTION(Packet8f, Packet8h, plog)
F16_PACKET_FUNCTION(Packet8f, Packet8h, plog1p)
F16_PACKET_FUNCTION(Packet8f, Packet8h, plog2)
F16_PACKET_FUNCTION(Packet8f, Packet8h, preciprocal)
F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pcbrt)
F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_F16(Packet8f, Packet8h)
#endif
} // end namespace internal

View File

@@ -106,32 +106,10 @@ EIGEN_STRONG_INLINE Packet16f preciprocal<Packet16f>(const Packet16f& a) {
}
#endif
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp2)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog2)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, preciprocal)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_BF16(Packet16f, Packet16bf)
#ifndef EIGEN_VECTORIZE_AVX512FP16
F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp2)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1)
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog)
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog1p)
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog2)
F16_PACKET_FUNCTION(Packet16f, Packet16h, preciprocal)
F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt)
F16_PACKET_FUNCTION(Packet16f, Packet16h, psin)
F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt)
F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh)
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_F16(Packet16f, Packet16h)
#endif // EIGEN_VECTORIZE_AVX512FP16
} // end namespace internal

View File

@@ -38,6 +38,40 @@ limitations under the License.
return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
}
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_BF16(PACKET_F, PACKET_BF16) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pcos) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, psin) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pexp) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pexp2) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pexpm1) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, plog) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, plog1p) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, plog2) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, preciprocal) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, prsqrt) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pcbrt) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, psqrt) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, ptanh)
// BF16 wrappers for unsupported/SpecialFunctions.
#define EIGEN_INSTANTIATE_SPECIAL_FUNCS_BF16(PACKET_F, PACKET_BF16) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, perf) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pndtri)
#define EIGEN_INSTANTIATE_BESSEL_FUNCS_BF16(PACKET_F, PACKET_BF16) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_i0) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_i0e) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_i1) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_i1e) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_j0) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_j1) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_k0) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_k0e) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_k1) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_k1e) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_y0) \
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_y1)
// Only use HIP GPU bf16 in kernels
#if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
#define EIGEN_USE_HIP_BF16

View File

@@ -57,6 +57,40 @@
return float2half(METHOD<PACKET_F>(half2float(_x))); \
}
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_F16(PACKET_F, PACKET_F16) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pcos) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, psin) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pexp) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pexp2) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pexpm1) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, plog) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, plog1p) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, plog2) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, preciprocal) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, prsqrt) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pcbrt) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, psqrt) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, ptanh)
// F16 wrappers for unsupported/SpecialFunctions.
#define EIGEN_INSTANTIATE_SPECIAL_FUNCS_F16(PACKET_F, PACKET_F16) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, perf) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pndtri)
#define EIGEN_INSTANTIATE_BESSEL_FUNCS_F16(PACKET_F, PACKET_F16) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_i0) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_i0e) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_i1) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_i1e) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_j0) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_j1) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_k0) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_k0e) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_k1) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_k1e) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_y0) \
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_y1)
namespace Eigen {
struct half;

View File

@@ -33,12 +33,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet8hf ptanh<Packet8hf>(const Packet8hf
}
#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp2)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_BF16(Packet4f, Packet4bf)
template <>
EIGEN_STRONG_INLINE Packet4bf pfrexp(const Packet4bf& a, Packet4bf& exponent) {

View File

@@ -16,36 +16,7 @@
namespace Eigen {
namespace internal {
template <>
EIGEN_STRONG_INLINE PacketXf pexp<PacketXf>(const PacketXf& x) {
return pexp_float(x);
}
template <>
EIGEN_STRONG_INLINE PacketXf plog<PacketXf>(const PacketXf& x) {
return plog_float(x);
}
template <>
EIGEN_STRONG_INLINE PacketXf psin<PacketXf>(const PacketXf& x) {
return psin_float(x);
}
template <>
EIGEN_STRONG_INLINE PacketXf pcos<PacketXf>(const PacketXf& x) {
return pcos_float(x);
}
template <>
EIGEN_STRONG_INLINE PacketXf ptan<PacketXf>(const PacketXf& x) {
return ptan_float(x);
}
// Hyperbolic Tangent function.
template <>
EIGEN_STRONG_INLINE PacketXf ptanh<PacketXf>(const PacketXf& x) {
return ptanh_float(x);
}
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_FLOAT(PacketXf)
} // end namespace internal
} // end namespace Eigen

View File

@@ -354,10 +354,17 @@ struct packet_traits<float> : default_packet_traits {
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasTan = EIGEN_FAST_MATH,
HasACos = 1,
HasASin = 1,
HasATan = 1,
HasATanh = 1,
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasExp = 1,
HasPow = 1,
HasSqrt = 1,
HasCbrt = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasErfc = EIGEN_FAST_MATH

View File

@@ -31,259 +31,69 @@ namespace internal {
// introduce conflicts between these packet_traits definitions and the ones
// we'll use on the host side (SSE, AVX, ...)
#if defined(SYCL_DEVICE_ONLY)
#define SYCL_PLOG(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog<packet_type>(const packet_type& a) { \
return cl::sycl::log(a); \
// Generic macro for unary SYCL math functions.
#define SYCL_PACKET_FUNCTION(EIGEN_FUNC, SYCL_FUNC, PACKET) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PACKET EIGEN_FUNC<PACKET>(const PACKET& a) { \
return cl::sycl::SYCL_FUNC(a); \
}
SYCL_PLOG(cl::sycl::cl_half8)
SYCL_PLOG(cl::sycl::cl_float4)
SYCL_PLOG(cl::sycl::cl_double2)
#undef SYCL_PLOG
// Instantiate a unary function for the standard set of SYCL vector types.
#define SYCL_UNARY_FUNCTION(EIGEN_FUNC, SYCL_FUNC) \
SYCL_PACKET_FUNCTION(EIGEN_FUNC, SYCL_FUNC, cl::sycl::cl_half8) \
SYCL_PACKET_FUNCTION(EIGEN_FUNC, SYCL_FUNC, cl::sycl::cl_float4) \
SYCL_PACKET_FUNCTION(EIGEN_FUNC, SYCL_FUNC, cl::sycl::cl_double2)
#define SYCL_PLOG1P(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog1p<packet_type>(const packet_type& a) { \
return cl::sycl::log1p(a); \
SYCL_UNARY_FUNCTION(plog, log)
SYCL_UNARY_FUNCTION(plog1p, log1p)
SYCL_UNARY_FUNCTION(plog10, log10)
SYCL_UNARY_FUNCTION(pexpm1, expm1)
SYCL_UNARY_FUNCTION(psqrt, sqrt)
SYCL_UNARY_FUNCTION(prsqrt, rsqrt)
SYCL_UNARY_FUNCTION(psin, sin)
SYCL_UNARY_FUNCTION(pcos, cos)
SYCL_UNARY_FUNCTION(ptan, tan)
SYCL_UNARY_FUNCTION(pasin, asin)
SYCL_UNARY_FUNCTION(pacos, acos)
SYCL_UNARY_FUNCTION(patan, atan)
SYCL_UNARY_FUNCTION(psinh, sinh)
SYCL_UNARY_FUNCTION(pcosh, cosh)
SYCL_UNARY_FUNCTION(ptanh, tanh)
SYCL_UNARY_FUNCTION(pround, round)
SYCL_UNARY_FUNCTION(print, rint)
SYCL_UNARY_FUNCTION(pfloor, floor)
// pexp has additional scalar type instantiations.
SYCL_UNARY_FUNCTION(pexp, exp)
SYCL_PACKET_FUNCTION(pexp, exp, cl::sycl::cl_half)
SYCL_PACKET_FUNCTION(pexp, exp, cl::sycl::cl_float)
// pceil uses cl_half (scalar) instead of cl_half8 (vector) — preserving original behavior.
SYCL_PACKET_FUNCTION(pceil, ceil, cl::sycl::cl_half)
SYCL_PACKET_FUNCTION(pceil, ceil, cl::sycl::cl_float4)
SYCL_PACKET_FUNCTION(pceil, ceil, cl::sycl::cl_double2)
#undef SYCL_UNARY_FUNCTION
#undef SYCL_PACKET_FUNCTION
// Binary min/max functions.
#define SYCL_BINARY_FUNCTION(EIGEN_FUNC, SYCL_FUNC, PACKET) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PACKET EIGEN_FUNC<PACKET>(const PACKET& a, const PACKET& b) { \
return cl::sycl::SYCL_FUNC(a, b); \
}
SYCL_PLOG1P(cl::sycl::cl_half8)
SYCL_PLOG1P(cl::sycl::cl_float4)
SYCL_PLOG1P(cl::sycl::cl_double2)
#undef SYCL_PLOG1P
SYCL_BINARY_FUNCTION(pmin, fmin, cl::sycl::cl_half8)
SYCL_BINARY_FUNCTION(pmin, fmin, cl::sycl::cl_float4)
SYCL_BINARY_FUNCTION(pmin, fmin, cl::sycl::cl_double2)
SYCL_BINARY_FUNCTION(pmax, fmax, cl::sycl::cl_half8)
SYCL_BINARY_FUNCTION(pmax, fmax, cl::sycl::cl_float4)
SYCL_BINARY_FUNCTION(pmax, fmax, cl::sycl::cl_double2)
#define SYCL_PLOG10(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog10<packet_type>(const packet_type& a) { \
return cl::sycl::log10(a); \
}
SYCL_PLOG10(cl::sycl::cl_half8)
SYCL_PLOG10(cl::sycl::cl_float4)
SYCL_PLOG10(cl::sycl::cl_double2)
#undef SYCL_PLOG10
#define SYCL_PEXP(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexp<packet_type>(const packet_type& a) { \
return cl::sycl::exp(a); \
}
SYCL_PEXP(cl::sycl::cl_half8)
SYCL_PEXP(cl::sycl::cl_half)
SYCL_PEXP(cl::sycl::cl_float4)
SYCL_PEXP(cl::sycl::cl_float)
SYCL_PEXP(cl::sycl::cl_double2)
#undef SYCL_PEXP
#define SYCL_PEXPM1(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexpm1<packet_type>(const packet_type& a) { \
return cl::sycl::expm1(a); \
}
SYCL_PEXPM1(cl::sycl::cl_half8)
SYCL_PEXPM1(cl::sycl::cl_float4)
SYCL_PEXPM1(cl::sycl::cl_double2)
#undef SYCL_PEXPM1
#define SYCL_PSQRT(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psqrt<packet_type>(const packet_type& a) { \
return cl::sycl::sqrt(a); \
}
SYCL_PSQRT(cl::sycl::cl_half8)
SYCL_PSQRT(cl::sycl::cl_float4)
SYCL_PSQRT(cl::sycl::cl_double2)
#undef SYCL_PSQRT
#define SYCL_PRSQRT(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type prsqrt<packet_type>(const packet_type& a) { \
return cl::sycl::rsqrt(a); \
}
SYCL_PRSQRT(cl::sycl::cl_half8)
SYCL_PRSQRT(cl::sycl::cl_float4)
SYCL_PRSQRT(cl::sycl::cl_double2)
#undef SYCL_PRSQRT
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
#define SYCL_PSIN(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psin<packet_type>(const packet_type& a) { \
return cl::sycl::sin(a); \
}
SYCL_PSIN(cl::sycl::cl_half8)
SYCL_PSIN(cl::sycl::cl_float4)
SYCL_PSIN(cl::sycl::cl_double2)
#undef SYCL_PSIN
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
#define SYCL_PCOS(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcos<packet_type>(const packet_type& a) { \
return cl::sycl::cos(a); \
}
SYCL_PCOS(cl::sycl::cl_half8)
SYCL_PCOS(cl::sycl::cl_float4)
SYCL_PCOS(cl::sycl::cl_double2)
#undef SYCL_PCOS
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
#define SYCL_PTAN(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptan<packet_type>(const packet_type& a) { \
return cl::sycl::tan(a); \
}
SYCL_PTAN(cl::sycl::cl_half8)
SYCL_PTAN(cl::sycl::cl_float4)
SYCL_PTAN(cl::sycl::cl_double2)
#undef SYCL_PTAN
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
#define SYCL_PASIN(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pasin<packet_type>(const packet_type& a) { \
return cl::sycl::asin(a); \
}
SYCL_PASIN(cl::sycl::cl_half8)
SYCL_PASIN(cl::sycl::cl_float4)
SYCL_PASIN(cl::sycl::cl_double2)
#undef SYCL_PASIN
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
#define SYCL_PACOS(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pacos<packet_type>(const packet_type& a) { \
return cl::sycl::acos(a); \
}
SYCL_PACOS(cl::sycl::cl_half8)
SYCL_PACOS(cl::sycl::cl_float4)
SYCL_PACOS(cl::sycl::cl_double2)
#undef SYCL_PACOS
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
#define SYCL_PATAN(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type patan<packet_type>(const packet_type& a) { \
return cl::sycl::atan(a); \
}
SYCL_PATAN(cl::sycl::cl_half8)
SYCL_PATAN(cl::sycl::cl_float4)
SYCL_PATAN(cl::sycl::cl_double2)
#undef SYCL_PATAN
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
#define SYCL_PSINH(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psinh<packet_type>(const packet_type& a) { \
return cl::sycl::sinh(a); \
}
SYCL_PSINH(cl::sycl::cl_half8)
SYCL_PSINH(cl::sycl::cl_float4)
SYCL_PSINH(cl::sycl::cl_double2)
#undef SYCL_PSINH
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
#define SYCL_PCOSH(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcosh<packet_type>(const packet_type& a) { \
return cl::sycl::cosh(a); \
}
SYCL_PCOSH(cl::sycl::cl_half8)
SYCL_PCOSH(cl::sycl::cl_float4)
SYCL_PCOSH(cl::sycl::cl_double2)
#undef SYCL_PCOSH
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
#define SYCL_PTANH(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptanh<packet_type>(const packet_type& a) { \
return cl::sycl::tanh(a); \
}
SYCL_PTANH(cl::sycl::cl_half8)
SYCL_PTANH(cl::sycl::cl_float4)
SYCL_PTANH(cl::sycl::cl_double2)
#undef SYCL_PTANH
#define SYCL_PCEIL(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pceil<packet_type>(const packet_type& a) { \
return cl::sycl::ceil(a); \
}
SYCL_PCEIL(cl::sycl::cl_half)
SYCL_PCEIL(cl::sycl::cl_float4)
SYCL_PCEIL(cl::sycl::cl_double2)
#undef SYCL_PCEIL
#define SYCL_PROUND(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pround<packet_type>(const packet_type& a) { \
return cl::sycl::round(a); \
}
SYCL_PROUND(cl::sycl::cl_half8)
SYCL_PROUND(cl::sycl::cl_float4)
SYCL_PROUND(cl::sycl::cl_double2)
#undef SYCL_PROUND
#define SYCL_PRINT(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type print<packet_type>(const packet_type& a) { \
return cl::sycl::rint(a); \
}
SYCL_PRINT(cl::sycl::cl_half8)
SYCL_PRINT(cl::sycl::cl_float4)
SYCL_PRINT(cl::sycl::cl_double2)
#undef SYCL_PRINT
#define SYCL_FLOOR(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pfloor<packet_type>(const packet_type& a) { \
return cl::sycl::floor(a); \
}
SYCL_FLOOR(cl::sycl::cl_half8)
SYCL_FLOOR(cl::sycl::cl_float4)
SYCL_FLOOR(cl::sycl::cl_double2)
#undef SYCL_FLOOR
#define SYCL_PMIN(packet_type, expr) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmin<packet_type>(const packet_type& a, const packet_type& b) { \
return expr; \
}
SYCL_PMIN(cl::sycl::cl_half8, cl::sycl::fmin(a, b))
SYCL_PMIN(cl::sycl::cl_float4, cl::sycl::fmin(a, b))
SYCL_PMIN(cl::sycl::cl_double2, cl::sycl::fmin(a, b))
#undef SYCL_PMIN
#define SYCL_PMAX(packet_type, expr) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmax<packet_type>(const packet_type& a, const packet_type& b) { \
return expr; \
}
SYCL_PMAX(cl::sycl::cl_half8, cl::sycl::fmax(a, b))
SYCL_PMAX(cl::sycl::cl_float4, cl::sycl::fmax(a, b))
SYCL_PMAX(cl::sycl::cl_double2, cl::sycl::fmax(a, b))
#undef SYCL_PMAX
#undef SYCL_BINARY_FUNCTION
// pldexp requires integer conversion of the exponent.
#define SYCL_PLDEXP(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pldexp(const packet_type& a, const packet_type& exponent) { \

View File

@@ -4,41 +4,8 @@
namespace Eigen {
namespace internal {
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i0)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i0)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i0e)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i0e)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i1)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i1)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i1e)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i1e)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_j0)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_j0)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_j1)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_j1)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k0)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k0)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k0e)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k0e)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k1)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k1)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k1e)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k1e)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_y0)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_y0)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_y1)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_y1)
EIGEN_INSTANTIATE_BESSEL_FUNCS_F16(Packet8f, Packet8h)
EIGEN_INSTANTIATE_BESSEL_FUNCS_BF16(Packet8f, Packet8bf)
} // namespace internal
} // namespace Eigen

View File

@@ -4,13 +4,10 @@
namespace Eigen {
namespace internal {
F16_PACKET_FUNCTION(Packet8f, Packet8h, perf)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, perf)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pndtri)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pndtri)
EIGEN_INSTANTIATE_SPECIAL_FUNCS_F16(Packet8f, Packet8h)
EIGEN_INSTANTIATE_SPECIAL_FUNCS_BF16(Packet8f, Packet8bf)
} // namespace internal
} // namespace Eigen
#endif // EIGEN_AVX_SPECIAL_FUNCTIONS_H
#endif // EIGEN_AVX_SPECIALFUNCTIONS_H

View File

@@ -4,41 +4,8 @@
namespace Eigen {
namespace internal {
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i0)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i0)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i0e)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i0e)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i1)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i1e)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i1e)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_j0)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_j0)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_j1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_j1)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k0)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k0)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k0e)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k0e)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k1)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k1e)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k1e)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_y0)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_y0)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_y1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_y1)
EIGEN_INSTANTIATE_BESSEL_FUNCS_F16(Packet16f, Packet16h)
EIGEN_INSTANTIATE_BESSEL_FUNCS_BF16(Packet16f, Packet16bf)
} // namespace internal
} // namespace Eigen

View File

@@ -4,13 +4,10 @@
namespace Eigen {
namespace internal {
F16_PACKET_FUNCTION(Packet16f, Packet16h, perf)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, perf)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pndtri)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pndtri)
EIGEN_INSTANTIATE_SPECIAL_FUNCS_F16(Packet16f, Packet16h)
EIGEN_INSTANTIATE_SPECIAL_FUNCS_BF16(Packet16f, Packet16bf)
} // namespace internal
} // namespace Eigen
#endif // EIGEN_AVX512_SPECIAL_FUNCTIONS_H
#endif // EIGEN_AVX512_SPECIALFUNCTIONS_H

View File

@@ -35,18 +35,7 @@ NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_y1)
#undef NEON_HALF_TO_FLOAT_FUNCTIONS
#endif
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i0)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i0e)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i1)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i1e)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_j0)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_j1)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k0)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k0e)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k1)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k1e)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_y0)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_y1)
EIGEN_INSTANTIATE_BESSEL_FUNCS_BF16(Packet4f, Packet4bf)
} // namespace internal
} // namespace Eigen

View File

@@ -25,8 +25,7 @@ NEON_HALF_TO_FLOAT_FUNCTIONS(pndtri)
#undef NEON_HALF_TO_FLOAT_FUNCTIONS
#endif
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, perf)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pndtri)
EIGEN_INSTANTIATE_SPECIAL_FUNCS_BF16(Packet4f, Packet4bf)
} // namespace internal
} // namespace Eigen