Slightly simplify ForkJoin code, and make sure the test is actually run.

This commit is contained in:
Rasmus Munk Larsen
2025-02-25 17:22:43 +00:00
parent 6aebfa9acc
commit 72adf891d5
4 changed files with 69 additions and 88 deletions

View File

@@ -12,39 +12,26 @@
#include "Eigen/ThreadPool"
struct TestData {
ThreadPool tp;
std::unique_ptr<ThreadPool> tp;
std::vector<double> data;
};
TestData make_test_data(int num_threads, int num_shards) {
return {ThreadPool(num_threads), std::vector<double>(num_shards, 1.0)};
return {std::make_unique<ThreadPool>(num_threads), std::vector<double>(num_shards, 1.0)};
}
static void test_unary_parallel_for(int granularity) {
static void test_parallel_for(int granularity) {
// Test correctness.
const int kNumTasks = 100000;
TestData test_data = make_test_data(/*num_threads=*/4, kNumTasks);
std::atomic<double> sum = 0.0;
std::function<void(int)> unary_do_fn = [&](int i) {
for (double new_sum = sum; !sum.compare_exchange_weak(new_sum, new_sum + test_data.data[i]);) {
};
};
ForkJoinScheduler::ParallelFor(0, kNumTasks, granularity, std::move(unary_do_fn), &test_data.tp);
VERIFY_IS_EQUAL(sum, kNumTasks);
}
static void test_binary_parallel_for(int granularity) {
// Test correctness.
const int kNumTasks = 100000;
TestData test_data = make_test_data(/*num_threads=*/4, kNumTasks);
std::atomic<double> sum = 0.0;
std::function<void(int, int)> binary_do_fn = [&](int i, int j) {
std::atomic<uint64_t> sum(0);
std::function<void(Index, Index)> binary_do_fn = [&](Index i, Index j) {
for (int k = i; k < j; ++k)
for (double new_sum = sum; !sum.compare_exchange_weak(new_sum, new_sum + test_data.data[k]);) {
for (uint64_t new_sum = sum; !sum.compare_exchange_weak(new_sum, new_sum + test_data.data[k]);) {
};
};
ForkJoinScheduler::ParallelFor(0, kNumTasks, granularity, std::move(binary_do_fn), &test_data.tp);
VERIFY_IS_EQUAL(sum, kNumTasks);
ForkJoinScheduler::ParallelFor(0, kNumTasks, granularity, std::move(binary_do_fn), test_data.tp.get());
VERIFY_IS_EQUAL(sum.load(), kNumTasks);
}
static void test_async_parallel_for() {
@@ -54,26 +41,26 @@ static void test_async_parallel_for() {
const int kNumTasks = 100;
const int kNumAsyncCalls = kNumThreads * 4;
TestData test_data = make_test_data(kNumThreads, kNumTasks);
std::atomic<double> sum = 0.0;
std::function<void(int)> unary_do_fn = [&](int i) {
for (double new_sum = sum; !sum.compare_exchange_weak(new_sum, new_sum + test_data.data[i]);) {
};
std::atomic<uint64_t> sum(0);
std::function<void(Index, Index)> binary_do_fn = [&](Index i, Index j) {
for (Index k = i; k < j; ++k) {
for (uint64_t new_sum = sum; !sum.compare_exchange_weak(new_sum, new_sum + test_data.data[i]);) {
}
}
};
Barrier barrier(kNumTasks * kNumAsyncCalls);
Barrier barrier(kNumAsyncCalls);
std::function<void()> done = [&]() { barrier.Notify(); };
for (int k = 0; k < kNumAsyncCalls; ++k) {
test_data.tp.Schedule([&]() {
ForkJoinScheduler::ParallelForAsync(0, kNumTasks, /*granularity=*/1, unary_do_fn, done, &test_data.tp);
test_data.tp->Schedule([&]() {
ForkJoinScheduler::ParallelForAsync(0, kNumTasks, /*granularity=*/1, binary_do_fn, done, test_data.tp.get());
});
}
barrier.Wait();
VERIFY_IS_EQUAL(sum, kNumTasks * kNumAsyncCalls);
VERIFY_IS_EQUAL(sum.load(), kNumTasks * kNumAsyncCalls);
}
EIGEN_DECLARE_TEST(fork_join) {
CALL_SUBTEST(test_unary_parallel_for(1));
CALL_SUBTEST(test_unary_parallel_for(2));
CALL_SUBTEST(test_binary_parallel_for(1));
CALL_SUBTEST(test_binary_parallel_for(2));
CALL_SUBTEST(test_parallel_for(1));
CALL_SUBTEST(test_parallel_for(2));
CALL_SUBTEST(test_async_parallel_for());
}