mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Consolidate BF16/F16 wrapper macros and simplify arch math functions
libeigen/eigen!2195 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
@@ -116,34 +116,10 @@ EIGEN_STRONG_INLINE Packet8bf pldexp(const Packet8bf& a, const Packet8bf& expone
|
||||
return F32ToBf16(pldexp<Packet8f>(Bf16ToF32(a), Bf16ToF32(exponent)));
|
||||
}
|
||||
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp2)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexpm1)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog2)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, preciprocal)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcbrt)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
|
||||
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_BF16(Packet8f, Packet8bf)
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp2)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexpm1)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, plog)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, plog1p)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, plog2)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, preciprocal)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pcbrt)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
|
||||
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_F16(Packet8f, Packet8h)
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
@@ -106,32 +106,10 @@ EIGEN_STRONG_INLINE Packet16f preciprocal<Packet16f>(const Packet16f& a) {
|
||||
}
|
||||
#endif
|
||||
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp2)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog2)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, preciprocal)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
|
||||
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_BF16(Packet16f, Packet16bf)
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp2)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog1p)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog2)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, preciprocal)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, psin)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh)
|
||||
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_F16(Packet16f, Packet16h)
|
||||
#endif // EIGEN_VECTORIZE_AVX512FP16
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
@@ -38,6 +38,40 @@ limitations under the License.
|
||||
return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
|
||||
}
|
||||
|
||||
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_BF16(PACKET_F, PACKET_BF16) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pcos) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, psin) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pexp) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pexp2) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pexpm1) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, plog) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, plog1p) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, plog2) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, preciprocal) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, prsqrt) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pcbrt) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, psqrt) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, ptanh)
|
||||
|
||||
// BF16 wrappers for unsupported/SpecialFunctions.
|
||||
#define EIGEN_INSTANTIATE_SPECIAL_FUNCS_BF16(PACKET_F, PACKET_BF16) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, perf) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pndtri)
|
||||
|
||||
#define EIGEN_INSTANTIATE_BESSEL_FUNCS_BF16(PACKET_F, PACKET_BF16) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_i0) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_i0e) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_i1) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_i1e) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_j0) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_j1) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_k0) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_k0e) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_k1) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_k1e) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_y0) \
|
||||
BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, pbessel_y1)
|
||||
|
||||
// Only use HIP GPU bf16 in kernels
|
||||
#if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
|
||||
#define EIGEN_USE_HIP_BF16
|
||||
|
||||
@@ -57,6 +57,40 @@
|
||||
return float2half(METHOD<PACKET_F>(half2float(_x))); \
|
||||
}
|
||||
|
||||
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_F16(PACKET_F, PACKET_F16) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pcos) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, psin) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pexp) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pexp2) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pexpm1) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, plog) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, plog1p) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, plog2) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, preciprocal) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, prsqrt) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pcbrt) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, psqrt) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, ptanh)
|
||||
|
||||
// F16 wrappers for unsupported/SpecialFunctions.
|
||||
#define EIGEN_INSTANTIATE_SPECIAL_FUNCS_F16(PACKET_F, PACKET_F16) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, perf) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pndtri)
|
||||
|
||||
#define EIGEN_INSTANTIATE_BESSEL_FUNCS_F16(PACKET_F, PACKET_F16) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_i0) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_i0e) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_i1) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_i1e) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_j0) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_j1) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_k0) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_k0e) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_k1) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_k1e) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_y0) \
|
||||
F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, pbessel_y1)
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
struct half;
|
||||
|
||||
@@ -33,12 +33,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet8hf ptanh<Packet8hf>(const Packet8hf
|
||||
}
|
||||
#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp2)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
|
||||
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_BF16(Packet4f, Packet4bf)
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4bf pfrexp(const Packet4bf& a, Packet4bf& exponent) {
|
||||
|
||||
@@ -16,36 +16,7 @@
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE PacketXf pexp<PacketXf>(const PacketXf& x) {
|
||||
return pexp_float(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE PacketXf plog<PacketXf>(const PacketXf& x) {
|
||||
return plog_float(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE PacketXf psin<PacketXf>(const PacketXf& x) {
|
||||
return psin_float(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE PacketXf pcos<PacketXf>(const PacketXf& x) {
|
||||
return pcos_float(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE PacketXf ptan<PacketXf>(const PacketXf& x) {
|
||||
return ptan_float(x);
|
||||
}
|
||||
|
||||
// Hyperbolic Tangent function.
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE PacketXf ptanh<PacketXf>(const PacketXf& x) {
|
||||
return ptanh_float(x);
|
||||
}
|
||||
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_FLOAT(PacketXf)
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
||||
@@ -354,10 +354,17 @@ struct packet_traits<float> : default_packet_traits {
|
||||
HasSin = EIGEN_FAST_MATH,
|
||||
HasCos = EIGEN_FAST_MATH,
|
||||
HasTan = EIGEN_FAST_MATH,
|
||||
HasACos = 1,
|
||||
HasASin = 1,
|
||||
HasATan = 1,
|
||||
HasATanh = 1,
|
||||
HasLog = 1,
|
||||
HasLog1p = 1,
|
||||
HasExpm1 = 1,
|
||||
HasExp = 1,
|
||||
HasPow = 1,
|
||||
HasSqrt = 1,
|
||||
HasCbrt = 1,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
HasErfc = EIGEN_FAST_MATH
|
||||
|
||||
@@ -31,259 +31,69 @@ namespace internal {
|
||||
// introduce conflicts between these packet_traits definitions and the ones
|
||||
// we'll use on the host side (SSE, AVX, ...)
|
||||
#if defined(SYCL_DEVICE_ONLY)
|
||||
#define SYCL_PLOG(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::log(a); \
|
||||
|
||||
// Generic macro for unary SYCL math functions.
|
||||
#define SYCL_PACKET_FUNCTION(EIGEN_FUNC, SYCL_FUNC, PACKET) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PACKET EIGEN_FUNC<PACKET>(const PACKET& a) { \
|
||||
return cl::sycl::SYCL_FUNC(a); \
|
||||
}
|
||||
|
||||
SYCL_PLOG(cl::sycl::cl_half8)
|
||||
SYCL_PLOG(cl::sycl::cl_float4)
|
||||
SYCL_PLOG(cl::sycl::cl_double2)
|
||||
#undef SYCL_PLOG
|
||||
// Instantiate a unary function for the standard set of SYCL vector types.
|
||||
#define SYCL_UNARY_FUNCTION(EIGEN_FUNC, SYCL_FUNC) \
|
||||
SYCL_PACKET_FUNCTION(EIGEN_FUNC, SYCL_FUNC, cl::sycl::cl_half8) \
|
||||
SYCL_PACKET_FUNCTION(EIGEN_FUNC, SYCL_FUNC, cl::sycl::cl_float4) \
|
||||
SYCL_PACKET_FUNCTION(EIGEN_FUNC, SYCL_FUNC, cl::sycl::cl_double2)
|
||||
|
||||
#define SYCL_PLOG1P(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog1p<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::log1p(a); \
|
||||
SYCL_UNARY_FUNCTION(plog, log)
|
||||
SYCL_UNARY_FUNCTION(plog1p, log1p)
|
||||
SYCL_UNARY_FUNCTION(plog10, log10)
|
||||
SYCL_UNARY_FUNCTION(pexpm1, expm1)
|
||||
SYCL_UNARY_FUNCTION(psqrt, sqrt)
|
||||
SYCL_UNARY_FUNCTION(prsqrt, rsqrt)
|
||||
SYCL_UNARY_FUNCTION(psin, sin)
|
||||
SYCL_UNARY_FUNCTION(pcos, cos)
|
||||
SYCL_UNARY_FUNCTION(ptan, tan)
|
||||
SYCL_UNARY_FUNCTION(pasin, asin)
|
||||
SYCL_UNARY_FUNCTION(pacos, acos)
|
||||
SYCL_UNARY_FUNCTION(patan, atan)
|
||||
SYCL_UNARY_FUNCTION(psinh, sinh)
|
||||
SYCL_UNARY_FUNCTION(pcosh, cosh)
|
||||
SYCL_UNARY_FUNCTION(ptanh, tanh)
|
||||
SYCL_UNARY_FUNCTION(pround, round)
|
||||
SYCL_UNARY_FUNCTION(print, rint)
|
||||
SYCL_UNARY_FUNCTION(pfloor, floor)
|
||||
|
||||
// pexp has additional scalar type instantiations.
|
||||
SYCL_UNARY_FUNCTION(pexp, exp)
|
||||
SYCL_PACKET_FUNCTION(pexp, exp, cl::sycl::cl_half)
|
||||
SYCL_PACKET_FUNCTION(pexp, exp, cl::sycl::cl_float)
|
||||
|
||||
// pceil uses cl_half (scalar) instead of cl_half8 (vector) — preserving original behavior.
|
||||
SYCL_PACKET_FUNCTION(pceil, ceil, cl::sycl::cl_half)
|
||||
SYCL_PACKET_FUNCTION(pceil, ceil, cl::sycl::cl_float4)
|
||||
SYCL_PACKET_FUNCTION(pceil, ceil, cl::sycl::cl_double2)
|
||||
|
||||
#undef SYCL_UNARY_FUNCTION
|
||||
#undef SYCL_PACKET_FUNCTION
|
||||
|
||||
// Binary min/max functions.
|
||||
#define SYCL_BINARY_FUNCTION(EIGEN_FUNC, SYCL_FUNC, PACKET) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PACKET EIGEN_FUNC<PACKET>(const PACKET& a, const PACKET& b) { \
|
||||
return cl::sycl::SYCL_FUNC(a, b); \
|
||||
}
|
||||
|
||||
SYCL_PLOG1P(cl::sycl::cl_half8)
|
||||
SYCL_PLOG1P(cl::sycl::cl_float4)
|
||||
SYCL_PLOG1P(cl::sycl::cl_double2)
|
||||
#undef SYCL_PLOG1P
|
||||
SYCL_BINARY_FUNCTION(pmin, fmin, cl::sycl::cl_half8)
|
||||
SYCL_BINARY_FUNCTION(pmin, fmin, cl::sycl::cl_float4)
|
||||
SYCL_BINARY_FUNCTION(pmin, fmin, cl::sycl::cl_double2)
|
||||
SYCL_BINARY_FUNCTION(pmax, fmax, cl::sycl::cl_half8)
|
||||
SYCL_BINARY_FUNCTION(pmax, fmax, cl::sycl::cl_float4)
|
||||
SYCL_BINARY_FUNCTION(pmax, fmax, cl::sycl::cl_double2)
|
||||
|
||||
#define SYCL_PLOG10(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog10<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::log10(a); \
|
||||
}
|
||||
|
||||
SYCL_PLOG10(cl::sycl::cl_half8)
|
||||
SYCL_PLOG10(cl::sycl::cl_float4)
|
||||
SYCL_PLOG10(cl::sycl::cl_double2)
|
||||
#undef SYCL_PLOG10
|
||||
|
||||
#define SYCL_PEXP(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexp<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::exp(a); \
|
||||
}
|
||||
|
||||
SYCL_PEXP(cl::sycl::cl_half8)
|
||||
SYCL_PEXP(cl::sycl::cl_half)
|
||||
SYCL_PEXP(cl::sycl::cl_float4)
|
||||
SYCL_PEXP(cl::sycl::cl_float)
|
||||
SYCL_PEXP(cl::sycl::cl_double2)
|
||||
#undef SYCL_PEXP
|
||||
|
||||
#define SYCL_PEXPM1(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexpm1<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::expm1(a); \
|
||||
}
|
||||
|
||||
SYCL_PEXPM1(cl::sycl::cl_half8)
|
||||
SYCL_PEXPM1(cl::sycl::cl_float4)
|
||||
SYCL_PEXPM1(cl::sycl::cl_double2)
|
||||
#undef SYCL_PEXPM1
|
||||
|
||||
#define SYCL_PSQRT(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psqrt<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::sqrt(a); \
|
||||
}
|
||||
|
||||
SYCL_PSQRT(cl::sycl::cl_half8)
|
||||
SYCL_PSQRT(cl::sycl::cl_float4)
|
||||
SYCL_PSQRT(cl::sycl::cl_double2)
|
||||
#undef SYCL_PSQRT
|
||||
|
||||
#define SYCL_PRSQRT(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type prsqrt<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::rsqrt(a); \
|
||||
}
|
||||
|
||||
SYCL_PRSQRT(cl::sycl::cl_half8)
|
||||
SYCL_PRSQRT(cl::sycl::cl_float4)
|
||||
SYCL_PRSQRT(cl::sycl::cl_double2)
|
||||
#undef SYCL_PRSQRT
|
||||
|
||||
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
|
||||
#define SYCL_PSIN(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psin<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::sin(a); \
|
||||
}
|
||||
|
||||
SYCL_PSIN(cl::sycl::cl_half8)
|
||||
SYCL_PSIN(cl::sycl::cl_float4)
|
||||
SYCL_PSIN(cl::sycl::cl_double2)
|
||||
#undef SYCL_PSIN
|
||||
|
||||
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
|
||||
#define SYCL_PCOS(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcos<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::cos(a); \
|
||||
}
|
||||
|
||||
SYCL_PCOS(cl::sycl::cl_half8)
|
||||
SYCL_PCOS(cl::sycl::cl_float4)
|
||||
SYCL_PCOS(cl::sycl::cl_double2)
|
||||
#undef SYCL_PCOS
|
||||
|
||||
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
|
||||
#define SYCL_PTAN(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptan<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::tan(a); \
|
||||
}
|
||||
|
||||
SYCL_PTAN(cl::sycl::cl_half8)
|
||||
SYCL_PTAN(cl::sycl::cl_float4)
|
||||
SYCL_PTAN(cl::sycl::cl_double2)
|
||||
#undef SYCL_PTAN
|
||||
|
||||
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
|
||||
#define SYCL_PASIN(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pasin<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::asin(a); \
|
||||
}
|
||||
|
||||
SYCL_PASIN(cl::sycl::cl_half8)
|
||||
SYCL_PASIN(cl::sycl::cl_float4)
|
||||
SYCL_PASIN(cl::sycl::cl_double2)
|
||||
#undef SYCL_PASIN
|
||||
|
||||
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
|
||||
#define SYCL_PACOS(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pacos<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::acos(a); \
|
||||
}
|
||||
|
||||
SYCL_PACOS(cl::sycl::cl_half8)
|
||||
SYCL_PACOS(cl::sycl::cl_float4)
|
||||
SYCL_PACOS(cl::sycl::cl_double2)
|
||||
#undef SYCL_PACOS
|
||||
|
||||
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
|
||||
#define SYCL_PATAN(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type patan<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::atan(a); \
|
||||
}
|
||||
|
||||
SYCL_PATAN(cl::sycl::cl_half8)
|
||||
SYCL_PATAN(cl::sycl::cl_float4)
|
||||
SYCL_PATAN(cl::sycl::cl_double2)
|
||||
#undef SYCL_PATAN
|
||||
|
||||
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
|
||||
#define SYCL_PSINH(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psinh<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::sinh(a); \
|
||||
}
|
||||
|
||||
SYCL_PSINH(cl::sycl::cl_half8)
|
||||
SYCL_PSINH(cl::sycl::cl_float4)
|
||||
SYCL_PSINH(cl::sycl::cl_double2)
|
||||
#undef SYCL_PSINH
|
||||
|
||||
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
|
||||
#define SYCL_PCOSH(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcosh<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::cosh(a); \
|
||||
}
|
||||
|
||||
SYCL_PCOSH(cl::sycl::cl_half8)
|
||||
SYCL_PCOSH(cl::sycl::cl_float4)
|
||||
SYCL_PCOSH(cl::sycl::cl_double2)
|
||||
#undef SYCL_PCOSH
|
||||
|
||||
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
|
||||
#define SYCL_PTANH(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptanh<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::tanh(a); \
|
||||
}
|
||||
|
||||
SYCL_PTANH(cl::sycl::cl_half8)
|
||||
SYCL_PTANH(cl::sycl::cl_float4)
|
||||
SYCL_PTANH(cl::sycl::cl_double2)
|
||||
#undef SYCL_PTANH
|
||||
|
||||
#define SYCL_PCEIL(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pceil<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::ceil(a); \
|
||||
}
|
||||
|
||||
SYCL_PCEIL(cl::sycl::cl_half)
|
||||
SYCL_PCEIL(cl::sycl::cl_float4)
|
||||
SYCL_PCEIL(cl::sycl::cl_double2)
|
||||
#undef SYCL_PCEIL
|
||||
|
||||
#define SYCL_PROUND(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pround<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::round(a); \
|
||||
}
|
||||
|
||||
SYCL_PROUND(cl::sycl::cl_half8)
|
||||
SYCL_PROUND(cl::sycl::cl_float4)
|
||||
SYCL_PROUND(cl::sycl::cl_double2)
|
||||
#undef SYCL_PROUND
|
||||
|
||||
#define SYCL_PRINT(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type print<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::rint(a); \
|
||||
}
|
||||
|
||||
SYCL_PRINT(cl::sycl::cl_half8)
|
||||
SYCL_PRINT(cl::sycl::cl_float4)
|
||||
SYCL_PRINT(cl::sycl::cl_double2)
|
||||
#undef SYCL_PRINT
|
||||
|
||||
#define SYCL_FLOOR(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pfloor<packet_type>(const packet_type& a) { \
|
||||
return cl::sycl::floor(a); \
|
||||
}
|
||||
|
||||
SYCL_FLOOR(cl::sycl::cl_half8)
|
||||
SYCL_FLOOR(cl::sycl::cl_float4)
|
||||
SYCL_FLOOR(cl::sycl::cl_double2)
|
||||
#undef SYCL_FLOOR
|
||||
|
||||
#define SYCL_PMIN(packet_type, expr) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmin<packet_type>(const packet_type& a, const packet_type& b) { \
|
||||
return expr; \
|
||||
}
|
||||
|
||||
SYCL_PMIN(cl::sycl::cl_half8, cl::sycl::fmin(a, b))
|
||||
SYCL_PMIN(cl::sycl::cl_float4, cl::sycl::fmin(a, b))
|
||||
SYCL_PMIN(cl::sycl::cl_double2, cl::sycl::fmin(a, b))
|
||||
#undef SYCL_PMIN
|
||||
|
||||
#define SYCL_PMAX(packet_type, expr) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmax<packet_type>(const packet_type& a, const packet_type& b) { \
|
||||
return expr; \
|
||||
}
|
||||
|
||||
SYCL_PMAX(cl::sycl::cl_half8, cl::sycl::fmax(a, b))
|
||||
SYCL_PMAX(cl::sycl::cl_float4, cl::sycl::fmax(a, b))
|
||||
SYCL_PMAX(cl::sycl::cl_double2, cl::sycl::fmax(a, b))
|
||||
#undef SYCL_PMAX
|
||||
#undef SYCL_BINARY_FUNCTION
|
||||
|
||||
// pldexp requires integer conversion of the exponent.
|
||||
#define SYCL_PLDEXP(packet_type) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pldexp(const packet_type& a, const packet_type& exponent) { \
|
||||
|
||||
@@ -4,41 +4,8 @@
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i0)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i0)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i0e)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i0e)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i1)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i1)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i1e)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i1e)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_j0)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_j0)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_j1)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_j1)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k0)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k0)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k0e)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k0e)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k1)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k1)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k1e)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k1e)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_y0)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_y0)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_y1)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_y1)
|
||||
EIGEN_INSTANTIATE_BESSEL_FUNCS_F16(Packet8f, Packet8h)
|
||||
EIGEN_INSTANTIATE_BESSEL_FUNCS_BF16(Packet8f, Packet8bf)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
@@ -4,13 +4,10 @@
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, perf)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, perf)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pndtri)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pndtri)
|
||||
EIGEN_INSTANTIATE_SPECIAL_FUNCS_F16(Packet8f, Packet8h)
|
||||
EIGEN_INSTANTIATE_SPECIAL_FUNCS_BF16(Packet8f, Packet8bf)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // EIGEN_AVX_SPECIAL_FUNCTIONS_H
|
||||
#endif // EIGEN_AVX_SPECIALFUNCTIONS_H
|
||||
|
||||
@@ -4,41 +4,8 @@
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i0)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i0)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i0e)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i0e)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i1)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i1)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i1e)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i1e)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_j0)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_j0)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_j1)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_j1)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k0)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k0)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k0e)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k0e)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k1)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k1)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k1e)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k1e)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_y0)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_y0)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_y1)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_y1)
|
||||
EIGEN_INSTANTIATE_BESSEL_FUNCS_F16(Packet16f, Packet16h)
|
||||
EIGEN_INSTANTIATE_BESSEL_FUNCS_BF16(Packet16f, Packet16bf)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
@@ -4,13 +4,10 @@
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, perf)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, perf)
|
||||
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pndtri)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pndtri)
|
||||
EIGEN_INSTANTIATE_SPECIAL_FUNCS_F16(Packet16f, Packet16h)
|
||||
EIGEN_INSTANTIATE_SPECIAL_FUNCS_BF16(Packet16f, Packet16bf)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // EIGEN_AVX512_SPECIAL_FUNCTIONS_H
|
||||
#endif // EIGEN_AVX512_SPECIALFUNCTIONS_H
|
||||
|
||||
@@ -35,18 +35,7 @@ NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_y1)
|
||||
#undef NEON_HALF_TO_FLOAT_FUNCTIONS
|
||||
#endif
|
||||
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i0)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i0e)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i1)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i1e)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_j0)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_j1)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k0)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k0e)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k1)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k1e)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_y0)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_y1)
|
||||
EIGEN_INSTANTIATE_BESSEL_FUNCS_BF16(Packet4f, Packet4bf)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
@@ -25,8 +25,7 @@ NEON_HALF_TO_FLOAT_FUNCTIONS(pndtri)
|
||||
#undef NEON_HALF_TO_FLOAT_FUNCTIONS
|
||||
#endif
|
||||
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, perf)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pndtri)
|
||||
EIGEN_INSTANTIATE_SPECIAL_FUNCS_BF16(Packet4f, Packet4bf)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
Reference in New Issue
Block a user