fix: 分离文件

This commit is contained in:
mayge
2026-04-07 13:38:19 +08:00
parent 2d52d70ae0
commit 8795c5bb37
7 changed files with 167 additions and 172 deletions

View File

@@ -1,6 +1,6 @@
/*返回数据结构体*/
#pragma once
#include <functional>
#include "functional"
#include <vector>
struct ADResult{

View File

@@ -1,9 +1,6 @@
/*对偶数的数据类型定义部分*/
#pragma once
#include <iostream>
#include <functional>
#include <memory>
#include <vector>
#include "types/common.hpp"
namespace forwardad {
struct Dual {
@@ -13,43 +10,4 @@ struct Dual {
Dual(double v = 0.0, double d = 0.0) : value(v), deriv(d) {}
};
} // namespace forwardad
namespace backwardad{
struct NodeData;
struct Node{
std::shared_ptr<NodeData> ptr;
Node() = default;
explicit Node(double v);
double value() const;
double gradient() const;
void add_gradient(double g) const;
void set_gradient(double g) const;
const std::vector<Node>& inputs() const;
std::vector<Node>& inputs();
std::function<void()>& backward();
const std::function<void()>& backward() const;
};
struct NodeData {
double value;
double gradient;
std::vector<Node> inputs;
std::function<void()> backward;
explicit NodeData(double v) : value(v), gradient(0.0) {}
};
inline Node::Node(double v) : ptr(std::make_shared<NodeData>(v)) {}
inline double Node::value() const { return ptr->value; }
inline double Node::gradient() const { return ptr->gradient; }
inline void Node::add_gradient(double g) const { ptr->gradient += g; }
inline void Node::set_gradient(double g) const { ptr->gradient = g; }
inline const std::vector<Node>& Node::inputs() const { return ptr->inputs; }
inline std::vector<Node>& Node::inputs() { return ptr->inputs; }
inline std::function<void()>& Node::backward() { return ptr->backward; }
inline const std::function<void()>& Node::backward() const { return ptr->backward; }
} // namespace backwardad
} // namespace forwardad

View File

@@ -0,0 +1,40 @@
#pragma once
#include "types/common.hpp"
namespace backwardad{
struct NodeData;
struct Node{
std::shared_ptr<NodeData> ptr;
Node() = default;
explicit Node(double v);
double value() const;
double gradient() const;
void add_gradient(double g) const;
void set_gradient(double g) const;
const std::vector<Node>& inputs() const;
std::vector<Node>& inputs();
std::function<void()>& backward();
const std::function<void()>& backward() const;
};
struct NodeData {
double value;
double gradient;
std::vector<Node> inputs;
std::function<void()> backward;
explicit NodeData(double v) : value(v), gradient(0.0) {}
};
inline Node::Node(double v) : ptr(std::make_shared<NodeData>(v)) {}
inline double Node::value() const { return ptr->value; }
inline double Node::gradient() const { return ptr->gradient; }
inline void Node::add_gradient(double g) const { ptr->gradient += g; }
inline void Node::set_gradient(double g) const { ptr->gradient = g; }
inline const std::vector<Node>& Node::inputs() const { return ptr->inputs; }
inline std::vector<Node>& Node::inputs() { return ptr->inputs; }
inline std::function<void()>& Node::backward() { return ptr->backward; }
inline const std::function<void()>& Node::backward() const { return ptr->backward; }
} // namespace backwardad