Files
2026-04-07 17:16:29 +08:00

54 lines
1.7 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*前向自动微分的声明部分*/
#pragma once
#include <vector>
#include "dual_ops.hpp"
#include "types/common.hpp"
namespace forwardad{
template <typename Func, typename... Args>
ADResult diff(const Func& f, Args... args) {
ADResult res;
constexpr size_t N = sizeof...(Args);
res.gradient.resize(N);
// 1. 计算函数值(所有导数为 0
Dual value_res = f(Dual((double)args, 0.0)...);
res.value = value_res.value;
// 2. 对每个输入变量求偏导seed 依次设为 1
double args_arr[] = { (double)args... };
// 梯度
[&]<size_t... Is>(std::index_sequence<Is...>) {
([&]() {
// 创建 Dual 输入,第 Is 个变量导数=1其余=0
Dual inputs[N];
for (size_t j = 0; j < N; ++j)
inputs[j] = Dual(args_arr[j], j == Is ? 1.0 : 0.0);
// 用内层 index_sequence 展开所有 N 个参数传给 f
res.gradient[Is] = [&]<size_t... Js>(std::index_sequence<Js...>) {
return f(inputs[Js]...).deriv;
}(std::make_index_sequence<N>{});
}(), ...);
}(std::make_index_sequence<N>{});
// 散度
res.divergence = 0.0;
[&]<size_t... Is>(std::index_sequence<Is...>) {
([&]() {
res.divergence += [&]<size_t... Js>(std::index_sequence<Js...>) {
return res.gradient[Is]; // 这里直接用梯度值,因为散度是梯度的和
}(std::make_index_sequence<N>{});
}(), ...);
}(std::make_index_sequence<N>{});
// 旋度
// 标量场的旋度始终为0
res.curl.resize(N);
for (size_t i = 0; i < N; ++i) {
res.curl[i] = 0.0; // 标量场的旋度为0
}
return res;
}
} // namespace forwardad