feat: 后向自动微分
This commit is contained in:
56
include/backwardad.hpp
Normal file
56
include/backwardad.hpp
Normal file
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "dual.hpp"
|
||||
|
||||
namespace backwardad {
|
||||
|
||||
inline void topo_sort(const Node& n, std::vector<Node>& order, std::vector<Node>& vis) {
|
||||
auto same = [&](const Node& x) { return x.ptr.get() == n.ptr.get(); };
|
||||
if (std::find_if(vis.begin(), vis.end(), same) != vis.end()) return;
|
||||
vis.push_back(n);
|
||||
for (const auto& in : n.inputs()) topo_sort(in, order, vis);
|
||||
order.push_back(n);
|
||||
}
|
||||
|
||||
inline void backward(const Node& output) {
|
||||
// 第一步:拓扑排序(保证从后往前算)
|
||||
std::vector<Node> order;
|
||||
std::vector<Node> visited;
|
||||
topo_sort(output, order, visited);
|
||||
|
||||
// 第二步:初始化输出梯度 = 1
|
||||
output.set_gradient(1.0);
|
||||
std::reverse(order.begin(), order.end());
|
||||
|
||||
// 第三步:反向遍历,逐个调用节点的局部导数
|
||||
for (const auto& v : order) {
|
||||
if (v.backward()) v.backward()();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Func, typename... Args>
|
||||
ADResult diff(const Func f, Args... args) {
|
||||
ADResult res;
|
||||
|
||||
// 1. 创建输入节点
|
||||
std::vector<Node> inputs;
|
||||
inputs.reserve(sizeof...(Args));
|
||||
(inputs.emplace_back((double)args), ...);
|
||||
|
||||
// 2. 前向计算
|
||||
Node output = [&]<size_t... Is>(std::index_sequence<Is...>) {
|
||||
return f(inputs[Is]...);
|
||||
}(std::make_index_sequence<sizeof...(Args)>{});
|
||||
|
||||
// 3. 反向传播
|
||||
backward(output);
|
||||
|
||||
// 4. 提取结果
|
||||
res.value = output.value();
|
||||
for (auto& in : inputs)
|
||||
res.gradient.push_back(in.gradient());
|
||||
|
||||
return res;
|
||||
}
|
||||
} // namespace backwardad
|
||||
Reference in New Issue
Block a user