feat: 后向自动微分
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
/*返回数据结构体*/
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
namespace forwardad{
|
||||
struct Result{
|
||||
double value; // 函数值
|
||||
std::vector<double> gradient; // 梯度
|
||||
};
|
||||
}
|
||||
struct ADResult{
|
||||
double value; // 函数值
|
||||
std::vector<double> gradient; // 梯度
|
||||
};
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
/*对偶数的数据类型定义部分*/
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace forwardad {
|
||||
struct Dual {
|
||||
@@ -9,4 +13,43 @@ struct Dual {
|
||||
Dual(double v = 0.0, double d = 0.0) : value(v), deriv(d) {}
|
||||
|
||||
};
|
||||
} // namespace forwardad
|
||||
} // namespace forwardad
|
||||
|
||||
namespace backwardad{
|
||||
struct NodeData;
|
||||
|
||||
struct Node{
|
||||
std::shared_ptr<NodeData> ptr;
|
||||
|
||||
Node() = default;
|
||||
explicit Node(double v);
|
||||
|
||||
double value() const;
|
||||
double gradient() const;
|
||||
void add_gradient(double g) const;
|
||||
void set_gradient(double g) const;
|
||||
const std::vector<Node>& inputs() const;
|
||||
std::vector<Node>& inputs();
|
||||
std::function<void()>& backward();
|
||||
const std::function<void()>& backward() const;
|
||||
};
|
||||
|
||||
struct NodeData {
|
||||
double value;
|
||||
double gradient;
|
||||
std::vector<Node> inputs;
|
||||
std::function<void()> backward;
|
||||
|
||||
explicit NodeData(double v) : value(v), gradient(0.0) {}
|
||||
};
|
||||
|
||||
inline Node::Node(double v) : ptr(std::make_shared<NodeData>(v)) {}
|
||||
inline double Node::value() const { return ptr->value; }
|
||||
inline double Node::gradient() const { return ptr->gradient; }
|
||||
inline void Node::add_gradient(double g) const { ptr->gradient += g; }
|
||||
inline void Node::set_gradient(double g) const { ptr->gradient = g; }
|
||||
inline const std::vector<Node>& Node::inputs() const { return ptr->inputs; }
|
||||
inline std::vector<Node>& Node::inputs() { return ptr->inputs; }
|
||||
inline std::function<void()>& Node::backward() { return ptr->backward; }
|
||||
inline const std::function<void()>& Node::backward() const { return ptr->backward; }
|
||||
} // namespace backwardad
|
||||
Reference in New Issue
Block a user