From 11e97c545a71c58ba3403bd424b1ac869465fcee Mon Sep 17 00:00:00 2001 From: mayge Date: Tue, 7 Apr 2026 17:16:29 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E6=B7=BB=E5=8A=A0=E6=95=A3=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/backwardad.hpp | 11 +++++++++++ include/forwardad.hpp | 18 ++++++++++++++++++ include/types/common.hpp | 2 ++ tests/test_backward.cpp | 7 +++++++ tests/test_forward.cpp | 6 ++++++ 5 files changed, 44 insertions(+) diff --git a/include/backwardad.hpp b/include/backwardad.hpp index 358ed01..fd0d01c 100644 --- a/include/backwardad.hpp +++ b/include/backwardad.hpp @@ -51,6 +51,17 @@ ADResult diff(const Func f, Args... args) { for (auto& in : inputs) res.gradient.push_back(in.gradient()); + // 散度 + res.divergence = 0.0; + for (const auto& g : res.gradient) res.divergence += g; + + // 旋度 + // 标量场的旋度始终为0 + res.curl.resize(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + res.curl[i] = 0.0; // 标量场的旋度为0 + } + return res; } } // namespace backwardad \ No newline at end of file diff --git a/include/forwardad.hpp b/include/forwardad.hpp index 34823cf..21a05f5 100644 --- a/include/forwardad.hpp +++ b/include/forwardad.hpp @@ -18,6 +18,8 @@ ADResult diff(const Func& f, Args... args) { // 2. 对每个输入变量求偏导(seed 依次设为 1) double args_arr[] = { (double)args... }; + + // 梯度 [&](std::index_sequence) { ([&]() { // 创建 Dual 输入,第 Is 个变量导数=1,其余=0 @@ -31,6 +33,22 @@ ADResult diff(const Func& f, Args... args) { }(), ...); }(std::make_index_sequence{}); + // 散度 + res.divergence = 0.0; + [&](std::index_sequence) { + ([&]() { + res.divergence += [&](std::index_sequence) { + return res.gradient[Is]; // 这里直接用梯度值,因为散度是梯度的和 + }(std::make_index_sequence{}); + }(), ...); + }(std::make_index_sequence{}); + + // 旋度 + // 标量场的旋度始终为0 + res.curl.resize(N); + for (size_t i = 0; i < N; ++i) { + res.curl[i] = 0.0; // 标量场的旋度为0 + } return res; } } // namespace forwardad \ No newline at end of file diff --git a/include/types/common.hpp b/include/types/common.hpp index 5efbe43..2109111 100644 --- a/include/types/common.hpp +++ b/include/types/common.hpp @@ -6,4 +6,6 @@ struct ADResult{ double value; // 函数值 std::vector gradient; // 梯度 + std::vector curl; // 旋度 + double divergence; // 散度 }; diff --git a/tests/test_backward.cpp b/tests/test_backward.cpp index 1172b6c..df3f3df 100644 --- a/tests/test_backward.cpp +++ b/tests/test_backward.cpp @@ -21,16 +21,23 @@ int test_backward_ad() { auto r = backwardad::diff(f2, 2.0); std::cout << "value = " << r.value << "\n"; std::cout << "grad = " << r.gradient[0] << "\n"; + std::cout << "div = " << r.divergence << "\n"; + std::cout << "curl = " << r.curl[0] << "\n"; std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n"; auto r2 = backwardad::diff(g2, 1.0, 2.0); std::cout << "value = " << r2.value << "\n"; std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n"; + std::cout << "div = " << r2.divergence << "\n"; + std::cout << "curl = (" << r2.curl[0] << ", " << r2.curl[1] << ")\n"; std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n"; auto r3 = backwardad::diff(h2, 1.0, 2.0, 3.0); std::cout << "value = " << r3.value << "\n"; std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n"; + std::cout << "div = " << r3.divergence << "\n"; + std::cout << "curl = (" << r3.curl[0] << ", " << r3.curl[1] << ", " << r3.curl[2] << ")\n"; + return 0; } diff --git a/tests/test_forward.cpp b/tests/test_forward.cpp index 3c14b3b..f649633 100644 --- a/tests/test_forward.cpp +++ b/tests/test_forward.cpp @@ -21,16 +21,22 @@ int test_forward_ad() { auto r = forwardad::diff(f1, 2.0); std::cout << "value = " << r.value << "\n"; std::cout << "grad = " << r.gradient[0] << "\n"; + std::cout << "div = " << r.divergence << "\n"; + std::cout << "curl = " << r.curl[0] << "\n"; std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n"; auto r2 = forwardad::diff(g1, 1.0, 2.0); std::cout << "value = " << r2.value << "\n"; std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n"; + std::cout << "div = " << r2.divergence << "\n"; + std::cout << "curl = (" << r2.curl[0] << ", " << r2.curl[1] << ")\n"; std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n"; auto r3 = forwardad::diff(h1, 1.0, 2.0, 3.0); std::cout << "value = " << r3.value << "\n"; std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n"; + std::cout << "div = " << r3.divergence << "\n"; + std::cout << "curl = (" << r3.curl[0] << ", " << r3.curl[1] << ", " << r3.curl[2] << ")\n"; return 0; }