Files
AutomaticDifferentiation/include/backwardad.hpp
2026-04-07 17:16:29 +08:00

67 lines
1.8 KiB
C++

#pragma once
#include <vector>
#include "types/node.hpp"
#include "node_ops.hpp"
namespace backwardad {
inline void topo_sort(const Node& n, std::vector<Node>& order, std::vector<Node>& vis) {
auto same = [&](const Node& x) { return x.ptr.get() == n.ptr.get(); };
if (std::find_if(vis.begin(), vis.end(), same) != vis.end()) return;
vis.push_back(n);
for (const auto& in : n.inputs()) topo_sort(in, order, vis);
order.push_back(n);
}
inline void backward(const Node& output) {
// 第一步:拓扑排序(保证从后往前算)
std::vector<Node> order;
std::vector<Node> visited;
topo_sort(output, order, visited);
// 第二步:初始化输出梯度 = 1
output.set_gradient(1.0);
std::reverse(order.begin(), order.end());
// 第三步:反向遍历,逐个调用节点的局部导数
for (const auto& v : order) {
if (v.backward()) v.backward()();
}
}
template <typename Func, typename... Args>
ADResult diff(const Func f, Args... args) {
ADResult res;
// 1. 创建输入节点
std::vector<Node> inputs;
inputs.reserve(sizeof...(Args));
(inputs.emplace_back((double)args), ...);
// 2. 前向计算
Node output = [&]<size_t... Is>(std::index_sequence<Is...>) {
return f(inputs[Is]...);
}(std::make_index_sequence<sizeof...(Args)>{});
// 3. 反向传播
backward(output);
// 4. 提取结果
res.value = output.value();
for (auto& in : inputs)
res.gradient.push_back(in.gradient());
// 散度
res.divergence = 0.0;
for (const auto& g : res.gradient) res.divergence += g;
// 旋度
// 标量场的旋度始终为0
res.curl.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
res.curl[i] = 0.0; // 标量场的旋度为0
}
return res;
}
} // namespace backwardad