diff --git a/include/backwardad.hpp b/include/backwardad.hpp new file mode 100644 index 0000000..edc85e4 --- /dev/null +++ b/include/backwardad.hpp @@ -0,0 +1,56 @@ +#pragma once +#include +#include +#include "dual.hpp" + +namespace backwardad { + +inline void topo_sort(const Node& n, std::vector& order, std::vector& vis) { + auto same = [&](const Node& x) { return x.ptr.get() == n.ptr.get(); }; + if (std::find_if(vis.begin(), vis.end(), same) != vis.end()) return; + vis.push_back(n); + for (const auto& in : n.inputs()) topo_sort(in, order, vis); + order.push_back(n); +} + +inline void backward(const Node& output) { +// 第一步:拓扑排序(保证从后往前算) + std::vector order; + std::vector visited; + topo_sort(output, order, visited); + + // 第二步:初始化输出梯度 = 1 + output.set_gradient(1.0); + std::reverse(order.begin(), order.end()); + + // 第三步:反向遍历,逐个调用节点的局部导数 + for (const auto& v : order) { + if (v.backward()) v.backward()(); + } +} + +template +ADResult diff(const Func f, Args... args) { + ADResult res; + + // 1. 创建输入节点 + std::vector inputs; + inputs.reserve(sizeof...(Args)); + (inputs.emplace_back((double)args), ...); + + // 2. 前向计算 + Node output = [&](std::index_sequence) { + return f(inputs[Is]...); + }(std::make_index_sequence{}); + + // 3. 反向传播 + backward(output); + + // 4. 提取结果 + res.value = output.value(); + for (auto& in : inputs) + res.gradient.push_back(in.gradient()); + + return res; +} +} // namespace backwardad \ No newline at end of file diff --git a/include/dual.hpp b/include/dual.hpp index eb2ea6e..b112a8d 100644 --- a/include/dual.hpp +++ b/include/dual.hpp @@ -3,6 +3,7 @@ #include #include #include +#include "types/common.hpp" #include "types/dual.hpp" // 前向自动微分的运算部分,两个deriv相乘为0 @@ -62,4 +63,123 @@ namespace forwardad{ inline Dual atan(const Dual& x) { return Dual(std::atan(x.value), x.deriv / (1 + x.value * x.value)); } -}// namespace forwardad \ No newline at end of file +}// namespace forwardad + +namespace backwardad { + inline Node operator+(const Node& a, const Node& b) { + Node out(a.value() + b.value()); + out.inputs() = {a, b}; + out.backward() = [out, a, b]() { + a.add_gradient(out.gradient()); // da = dout * 1 + b.add_gradient(out.gradient()); // db = dout * 1 + }; + return out; + } + + inline Node operator-(const Node& a, const Node& b) { + Node out(a.value() - b.value()); + out.inputs() = {a, b}; + out.backward() = [out, a, b]() { + a.add_gradient(out.gradient()); // da = dout * 1 + b.add_gradient(-out.gradient()); // db = dout * (-1) + }; + return out; + } + + inline Node operator*(const Node& a, const Node& b) { + Node out(a.value() * b.value()); + out.inputs() = {a, b}; + out.backward() = [out, a, b]() { + a.add_gradient(out.gradient() * b.value()); // da = dout * b + b.add_gradient(out.gradient() * a.value()); // db = dout * a + }; + return out; + } + + inline Node operator/(const Node& a, const Node& b) { + Node out(a.value() / b.value()); + out.inputs() = {a, b}; + out.backward() = [out, a, b]() { + a.add_gradient(out.gradient() / b.value()); // da = dout * (1/b) + b.add_gradient(-out.gradient() * a.value() / (b.value() * b.value())); // db = dout * (-a/b^2) + }; + return out; + } + + // 下面的函数需要用到链式法则(使用泰勒展开后保留一次项) + // 三角函数 + inline Node sin(const Node& x) { + Node out(std::sin(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() * std::cos(x.value())); // dx = dout * cos(x) + }; + return out; + } + + inline Node cos(const Node& x) { + Node out(std::cos(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(-out.gradient() * std::sin(x.value())); // dx = dout * (-sin(x)) + }; + return out; + } + + // 指数和对数 + inline Node exp(const Node& x) { + Node out(std::exp(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() * out.value()); // dx = dout * exp(x) = dout * out.value + }; + return out; + } + + inline Node log(const Node& x) { + Node out(std::log(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() / x.value()); // dx = dout * (1/x) + }; + return out; + } + + inline Node pow(const Node& x, double n) { + Node out(std::pow(x.value(), n)); + out.inputs() = {x}; + out.backward() = [out, x, n]() { + x.add_gradient(out.gradient() * n * std::pow(x.value(), n - 1)); // dx = dout * (n * x^(n-1)) + }; + return out; + } + + // 反三角函数 + inline Node asin(const Node& x) { + Node out(std::asin(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (1/sqrt(1-x^2)) + }; + return out; + } + + inline Node acos(const Node& x) { + Node out(std::acos(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(-out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (-1/sqrt(1-x^2)) + }; + return out; + } + + inline Node atan(const Node& x) { + Node out(std::atan(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() / (1 + x.value() * x.value())); // dx = dout * (1/(1+x^2)) + }; + return out; + } + +} \ No newline at end of file diff --git a/include/forwardad.hpp b/include/forwardad.hpp index a1385de..c1b125a 100644 --- a/include/forwardad.hpp +++ b/include/forwardad.hpp @@ -9,8 +9,8 @@ namespace forwardad{ template -Result diff(const Func& f, Args... args) { - Result res; +ADResult diff(const Func& f, Args... args) { + ADResult res; constexpr size_t N = sizeof...(Args); res.gradient.resize(N); diff --git a/include/types/common.hpp b/include/types/common.hpp index 41ff214..a084212 100644 --- a/include/types/common.hpp +++ b/include/types/common.hpp @@ -1,9 +1,9 @@ /*返回数据结构体*/ +#pragma once +#include #include -namespace forwardad{ - struct Result{ - double value; // 函数值 - std::vector gradient; // 梯度 - }; -} \ No newline at end of file +struct ADResult{ + double value; // 函数值 + std::vector gradient; // 梯度 +}; diff --git a/include/types/dual.hpp b/include/types/dual.hpp index 630588c..f8294cf 100644 --- a/include/types/dual.hpp +++ b/include/types/dual.hpp @@ -1,5 +1,9 @@ /*对偶数的数据类型定义部分*/ #pragma once +#include +#include +#include +#include namespace forwardad { struct Dual { @@ -9,4 +13,43 @@ struct Dual { Dual(double v = 0.0, double d = 0.0) : value(v), deriv(d) {} }; -} // namespace forwardad \ No newline at end of file +} // namespace forwardad + +namespace backwardad{ + struct NodeData; + + struct Node{ + std::shared_ptr ptr; + + Node() = default; + explicit Node(double v); + + double value() const; + double gradient() const; + void add_gradient(double g) const; + void set_gradient(double g) const; + const std::vector& inputs() const; + std::vector& inputs(); + std::function& backward(); + const std::function& backward() const; + }; + + struct NodeData { + double value; + double gradient; + std::vector inputs; + std::function backward; + + explicit NodeData(double v) : value(v), gradient(0.0) {} + }; + + inline Node::Node(double v) : ptr(std::make_shared(v)) {} + inline double Node::value() const { return ptr->value; } + inline double Node::gradient() const { return ptr->gradient; } + inline void Node::add_gradient(double g) const { ptr->gradient += g; } + inline void Node::set_gradient(double g) const { ptr->gradient = g; } + inline const std::vector& Node::inputs() const { return ptr->inputs; } + inline std::vector& Node::inputs() { return ptr->inputs; } + inline std::function& Node::backward() { return ptr->backward; } + inline const std::function& Node::backward() const { return ptr->backward; } +} // namespace backwardad \ No newline at end of file diff --git a/tests/basic.cpp b/tests/basic.cpp index c734a15..f669492 100644 --- a/tests/basic.cpp +++ b/tests/basic.cpp @@ -1,40 +1,75 @@ #include #include "forwardad.hpp" +#include "backwardad.hpp" #include "types/dual.hpp" -using namespace forwardad; - // 一阶测试函数 -Dual f(Dual x) { +forwardad::Dual f1(forwardad::Dual x) { return x + x*x + pow(x,3); } // 二阶测试函数 -Dual g(Dual x, Dual y) { +forwardad::Dual g1(forwardad::Dual x, forwardad::Dual y) { return exp(x) * log(y); } // 高阶测试函数 -Dual h(Dual x, Dual y, Dual z) { +forwardad::Dual h1(forwardad::Dual x, forwardad::Dual y, forwardad::Dual z) { return pow(x, 3) + pow(y, 2) + z + cos(x * y * z); } -int main() { +backwardad::Node f2(backwardad::Node x) { + return x + x*x + pow(x,3); +} + +backwardad::Node g2(backwardad::Node x, backwardad::Node y) { + return exp(x) * log(y); +} + +backwardad::Node h2(backwardad::Node x, backwardad::Node y, backwardad::Node z) { + return pow(x, 3) + pow(y, 2) + z + cos(x * y * z); +} + +int test_forward_ad() { std::cout << "Testing f(x) = x^2 + sin(x) at x=2.0\n"; - auto r = diff(f, 2.0); + auto r = forwardad::diff(f1, 2.0); std::cout << "value = " << r.value << "\n"; std::cout << "grad = " << r.gradient[0] << "\n"; std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n"; - auto r2 = diff(g, 1.0, 2.0); + auto r2 = forwardad::diff(g1, 1.0, 2.0); std::cout << "value = " << r2.value << "\n"; std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n"; std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n"; - auto r3 = diff(h, 1.0, 2.0, 3.0); + auto r3 = forwardad::diff(h1, 1.0, 2.0, 3.0); std::cout << "value = " << r3.value << "\n"; std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n"; + return 0; +} +int test_backward_ad() { + std::cout << "Testing f(x) = x^2 + sin(x) at x=2.0\n"; + auto r = backwardad::diff(f2, 2.0); + std::cout << "value = " << r.value << "\n"; + std::cout << "grad = " << r.gradient[0] << "\n"; + std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n"; + auto r2 = backwardad::diff(g2, 1.0, 2.0); + std::cout << "value = " << r2.value << "\n"; + std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n"; + + std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n"; + auto r3 = backwardad::diff(h2, 1.0, 2.0, 3.0); + std::cout << "value = " << r3.value << "\n"; + std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n"; + return 0; +} + +int main() { + std::cout << "=== Forward AD Tests ===\n"; + test_forward_ad(); + std::cout << "\n=== Backward AD Tests ===\n"; + test_backward_ad(); return 0; } \ No newline at end of file