// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2026 Rasmus Munk Larsen // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. // Tests for cuBLAS GEMM dispatch via DeviceMatrix expression syntax. // Covers: d_C = d_A * d_B, adjoint, transpose, scaled, +=, .device(ctx). #define EIGEN_USE_GPU #include "main.h" #include using namespace Eigen; // ---- Basic GEMM: C = A * B ------------------------------------------------- template void test_gemm_basic(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(m, k); Mat B = Mat::Random(k, n); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); // Expression: d_C = d_A * d_B DeviceMatrix d_C; d_C = d_A * d_B; Mat C = d_C.toHost(); Mat C_ref = A * B; RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM with adjoint: C = A^H * B ---------------------------------------- template void test_gemm_adjoint_lhs(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(k, m); // A is k×m, A^H is m×k Mat B = Mat::Random(k, n); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_C; d_C = d_A.adjoint() * d_B; Mat C = d_C.toHost(); Mat C_ref = A.adjoint() * B; RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM with transpose: C = A * B^T -------------------------------------- template void test_gemm_transpose_rhs(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(m, k); Mat B = Mat::Random(n, k); // B is n×k, B^T is k×n auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_C; d_C = d_A * d_B.transpose(); Mat C = d_C.toHost(); Mat C_ref = A * B.transpose(); RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM with scaled: C = alpha * A * B ------------------------------------ template void test_gemm_scaled(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(m, k); Mat B = Mat::Random(k, n); Scalar alpha = Scalar(2.5); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_C; d_C = alpha * d_A * d_B; Mat C = d_C.toHost(); Mat C_ref = alpha * A * B; RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM accumulate: C += A * B (beta=1) ----------------------------------- template void test_gemm_accumulate(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(m, k); Mat B = Mat::Random(k, n); Mat C_init = Mat::Random(m, n); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); auto d_C = DeviceMatrix::fromHost(C_init); d_C += d_A * d_B; Mat C = d_C.toHost(); Mat C_ref = C_init + A * B; RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM accumulate into empty destination --------------------------------- template void test_gemm_accumulate_empty(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(m, k); Mat B = Mat::Random(k, n); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_C; d_C += d_A * d_B; Mat C = d_C.toHost(); Mat C_ref = A * B; RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM subtract: C -= A * B (beta=1, alpha=-1) -------------------------- template void test_gemm_subtract(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(m, k); Mat B = Mat::Random(k, n); Mat C_init = Mat::Random(m, n); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); auto d_C = DeviceMatrix::fromHost(C_init); GpuContext ctx; d_C.device(ctx) -= d_A * d_B; Mat C = d_C.toHost(); Mat C_ref = C_init - A * B; RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM subtract from empty destination ----------------------------------- template void test_gemm_subtract_empty(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(m, k); Mat B = Mat::Random(k, n); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); GpuContext ctx; DeviceMatrix d_C; d_C.device(ctx) -= d_A * d_B; Mat C = d_C.toHost(); Mat C_ref = -(A * B); RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM with scaled RHS: C = A * (alpha * B) ----------------------------- template void test_gemm_scaled_rhs(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(m, k); Mat B = Mat::Random(k, n); Scalar alpha = Scalar(3.0); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_C; d_C = d_A * (alpha * d_B); Mat C = d_C.toHost(); Mat C_ref = A * (alpha * B); RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM dimension mismatch must assert ------------------------------------ template void test_gemm_dimension_mismatch() { using Mat = Matrix; Mat A = Mat::Random(8, 5); Mat B = Mat::Random(6, 7); // inner dimension mismatch auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_C; VERIFY_RAISES_ASSERT(d_C = d_A * d_B); } // ---- GEMM with explicit GpuContext ------------------------------------------ template void test_gemm_explicit_context(Index m, Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(m, k); Mat B = Mat::Random(k, n); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); GpuContext ctx; DeviceMatrix d_C; d_C.device(ctx) = d_A * d_B; Mat C = d_C.toHost(); Mat C_ref = A * B; RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM cross-context reuse of the same destination ----------------------- template void test_gemm_cross_context_reuse(Index n) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(n, n); Mat B = Mat::Random(n, n); Mat D = Mat::Random(n, n); Mat E = Mat::Random(n, n); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); auto d_D = DeviceMatrix::fromHost(D); auto d_E = DeviceMatrix::fromHost(E); GpuContext ctx1; GpuContext ctx2; DeviceMatrix d_C; d_C.device(ctx1) = d_A * d_B; d_C.device(ctx2) += d_D * d_E; Mat C = d_C.toHost(); Mat C_ref = A * B + D * E; RealScalar tol = RealScalar(2) * RealScalar(n) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM cross-context resize of the destination --------------------------- template void test_gemm_cross_context_resize() { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(64, 64); Mat B = Mat::Random(64, 64); Mat D = Mat::Random(32, 16); Mat E = Mat::Random(16, 8); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); auto d_D = DeviceMatrix::fromHost(D); auto d_E = DeviceMatrix::fromHost(E); GpuContext ctx1; GpuContext ctx2; DeviceMatrix d_C; d_C.device(ctx1) = d_A * d_B; d_C.device(ctx2) = d_D * d_E; Mat C = d_C.toHost(); Mat C_ref = D * E; RealScalar tol = RealScalar(16) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- GEMM chaining: C = (A * B) then D = C * E ----------------------------- template void test_gemm_chain(Index n) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(n, n); Mat B = Mat::Random(n, n); Mat E = Mat::Random(n, n); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); auto d_E = DeviceMatrix::fromHost(E); DeviceMatrix d_C; d_C = d_A * d_B; DeviceMatrix d_D; d_D = d_C * d_E; Mat D = d_D.toHost(); Mat D_ref = (A * B) * E; RealScalar tol = RealScalar(2) * RealScalar(n) * NumTraits::epsilon() * D_ref.norm(); VERIFY((D - D_ref).norm() < tol); } // ---- Square identity check: A * I = A --------------------------------------- template void test_gemm_identity(Index n) { using Mat = Matrix; Mat A = Mat::Random(n, n); Mat eye = Mat::Identity(n, n); auto d_A = DeviceMatrix::fromHost(A); auto d_I = DeviceMatrix::fromHost(eye); DeviceMatrix d_C; d_C = d_A * d_I; Mat C = d_C.toHost(); VERIFY_IS_APPROX(C, A); } // ---- LLT solve expression: d_X = d_A.llt().solve(d_B) ---------------------- template MatrixType make_spd(Index n) { using Scalar = typename MatrixType::Scalar; MatrixType M = MatrixType::Random(n, n); return M.adjoint() * M + MatrixType::Identity(n, n) * static_cast(n); } template void test_llt_solve_expr(Index n, Index nrhs) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = make_spd(n); Mat B = Mat::Random(n, nrhs); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_X; d_X = d_A.llt().solve(d_B); Mat X = d_X.toHost(); RealScalar residual = (A * X - B).norm() / B.norm(); VERIFY(residual < RealScalar(n) * NumTraits::epsilon()); } // ---- LLT solve with explicit context ---------------------------------------- template void test_llt_solve_expr_context(Index n, Index nrhs) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = make_spd(n); Mat B = Mat::Random(n, nrhs); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); GpuContext ctx; DeviceMatrix d_X; d_X.device(ctx) = d_A.llt().solve(d_B); Mat X = d_X.toHost(); RealScalar residual = (A * X - B).norm() / B.norm(); VERIFY(residual < RealScalar(n) * NumTraits::epsilon()); } // ---- LU solve expression: d_X = d_A.lu().solve(d_B) ------------------------ template void test_lu_solve_expr(Index n, Index nrhs) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(n, n); Mat B = Mat::Random(n, nrhs); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_X; d_X = d_A.lu().solve(d_B); Mat X = d_X.toHost(); RealScalar residual = (A * X - B).norm() / (A.norm() * X.norm()); VERIFY(residual < RealScalar(10) * RealScalar(n) * NumTraits::epsilon()); } // ---- GEMM + solver chain: C = A * B, X = C.llt().solve(D) ------------------ template void test_gemm_then_solve(Index n) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(n, n); Mat D = Mat::Random(n, 1); // Make SPD: C = A^H * A + n*I auto d_A = DeviceMatrix::fromHost(A); DeviceMatrix d_C; d_C = d_A.adjoint() * d_A; // Add n*I on host (no element-wise ops on DeviceMatrix yet). Mat C = d_C.toHost(); C += Mat::Identity(n, n) * static_cast(n); d_C = DeviceMatrix::fromHost(C); auto d_D = DeviceMatrix::fromHost(D); DeviceMatrix d_X; d_X = d_C.llt().solve(d_D); Mat X = d_X.toHost(); RealScalar residual = (C * X - D).norm() / D.norm(); VERIFY(residual < RealScalar(n) * NumTraits::epsilon()); } // ---- LLT solve with Upper triangle ----------------------------------------- template void test_llt_solve_upper(Index n, Index nrhs) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = make_spd(n); Mat B = Mat::Random(n, nrhs); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_X; d_X = d_A.template llt().solve(d_B); Mat X = d_X.toHost(); RealScalar residual = (A * X - B).norm() / B.norm(); VERIFY(residual < RealScalar(n) * NumTraits::epsilon()); } // ---- LU solve with explicit context ----------------------------------------- template void test_lu_solve_expr_context(Index n, Index nrhs) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(n, n); Mat B = Mat::Random(n, nrhs); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); GpuContext ctx; DeviceMatrix d_X; d_X.device(ctx) = d_A.lu().solve(d_B); Mat X = d_X.toHost(); RealScalar residual = (A * X - B).norm() / (A.norm() * X.norm()); VERIFY(residual < RealScalar(10) * RealScalar(n) * NumTraits::epsilon()); } // ---- Zero-nrhs solver expressions ------------------------------------------ template void test_llt_solve_zero_nrhs(Index n) { using Mat = Matrix; Mat A = make_spd(n); Mat B = Mat::Random(n, 0); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_X; d_X = d_A.llt().solve(d_B); VERIFY_IS_EQUAL(d_X.rows(), n); VERIFY_IS_EQUAL(d_X.cols(), 0); } template void test_lu_solve_zero_nrhs(Index n) { using Mat = Matrix; Mat A = Mat::Random(n, n); Mat B = Mat::Random(n, 0); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_X; d_X = d_A.lu().solve(d_B); VERIFY_IS_EQUAL(d_X.rows(), n); VERIFY_IS_EQUAL(d_X.cols(), 0); } // ---- TRSM: triangularView().solve(B) ---------------------------------- template void test_trsm(Index n, Index nrhs) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; // Build a well-conditioned triangular matrix. Mat A = Mat::Random(n, n); A.diagonal().array() += static_cast(n); // ensure non-singular if (UpLo == Lower) A = A.template triangularView(); else A = A.template triangularView(); Mat B = Mat::Random(n, nrhs); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_X; d_X = d_A.template triangularView().solve(d_B); Mat X = d_X.toHost(); RealScalar residual = (A * X - B).norm() / B.norm(); VERIFY(residual < RealScalar(n) * NumTraits::epsilon()); } // ---- SYMM/HEMM: selfadjointView() * B -------------------------------- template void test_symm(Index n, Index nrhs) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = make_spd(n); // SPD is also self-adjoint Mat B = Mat::Random(n, nrhs); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_C; d_C = d_A.template selfadjointView() * d_B; Mat C = d_C.toHost(); Mat C_ref = A * B; // A is symmetric, so full multiply == symm RealScalar tol = RealScalar(n) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C - C_ref).norm() < tol); } // ---- SYRK/HERK: rankUpdate(A) → C = A * A^H -------------------------------- template void test_syrk(Index n, Index k) { using Mat = Matrix; using RealScalar = typename NumTraits::Real; Mat A = Mat::Random(n, k); auto d_A = DeviceMatrix::fromHost(A); DeviceMatrix d_C; d_C.template selfadjointView().rankUpdate(d_A); Mat C = d_C.toHost(); // Only lower triangle is meaningful for SYRK. Compare lower triangle. Mat C_ref = A * A.adjoint(); // Extract lower triangle for comparison. Mat C_lower = C.template triangularView(); Mat C_ref_lower = C_ref.template triangularView(); RealScalar tol = RealScalar(k) * NumTraits::epsilon() * C_ref.norm(); VERIFY((C_lower - C_ref_lower).norm() < tol); } // ---- Per-scalar driver ------------------------------------------------------ template void test_scalar() { CALL_SUBTEST(test_gemm_basic(64, 64, 64)); CALL_SUBTEST(test_gemm_basic(128, 64, 32)); CALL_SUBTEST(test_gemm_basic(1, 1, 1)); CALL_SUBTEST(test_gemm_basic(256, 256, 256)); CALL_SUBTEST(test_gemm_adjoint_lhs(64, 64, 64)); CALL_SUBTEST(test_gemm_adjoint_lhs(128, 32, 64)); CALL_SUBTEST(test_gemm_transpose_rhs(64, 64, 64)); CALL_SUBTEST(test_gemm_transpose_rhs(128, 32, 64)); CALL_SUBTEST(test_gemm_scaled(64, 64, 64)); CALL_SUBTEST(test_gemm_scaled_rhs(64, 64, 64)); CALL_SUBTEST(test_gemm_accumulate(64, 64, 64)); CALL_SUBTEST(test_gemm_accumulate_empty(64, 64, 64)); CALL_SUBTEST(test_gemm_subtract(64, 64, 64)); CALL_SUBTEST(test_gemm_subtract_empty(64, 64, 64)); CALL_SUBTEST(test_gemm_dimension_mismatch()); CALL_SUBTEST(test_gemm_explicit_context(64, 64, 64)); CALL_SUBTEST(test_gemm_cross_context_reuse(64)); CALL_SUBTEST(test_gemm_cross_context_resize()); CALL_SUBTEST(test_gemm_chain(64)); CALL_SUBTEST(test_gemm_identity(64)); // Solver expressions — zero-size edge cases (use dedicated tests, not residual-based) // Solver expressions CALL_SUBTEST(test_llt_solve_expr(64, 1)); CALL_SUBTEST(test_llt_solve_expr(64, 4)); CALL_SUBTEST(test_llt_solve_expr(256, 8)); CALL_SUBTEST(test_llt_solve_expr_context(64, 4)); CALL_SUBTEST(test_llt_solve_upper(64, 4)); CALL_SUBTEST(test_lu_solve_expr(64, 1)); CALL_SUBTEST(test_lu_solve_expr(64, 4)); CALL_SUBTEST(test_lu_solve_expr(256, 8)); CALL_SUBTEST(test_lu_solve_expr_context(64, 4)); CALL_SUBTEST(test_llt_solve_zero_nrhs(64)); CALL_SUBTEST(test_llt_solve_zero_nrhs(0)); CALL_SUBTEST(test_lu_solve_zero_nrhs(64)); CALL_SUBTEST(test_lu_solve_zero_nrhs(0)); CALL_SUBTEST(test_gemm_then_solve(64)); // TRSM CALL_SUBTEST((test_trsm(64, 1))); CALL_SUBTEST((test_trsm(64, 4))); CALL_SUBTEST((test_trsm(64, 4))); CALL_SUBTEST((test_trsm(256, 8))); // SYMM/HEMM CALL_SUBTEST((test_symm(64, 4))); CALL_SUBTEST((test_symm(64, 4))); CALL_SUBTEST((test_symm(128, 8))); // SYRK/HERK CALL_SUBTEST(test_syrk(64, 64)); CALL_SUBTEST(test_syrk(64, 32)); CALL_SUBTEST(test_syrk(128, 64)); } // ---- Solver failure mode tests (not templated on Scalar) -------------------- void test_llt_not_spd() { // Negative definite matrix — LLT factorization must fail. MatrixXd A = -MatrixXd::Identity(8, 8); MatrixXd B = MatrixXd::Random(8, 1); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_X; VERIFY_RAISES_ASSERT(d_X = d_A.llt().solve(d_B)); } void test_lu_singular() { // Zero matrix — LU factorization must detect singularity. MatrixXd A = MatrixXd::Zero(8, 8); MatrixXd B = MatrixXd::Random(8, 1); auto d_A = DeviceMatrix::fromHost(A); auto d_B = DeviceMatrix::fromHost(B); DeviceMatrix d_X; VERIFY_RAISES_ASSERT(d_X = d_A.lu().solve(d_B)); } EIGEN_DECLARE_TEST(gpu_cublas) { CALL_SUBTEST(test_scalar()); CALL_SUBTEST(test_scalar()); CALL_SUBTEST(test_scalar>()); CALL_SUBTEST(test_scalar>()); CALL_SUBTEST(test_llt_not_spd()); CALL_SUBTEST(test_lu_singular()); }