feat: 后向自动微分

This commit is contained in:
mayge
2026-04-07 13:00:35 +08:00
parent 58aedec43c
commit e8cbf452e8
6 changed files with 273 additions and 19 deletions

56
include/backwardad.hpp Normal file
View File

@@ -0,0 +1,56 @@
#pragma once
#include <vector>
#include <cmath>
#include "dual.hpp"
namespace backwardad {
inline void topo_sort(const Node& n, std::vector<Node>& order, std::vector<Node>& vis) {
auto same = [&](const Node& x) { return x.ptr.get() == n.ptr.get(); };
if (std::find_if(vis.begin(), vis.end(), same) != vis.end()) return;
vis.push_back(n);
for (const auto& in : n.inputs()) topo_sort(in, order, vis);
order.push_back(n);
}
inline void backward(const Node& output) {
// 第一步:拓扑排序(保证从后往前算)
std::vector<Node> order;
std::vector<Node> visited;
topo_sort(output, order, visited);
// 第二步:初始化输出梯度 = 1
output.set_gradient(1.0);
std::reverse(order.begin(), order.end());
// 第三步:反向遍历,逐个调用节点的局部导数
for (const auto& v : order) {
if (v.backward()) v.backward()();
}
}
template <typename Func, typename... Args>
ADResult diff(const Func f, Args... args) {
ADResult res;
// 1. 创建输入节点
std::vector<Node> inputs;
inputs.reserve(sizeof...(Args));
(inputs.emplace_back((double)args), ...);
// 2. 前向计算
Node output = [&]<size_t... Is>(std::index_sequence<Is...>) {
return f(inputs[Is]...);
}(std::make_index_sequence<sizeof...(Args)>{});
// 3. 反向传播
backward(output);
// 4. 提取结果
res.value = output.value();
for (auto& in : inputs)
res.gradient.push_back(in.gradient());
return res;
}
} // namespace backwardad