/*对偶数的声明部分*/ #pragma once #include #include #include #include "types/common.hpp" #include "types/dual.hpp" // 前向自动微分的运算部分,两个deriv相乘为0 namespace forwardad{ // 加减乘除 inline Dual operator+(const Dual& a, const Dual& b) { return Dual(a.value + b.value, a.deriv + b.deriv); } inline Dual operator-(const Dual& a, const Dual& b) { return Dual(a.value - b.value, a.deriv - b.deriv); } inline Dual operator*(const Dual& a, const Dual& b) { return Dual(a.value * b.value, a.value * b.deriv + a.deriv * b.value); } inline Dual operator/(const Dual& a, const Dual& b) { return Dual(a.value / b.value, (a.deriv * b.value - a.value * b.deriv) / (b.value * b.value)); } // 下面的函数需要用到链式法则(使用泰勒展开后保留一次项) // 三角函数 inline Dual sin(const Dual& x) { return Dual(std::sin(x.value), std::cos(x.value) * x.deriv); } inline Dual cos(const Dual& x) { return Dual(std::cos(x.value), -std::sin(x.value) * x.deriv); } // 指数和对数 inline Dual exp(const Dual& x) { double exp_val = std::exp(x.value); return Dual(exp_val, exp_val * x.deriv); } inline Dual log(const Dual& x) { return Dual(std::log(x.value), x.deriv / x.value); } // 幂函数 inline Dual pow(const Dual& x, double n) { double pow_val = std::pow(x.value, n); return Dual(pow_val, n * std::pow(x.value, n - 1) * x.deriv); } // 反三角函数 inline Dual asin(const Dual& x) { return Dual(std::asin(x.value), x.deriv / std::sqrt(1 - x.value * x.value)); } inline Dual acos(const Dual& x) { return Dual(std::acos(x.value), -x.deriv / std::sqrt(1 - x.value * x.value)); } inline Dual atan(const Dual& x) { return Dual(std::atan(x.value), x.deriv / (1 + x.value * x.value)); } }// namespace forwardad namespace backwardad { inline Node operator+(const Node& a, const Node& b) { Node out(a.value() + b.value()); out.inputs() = {a, b}; out.backward() = [out, a, b]() { a.add_gradient(out.gradient()); // da = dout * 1 b.add_gradient(out.gradient()); // db = dout * 1 }; return out; } inline Node operator-(const Node& a, const Node& b) { Node out(a.value() - b.value()); out.inputs() = {a, b}; out.backward() = [out, a, b]() { a.add_gradient(out.gradient()); // da = dout * 1 b.add_gradient(-out.gradient()); // db = dout * (-1) }; return out; } inline Node operator*(const Node& a, const Node& b) { Node out(a.value() * b.value()); out.inputs() = {a, b}; out.backward() = [out, a, b]() { a.add_gradient(out.gradient() * b.value()); // da = dout * b b.add_gradient(out.gradient() * a.value()); // db = dout * a }; return out; } inline Node operator/(const Node& a, const Node& b) { Node out(a.value() / b.value()); out.inputs() = {a, b}; out.backward() = [out, a, b]() { a.add_gradient(out.gradient() / b.value()); // da = dout * (1/b) b.add_gradient(-out.gradient() * a.value() / (b.value() * b.value())); // db = dout * (-a/b^2) }; return out; } // 下面的函数需要用到链式法则(使用泰勒展开后保留一次项) // 三角函数 inline Node sin(const Node& x) { Node out(std::sin(x.value())); out.inputs() = {x}; out.backward() = [out, x]() { x.add_gradient(out.gradient() * std::cos(x.value())); // dx = dout * cos(x) }; return out; } inline Node cos(const Node& x) { Node out(std::cos(x.value())); out.inputs() = {x}; out.backward() = [out, x]() { x.add_gradient(-out.gradient() * std::sin(x.value())); // dx = dout * (-sin(x)) }; return out; } // 指数和对数 inline Node exp(const Node& x) { Node out(std::exp(x.value())); out.inputs() = {x}; out.backward() = [out, x]() { x.add_gradient(out.gradient() * out.value()); // dx = dout * exp(x) = dout * out.value }; return out; } inline Node log(const Node& x) { Node out(std::log(x.value())); out.inputs() = {x}; out.backward() = [out, x]() { x.add_gradient(out.gradient() / x.value()); // dx = dout * (1/x) }; return out; } inline Node pow(const Node& x, double n) { Node out(std::pow(x.value(), n)); out.inputs() = {x}; out.backward() = [out, x, n]() { x.add_gradient(out.gradient() * n * std::pow(x.value(), n - 1)); // dx = dout * (n * x^(n-1)) }; return out; } // 反三角函数 inline Node asin(const Node& x) { Node out(std::asin(x.value())); out.inputs() = {x}; out.backward() = [out, x]() { x.add_gradient(out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (1/sqrt(1-x^2)) }; return out; } inline Node acos(const Node& x) { Node out(std::acos(x.value())); out.inputs() = {x}; out.backward() = [out, x]() { x.add_gradient(-out.gradient() / std::sqrt(1 - x.value() * x.value())); // dx = dout * (-1/sqrt(1-x^2)) }; return out; } inline Node atan(const Node& x) { Node out(std::atan(x.value())); out.inputs() = {x}; out.backward() = [out, x]() { x.add_gradient(out.gradient() / (1 + x.value() * x.value())); // dx = dout * (1/(1+x^2)) }; return out; } }