#pragma once #include #include "types/node.hpp" #include "node_ops.hpp" namespace backwardad { inline void topo_sort(const Node& n, std::vector& order, std::vector& vis) { auto same = [&](const Node& x) { return x.ptr.get() == n.ptr.get(); }; if (std::find_if(vis.begin(), vis.end(), same) != vis.end()) return; vis.push_back(n); for (const auto& in : n.inputs()) topo_sort(in, order, vis); order.push_back(n); } inline void backward(const Node& output) { // 第一步:拓扑排序(保证从后往前算) std::vector order; std::vector visited; topo_sort(output, order, visited); // 第二步:初始化输出梯度 = 1 output.set_gradient(1.0); std::reverse(order.begin(), order.end()); // 第三步:反向遍历,逐个调用节点的局部导数 for (const auto& v : order) { if (v.backward()) v.backward()(); } } template ADResult diff(const Func f, Args... args) { ADResult res; // 1. 创建输入节点 std::vector inputs; inputs.reserve(sizeof...(Args)); (inputs.emplace_back((double)args), ...); // 2. 前向计算 Node output = [&](std::index_sequence) { return f(inputs[Is]...); }(std::make_index_sequence{}); // 3. 反向传播 backward(output); // 4. 提取结果 res.value = output.value(); for (auto& in : inputs) res.gradient.push_back(in.gradient()); // 散度 res.divergence = 0.0; for (const auto& g : res.gradient) res.divergence += g; // 旋度 // 标量场的旋度始终为0 res.curl.resize(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { res.curl[i] = 0.0; // 标量场的旋度为0 } return res; } } // namespace backwardad