Files
eigen/unsupported/test/cxx11_tensor_uint128.cpp
Pavel Guzenfeld 821ab7d3e6 Fix TensorUInt128 division infinite loop on overflow
libeigen/eigen!2300

Closes #3012

Co-authored-by: Pavel Guzenfeld <67074795+PavelGuzenfeld@users.noreply.github.com>
2026-03-20 15:41:00 +00:00

179 lines
6.5 KiB
C++

// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include "main.h"
#include <Eigen/Tensor>
#if EIGEN_COMP_MSVC || !defined(__SIZEOF_INT128__)
#define EIGEN_NO_INT128
#else
typedef __uint128_t uint128_t;
#endif
// Only run the test on compilers that support 128bit integers natively
#ifndef EIGEN_NO_INT128
using Eigen::internal::static_val;
using Eigen::internal::TensorUInt128;
void VERIFY_EQUAL(TensorUInt128<uint64_t, uint64_t> actual, uint128_t expected) {
bool matchl = actual.lower() == static_cast<uint64_t>(expected);
bool matchh = actual.upper() == static_cast<uint64_t>(expected >> 64);
if (!matchl || !matchh) {
const char* testname = g_test_stack.back().c_str();
std::cerr << "Test " << testname << " failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
abort();
}
}
void test_add() {
uint64_t incr = internal::random<uint64_t>(1, 9999999999);
for (uint64_t i1 = 0; i1 < 100; ++i1) {
for (uint64_t i2 = 1; i2 < 100 * incr; i2 += incr) {
TensorUInt128<uint64_t, uint64_t> i(i1, i2);
uint128_t a = (static_cast<uint128_t>(i1) << 64) + static_cast<uint128_t>(i2);
for (uint64_t j1 = 0; j1 < 100; ++j1) {
for (uint64_t j2 = 1; j2 < 100 * incr; j2 += incr) {
TensorUInt128<uint64_t, uint64_t> j(j1, j2);
uint128_t b = (static_cast<uint128_t>(j1) << 64) + static_cast<uint128_t>(j2);
TensorUInt128<uint64_t, uint64_t> actual = i + j;
uint128_t expected = a + b;
VERIFY_EQUAL(actual, expected);
}
}
}
}
}
void test_sub() {
uint64_t incr = internal::random<uint64_t>(1, 9999999999);
for (uint64_t i1 = 0; i1 < 100; ++i1) {
for (uint64_t i2 = 1; i2 < 100 * incr; i2 += incr) {
TensorUInt128<uint64_t, uint64_t> i(i1, i2);
uint128_t a = (static_cast<uint128_t>(i1) << 64) + static_cast<uint128_t>(i2);
for (uint64_t j1 = 0; j1 < 100; ++j1) {
for (uint64_t j2 = 1; j2 < 100 * incr; j2 += incr) {
TensorUInt128<uint64_t, uint64_t> j(j1, j2);
uint128_t b = (static_cast<uint128_t>(j1) << 64) + static_cast<uint128_t>(j2);
TensorUInt128<uint64_t, uint64_t> actual = i - j;
uint128_t expected = a - b;
VERIFY_EQUAL(actual, expected);
}
}
}
}
}
void test_mul() {
uint64_t incr = internal::random<uint64_t>(1, 9999999999);
for (uint64_t i1 = 0; i1 < 100; ++i1) {
for (uint64_t i2 = 1; i2 < 100 * incr; i2 += incr) {
TensorUInt128<uint64_t, uint64_t> i(i1, i2);
uint128_t a = (static_cast<uint128_t>(i1) << 64) + static_cast<uint128_t>(i2);
for (uint64_t j1 = 0; j1 < 100; ++j1) {
for (uint64_t j2 = 1; j2 < 100 * incr; j2 += incr) {
TensorUInt128<uint64_t, uint64_t> j(j1, j2);
uint128_t b = (static_cast<uint128_t>(j1) << 64) + static_cast<uint128_t>(j2);
TensorUInt128<uint64_t, uint64_t> actual = i * j;
uint128_t expected = a * b;
VERIFY_EQUAL(actual, expected);
}
}
}
}
}
void test_div() {
uint64_t incr = internal::random<uint64_t>(1, 9999999999);
for (uint64_t i1 = 0; i1 < 100; ++i1) {
for (uint64_t i2 = 1; i2 < 100 * incr; i2 += incr) {
TensorUInt128<uint64_t, uint64_t> i(i1, i2);
uint128_t a = (static_cast<uint128_t>(i1) << 64) + static_cast<uint128_t>(i2);
for (uint64_t j1 = 0; j1 < 100; ++j1) {
for (uint64_t j2 = 1; j2 < 100 * incr; j2 += incr) {
TensorUInt128<uint64_t, uint64_t> j(j1, j2);
uint128_t b = (static_cast<uint128_t>(j1) << 64) + static_cast<uint128_t>(j2);
TensorUInt128<uint64_t, uint64_t> actual = i / j;
uint128_t expected = a / b;
VERIFY_EQUAL(actual, expected);
}
}
}
}
}
void test_misc1() {
uint64_t incr = internal::random<uint64_t>(1, 9999999999);
for (uint64_t i2 = 1; i2 < 100 * incr; i2 += incr) {
TensorUInt128<static_val<0>, uint64_t> i(0, i2);
uint128_t a = static_cast<uint128_t>(i2);
for (uint64_t j2 = 1; j2 < 100 * incr; j2 += incr) {
TensorUInt128<static_val<0>, uint64_t> j(0, j2);
uint128_t b = static_cast<uint128_t>(j2);
uint64_t actual = (i * j).upper();
uint64_t expected = (a * b) >> 64;
VERIFY_IS_EQUAL(actual, expected);
}
}
}
void test_div_overflow() {
// Regression test for infinite loop when lhs > 2^127 (issue #3012).
// Division would overflow d during doubling, causing an infinite loop.
TensorUInt128<uint64_t, uint64_t> a(1ULL << 63, 1);
TensorUInt128<uint64_t, uint64_t> b(2);
uint128_t expected = ((static_cast<uint128_t>(1ULL << 63) << 64) + 1) / 2;
VERIFY_EQUAL(a / b, expected);
// Also test with high bits in both operands
TensorUInt128<uint64_t, uint64_t> c(UINT64_MAX, UINT64_MAX);
TensorUInt128<uint64_t, uint64_t> d(1);
uint128_t c128 = (static_cast<uint128_t>(UINT64_MAX) << 64) | UINT64_MAX;
VERIFY_EQUAL(c / d, c128);
TensorUInt128<uint64_t, uint64_t> e(UINT64_MAX, UINT64_MAX);
TensorUInt128<uint64_t, uint64_t> f(0, 3);
uint128_t e128 = (static_cast<uint128_t>(UINT64_MAX) << 64) | UINT64_MAX;
VERIFY_EQUAL(e / f, e128 / 3);
}
void test_misc2() {
int64_t incr = internal::random<int64_t>(1, 100);
for (int64_t log_div = 0; log_div < 63; ++log_div) {
for (int64_t divider = 1; divider <= 1000000 * incr; divider += incr) {
uint64_t expected = (static_cast<uint128_t>(1) << (64 + log_div)) / static_cast<uint128_t>(divider) -
(static_cast<uint128_t>(1) << 64) + 1;
uint64_t shift = 1ULL << log_div;
TensorUInt128<uint64_t, uint64_t> result =
(TensorUInt128<uint64_t, static_val<0> >(shift, 0) / TensorUInt128<static_val<0>, uint64_t>(divider) -
TensorUInt128<static_val<1>, static_val<0> >(1, 0) + TensorUInt128<static_val<0>, static_val<1> >(1));
uint64_t actual = static_cast<uint64_t>(result);
VERIFY_IS_EQUAL(actual, expected);
}
}
}
#endif
EIGEN_DECLARE_TEST(cxx11_tensor_uint128) {
#ifdef EIGEN_NO_INT128
// Skip the test on compilers that don't support 128bit integers natively
return;
#else
CALL_SUBTEST_1(test_add());
CALL_SUBTEST_2(test_sub());
CALL_SUBTEST_3(test_mul());
CALL_SUBTEST_4(test_div());
CALL_SUBTEST_5(test_misc1());
CALL_SUBTEST_6(test_misc2());
CALL_SUBTEST_7(test_div_overflow());
#endif
}