Files
AutomaticDifferentiation/include/forwardad.hpp
2026-04-07 13:55:32 +08:00

36 lines
1.2 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*前向自动微分的声明部分*/
#pragma once
#include <vector>
#include "dual_ops.hpp"
#include "types/common.hpp"
namespace forwardad{
template <typename Func, typename... Args>
ADResult diff(const Func& f, Args... args) {
ADResult res;
constexpr size_t N = sizeof...(Args);
res.gradient.resize(N);
// 1. 计算函数值(所有导数为 0
Dual value_res = f(Dual((double)args, 0.0)...);
res.value = value_res.value;
// 2. 对每个输入变量求偏导seed 依次设为 1
double args_arr[] = { (double)args... };
[&]<size_t... Is>(std::index_sequence<Is...>) {
([&]() {
// 创建 Dual 输入,第 Is 个变量导数=1其余=0
Dual inputs[N];
for (size_t j = 0; j < N; ++j)
inputs[j] = Dual(args_arr[j], j == Is ? 1.0 : 0.0);
// 用内层 index_sequence 展开所有 N 个参数传给 f
res.gradient[Is] = [&]<size_t... Js>(std::index_sequence<Js...>) {
return f(inputs[Js]...).deriv;
}(std::make_index_sequence<N>{});
}(), ...);
}(std::make_index_sequence<N>{});
return res;
}
} // namespace forwardad