2026-04-07 13:38:19 +08:00
|
|
|
#pragma once
|
2026-04-07 13:55:32 +08:00
|
|
|
#include "types/node.hpp"
|
2026-04-07 13:38:19 +08:00
|
|
|
|
|
|
|
|
namespace backwardad {
|
|
|
|
|
inline Node operator+(const Node& a, const Node& b) {
|
|
|
|
|
Node out(a.value() + b.value());
|
|
|
|
|
out.inputs() = {a, b};
|
|
|
|
|
out.backward() = [out, a, b]() {
|
|
|
|
|
a.add_gradient(out.gradient()); // da = dout * 1
|
|
|
|
|
b.add_gradient(out.gradient()); // db = dout * 1
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline Node operator-(const Node& a, const Node& b) {
|
|
|
|
|
Node out(a.value() - b.value());
|
|
|
|
|
out.inputs() = {a, b};
|
|
|
|
|
out.backward() = [out, a, b]() {
|
|
|
|
|
a.add_gradient(out.gradient()); // da = dout * 1
|
|
|
|
|
b.add_gradient(-out.gradient()); // db = dout * (-1)
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline Node operator*(const Node& a, const Node& b) {
|
|
|
|
|
Node out(a.value() * b.value());
|
|
|
|
|
out.inputs() = {a, b};
|
|
|
|
|
out.backward() = [out, a, b]() {
|
|
|
|
|
a.add_gradient(out.gradient() * b.value()); // da = dout * b
|
|
|
|
|
b.add_gradient(out.gradient() * a.value()); // db = dout * a
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline Node operator/(const Node& a, const Node& b) {
|
|
|
|
|
Node out(a.value() / b.value());
|
|
|
|
|
out.inputs() = {a, b};
|
|
|
|
|
out.backward() = [out, a, b]() {
|
|
|
|
|
a.add_gradient(out.gradient() / b.value()); // da = dout * (1/b)
|
|
|
|
|
b.add_gradient(-out.gradient() * a.value() / (b.value() * b.value())); // db = dout * (-a/b^2)
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 下面的函数需要用到链式法则(使用泰勒展开后保留一次项)
|
|
|
|
|
// 三角函数
|
|
|
|
|
inline Node sin(const Node& x) {
|
|
|
|
|
Node out(std::sin(x.value()));
|
|
|
|
|
out.inputs() = {x};
|
|
|
|
|
out.backward() = [out, x]() {
|
|
|
|
|
x.add_gradient(out.gradient() * std::cos(x.value())); // dx = dout * cos(x)
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline Node cos(const Node& x) {
|
|
|
|
|
Node out(std::cos(x.value()));
|
|
|
|
|
out.inputs() = {x};
|
|
|
|
|
out.backward() = [out, x]() {
|
|
|
|
|
x.add_gradient(-out.gradient() * std::sin(x.value())); // dx = dout * (-sin(x))
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 指数和对数
|
|
|
|
|
inline Node exp(const Node& x) {
|
|
|
|
|
Node out(std::exp(x.value()));
|
|
|
|
|
out.inputs() = {x};
|
|
|
|
|
out.backward() = [out, x]() {
|
|
|
|
|
x.add_gradient(out.gradient() * out.value()); // dx = dout * exp(x) = dout * out.value
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline Node log(const Node& x) {
|
|
|
|
|
Node out(std::log(x.value()));
|
|
|
|
|
out.inputs() = {x};
|
|
|
|
|
out.backward() = [out, x]() {
|
|
|
|
|
x.add_gradient(out.gradient() / x.value()); // dx = dout * (1/x)
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline Node pow(const Node& x, double n) {
|
|
|
|
|
Node out(std::pow(x.value(), n));
|
|
|
|
|
out.inputs() = {x};
|
|
|
|
|
out.backward() = [out, x, n]() {
|
|
|
|
|
x.add_gradient(out.gradient() * n * std::pow(x.value(), n - 1)); // dx = dout * (n * x^(n-1))
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 反三角函数
|
|
|
|
|
inline Node asin(const Node& x) {
|
|
|
|
|
Node out(std::asin(x.value()));
|
|
|
|
|
out.inputs() = {x};
|
|
|
|
|
out.backward() = [out, x]() {
|
|
|
|
|
x.add_gradient(out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (1/sqrt(1-x^2))
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline Node acos(const Node& x) {
|
|
|
|
|
Node out(std::acos(x.value()));
|
|
|
|
|
out.inputs() = {x};
|
|
|
|
|
out.backward() = [out, x]() {
|
|
|
|
|
x.add_gradient(-out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (-1/sqrt(1-x^2))
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline Node atan(const Node& x) {
|
|
|
|
|
Node out(std::atan(x.value()));
|
|
|
|
|
out.inputs() = {x};
|
|
|
|
|
out.backward() = [out, x]() {
|
|
|
|
|
x.add_gradient(out.gradient() / (1 + x.value() * x.value())); // dx = dout * (1/(1+x^2))
|
|
|
|
|
};
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|