cxx11_tensor_random: use retry loop for low-precision RNG collisions

libeigen/eigen!2269

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
This commit is contained in:
Rasmus Munk Larsen
2026-03-08 16:19:48 -07:00
parent f80d7b8254
commit a3cb1c6591

View File

@@ -14,44 +14,44 @@
template <typename Scalar>
static void test_default() {
Tensor<Scalar, 1> vec(6);
vec.setRandom();
// Fixme: we should check that the generated numbers follow a uniform
// distribution instead.
// For low-precision types (half, bfloat16), the RNG has limited distinct
// values (e.g. 128 for bfloat16), so adjacent collisions are statistically
// inevitable. Only verify that not all values are identical.
if (sizeof(Scalar) <= 2) {
bool has_distinct = false;
for (int i = 1; i < 6 && !has_distinct; ++i) {
if (vec(i) != vec(i - 1)) has_distinct = true;
}
VERIFY(has_distinct);
} else {
// values (e.g. 128 for bfloat16), so adjacent collisions are possible.
// Retry a few times to avoid spurious failures.
bool all_distinct = false;
for (int attempt = 0; attempt < 10 && !all_distinct; ++attempt) {
vec.setRandom();
all_distinct = true;
for (int i = 1; i < 6; ++i) {
VERIFY_IS_NOT_EQUAL(vec(i), vec(i - 1));
if (vec(i) == vec(i - 1)) {
all_distinct = false;
break;
}
}
}
VERIFY(all_distinct);
}
template <typename Scalar>
static void test_normal() {
Tensor<Scalar, 1> vec(6);
vec.template setRandom<Eigen::internal::NormalRandomGenerator<Scalar>>();
// Fixme: we should check that the generated numbers follow a gaussian
// distribution instead.
if (sizeof(Scalar) <= 2) {
bool has_distinct = false;
for (int i = 1; i < 6 && !has_distinct; ++i) {
if (vec(i) != vec(i - 1)) has_distinct = true;
}
VERIFY(has_distinct);
} else {
bool all_distinct = false;
for (int attempt = 0; attempt < 10 && !all_distinct; ++attempt) {
vec.template setRandom<Eigen::internal::NormalRandomGenerator<Scalar>>();
all_distinct = true;
for (int i = 1; i < 6; ++i) {
VERIFY_IS_NOT_EQUAL(vec(i), vec(i - 1));
if (vec(i) == vec(i - 1)) {
all_distinct = false;
break;
}
}
}
VERIFY(all_distinct);
}
struct MyGenerator {