feat: 后向自动微分

This commit is contained in:
mayge
2026-04-07 13:00:35 +08:00
parent 58aedec43c
commit e8cbf452e8
6 changed files with 273 additions and 19 deletions

56
include/backwardad.hpp Normal file
View File

@@ -0,0 +1,56 @@
#pragma once
#include <vector>
#include <cmath>
#include "dual.hpp"
namespace backwardad {
inline void topo_sort(const Node& n, std::vector<Node>& order, std::vector<Node>& vis) {
auto same = [&](const Node& x) { return x.ptr.get() == n.ptr.get(); };
if (std::find_if(vis.begin(), vis.end(), same) != vis.end()) return;
vis.push_back(n);
for (const auto& in : n.inputs()) topo_sort(in, order, vis);
order.push_back(n);
}
inline void backward(const Node& output) {
// 第一步:拓扑排序(保证从后往前算)
std::vector<Node> order;
std::vector<Node> visited;
topo_sort(output, order, visited);
// 第二步:初始化输出梯度 = 1
output.set_gradient(1.0);
std::reverse(order.begin(), order.end());
// 第三步:反向遍历,逐个调用节点的局部导数
for (const auto& v : order) {
if (v.backward()) v.backward()();
}
}
template <typename Func, typename... Args>
ADResult diff(const Func f, Args... args) {
ADResult res;
// 1. 创建输入节点
std::vector<Node> inputs;
inputs.reserve(sizeof...(Args));
(inputs.emplace_back((double)args), ...);
// 2. 前向计算
Node output = [&]<size_t... Is>(std::index_sequence<Is...>) {
return f(inputs[Is]...);
}(std::make_index_sequence<sizeof...(Args)>{});
// 3. 反向传播
backward(output);
// 4. 提取结果
res.value = output.value();
for (auto& in : inputs)
res.gradient.push_back(in.gradient());
return res;
}
} // namespace backwardad

View File

