chore: 添加散度
This commit is contained in:
@@ -51,6 +51,17 @@ ADResult diff(const Func f, Args... args) {
|
||||
for (auto& in : inputs)
|
||||
res.gradient.push_back(in.gradient());
|
||||
|
||||
// 散度
|
||||
res.divergence = 0.0;
|
||||
for (const auto& g : res.gradient) res.divergence += g;
|
||||
|
||||
// 旋度
|
||||
// 标量场的旋度始终为0
|
||||
res.curl.resize(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
res.curl[i] = 0.0; // 标量场的旋度为0
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
} // namespace backwardad
|
||||
@@ -18,6 +18,8 @@ ADResult diff(const Func& f, Args... args) {
|
||||
|
||||
// 2. 对每个输入变量求偏导(seed 依次设为 1)
|
||||
double args_arr[] = { (double)args... };
|
||||
|
||||
// 梯度
|
||||
[&]<size_t... Is>(std::index_sequence<Is...>) {
|
||||
([&]() {
|
||||
// 创建 Dual 输入,第 Is 个变量导数=1,其余=0
|
||||
@@ -31,6 +33,22 @@ ADResult diff(const Func& f, Args... args) {
|
||||
}(), ...);
|
||||
}(std::make_index_sequence<N>{});
|
||||
|
||||
// 散度
|
||||
res.divergence = 0.0;
|
||||
[&]<size_t... Is>(std::index_sequence<Is...>) {
|
||||
([&]() {
|
||||
res.divergence += [&]<size_t... Js>(std::index_sequence<Js...>) {
|
||||
return res.gradient[Is]; // 这里直接用梯度值,因为散度是梯度的和
|
||||
}(std::make_index_sequence<N>{});
|
||||
}(), ...);
|
||||
}(std::make_index_sequence<N>{});
|
||||
|
||||
// 旋度
|
||||
// 标量场的旋度始终为0
|
||||
res.curl.resize(N);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
res.curl[i] = 0.0; // 标量场的旋度为0
|
||||
}
|
||||
return res;
|
||||
}
|
||||
} // namespace forwardad
|
||||
@@ -6,4 +6,6 @@
|
||||
struct ADResult{
|
||||
double value; // 函数值
|
||||
std::vector<double> gradient; // 梯度
|
||||
std::vector<double> curl; // 旋度
|
||||
double divergence; // 散度
|
||||
};
|
||||
|
||||
@@ -21,16 +21,23 @@ int test_backward_ad() {
|
||||
auto r = backwardad::diff(f2, 2.0);
|
||||
std::cout << "value = " << r.value << "\n";
|
||||
std::cout << "grad = " << r.gradient[0] << "\n";
|
||||
std::cout << "div = " << r.divergence << "\n";
|
||||
std::cout << "curl = " << r.curl[0] << "\n";
|
||||
|
||||
std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n";
|
||||
auto r2 = backwardad::diff(g2, 1.0, 2.0);
|
||||
std::cout << "value = " << r2.value << "\n";
|
||||
std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n";
|
||||
std::cout << "div = " << r2.divergence << "\n";
|
||||
std::cout << "curl = (" << r2.curl[0] << ", " << r2.curl[1] << ")\n";
|
||||
|
||||
std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n";
|
||||
auto r3 = backwardad::diff(h2, 1.0, 2.0, 3.0);
|
||||
std::cout << "value = " << r3.value << "\n";
|
||||
std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n";
|
||||
std::cout << "div = " << r3.divergence << "\n";
|
||||
std::cout << "curl = (" << r3.curl[0] << ", " << r3.curl[1] << ", " << r3.curl[2] << ")\n";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,16 +21,22 @@ int test_forward_ad() {
|
||||
auto r = forwardad::diff(f1, 2.0);
|
||||
std::cout << "value = " << r.value << "\n";
|
||||
std::cout << "grad = " << r.gradient[0] << "\n";
|
||||
std::cout << "div = " << r.divergence << "\n";
|
||||
std::cout << "curl = " << r.curl[0] << "\n";
|
||||
|
||||
std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n";
|
||||
auto r2 = forwardad::diff(g1, 1.0, 2.0);
|
||||
std::cout << "value = " << r2.value << "\n";
|
||||
std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n";
|
||||
std::cout << "div = " << r2.divergence << "\n";
|
||||
std::cout << "curl = (" << r2.curl[0] << ", " << r2.curl[1] << ")\n";
|
||||
|
||||
std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n";
|
||||
auto r3 = forwardad::diff(h1, 1.0, 2.0, 3.0);
|
||||
std::cout << "value = " << r3.value << "\n";
|
||||
std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n";
|
||||
std::cout << "div = " << r3.divergence << "\n";
|
||||
std::cout << "curl = (" << r3.curl[0] << ", " << r3.curl[1] << ", " << r3.curl[2] << ")\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user