67 lines
1.8 KiB
C++
67 lines
1.8 KiB
C++
#pragma once
|
|
#include <vector>
|
|
#include "types/node.hpp"
|
|
#include "node_ops.hpp"
|
|
|
|
namespace backwardad {
|
|
|
|
inline void topo_sort(const Node& n, std::vector<Node>& order, std::vector<Node>& vis) {
|
|
auto same = [&](const Node& x) { return x.ptr.get() == n.ptr.get(); };
|
|
if (std::find_if(vis.begin(), vis.end(), same) != vis.end()) return;
|
|
vis.push_back(n);
|
|
for (const auto& in : n.inputs()) topo_sort(in, order, vis);
|
|
order.push_back(n);
|
|
}
|
|
|
|
inline void backward(const Node& output) {
|
|
// 第一步:拓扑排序(保证从后往前算)
|
|
std::vector<Node> order;
|
|
std::vector<Node> visited;
|
|
topo_sort(output, order, visited);
|
|
|
|
// 第二步:初始化输出梯度 = 1
|
|
output.set_gradient(1.0);
|
|
std::reverse(order.begin(), order.end());
|
|
|
|
// 第三步:反向遍历,逐个调用节点的局部导数
|
|
for (const auto& v : order) {
|
|
if (v.backward()) v.backward()();
|
|
}
|
|
}
|
|
|
|
template <typename Func, typename... Args>
|
|
ADResult diff(const Func f, Args... args) {
|
|
ADResult res;
|
|
|
|
// 1. 创建输入节点
|
|
std::vector<Node> inputs;
|
|
inputs.reserve(sizeof...(Args));
|
|
(inputs.emplace_back((double)args), ...);
|
|
|
|
// 2. 前向计算
|
|
Node output = [&]<size_t... Is>(std::index_sequence<Is...>) {
|
|
return f(inputs[Is]...);
|
|
}(std::make_index_sequence<sizeof...(Args)>{});
|
|
|
|
// 3. 反向传播
|
|
backward(output);
|
|
|
|
// 4. 提取结果
|
|
res.value = output.value();
|
|
for (auto& in : inputs)
|
|
res.gradient.push_back(in.gradient());
|
|
|
|
// 散度
|
|
res.divergence = 0.0;
|
|
for (const auto& g : res.gradient) res.divergence += g;
|
|
|
|
// 旋度
|
|
// 标量场的旋度始终为0
|
|
res.curl.resize(inputs.size());
|
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
res.curl[i] = 0.0; // 标量场的旋度为0
|
|
}
|
|
|
|
return res;
|
|
}
|
|
} // namespace backwardad
|