diff --git a/include/backwardad.hpp b/include/backwardad.hpp index edc85e4..190527e 100644 --- a/include/backwardad.hpp +++ b/include/backwardad.hpp @@ -1,7 +1,7 @@ #pragma once #include -#include -#include "dual.hpp" +#include "types/gradient.hpp" +#include "gradient.hpp" namespace backwardad { diff --git a/include/dual.hpp b/include/dual.hpp index b112a8d..eacc8c4 100644 --- a/include/dual.hpp +++ b/include/dual.hpp @@ -1,9 +1,6 @@ /*对偶数的声明部分*/ #pragma once -#include -#include #include -#include "types/common.hpp" #include "types/dual.hpp" // 前向自动微分的运算部分,两个deriv相乘为0 @@ -63,123 +60,4 @@ namespace forwardad{ inline Dual atan(const Dual& x) { return Dual(std::atan(x.value), x.deriv / (1 + x.value * x.value)); } -}// namespace forwardad - -namespace backwardad { - inline Node operator+(const Node& a, const Node& b) { - Node out(a.value() + b.value()); - out.inputs() = {a, b}; - out.backward() = [out, a, b]() { - a.add_gradient(out.gradient()); // da = dout * 1 - b.add_gradient(out.gradient()); // db = dout * 1 - }; - return out; - } - - inline Node operator-(const Node& a, const Node& b) { - Node out(a.value() - b.value()); - out.inputs() = {a, b}; - out.backward() = [out, a, b]() { - a.add_gradient(out.gradient()); // da = dout * 1 - b.add_gradient(-out.gradient()); // db = dout * (-1) - }; - return out; - } - - inline Node operator*(const Node& a, const Node& b) { - Node out(a.value() * b.value()); - out.inputs() = {a, b}; - out.backward() = [out, a, b]() { - a.add_gradient(out.gradient() * b.value()); // da = dout * b - b.add_gradient(out.gradient() * a.value()); // db = dout * a - }; - return out; - } - - inline Node operator/(const Node& a, const Node& b) { - Node out(a.value() / b.value()); - out.inputs() = {a, b}; - out.backward() = [out, a, b]() { - a.add_gradient(out.gradient() / b.value()); // da = dout * (1/b) - b.add_gradient(-out.gradient() * a.value() / (b.value() * b.value())); // db = dout * (-a/b^2) - }; - return out; - } - - // 下面的函数需要用到链式法则(使用泰勒展开后保留一次项) - // 三角函数 - inline Node sin(const Node& x) { - Node out(std::sin(x.value())); - out.inputs() = {x}; - out.backward() = [out, x]() { - x.add_gradient(out.gradient() * std::cos(x.value())); // dx = dout * cos(x) - }; - return out; - } - - inline Node cos(const Node& x) { - Node out(std::cos(x.value())); - out.inputs() = {x}; - out.backward() = [out, x]() { - x.add_gradient(-out.gradient() * std::sin(x.value())); // dx = dout * (-sin(x)) - }; - return out; - } - - // 指数和对数 - inline Node exp(const Node& x) { - Node out(std::exp(x.value())); - out.inputs() = {x}; - out.backward() = [out, x]() { - x.add_gradient(out.gradient() * out.value()); // dx = dout * exp(x) = dout * out.value - }; - return out; - } - - inline Node log(const Node& x) { - Node out(std::log(x.value())); - out.inputs() = {x}; - out.backward() = [out, x]() { - x.add_gradient(out.gradient() / x.value()); // dx = dout * (1/x) - }; - return out; - } - - inline Node pow(const Node& x, double n) { - Node out(std::pow(x.value(), n)); - out.inputs() = {x}; - out.backward() = [out, x, n]() { - x.add_gradient(out.gradient() * n * std::pow(x.value(), n - 1)); // dx = dout * (n * x^(n-1)) - }; - return out; - } - - // 反三角函数 - inline Node asin(const Node& x) { - Node out(std::asin(x.value())); - out.inputs() = {x}; - out.backward() = [out, x]() { - x.add_gradient(out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (1/sqrt(1-x^2)) - }; - return out; - } - - inline Node acos(const Node& x) { - Node out(std::acos(x.value())); - out.inputs() = {x}; - out.backward() = [out, x]() { - x.add_gradient(-out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (-1/sqrt(1-x^2)) - }; - return out; - } - - inline Node atan(const Node& x) { - Node out(std::atan(x.value())); - out.inputs() = {x}; - out.backward() = [out, x]() { - x.add_gradient(out.gradient() / (1 + x.value() * x.value())); // dx = dout * (1/(1+x^2)) - }; - return out; - } - -} \ No newline at end of file +}// namespace forwardad \ No newline at end of file diff --git a/include/forwardad.hpp b/include/forwardad.hpp index c1b125a..d084b57 100644 --- a/include/forwardad.hpp +++ b/include/forwardad.hpp @@ -1,8 +1,6 @@ /*前向自动微分的声明部分*/ #pragma once -#include #include -#include #include "dual.hpp" #include "types/common.hpp" diff --git a/include/gradient.hpp b/include/gradient.hpp new file mode 100644 index 0000000..b588eeb --- /dev/null +++ b/include/gradient.hpp @@ -0,0 +1,121 @@ +#pragma once +#include "types/gradient.hpp" + +namespace backwardad { + inline Node operator+(const Node& a, const Node& b) { + Node out(a.value() + b.value()); + out.inputs() = {a, b}; + out.backward() = [out, a, b]() { + a.add_gradient(out.gradient()); // da = dout * 1 + b.add_gradient(out.gradient()); // db = dout * 1 + }; + return out; + } + + inline Node operator-(const Node& a, const Node& b) { + Node out(a.value() - b.value()); + out.inputs() = {a, b}; + out.backward() = [out, a, b]() { + a.add_gradient(out.gradient()); // da = dout * 1 + b.add_gradient(-out.gradient()); // db = dout * (-1) + }; + return out; + } + + inline Node operator*(const Node& a, const Node& b) { + Node out(a.value() * b.value()); + out.inputs() = {a, b}; + out.backward() = [out, a, b]() { + a.add_gradient(out.gradient() * b.value()); // da = dout * b + b.add_gradient(out.gradient() * a.value()); // db = dout * a + }; + return out; + } + + inline Node operator/(const Node& a, const Node& b) { + Node out(a.value() / b.value()); + out.inputs() = {a, b}; + out.backward() = [out, a, b]() { + a.add_gradient(out.gradient() / b.value()); // da = dout * (1/b) + b.add_gradient(-out.gradient() * a.value() / (b.value() * b.value())); // db = dout * (-a/b^2) + }; + return out; + } + + // 下面的函数需要用到链式法则(使用泰勒展开后保留一次项) + // 三角函数 + inline Node sin(const Node& x) { + Node out(std::sin(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() * std::cos(x.value())); // dx = dout * cos(x) + }; + return out; + } + + inline Node cos(const Node& x) { + Node out(std::cos(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(-out.gradient() * std::sin(x.value())); // dx = dout * (-sin(x)) + }; + return out; + } + + // 指数和对数 + inline Node exp(const Node& x) { + Node out(std::exp(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() * out.value()); // dx = dout * exp(x) = dout * out.value + }; + return out; + } + + inline Node log(const Node& x) { + Node out(std::log(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() / x.value()); // dx = dout * (1/x) + }; + return out; + } + + inline Node pow(const Node& x, double n) { + Node out(std::pow(x.value(), n)); + out.inputs() = {x}; + out.backward() = [out, x, n]() { + x.add_gradient(out.gradient() * n * std::pow(x.value(), n - 1)); // dx = dout * (n * x^(n-1)) + }; + return out; + } + + // 反三角函数 + inline Node asin(const Node& x) { + Node out(std::asin(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (1/sqrt(1-x^2)) + }; + return out; + } + + inline Node acos(const Node& x) { + Node out(std::acos(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(-out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (-1/sqrt(1-x^2)) + }; + return out; + } + + inline Node atan(const Node& x) { + Node out(std::atan(x.value())); + out.inputs() = {x}; + out.backward() = [out, x]() { + x.add_gradient(out.gradient() / (1 + x.value() * x.value())); // dx = dout * (1/(1+x^2)) + }; + return out; + } + +} \ No newline at end of file diff --git a/include/types/common.hpp b/include/types/common.hpp index a084212..5efbe43 100644 --- a/include/types/common.hpp +++ b/include/types/common.hpp @@ -1,6 +1,6 @@ /*返回数据结构体*/ #pragma once -#include +#include "functional" #include struct ADResult{ diff --git a/include/types/dual.hpp b/include/types/dual.hpp index f8294cf..d99a964 100644 --- a/include/types/dual.hpp +++ b/include/types/dual.hpp @@ -1,9 +1,6 @@ /*对偶数的数据类型定义部分*/ #pragma once -#include -#include -#include -#include +#include "types/common.hpp" namespace forwardad { struct Dual { @@ -13,43 +10,4 @@ struct Dual { Dual(double v = 0.0, double d = 0.0) : value(v), deriv(d) {} }; -} // namespace forwardad - -namespace backwardad{ - struct NodeData; - - struct Node{ - std::shared_ptr ptr; - - Node() = default; - explicit Node(double v); - - double value() const; - double gradient() const; - void add_gradient(double g) const; - void set_gradient(double g) const; - const std::vector& inputs() const; - std::vector& inputs(); - std::function& backward(); - const std::function& backward() const; - }; - - struct NodeData { - double value; - double gradient; - std::vector inputs; - std::function backward; - - explicit NodeData(double v) : value(v), gradient(0.0) {} - }; - - inline Node::Node(double v) : ptr(std::make_shared(v)) {} - inline double Node::value() const { return ptr->value; } - inline double Node::gradient() const { return ptr->gradient; } - inline void Node::add_gradient(double g) const { ptr->gradient += g; } - inline void Node::set_gradient(double g) const { ptr->gradient = g; } - inline const std::vector& Node::inputs() const { return ptr->inputs; } - inline std::vector& Node::inputs() { return ptr->inputs; } - inline std::function& Node::backward() { return ptr->backward; } - inline const std::function& Node::backward() const { return ptr->backward; } -} // namespace backwardad \ No newline at end of file +} // namespace forwardad \ No newline at end of file diff --git a/include/types/gradient.hpp b/include/types/gradient.hpp new file mode 100644 index 0000000..bc801d7 --- /dev/null +++ b/include/types/gradient.hpp @@ -0,0 +1,40 @@ +#pragma once +#include "types/common.hpp" +namespace backwardad{ + struct NodeData; + + struct Node{ + std::shared_ptr ptr; + + Node() = default; + explicit Node(double v); + + double value() const; + double gradient() const; + void add_gradient(double g) const; + void set_gradient(double g) const; + const std::vector& inputs() const; + std::vector& inputs(); + std::function& backward(); + const std::function& backward() const; + }; + + struct NodeData { + double value; + double gradient; + std::vector inputs; + std::function backward; + + explicit NodeData(double v) : value(v), gradient(0.0) {} + }; + + inline Node::Node(double v) : ptr(std::make_shared(v)) {} + inline double Node::value() const { return ptr->value; } + inline double Node::gradient() const { return ptr->gradient; } + inline void Node::add_gradient(double g) const { ptr->gradient += g; } + inline void Node::set_gradient(double g) const { ptr->gradient = g; } + inline const std::vector& Node::inputs() const { return ptr->inputs; } + inline std::vector& Node::inputs() { return ptr->inputs; } + inline std::function& Node::backward() { return ptr->backward; } + inline const std::function& Node::backward() const { return ptr->backward; } +} // namespace backwardad \ No newline at end of file