36 lines
1.2 KiB
C++
36 lines
1.2 KiB
C++
/*前向自动微分的声明部分*/
|
||
#pragma once
|
||
#include <vector>
|
||
#include "dual_ops.hpp"
|
||
#include "types/common.hpp"
|
||
|
||
namespace forwardad{
|
||
|
||
template <typename Func, typename... Args>
|
||
ADResult diff(const Func& f, Args... args) {
|
||
ADResult res;
|
||
constexpr size_t N = sizeof...(Args);
|
||
res.gradient.resize(N);
|
||
|
||
// 1. 计算函数值(所有导数为 0)
|
||
Dual value_res = f(Dual((double)args, 0.0)...);
|
||
res.value = value_res.value;
|
||
|
||
// 2. 对每个输入变量求偏导(seed 依次设为 1)
|
||
double args_arr[] = { (double)args... };
|
||
[&]<size_t... Is>(std::index_sequence<Is...>) {
|
||
([&]() {
|
||
// 创建 Dual 输入,第 Is 个变量导数=1,其余=0
|
||
Dual inputs[N];
|
||
for (size_t j = 0; j < N; ++j)
|
||
inputs[j] = Dual(args_arr[j], j == Is ? 1.0 : 0.0);
|
||
// 用内层 index_sequence 展开所有 N 个参数传给 f
|
||
res.gradient[Is] = [&]<size_t... Js>(std::index_sequence<Js...>) {
|
||
return f(inputs[Js]...).deriv;
|
||
}(std::make_index_sequence<N>{});
|
||
}(), ...);
|
||
}(std::make_index_sequence<N>{});
|
||
|
||
return res;
|
||
}
|
||
} // namespace forwardad
|