2026-04-01 02:39:40 +08:00
|
|
|
|
/*前向自动微分的声明部分*/
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
#include <vector>
|
2026-04-07 13:55:32 +08:00
|
|
|
|
#include "dual_ops.hpp"
|
2026-04-01 02:39:40 +08:00
|
|
|
|
#include "types/common.hpp"
|
|
|
|
|
|
|
|
|
|
|
|
namespace forwardad{
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Func, typename... Args>
|
2026-04-07 13:00:35 +08:00
|
|
|
|
ADResult diff(const Func& f, Args... args) {
|
|
|
|
|
|
ADResult res;
|
2026-04-01 02:39:40 +08:00
|
|
|
|
constexpr size_t N = sizeof...(Args);
|
|
|
|
|
|
res.gradient.resize(N);
|
|
|
|
|
|
|
|
|
|
|
|
// 1. 计算函数值(所有导数为 0)
|
|
|
|
|
|
Dual value_res = f(Dual((double)args, 0.0)...);
|
|
|
|
|
|
res.value = value_res.value;
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 对每个输入变量求偏导(seed 依次设为 1)
|
|
|
|
|
|
double args_arr[] = { (double)args... };
|
2026-04-07 17:16:29 +08:00
|
|
|
|
|
|
|
|
|
|
// 梯度
|
2026-04-01 02:39:40 +08:00
|
|
|
|
[&]<size_t... Is>(std::index_sequence<Is...>) {
|
|
|
|
|
|
([&]() {
|
|
|
|
|
|
// 创建 Dual 输入,第 Is 个变量导数=1,其余=0
|
|
|
|
|
|
Dual inputs[N];
|
|
|
|
|
|
for (size_t j = 0; j < N; ++j)
|
|
|
|
|
|
inputs[j] = Dual(args_arr[j], j == Is ? 1.0 : 0.0);
|
|
|
|
|
|
// 用内层 index_sequence 展开所有 N 个参数传给 f
|
|
|
|
|
|
res.gradient[Is] = [&]<size_t... Js>(std::index_sequence<Js...>) {
|
|
|
|
|
|
return f(inputs[Js]...).deriv;
|
|
|
|
|
|
}(std::make_index_sequence<N>{});
|
|
|
|
|
|
}(), ...);
|
|
|
|
|
|
}(std::make_index_sequence<N>{});
|
|
|
|
|
|
|
2026-04-07 17:16:29 +08:00
|
|
|
|
// 散度
|
|
|
|
|
|
res.divergence = 0.0;
|
|
|
|
|
|
[&]<size_t... Is>(std::index_sequence<Is...>) {
|
|
|
|
|
|
([&]() {
|
|
|
|
|
|
res.divergence += [&]<size_t... Js>(std::index_sequence<Js...>) {
|
|
|
|
|
|
return res.gradient[Is]; // 这里直接用梯度值,因为散度是梯度的和
|
|
|
|
|
|
}(std::make_index_sequence<N>{});
|
|
|
|
|
|
}(), ...);
|
|
|
|
|
|
}(std::make_index_sequence<N>{});
|
|
|
|
|
|
|
|
|
|
|
|
// 旋度
|
|
|
|
|
|
// 标量场的旋度始终为0
|
|
|
|
|
|
res.curl.resize(N);
|
|
|
|
|
|
for (size_t i = 0; i < N; ++i) {
|
|
|
|
|
|
res.curl[i] = 0.0; // 标量场的旋度为0
|
|
|
|
|
|
}
|
2026-04-01 02:39:40 +08:00
|
|
|
|
return res;
|
|
|
|
|
|
}
|
|
|
|
|
|
} // namespace forwardad
|