feat: 后向自动微分
This commit is contained in:
122
include/dual.hpp
122
include/dual.hpp
@@ -3,6 +3,7 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "types/common.hpp"
|
||||
#include "types/dual.hpp"
|
||||
|
||||
// 前向自动微分的运算部分,两个deriv相乘为0
|
||||
@@ -62,4 +63,123 @@ namespace forwardad{
|
||||
inline Dual atan(const Dual& x) {
|
||||
return Dual(std::atan(x.value), x.deriv / (1 + x.value * x.value));
|
||||
}
|
||||
}// namespace forwardad
|
||||
}// namespace forwardad
|
||||
|
||||
namespace backwardad {
|
||||
inline Node operator+(const Node& a, const Node& b) {
|
||||
Node out(a.value() + b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient()); // da = dout * 1
|
||||
b.add_gradient(out.gradient()); // db = dout * 1
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node operator-(const Node& a, const Node& b) {
|
||||
Node out(a.value() - b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient()); // da = dout * 1
|
||||
b.add_gradient(-out.gradient()); // db = dout * (-1)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node operator*(const Node& a, const Node& b) {
|
||||
Node out(a.value() * b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient() * b.value()); // da = dout * b
|
||||
b.add_gradient(out.gradient() * a.value()); // db = dout * a
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node operator/(const Node& a, const Node& b) {
|
||||
Node out(a.value() / b.value());
|
||||
out.inputs() = {a, b};
|
||||
out.backward() = [out, a, b]() {
|
||||
a.add_gradient(out.gradient() / b.value()); // da = dout * (1/b)
|
||||
b.add_gradient(-out.gradient() * a.value() / (b.value() * b.value())); // db = dout * (-a/b^2)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
// 下面的函数需要用到链式法则(使用泰勒展开后保留一次项)
|
||||
// 三角函数
|
||||
inline Node sin(const Node& x) {
|
||||
Node out(std::sin(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() * std::cos(x.value())); // dx = dout * cos(x)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node cos(const Node& x) {
|
||||
Node out(std::cos(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(-out.gradient() * std::sin(x.value())); // dx = dout * (-sin(x))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
// 指数和对数
|
||||
inline Node exp(const Node& x) {
|
||||
Node out(std::exp(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() * out.value()); // dx = dout * exp(x) = dout * out.value
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node log(const Node& x) {
|
||||
Node out(std::log(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() / x.value()); // dx = dout * (1/x)
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node pow(const Node& x, double n) {
|
||||
Node out(std::pow(x.value(), n));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x, n]() {
|
||||
x.add_gradient(out.gradient() * n * std::pow(x.value(), n - 1)); // dx = dout * (n * x^(n-1))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
// 反三角函数
|
||||
inline Node asin(const Node& x) {
|
||||
Node out(std::asin(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (1/sqrt(1-x^2))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node acos(const Node& x) {
|
||||
Node out(std::acos(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(-out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (-1/sqrt(1-x^2))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Node atan(const Node& x) {
|
||||
Node out(std::atan(x.value()));
|
||||
out.inputs() = {x};
|
||||
out.backward() = [out, x]() {
|
||||
x.add_gradient(out.gradient() / (1 + x.value() * x.value())); // dx = dout * (1/(1+x^2))
|
||||
};
|
||||
return out;
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user