From e63d9f6ccb7f6f29f31241b87c542f3f0ab3112b Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Fri, 29 Mar 2024 21:49:27 +0000 Subject: [PATCH] Fix random again --- Eigen/Core | 1 + Eigen/src/Core/MathFunctions.h | 223 ++++------------------------- Eigen/src/Core/RandomImpl.h | 253 +++++++++++++++++++++++++++++++++ test/AnnoyingScalar.h | 13 -- test/MovableScalar.h | 20 +-- test/SafeScalar.h | 37 ++--- test/rand.cpp | 107 ++++++++++---- 7 files changed, 380 insertions(+), 274 deletions(-) create mode 100644 Eigen/src/Core/RandomImpl.h diff --git a/Eigen/Core b/Eigen/Core index f9d9974b0..ed7d3538f 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -178,6 +178,7 @@ using std::ptrdiff_t; #include "src/Core/NumTraits.h" #include "src/Core/MathFunctions.h" +#include "src/Core/RandomImpl.h" #include "src/Core/GenericPacketMath.h" #include "src/Core/MathFunctionsImpl.h" #include "src/Core/arch/Default/ConjHelper.h" diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 3f28068ce..29e00ff8e 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -604,7 +604,6 @@ template struct count_bits_impl { static_assert(std::is_integral::value && std::is_unsigned::value, "BitsType must be an unsigned integer"); - static EIGEN_DEVICE_FUNC inline int clz(BitsType bits) { int n = CHAR_BIT * sizeof(BitsType); int shift = n / 2; @@ -655,7 +654,8 @@ EIGEN_DEVICE_FUNC inline int ctz(BitsType bits) { #if EIGEN_COMP_GNUC || EIGEN_COMP_CLANG template -struct count_bits_impl> { +struct count_bits_impl< + BitsType, std::enable_if_t::value && sizeof(BitsType) <= sizeof(unsigned int)>> { static constexpr int kNumBits = static_cast(sizeof(BitsType) * CHAR_BIT); static_assert(std::is_integral::value, "BitsType must be a built-in integer"); static EIGEN_DEVICE_FUNC inline int clz(BitsType bits) { @@ -669,8 +669,9 @@ struct count_bits_impl -struct count_bits_impl< - BitsType, std::enable_if_t> { +struct count_bits_impl::value && sizeof(unsigned int) < sizeof(BitsType) && + sizeof(BitsType) <= sizeof(unsigned long)>> { static constexpr int kNumBits = static_cast(sizeof(BitsType) * CHAR_BIT); static_assert(std::is_integral::value, "BitsType must be a built-in integer"); static EIGEN_DEVICE_FUNC inline int clz(BitsType bits) { @@ -684,8 +685,9 @@ struct count_bits_impl< }; template -struct count_bits_impl> { +struct count_bits_impl::value && sizeof(unsigned long) < sizeof(BitsType) && + sizeof(BitsType) <= sizeof(unsigned long long)>> { static constexpr int kNumBits = static_cast(sizeof(BitsType) * CHAR_BIT); static_assert(std::is_integral::value, "BitsType must be a built-in integer"); static EIGEN_DEVICE_FUNC inline int clz(BitsType bits) { @@ -701,7 +703,8 @@ struct count_bits_impl -struct count_bits_impl> { +struct count_bits_impl< + BitsType, std::enable_if_t::value && sizeof(BitsType) <= sizeof(unsigned long)>> { static constexpr int kNumBits = static_cast(sizeof(BitsType) * CHAR_BIT); static_assert(std::is_integral::value, "BitsType must be a built-in integer"); static EIGEN_DEVICE_FUNC inline int clz(BitsType bits) { @@ -720,8 +723,9 @@ struct count_bits_impl -struct count_bits_impl< - BitsType, std::enable_if_t> { +struct count_bits_impl::value && sizeof(unsigned long) < sizeof(BitsType) && + sizeof(BitsType) <= sizeof(__int64)>> { static constexpr int kNumBits = static_cast(sizeof(BitsType) * CHAR_BIT); static_assert(std::is_integral::value, "BitsType must be a built-in integer"); static EIGEN_DEVICE_FUNC inline int clz(BitsType bits) { @@ -742,192 +746,27 @@ struct count_bits_impl< #endif // EIGEN_COMP_GNUC || EIGEN_COMP_CLANG template -int log2_ceil(BitsType x) { - int n = CHAR_BIT * sizeof(BitsType) - clz(x); - bool powerOfTwo = (x & (x - 1)) == 0; - return x == 0 ? 0 : powerOfTwo ? n - 1 : n; +struct log_2_impl { + static constexpr int kTotalBits = sizeof(BitsType) * CHAR_BIT; + static EIGEN_DEVICE_FUNC inline int run_ceil(const BitsType& x) { + const int n = kTotalBits - clz(x); + bool power_of_two = (x & (x - 1)) == 0; + return x == 0 ? 0 : power_of_two ? (n - 1) : n; + } + static EIGEN_DEVICE_FUNC inline int run_floor(const BitsType& x) { + const int n = kTotalBits - clz(x); + return x == 0 ? 0 : n - 1; + } +}; + +template +int log2_ceil(const BitsType& x) { + return log_2_impl::run_ceil(x); } template -int log2_floor(BitsType x) { - int n = CHAR_BIT * sizeof(BitsType) - clz(x); - return x == 0 ? 0 : n - 1; -} - -/**************************************************************************** - * Implementation of random * - ****************************************************************************/ - -// return a Scalar filled with numRandomBits beginning from the least significant bit -template -Scalar getRandomBits(int numRandomBits) { - using BitsType = typename numext::get_integer_by_size::unsigned_type; - enum : int { - StdRandBits = meta_floor_log2<(unsigned int)(RAND_MAX) + 1>::value, - ScalarBits = sizeof(Scalar) * CHAR_BIT - }; - eigen_assert((numRandomBits >= 0) && (numRandomBits <= ScalarBits)); - const BitsType mask = BitsType(-1) >> ((ScalarBits - numRandomBits) & (ScalarBits - 1)); - BitsType randomBits = BitsType(0); - for (int shift = 0; shift < numRandomBits; shift += StdRandBits) { - int r = std::rand(); - randomBits |= static_cast(r) << shift; - } - // clear the excess bits - randomBits &= mask; - return numext::bit_cast(randomBits); -} - -template -struct random_default_impl {}; - -template -struct random_impl : random_default_impl::IsComplex, NumTraits::IsInteger> {}; - -template -struct random_retval { - typedef Scalar type; -}; - -template -inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random(const Scalar& x, const Scalar& y); -template -inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random(); - -template -struct random_default_impl { - using BitsType = typename numext::get_integer_by_size::unsigned_type; - static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y, int numRandomBits) { - Scalar half_x = Scalar(0.5) * x; - Scalar half_y = Scalar(0.5) * y; - Scalar result = (half_x + half_y) + (half_y - half_x) * run(numRandomBits); - // result is in the half-open interval [x, y) -- provided that x < y - return result; - } - static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) { - const int mantissa_bits = NumTraits::digits() - 1; - return run(x, y, mantissa_bits); - } - static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits) { - const int mantissa_bits = NumTraits::digits() - 1; - eigen_assert(numRandomBits >= 0 && numRandomBits <= mantissa_bits); - BitsType randomBits = getRandomBits(numRandomBits); - // if fewer than MantissaBits is requested, shift them to the left - randomBits <<= (mantissa_bits - numRandomBits); - // randomBits is in the half-open interval [2,4) - randomBits |= numext::bit_cast(Scalar(2)); - // result is in the half-open interval [-1,1) - Scalar result = numext::bit_cast(randomBits) - Scalar(3); - return result; - } - static EIGEN_DEVICE_FUNC inline Scalar run() { - const int mantissa_bits = NumTraits::digits() - 1; - return run(mantissa_bits); - } -}; - -// TODO: fix this for PPC -template -struct random_longdouble_impl { - enum : int { - Size = sizeof(long double), - MantissaBits = NumTraits::digits() - 1, - LowBits = MantissaBits > 64 ? 64 : MantissaBits, - HighBits = MantissaBits > 64 ? MantissaBits - 64 : 0 - }; - static EIGEN_DEVICE_FUNC inline long double run() { - EIGEN_USING_STD(memcpy) - uint64_t randomBits[2]; - long double result = 2.0L; - memcpy(&randomBits, &result, Size); - randomBits[0] |= getRandomBits(LowBits); - randomBits[1] |= getRandomBits(HighBits); - memcpy(&result, &randomBits, Size); - result -= 3.0L; - return result; - } -}; - -// GPUs treat long double as double. -#ifndef EIGEN_GPU_COMPILE_PHASE -template <> -struct random_longdouble_impl { - using Impl = random_impl; - static EIGEN_DEVICE_FUNC inline long double run() { return static_cast(Impl::run()); } -}; - -template <> -struct random_impl { - static EIGEN_DEVICE_FUNC inline long double run(const long double& x, const long double& y) { - long double half_x = 0.5L * x; - long double half_y = 0.5L * y; - long double result = (half_x + half_y) + (half_y - half_x) * run(); - return result; - } - static EIGEN_DEVICE_FUNC inline long double run() { return random_longdouble_impl<>::run(); } -}; -#endif - -template -struct random_default_impl { - using BitsType = typename numext::get_integer_by_size::unsigned_type; - enum : int { ScalarBits = sizeof(Scalar) * CHAR_BIT }; - static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) { - if (y <= x) return x; - const BitsType range = static_cast(y) - static_cast(x) + 1; - // handle edge case where [x,y] spans the entire range of Scalar - if (range == 0) return getRandomBits(ScalarBits); - // calculate the number of random bits needed to fill range - const int numRandomBits = log2_ceil(range); - BitsType randomBits; - do { - randomBits = getRandomBits(numRandomBits); - // if the random draw is outside [0, range), try again (rejection sampling) - // in the worst-case scenario, the probability of rejection is: 1/2 - 1/2^numRandomBits < 50% - } while (randomBits >= range); - // Avoid overflow in the case where `x` is negative and there is a large range so - // `randomBits` would also be negative if cast to `Scalar` first. - Scalar result = static_cast(static_cast(x) + randomBits); - return result; - } - - static EIGEN_DEVICE_FUNC inline Scalar run() { -#ifdef EIGEN_MAKING_DOCS - return run(Scalar(NumTraits::IsSigned ? -10 : 0), Scalar(10)); -#else - return getRandomBits(ScalarBits); -#endif - } -}; - -template <> -struct random_impl { - static EIGEN_DEVICE_FUNC inline bool run(const bool& x, const bool& y) { - if (y <= x) return x; - return run(); - } - static EIGEN_DEVICE_FUNC inline bool run() { return getRandomBits(1) ? true : false; } -}; - -template -struct random_default_impl { - static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) { - return Scalar(random(x.real(), y.real()), random(x.imag(), y.imag())); - } - static EIGEN_DEVICE_FUNC inline Scalar run() { - typedef typename NumTraits::Real RealScalar; - return Scalar(random(), random()); - } -}; - -template -inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random(const Scalar& x, const Scalar& y) { - return EIGEN_MATHFUNC_IMPL(random, Scalar)::run(x, y); -} - -template -inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random() { - return EIGEN_MATHFUNC_IMPL(random, Scalar)::run(); +int log2_floor(const BitsType& x) { + return log_2_impl::run_floor(x); } // Implementation of is* functions diff --git a/Eigen/src/Core/RandomImpl.h b/Eigen/src/Core/RandomImpl.h new file mode 100644 index 000000000..d5cd335fd --- /dev/null +++ b/Eigen/src/Core/RandomImpl.h @@ -0,0 +1,253 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2024 Charles Schlosser +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_RANDOM_IMPL_H +#define EIGEN_RANDOM_IMPL_H + +// IWYU pragma: private +#include "./InternalHeaderCheck.h" + +namespace Eigen { + +namespace internal { + +/**************************************************************************** + * Implementation of random * + ****************************************************************************/ + +template +struct random_default_impl {}; + +template +struct random_impl : random_default_impl::IsComplex, NumTraits::IsInteger> {}; + +template +struct random_retval { + typedef Scalar type; +}; + +template +inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random(const Scalar& x, const Scalar& y) { + return EIGEN_MATHFUNC_IMPL(random, Scalar)::run(x, y); +} + +template +inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random() { + return EIGEN_MATHFUNC_IMPL(random, Scalar)::run(); +} + +// TODO: replace or provide alternatives to this, e.g. std::random_device +struct eigen_random_device { + using ReturnType = int; + static constexpr int Entropy = meta_floor_log2<(unsigned int)(RAND_MAX) + 1>::value; + static constexpr ReturnType Highest = RAND_MAX; + static EIGEN_DEVICE_FUNC inline ReturnType run() { return std::rand(); }; +}; + +// Fill a built-in unsigned integer with numRandomBits beginning with the least significant bit +template +struct random_bits_impl { + EIGEN_STATIC_ASSERT(std::is_unsigned::value, SCALAR MUST BE A BUILT - IN UNSIGNED INTEGER) + using RandomDevice = eigen_random_device; + using RandomReturnType = typename RandomDevice::ReturnType; + static constexpr int kEntropy = RandomDevice::Entropy; + static constexpr int kTotalBits = sizeof(Scalar) * CHAR_BIT; + // return a Scalar filled with numRandomBits beginning from the least significant bit + static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits) { + eigen_assert((numRandomBits >= 0) && (numRandomBits <= kTotalBits)); + const Scalar mask = Scalar(-1) >> ((kTotalBits - numRandomBits) & (kTotalBits - 1)); + Scalar randomBits = 0; + for (int shift = 0; shift < numRandomBits; shift += kEntropy) { + RandomReturnType r = RandomDevice::run(); + randomBits |= static_cast(r) << shift; + } + // clear the excess bits + randomBits &= mask; + return randomBits; + } +}; + +template +EIGEN_DEVICE_FUNC inline BitsType getRandomBits(int numRandomBits) { + return random_bits_impl::run(numRandomBits); +} + +// random implementation for a built-in floating point type +template ::value> +struct random_float_impl { + using BitsType = typename numext::get_integer_by_size::unsigned_type; + static constexpr EIGEN_DEVICE_FUNC inline int mantissaBits() { + const int digits = NumTraits::digits(); + return digits - 1; + } + static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits) { + eigen_assert(numRandomBits >= 0 && numRandomBits <= mantissaBits()); + BitsType randomBits = getRandomBits(numRandomBits); + // if fewer than MantissaBits is requested, shift them to the left + randomBits <<= (mantissaBits() - numRandomBits); + // randomBits is in the half-open interval [2,4) + randomBits |= numext::bit_cast(Scalar(2)); + // result is in the half-open interval [-1,1) + Scalar result = numext::bit_cast(randomBits) - Scalar(3); + return result; + } +}; +// random implementation for a custom floating point type +// uses double as the implementation with a mantissa with a size equal to either the target scalar's mantissa or that of +// double, whichever is smaller +template +struct random_float_impl { + static EIGEN_DEVICE_FUNC inline int mantissaBits() { + const int digits = NumTraits::digits(); + constexpr int kDoubleDigits = NumTraits::digits(); + return numext::mini(digits, kDoubleDigits) - 1; + } + static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits) { + eigen_assert(numRandomBits >= 0 && numRandomBits <= mantissaBits()); + Scalar result = static_cast(random_float_impl::run(numRandomBits)); + return result; + } +}; + +// random implementation for long double +// this specialization is not compatible with double-double scalars +template ::digits != (2 * std::numeric_limits::digits)))> +struct random_longdouble_impl { + static constexpr int Size = sizeof(long double); + static constexpr EIGEN_DEVICE_FUNC inline int mantissaBits() { return NumTraits::digits() - 1; } + static EIGEN_DEVICE_FUNC inline long double run(int numRandomBits) { + eigen_assert(numRandomBits >= 0 && numRandomBits <= mantissaBits()); + EIGEN_USING_STD(memcpy); + int numLowBits = numext::mini(numRandomBits, 64); + int numHighBits = numext::maxi(numRandomBits - 64, 0); + uint64_t randomBits[2]; + long double result = 2.0L; + memcpy(&randomBits, &result, Size); + randomBits[0] |= getRandomBits(numLowBits); + randomBits[1] |= getRandomBits(numHighBits); + memcpy(&result, &randomBits, Size); + result -= 3.0L; + return result; + } +}; +template <> +struct random_longdouble_impl { + static constexpr EIGEN_DEVICE_FUNC inline int mantissaBits() { return NumTraits::digits() - 1; } + static EIGEN_DEVICE_FUNC inline long double run(int numRandomBits) { + return static_cast(random_float_impl::run(numRandomBits)); + } +}; +template <> +struct random_float_impl : random_longdouble_impl<> {}; + +template +struct random_default_impl { + using Impl = random_float_impl; + static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y, int numRandomBits) { + Scalar half_x = Scalar(0.5) * x; + Scalar half_y = Scalar(0.5) * y; + Scalar result = (half_x + half_y) + (half_y - half_x) * run(numRandomBits); + // result is in the half-open interval [x, y) -- provided that x < y + return result; + } + static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) { + return run(x, y, Impl::mantissaBits()); + } + static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits) { return Impl::run(numRandomBits); } + static EIGEN_DEVICE_FUNC inline Scalar run() { return run(Impl::mantissaBits()); } +}; + +template ::IsSigned, bool BuiltIn = std::is_integral::value> +struct random_int_impl; + +// random implementation for a built-in unsigned integer type +template +struct random_int_impl { + static constexpr int kTotalBits = sizeof(Scalar) * CHAR_BIT; + static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) { + if (y <= x) return x; + Scalar range = y - x; + // handle edge case where [x,y] spans the entire range of Scalar + if (range == NumTraits::highest()) return run(); + Scalar count = range + 1; + // calculate the number of random bits needed to fill range + int numRandomBits = log2_ceil(count); + Scalar randomBits; + do { + randomBits = getRandomBits(numRandomBits); + // if the random draw is outside [0, range), try again (rejection sampling) + // in the worst-case scenario, the probability of rejection is: 1/2 - 1/2^numRandomBits < 50% + } while (randomBits >= count); + Scalar result = x + randomBits; + return result; + } + static EIGEN_DEVICE_FUNC inline Scalar run() { return getRandomBits(kTotalBits); } +}; + +// random implementation for a built-in signed integer type +template +struct random_int_impl { + static constexpr int kTotalBits = sizeof(Scalar) * CHAR_BIT; + using BitsType = typename make_unsigned::type; + static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) { + if (y <= x) return x; + // Avoid overflow by representing `range` as an unsigned type + BitsType range = static_cast(y) - static_cast(x); + BitsType randomBits = random_int_impl::run(0, range); + // Avoid overflow in the case where `x` is negative and there is a large range so + // `randomBits` would also be negative if cast to `Scalar` first. + Scalar result = static_cast(static_cast(x) + randomBits); + return result; + } + static EIGEN_DEVICE_FUNC inline Scalar run() { return static_cast(getRandomBits(kTotalBits)); } +}; + +// todo: custom integers +template +struct random_int_impl { + static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar&, const Scalar&) { return run(); } + static EIGEN_DEVICE_FUNC inline Scalar run() { + eigen_assert(std::false_type::value && "RANDOM FOR CUSTOM INTEGERS NOT YET SUPPORTED"); + return Scalar(0); + } +}; + +template +struct random_default_impl : random_int_impl {}; + +template <> +struct random_impl { + static EIGEN_DEVICE_FUNC inline bool run(const bool& x, const bool& y) { + if (y <= x) return x; + return run(); + } + static EIGEN_DEVICE_FUNC inline bool run() { return getRandomBits(1) ? true : false; } +}; + +template +struct random_default_impl { + typedef typename NumTraits::Real RealScalar; + using Impl = random_impl; + static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y, int numRandomBits) { + return Scalar(Impl::run(x.real(), y.real(), numRandomBits), Impl::run(x.imag(), y.imag(), numRandomBits)); + } + static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) { + return Scalar(Impl::run(x.real(), y.real()), Impl::run(x.imag(), y.imag())); + } + static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits) { + return Scalar(Impl::run(numRandomBits), Impl::run(numRandomBits)); + } + static EIGEN_DEVICE_FUNC inline Scalar run() { return Scalar(Impl::run(), Impl::run()); } +}; + +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_RANDOM_IMPL_H diff --git a/test/AnnoyingScalar.h b/test/AnnoyingScalar.h index 637fdbfe1..00a20c7c7 100644 --- a/test/AnnoyingScalar.h +++ b/test/AnnoyingScalar.h @@ -184,19 +184,6 @@ EIGEN_STRONG_INLINE float cast(const AnnoyingScalar& x) { return *x.v; } -template <> -struct random_impl { - using Impl = random_impl; - static EIGEN_DEVICE_FUNC inline AnnoyingScalar run(const AnnoyingScalar& x, const AnnoyingScalar& y) { - float result = Impl::run(*x.v, *y.v); - return AnnoyingScalar(result); - } - static EIGEN_DEVICE_FUNC inline AnnoyingScalar run() { - float result = Impl::run(); - return AnnoyingScalar(result); - } -}; - } // namespace internal } // namespace Eigen diff --git a/test/MovableScalar.h b/test/MovableScalar.h index 56a873ee6..c8bf546df 100644 --- a/test/MovableScalar.h +++ b/test/MovableScalar.h @@ -26,24 +26,10 @@ struct MovableScalar : public Base { operator Scalar() const { return this->size() > 0 ? this->back() : Scalar(); } }; -template <> -struct NumTraits> : GenericNumTraits {}; - -namespace internal { -template -struct random_impl> { - using MoveableT = MovableScalar; - using Impl = random_impl; - static EIGEN_DEVICE_FUNC inline MoveableT run(const MoveableT& x, const MoveableT& y) { - T result = Impl::run(x, y); - return MoveableT(result); - } - static EIGEN_DEVICE_FUNC inline MoveableT run() { - T result = Impl::run(); - return MoveableT(result); - } +template +struct NumTraits> : GenericNumTraits { + enum { RequireInitialization = 1 }; }; -} // namespace internal } // namespace Eigen diff --git a/test/SafeScalar.h b/test/SafeScalar.h index 4f4da5605..33a54c5af 100644 --- a/test/SafeScalar.h +++ b/test/SafeScalar.h @@ -4,43 +4,30 @@ template class SafeScalar { public: SafeScalar() : initialized_(false) {} - SafeScalar(const SafeScalar& other) { *this = other; } - SafeScalar& operator=(const SafeScalar& other) { - val_ = T(other); - initialized_ = true; - return *this; - } - SafeScalar(T val) : val_(val), initialized_(true) {} - SafeScalar& operator=(T val) { - val_ = val; - initialized_ = true; - } + SafeScalar(const T& val) : val_(val), initialized_(true) {} + + template + explicit SafeScalar(const Source& val) : SafeScalar(T(val)) {} operator T() const { VERIFY(initialized_ && "Uninitialized access."); return val_; } + template + explicit operator Target() const { + return Target(this->operator T()); + } + private: T val_; bool initialized_; }; namespace Eigen { -namespace internal { template -struct random_impl> { - using SafeT = SafeScalar; - using Impl = random_impl; - static EIGEN_DEVICE_FUNC inline SafeT run(const SafeT& x, const SafeT& y) { - T result = Impl::run(x, y); - return SafeT(result); - } - static EIGEN_DEVICE_FUNC inline SafeT run() { - T result = Impl::run(); - return SafeT(result); - } +struct NumTraits> : GenericNumTraits { + enum { RequireInitialization = 1 }; }; -} // namespace internal -} // namespace Eigen +} // namespace Eigen \ No newline at end of file diff --git a/test/rand.cpp b/test/rand.cpp index 6a7c316d8..4131f3837 100644 --- a/test/rand.cpp +++ b/test/rand.cpp @@ -9,6 +9,10 @@ #include #include "main.h" +#include "SafeScalar.h" + +// SafeScalar is used to simulate custom Scalar types, which use a more generalized approach to generate random +// numbers // For GCC-6, if this function is inlined then there seems to be an optimization // bug that triggers a failure. This failure goes away if you access `r` in @@ -25,15 +29,28 @@ EIGEN_DONT_INLINE Scalar check_in_range(Scalar x, Scalar y) { template void check_all_in_range(Scalar x, Scalar y) { - Array mask(y - x + 1); - mask.fill(0); - int64_t n = (y - x + 1) * 32; - for (int64_t k = 0; k < n; ++k) { - mask(check_in_range(x, y) - x)++; - } + constexpr int repeats = 32; + uint64_t count = static_cast(y) - static_cast(x) + 1; + ArrayX mask(count); + // ensure that `count` does not overflow the return type of `mask.size()` + VERIFY(count == static_cast(mask.size())); + mask.setConstant(false); + for (uint64_t k = 0; k < count; k++) + for (int repeat = 0; repeat < repeats; repeat++) { + Scalar r = check_in_range(x, y); + Index i = static_cast(r) - static_cast(x); + mask(i) = true; + } for (Index i = 0; i < mask.size(); ++i) - if (mask(i) == 0) std::cout << "WARNING: value " << x + i << " not reached." << std::endl; - VERIFY((mask > 0).all()); + if (mask(i) == false) std::cout << "WARNING: value " << x + i << " not reached." << std::endl; + VERIFY(mask.cwiseEqual(true).all()); +} + +template +void check_all_in_range() { + const Scalar x = NumTraits::lowest(); + const Scalar y = NumTraits::highest(); + check_all_in_range(x, y); } template @@ -66,10 +83,16 @@ class HistogramHelper { double bin_width_; }; +// helper class to avoid extending std:: namespace +template +struct get_range_type : internal::make_unsigned {}; +template +struct get_range_type> : internal::make_unsigned {}; + template class HistogramHelper::IsInteger>> { public: - using RangeType = typename Eigen::internal::make_unsigned::type; + using RangeType = typename get_range_type::type; HistogramHelper(int nbins) : HistogramHelper(Eigen::NumTraits::lowest(), Eigen::NumTraits::highest(), nbins) {} HistogramHelper(Scalar lower, Scalar upper, int nbins) @@ -109,38 +132,59 @@ class HistogramHelper::IsInteg template void check_histogram(Scalar x, Scalar y, int bins) { + constexpr int repeats = 10000; + double count = double(bins) * double(repeats); Eigen::VectorXd hist = Eigen::VectorXd::Zero(bins); HistogramHelper hist_helper(x, y, bins); - int64_t n = static_cast(bins) * 10000; // Approx 10000 per bin. - for (int64_t k = 0; k < n; ++k) { - Scalar r = check_in_range(x, y); - int bin = hist_helper.bin(r); - hist(bin)++; - } - // Normalize bins by probability. + for (int k = 0; k < bins; k++) + for (int repeat = 0; repeat < repeats; repeat++) { + Scalar r = check_in_range(x, y); + int bin = hist_helper.bin(r); + hist(bin)++; + } + // Normalize bins by probability. + hist /= count; for (int i = 0; i < bins; ++i) { - hist(i) = hist(i) / n / hist_helper.uniform_bin_probability(i); + hist(i) = hist(i) / hist_helper.uniform_bin_probability(i); } VERIFY(((hist.array() - 1.0).abs() < 0.05).all()); } template void check_histogram(int bins) { + constexpr int repeats = 10000; + double count = double(bins) * double(repeats); Eigen::VectorXd hist = Eigen::VectorXd::Zero(bins); HistogramHelper hist_helper(bins); - int64_t n = static_cast(bins) * 10000; // Approx 10000 per bin. - for (int64_t k = 0; k < n; ++k) { - Scalar r = Eigen::internal::random(); - int bin = hist_helper.bin(r); - hist(bin)++; - } - // Normalize bins by probability. + for (int k = 0; k < bins; k++) + for (int repeat = 0; repeat < repeats; repeat++) { + Scalar r = Eigen::internal::random(); + int bin = hist_helper.bin(r); + hist(bin)++; + } + // Normalize bins by probability. + hist /= count; for (int i = 0; i < bins; ++i) { - hist(i) = hist(i) / n / hist_helper.uniform_bin_probability(i); + hist(i) = hist(i) / hist_helper.uniform_bin_probability(i); } VERIFY(((hist.array() - 1.0).abs() < 0.05).all()); } +template <> +void check_histogram(int) { + constexpr int bins = 2; + constexpr int repeats = 10000; + double count = double(bins) * double(repeats); + double true_count = 0.0; + for (int k = 0; k < bins; k++) + for (int repeat = 0; repeat < repeats; repeat++) { + bool r = Eigen::internal::random(); + if (r) true_count += 1.0; + } + double p = true_count / count; + VERIFY(numext::abs(p - 0.5) < 0.05); +} + EIGEN_DECLARE_TEST(rand) { int64_t int64_ref = NumTraits::highest() / 10; // the minimum guarantees that these conversions are safe @@ -191,14 +235,16 @@ EIGEN_DECLARE_TEST(rand) { CALL_SUBTEST_7(check_all_in_range(-11 - int8t_offset, -11)); CALL_SUBTEST_7(check_all_in_range(-126, -126 + int8t_offset)); CALL_SUBTEST_7(check_all_in_range(126 - int8t_offset, 126)); - CALL_SUBTEST_7(check_all_in_range(-126, 126)); + CALL_SUBTEST_7(check_all_in_range()); + CALL_SUBTEST_7(check_all_in_range()); CALL_SUBTEST_8(check_all_in_range(11, 11)); CALL_SUBTEST_8(check_all_in_range(11, 11 + int16t_offset)); CALL_SUBTEST_8(check_all_in_range(-5, 5)); CALL_SUBTEST_8(check_all_in_range(-11 - int16t_offset, -11)); CALL_SUBTEST_8(check_all_in_range(-24345, -24345 + int16t_offset)); - CALL_SUBTEST_8(check_all_in_range(24345, 24345 + int16t_offset)); + CALL_SUBTEST_8(check_all_in_range()); + CALL_SUBTEST_8(check_all_in_range()); CALL_SUBTEST_9(check_all_in_range(11, 11)); CALL_SUBTEST_9(check_all_in_range(11, 11 + g_repeat)); @@ -223,6 +269,7 @@ EIGEN_DECLARE_TEST(rand) { CALL_SUBTEST_11(check_histogram(-RAND_MAX + 10, -int64_t(RAND_MAX) + 10 + bins * (2 * int64_t(RAND_MAX) / bins) - 1, bins)); + CALL_SUBTEST_12(check_histogram(/*bins=*/2)); CALL_SUBTEST_12(check_histogram(/*bins=*/16)); CALL_SUBTEST_12(check_histogram(/*bins=*/1024)); CALL_SUBTEST_12(check_histogram(/*bins=*/1024)); @@ -238,10 +285,16 @@ EIGEN_DECLARE_TEST(rand) { CALL_SUBTEST_14(check_histogram(-10.0L, 10.0L, /*bins=*/1024)); CALL_SUBTEST_14(check_histogram(half(-10.0f), half(10.0f), /*bins=*/512)); CALL_SUBTEST_14(check_histogram(bfloat16(-10.0f), bfloat16(10.0f), /*bins=*/64)); + CALL_SUBTEST_14(check_histogram>(-10.0f, 10.0f, /*bins=*/1024)); + CALL_SUBTEST_14(check_histogram>(half(-10.0f), half(10.0f), /*bins=*/512)); + CALL_SUBTEST_14(check_histogram>(bfloat16(-10.0f), bfloat16(10.0f), /*bins=*/64)); CALL_SUBTEST_15(check_histogram(/*bins=*/1024)); CALL_SUBTEST_15(check_histogram(/*bins=*/1024)); CALL_SUBTEST_15(check_histogram(/*bins=*/1024)); CALL_SUBTEST_15(check_histogram(/*bins=*/512)); CALL_SUBTEST_15(check_histogram(/*bins=*/64)); + CALL_SUBTEST_15(check_histogram>(/*bins=*/1024)); + CALL_SUBTEST_15(check_histogram>(/*bins=*/512)); + CALL_SUBTEST_15(check_histogram>(/*bins=*/64)); }