Files
AutomaticDifferentiation/include/dual_ops.hpp
2026-04-07 13:55:32 +08:00

63 lines
1.9 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*对偶数的声明部分*/
#pragma once
#include <cmath>
#include "types/dual.hpp"
// 前向自动微分的运算部分两个deriv相乘为0
namespace forwardad{
// 加减乘除
inline Dual operator+(const Dual& a, const Dual& b) {
return Dual(a.value + b.value, a.deriv + b.deriv);
}
inline Dual operator-(const Dual& a, const Dual& b) {
return Dual(a.value - b.value, a.deriv - b.deriv);
}
inline Dual operator*(const Dual& a, const Dual& b) {
return Dual(a.value * b.value, a.value * b.deriv + a.deriv * b.value);
}
inline Dual operator/(const Dual& a, const Dual& b) {
return Dual(a.value / b.value, (a.deriv * b.value - a.value * b.deriv) / (b.value * b.value));
}
// 下面的函数需要用到链式法则(使用泰勒展开后保留一次项)
// 三角函数
inline Dual sin(const Dual& x) {
return Dual(std::sin(x.value), std::cos(x.value) * x.deriv);
}
inline Dual cos(const Dual& x) {
return Dual(std::cos(x.value), -std::sin(x.value) * x.deriv);
}
// 指数和对数
inline Dual exp(const Dual& x) {
double exp_val = std::exp(x.value);
return Dual(exp_val, exp_val * x.deriv);
}
inline Dual log(const Dual& x) {
return Dual(std::log(x.value), x.deriv / x.value);
}
// 幂函数
inline Dual pow(const Dual& x, double n) {
double pow_val = std::pow(x.value, n);
return Dual(pow_val, n * std::pow(x.value, n - 1) * x.deriv);
}
// 反三角函数
inline Dual asin(const Dual& x) {
return Dual(std::asin(x.value), x.deriv / std::sqrt(1 - x.value * x.value));
}
inline Dual acos(const Dual& x) {
return Dual(std::acos(x.value), -x.deriv / std::sqrt(1 - x.value * x.value));
}
inline Dual atan(const Dual& x) {
return Dual(std::atan(x.value), x.deriv / (1 + x.value * x.value));
}
}// namespace forwardad