@@ -3,6 +3,7 @@
#include <iostream>
#include <vector>
#include <cmath>
#include "types/common.hpp"
#include "types/dual.hpp"
// 前向自动微分的运算部分两个deriv相乘为0
@@ -62,4 +63,123 @@ namespace forwardad{
inline Dual atan(const Dual& x) {
return Dual(std::atan(x.value), x.deriv / (1 + x.value * x.value));
}
}// namespace forwardad
}// namespace forwardad
namespace backwardad {
inline Node operator+(const Node& a, const Node& b) {
Node out(a.value() + b.value());
out.inputs() = {a, b};
out.backward() = [out, a, b]() {
a.add_gradient(out.gradient()); // da = dout * 1
b.add_gradient(out.gradient()); // db = dout * 1
};
return out;
}
inline Node operator-(const Node& a, const Node& b) {
Node out(a.value() - b.value());
out.inputs() = {a, b};
out.backward() = [out, a, b]() {
a.add_gradient(out.gradient()); // da = dout * 1
b.add_gradient(-out.gradient()); // db = dout * (-1)
};
return out;
}
inline Node operator*(const Node& a, const Node& b) {
Node out(a.value() * b.value());
out.inputs() = {a, b};
out.backward() = [out, a, b]() {
a.add_gradient(out.gradient() * b.value()); // da = dout * b
b.add_gradient(out.gradient() * a.value()); // db = dout * a
};
return out;
}
inline Node operator/(const Node& a, const Node& b) {
Node out(a.value() / b.value());
out.inputs() = {a, b};
out.backward() = [out, a, b]() {
a.add_gradient(out.gradient() / b.value()); // da = dout * (1/b)
b.add_gradient(-out.gradient() * a.value() / (b.value() * b.value())); // db = dout * (-a/b^2)
};
return out;
}
// 下面的函数需要用到链式法则(使用泰勒展开后保留一次项)
// 三角函数
inline Node sin(const Node& x) {
Node out(std::sin(x.value()));
out.inputs() = {x};
out.backward() = [out, x]() {
x.add_gradient(out.gradient() * std::cos(x.value())); // dx = dout * cos(x)
};
return out;
}
inline Node cos(const Node& x) {
Node out(std::cos(x.value()));
out.inputs() = {x};
out.backward() = [out, x]() {
x.add_gradient(-out.gradient() * std::sin(x.value())); // dx = dout * (-sin(x))
};
return out;
}
// 指数和对数
inline Node exp(const Node& x) {
Node out(std::exp(x.value()));
out.inputs() = {x};
out.backward() = [out, x]() {
x.add_gradient(out.gradient() * out.value()); // dx = dout * exp(x) = dout * out.value
};
return out;
}
inline Node log(const Node& x) {
Node out(std::log(x.value()));
out.inputs() = {x};
out.backward() = [out, x]() {
x.add_gradient(out.gradient() / x.value()); // dx = dout * (1/x)
};
return out;
}
inline Node pow(const Node& x, double n) {
Node out(std::pow(x.value(), n));
out.inputs() = {x};
out.backward() = [out, x, n]() {
x.add_gradient(out.gradient() * n * std::pow(x.value(), n - 1)); // dx = dout * (n * x^(n-1))
};
return out;
}
// 反三角函数
inline Node asin(const Node& x) {
Node out(std::asin(x.value()));
out.inputs() = {x};
out.backward() = [out, x]() {
x.add_gradient(out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (1/sqrt(1-x^2))
};
return out;
}
inline Node acos(const Node& x) {
Node out(std::acos(x.value()));
out.inputs() = {x};
out.backward() = [out, x]() {
x.add_gradient(-out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (-1/sqrt(1-x^2))
};
return out;
}
inline Node atan(const Node& x) {
Node out(std::atan(x.value()));
out.inputs() = {x};
out.backward() = [out, x]() {
x.add_gradient(out.gradient() / (1 + x.value() * x.value())); // dx = dout * (1/(1+x^2))
};
return out;
}
}

View File

@@ -9,8 +9,8 @@
namespace forwardad{
template <typename Func, typename... Args>
Result diff(const Func& f, Args... args) {
Result res;
ADResult diff(const Func& f, Args... args) {
ADResult res;
constexpr size_t N = sizeof...(Args);
res.gradient.resize(N);

View File

@@ -1,9 +1,9 @@
/*返回数据结构体*/
#pragma once
#include <functional>
#include <vector>
namespace forwardad{
struct Result{
double value; // 函数值
std::vector<double> gradient; // 梯度
};
}
struct ADResult{
double value; // 函数值
std::vector<double> gradient; // 梯度
};

View File

@@ -1,5 +1,9 @@
/*对偶数的数据类型定义部分*/
#pragma once
#include <iostream>
#include <functional>
#include <memory>
#include <vector>
namespace forwardad {
struct Dual {
@@ -9,4 +13,43 @@ struct Dual {
Dual(double v = 0.0, double d = 0.0) : value(v), deriv(d) {}
};
} // namespace forwardad
} // namespace forwardad
namespace backwardad{
struct NodeData;
struct Node{
std::shared_ptr<NodeData> ptr;
Node() = default;
explicit Node(double v);
double value() const;
double gradient() const;
void add_gradient(double g) const;
void set_gradient(double g) const;
const std::vector<Node>& inputs() const;
std::vector<Node>& inputs();
std::function<void()>& backward();
const std::function<void()>& backward() const;
};
struct NodeData {
double value;
double gradient;
std::vector<Node> inputs;
std::function<void()> backward;
explicit NodeData(double v) : value(v), gradient(0.0) {}
};
inline Node::Node(double v) : ptr(std::make_shared<NodeData>(v)) {}
inline double Node::value() const { return ptr->value; }
inline double Node::gradient() const { return ptr->gradient; }
inline void Node::add_gradient(double g) const { ptr->gradient += g; }
inline void Node::set_gradient(double g) const { ptr->gradient = g; }
inline const std::vector<Node>& Node::inputs() const { return ptr->inputs; }
inline std::vector<Node>& Node::inputs() { return ptr->inputs; }
inline std::function<void()>& Node::backward() { return ptr->backward; }
inline const std::function<void()>& Node::backward() const { return ptr->backward; }
} // namespace backwardad

View File

@@ -1,40 +1,75 @@
#include <iostream>
#include "forwardad.hpp"
#include "backwardad.hpp"
#include "types/dual.hpp"
using namespace forwardad;
// 一阶测试函数
Dual f(Dual x) {
forwardad::Dual f1(forwardad::Dual x) {
return x + x*x + pow(x,3);
}
// 二阶测试函数
Dual g(Dual x, Dual y) {
forwardad::Dual g1(forwardad::Dual x, forwardad::Dual y) {
return exp(x) * log(y);
}
// 高阶测试函数
Dual h(Dual x, Dual y, Dual z) {
forwardad::Dual h1(forwardad::Dual x, forwardad::Dual y, forwardad::Dual z) {
return pow(x, 3) + pow(y, 2) + z + cos(x * y * z);
}
int main() {
backwardad::Node f2(backwardad::Node x) {
return x + x*x + pow(x,3);
}
backwardad::Node g2(backwardad::Node x, backwardad::Node y) {
return exp(x) * log(y);
}
backwardad::Node h2(backwardad::Node x, backwardad::Node y, backwardad::Node z) {
return pow(x, 3) + pow(y, 2) + z + cos(x * y * z);
}
int test_forward_ad() {
std::cout << "Testing f(x) = x^2 + sin(x) at x=2.0\n";
auto r = diff(f, 2.0);
auto r = forwardad::diff(f1, 2.0);
std::cout << "value = " << r.value << "\n";
std::cout << "grad = " << r.gradient[0] << "\n";
std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n";
auto r2 = diff(g, 1.0, 2.0);
auto r2 = forwardad::diff(g1, 1.0, 2.0);
std::cout << "value = " << r2.value << "\n";
std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n";
std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n";
auto r3 = diff(h, 1.0, 2.0, 3.0);
auto r3 = forwardad::diff(h1, 1.0, 2.0, 3.0);
std::cout << "value = " << r3.value << "\n";
std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n";
return 0;
}
int test_backward_ad() {
std::cout << "Testing f(x) = x^2 + sin(x) at x=2.0\n";
auto r = backwardad::diff(f2, 2.0);
std::cout << "value = " << r.value << "\n";
std::cout << "grad = " << r.gradient[0] << "\n";
std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n";
auto r2 = backwardad::diff(g2, 1.0, 2.0);
std::cout << "value = " << r2.value << "\n";
std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n";
std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n";
auto r3 = backwardad::diff(h2, 1.0, 2.0, 3.0);
std::cout << "value = " << r3.value << "\n";
std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n";
return 0;
}
int main() {
std::cout << "=== Forward AD Tests ===\n";
test_forward_ad();
std::cout << "\n=== Backward AD Tests ===\n";
test_backward_ad();
return 0;
}