feat: 前向自动微分纯头文件库
This commit is contained in:
65
include/dual.hpp
Normal file
65
include/dual.hpp
Normal file
@@ -0,0 +1,65 @@
|
||||
/*对偶数的声明部分*/
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "types/dual.hpp"
|
||||
|
||||
// 前向自动微分的运算部分,两个deriv相乘为0
|
||||
namespace forwardad{
|
||||
// 加减乘除
|
||||
inline Dual operator+(const Dual& a, const Dual& b) {
|
||||
return Dual(a.value + b.value, a.deriv + b.deriv);
|
||||
}
|
||||
|
||||
inline Dual operator-(const Dual& a, const Dual& b) {
|
||||
return Dual(a.value - b.value, a.deriv - b.deriv);
|
||||
}
|
||||
|
||||
inline Dual operator*(const Dual& a, const Dual& b) {
|
||||
return Dual(a.value * b.value, a.value * b.deriv + a.deriv * b.value);
|
||||
}
|
||||
|
||||
inline Dual operator/(const Dual& a, const Dual& b) {
|
||||
return Dual(a.value / b.value, (a.deriv * b.value - a.value * b.deriv) / (b.value * b.value));
|
||||
}
|
||||
|
||||
// 下面的函数需要用到链式法则(使用泰勒展开后保留一次项)
|
||||
// 三角函数
|
||||
inline Dual sin(const Dual& x) {
|
||||
return Dual(std::sin(x.value), std::cos(x.value) * x.deriv);
|
||||
}
|
||||
|
||||
inline Dual cos(const Dual& x) {
|
||||
return Dual(std::cos(x.value), -std::sin(x.value) * x.deriv);
|
||||
}
|
||||
|
||||
// 指数和对数
|
||||
inline Dual exp(const Dual& x) {
|
||||
double exp_val = std::exp(x.value);
|
||||
return Dual(exp_val, exp_val * x.deriv);
|
||||
}
|
||||
|
||||
inline Dual log(const Dual& x) {
|
||||
return Dual(std::log(x.value), x.deriv / x.value);
|
||||
}
|
||||
|
||||
// 幂函数
|
||||
inline Dual pow(const Dual& x, double n) {
|
||||
double pow_val = std::pow(x.value, n);
|
||||
return Dual(pow_val, n * std::pow(x.value, n - 1) * x.deriv);
|
||||
}
|
||||
|
||||
// 反三角函数
|
||||
inline Dual asin(const Dual& x) {
|
||||
return Dual(std::asin(x.value), x.deriv / std::sqrt(1 - x.value * x.value));
|
||||
}
|
||||
|
||||
inline Dual acos(const Dual& x) {
|
||||
return Dual(std::acos(x.value), -x.deriv / std::sqrt(1 - x.value * x.value));
|
||||
}
|
||||
|
||||
inline Dual atan(const Dual& x) {
|
||||
return Dual(std::atan(x.value), x.deriv / (1 + x.value * x.value));
|
||||
}
|
||||
}// namespace forwardad
|
||||
38
include/forwardad.hpp
Normal file
38
include/forwardad.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
/*前向自动微分的声明部分*/
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "dual.hpp"
|
||||
#include "types/common.hpp"
|
||||
|
||||
namespace forwardad{
|
||||
|
||||
template <typename Func, typename... Args>
|
||||
Result diff(const Func& f, Args... args) {
|
||||
Result res;
|
||||
constexpr size_t N = sizeof...(Args);
|
||||
res.gradient.resize(N);
|
||||
|
||||
// 1. 计算函数值(所有导数为 0)
|
||||
Dual value_res = f(Dual((double)args, 0.0)...);
|
||||
res.value = value_res.value;
|
||||
|
||||
// 2. 对每个输入变量求偏导(seed 依次设为 1)
|
||||
double args_arr[] = { (double)args... };
|
||||
[&]<size_t... Is>(std::index_sequence<Is...>) {
|
||||
([&]() {
|
||||
// 创建 Dual 输入,第 Is 个变量导数=1,其余=0
|
||||
Dual inputs[N];
|
||||
for (size_t j = 0; j < N; ++j)
|
||||
inputs[j] = Dual(args_arr[j], j == Is ? 1.0 : 0.0);
|
||||
// 用内层 index_sequence 展开所有 N 个参数传给 f
|
||||
res.gradient[Is] = [&]<size_t... Js>(std::index_sequence<Js...>) {
|
||||
return f(inputs[Js]...).deriv;
|
||||
}(std::make_index_sequence<N>{});
|
||||
}(), ...);
|
||||
}(std::make_index_sequence<N>{});
|
||||
|
||||
return res;
|
||||
}
|
||||
} // namespace forwardad
|
||||
1
include/types/README.md
Normal file
1
include/types/README.md
Normal file
@@ -0,0 +1 @@
|
||||
### 类模型定义
|
||||
9
include/types/common.hpp
Normal file
9
include/types/common.hpp
Normal file
@@ -0,0 +1,9 @@
|
||||
/*返回数据结构体*/
|
||||
#include <vector>
|
||||
|
||||
namespace forwardad{
|
||||
struct Result{
|
||||
double value; // 函数值
|
||||
std::vector<double> gradient; // 梯度
|
||||
};
|
||||
}
|
||||
12
include/types/dual.hpp
Normal file
12
include/types/dual.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
/*对偶数的数据类型定义部分*/
|
||||
#pragma once
|
||||
|
||||
namespace forwardad {
|
||||
struct Dual {
|
||||
double value;
|
||||
double deriv;
|
||||
|
||||
Dual(double v = 0.0, double d = 0.0) : value(v), deriv(d) {}
|
||||
|
||||
};
|
||||
} // namespace forwardad
|
||||
Reference in New Issue
Block a user