fix: 分离文件
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "dual.hpp"
|
||||
#include "types/gradient.hpp"
|
||||
#include "gradient.hpp"
|
||||
|
||||
namespace backwardad {
|
||||
|
||||
|
||||
124
include/dual.hpp
124
include/dual.hpp
@@ -1,9 +1,6 @@
|
||||
/*对偶数的声明部分*/
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "types/common.hpp"
|
||||
#include "types/dual.hpp"
|
||||
|
||||
// 前向自动微分的运算部分,两个deriv相乘为0
|
||||
@@ -63,123 +60,4 @@ namespace forwardad{
|
||||
inline Dual atan(const Dual& x) {
|
||||
return Dual(std::atan(x.value), x.deriv / (1 + x.value * x.value));
|
||||
}
|
||||
}// namespace forwardad
|
||||
|
||||
namespace backwardad {
|
||||
inline Node operator+(const Node& a, const Node& b) {
|
||||
Node out(a.value() + b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient()); // da = dout * 1
|
||||
b.add_gradient(out.gradient()); // db = dout * 1
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node operator-(const Node& a, const Node& b) {
|
||||
Node out(a.value() - b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient()); // da = dout * 1
|
||||
b.add_gradient(-out.gradient()); // db = dout * (-1)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node operator*(const Node& a, const Node& b) {
|
||||
Node out(a.value() * b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient() * b.value()); // da = dout * b
|
||||
b.add_gradient(out.gradient() * a.value()); // db = dout * a
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node operator/(const Node& a, const Node& b) {
|
||||
Node out(a.value() / b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient() / b.value()); // da = dout * (1/b)
|
||||
b.add_gradient(-out.gradient() * a.value() / (b.value() * b.value())); // db = dout * (-a/b^2)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
// 下面的函数需要用到链式法则(使用泰勒展开后保留一次项)
|
||||
// 三角函数
|
||||
inline Node sin(const Node& x) {
|
||||
Node out(std::sin(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() * std::cos(x.value())); // dx = dout * cos(x)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node cos(const Node& x) {
|
||||
Node out(std::cos(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(-out.gradient() * std::sin(x.value())); // dx = dout * (-sin(x))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
// 指数和对数
|
||||
inline Node exp(const Node& x) {
|
||||
Node out(std::exp(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() * out.value()); // dx = dout * exp(x) = dout * out.value
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node log(const Node& x) {
|
||||
Node out(std::log(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() / x.value()); // dx = dout * (1/x)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node pow(const Node& x, double n) {
|
||||
Node out(std::pow(x.value(), n));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x, n]() {
|
||||
x.add_gradient(out.gradient() * n * std::pow(x.value(), n - 1)); // dx = dout * (n * x^(n-1))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
// 反三角函数
|
||||
inline Node asin(const Node& x) {
|
||||
Node out(std::asin(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (1/sqrt(1-x^2))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node acos(const Node& x) {
|
||||
Node out(std::acos(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(-out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (-1/sqrt(1-x^2))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node atan(const Node& x) {
|
||||
Node out(std::atan(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() / (1 + x.value() * x.value())); // dx = dout * (1/(1+x^2))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
}
|
||||
}// namespace forwardad
|
||||
@@ -1,8 +1,6 @@
|
||||
/*前向自动微分的声明部分*/
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "dual.hpp"
|
||||
#include "types/common.hpp"
|
||||
|
||||
|
||||
121
include/gradient.hpp
Normal file
121
include/gradient.hpp
Normal file
@@ -0,0 +1,121 @@
|
||||
#pragma once
|
||||
#include "types/gradient.hpp"
|
||||
|
||||
namespace backwardad {
|
||||
inline Node operator+(const Node& a, const Node& b) {
|
||||
Node out(a.value() + b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient()); // da = dout * 1
|
||||
b.add_gradient(out.gradient()); // db = dout * 1
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node operator-(const Node& a, const Node& b) {
|
||||
Node out(a.value() - b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient()); // da = dout * 1
|
||||
b.add_gradient(-out.gradient()); // db = dout * (-1)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node operator*(const Node& a, const Node& b) {
|
||||
Node out(a.value() * b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient() * b.value()); // da = dout * b
|
||||
b.add_gradient(out.gradient() * a.value()); // db = dout * a
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node operator/(const Node& a, const Node& b) {
|
||||
Node out(a.value() / b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient() / b.value()); // da = dout * (1/b)
|
||||
b.add_gradient(-out.gradient() * a.value() / (b.value() * b.value())); // db = dout * (-a/b^2)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
// 下面的函数需要用到链式法则(使用泰勒展开后保留一次项)
|
||||
// 三角函数
|
||||
inline Node sin(const Node& x) {
|
||||
Node out(std::sin(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() * std::cos(x.value())); // dx = dout * cos(x)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node cos(const Node& x) {
|
||||
Node out(std::cos(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(-out.gradient() * std::sin(x.value())); // dx = dout * (-sin(x))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
// 指数和对数
|
||||
inline Node exp(const Node& x) {
|
||||
Node out(std::exp(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() * out.value()); // dx = dout * exp(x) = dout * out.value
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node log(const Node& x) {
|
||||
Node out(std::log(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() / x.value()); // dx = dout * (1/x)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node pow(const Node& x, double n) {
|
||||
Node out(std::pow(x.value(), n));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x, n]() {
|
||||
x.add_gradient(out.gradient() * n * std::pow(x.value(), n - 1)); // dx = dout * (n * x^(n-1))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
// 反三角函数
|
||||
inline Node asin(const Node& x) {
|
||||
Node out(std::asin(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (1/sqrt(1-x^2))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node acos(const Node& x) {
|
||||
Node out(std::acos(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(-out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (-1/sqrt(1-x^2))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node atan(const Node& x) {
|
||||
Node out(std::atan(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() / (1 + x.value() * x.value())); // dx = dout * (1/(1+x^2))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
/*返回数据结构体*/
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include "functional"
|
||||
#include <vector>
|
||||
|
||||
struct ADResult{
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
/*对偶数的数据类型定义部分*/
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "types/common.hpp"
|
||||
|
||||
namespace forwardad {
|
||||
struct Dual {
|
||||
@@ -13,43 +10,4 @@ struct Dual {
|
||||
Dual(double v = 0.0, double d = 0.0) : value(v), deriv(d) {}
|
||||
|
||||
};
|
||||
} // namespace forwardad
|
||||
|
||||
namespace backwardad{
|
||||
struct NodeData;
|
||||
|
||||
struct Node{
|
||||
std::shared_ptr<NodeData> ptr;
|
||||
|
||||
Node() = default;
|
||||
explicit Node(double v);
|
||||
|
||||
double value() const;
|
||||
double gradient() const;
|
||||
void add_gradient(double g) const;
|
||||
void set_gradient(double g) const;
|
||||
const std::vector<Node>& inputs() const;
|
||||
std::vector<Node>& inputs();
|
||||
std::function<void()>& backward();
|
||||
const std::function<void()>& backward() const;
|
||||
};
|
||||
|
||||
struct NodeData {
|
||||
double value;
|
||||
double gradient;
|
||||
std::vector<Node> inputs;
|
||||
std::function<void()> backward;
|
||||
|
||||
explicit NodeData(double v) : value(v), gradient(0.0) {}
|
||||
};
|
||||
|
||||
inline Node::Node(double v) : ptr(std::make_shared<NodeData>(v)) {}
|
||||
inline double Node::value() const { return ptr->value; }
|
||||
inline double Node::gradient() const { return ptr->gradient; }
|
||||
inline void Node::add_gradient(double g) const { ptr->gradient += g; }
|
||||
inline void Node::set_gradient(double g) const { ptr->gradient = g; }
|
||||
inline const std::vector<Node>& Node::inputs() const { return ptr->inputs; }
|
||||
inline std::vector<Node>& Node::inputs() { return ptr->inputs; }
|
||||
inline std::function<void()>& Node::backward() { return ptr->backward; }
|
||||
inline const std::function<void()>& Node::backward() const { return ptr->backward; }
|
||||
} // namespace backwardad
|
||||
} // namespace forwardad
|
||||
40
include/types/gradient.hpp
Normal file
40
include/types/gradient.hpp
Normal file
@@ -0,0 +1,40 @@
|
||||
#pragma once
|
||||
#include "types/common.hpp"
|
||||
namespace backwardad{
|
||||
struct NodeData;
|
||||
|
||||
struct Node{
|
||||
std::shared_ptr<NodeData> ptr;
|
||||
|
||||
Node() = default;
|
||||
explicit Node(double v);
|
||||
|
||||
double value() const;
|
||||
double gradient() const;
|
||||
void add_gradient(double g) const;
|
||||
void set_gradient(double g) const;
|
||||
const std::vector<Node>& inputs() const;
|
||||
std::vector<Node>& inputs();
|
||||
std::function<void()>& backward();
|
||||
const std::function<void()>& backward() const;
|
||||
};
|
||||
|
||||
struct NodeData {
|
||||
double value;
|
||||
double gradient;
|
||||
std::vector<Node> inputs;
|
||||
std::function<void()> backward;
|
||||
|
||||
explicit NodeData(double v) : value(v), gradient(0.0) {}
|
||||
};
|
||||
|
||||
inline Node::Node(double v) : ptr(std::make_shared<NodeData>(v)) {}
|
||||
inline double Node::value() const { return ptr->value; }
|
||||
inline double Node::gradient() const { return ptr->gradient; }
|
||||
inline void Node::add_gradient(double g) const { ptr->gradient += g; }
|
||||
inline void Node::set_gradient(double g) const { ptr->gradient = g; }
|
||||
inline const std::vector<Node>& Node::inputs() const { return ptr->inputs; }
|
||||
inline std::vector<Node>& Node::inputs() { return ptr->inputs; }
|
||||
inline std::function<void()>& Node::backward() { return ptr->backward; }
|
||||
inline const std::function<void()>& Node::backward() const { return ptr->backward; }
|
||||
} // namespace backwardad
|
||||
Reference in New Issue
Block a user