From 58aedec43c568ba49a9a6ef45fc9470a42cd9cda Mon Sep 17 00:00:00 2001 From: mayge Date: Wed, 1 Apr 2026 02:39:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=89=8D=E5=90=91=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E5=BE=AE=E5=88=86=E7=BA=AF=E5=A4=B4=E6=96=87=E4=BB=B6=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 ++ CMakeLists.txt | 12 ++++++++ README.md | 1 + include/dual.hpp | 65 ++++++++++++++++++++++++++++++++++++++++ include/forwardad.hpp | 38 +++++++++++++++++++++++ include/types/README.md | 1 + include/types/common.hpp | 9 ++++++ include/types/dual.hpp | 12 ++++++++ src/.gitkeep | 0 tests/basic.cpp | 40 +++++++++++++++++++++++++ 10 files changed, 180 insertions(+) create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 include/dual.hpp create mode 100644 include/forwardad.hpp create mode 100644 include/types/README.md create mode 100644 include/types/common.hpp create mode 100644 include/types/dual.hpp create mode 100644 src/.gitkeep create mode 100644 tests/basic.cpp diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d835f77 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +build +.cache \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..eff1055 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.10) +project(forwardad) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# 头文件路径 +include_directories(include) + +# 测试可执行文件 +add_executable(test_basic tests/basic.cpp) \ No newline at end of file diff --git a/README.md b/README.md index e69de29..b4868d3 100644 --- a/README.md +++ b/README.md @@ -0,0 +1 @@ +# 前向自动微分 \ No newline at end of file diff --git a/include/dual.hpp b/include/dual.hpp new file mode 100644 index 0000000..eb2ea6e --- /dev/null +++ b/include/dual.hpp @@ -0,0 +1,65 @@ +/*对偶数的声明部分*/ +#pragma once +#include +#include +#include +#include "types/dual.hpp" + +// 前向自动微分的运算部分,两个deriv相乘为0 +namespace forwardad{ + // 加减乘除 + inline Dual operator+(const Dual& a, const Dual& b) { + return Dual(a.value + b.value, a.deriv + b.deriv); + } + + inline Dual operator-(const Dual& a, const Dual& b) { + return Dual(a.value - b.value, a.deriv - b.deriv); + } + + inline Dual operator*(const Dual& a, const Dual& b) { + return Dual(a.value * b.value, a.value * b.deriv + a.deriv * b.value); + } + + inline Dual operator/(const Dual& a, const Dual& b) { + return Dual(a.value / b.value, (a.deriv * b.value - a.value * b.deriv) / (b.value * b.value)); + } + + // 下面的函数需要用到链式法则(使用泰勒展开后保留一次项) + // 三角函数 + inline Dual sin(const Dual& x) { + return Dual(std::sin(x.value), std::cos(x.value) * x.deriv); + } + + inline Dual cos(const Dual& x) { + return Dual(std::cos(x.value), -std::sin(x.value) * x.deriv); + } + + // 指数和对数 + inline Dual exp(const Dual& x) { + double exp_val = std::exp(x.value); + return Dual(exp_val, exp_val * x.deriv); + } + + inline Dual log(const Dual& x) { + return Dual(std::log(x.value), x.deriv / x.value); + } + + // 幂函数 + inline Dual pow(const Dual& x, double n) { + double pow_val = std::pow(x.value, n); + return Dual(pow_val, n * std::pow(x.value, n - 1) * x.deriv); + } + + // 反三角函数 + inline Dual asin(const Dual& x) { + return Dual(std::asin(x.value), x.deriv / std::sqrt(1 - x.value * x.value)); + } + + inline Dual acos(const Dual& x) { + return Dual(std::acos(x.value), -x.deriv / std::sqrt(1 - x.value * x.value)); + } + + inline Dual atan(const Dual& x) { + return Dual(std::atan(x.value), x.deriv / (1 + x.value * x.value)); + } +}// namespace forwardad \ No newline at end of file diff --git a/include/forwardad.hpp b/include/forwardad.hpp new file mode 100644 index 0000000..a1385de --- /dev/null +++ b/include/forwardad.hpp @@ -0,0 +1,38 @@ +/*前向自动微分的声明部分*/ +#pragma once +#include +#include +#include +#include "dual.hpp" +#include "types/common.hpp" + +namespace forwardad{ + +template +Result diff(const Func& f, Args... args) { + Result res; + constexpr size_t N = sizeof...(Args); + res.gradient.resize(N); + + // 1. 计算函数值(所有导数为 0) + Dual value_res = f(Dual((double)args, 0.0)...); + res.value = value_res.value; + + // 2. 对每个输入变量求偏导(seed 依次设为 1) + double args_arr[] = { (double)args... }; + [&](std::index_sequence) { + ([&]() { + // 创建 Dual 输入,第 Is 个变量导数=1,其余=0 + Dual inputs[N]; + for (size_t j = 0; j < N; ++j) + inputs[j] = Dual(args_arr[j], j == Is ? 1.0 : 0.0); + // 用内层 index_sequence 展开所有 N 个参数传给 f + res.gradient[Is] = [&](std::index_sequence) { + return f(inputs[Js]...).deriv; + }(std::make_index_sequence{}); + }(), ...); + }(std::make_index_sequence{}); + + return res; +} +} // namespace forwardad \ No newline at end of file diff --git a/include/types/README.md b/include/types/README.md new file mode 100644 index 0000000..58969b1 --- /dev/null +++ b/include/types/README.md @@ -0,0 +1 @@ +### 类模型定义 \ No newline at end of file diff --git a/include/types/common.hpp b/include/types/common.hpp new file mode 100644 index 0000000..41ff214 --- /dev/null +++ b/include/types/common.hpp @@ -0,0 +1,9 @@ +/*返回数据结构体*/ +#include + +namespace forwardad{ + struct Result{ + double value; // 函数值 + std::vector gradient; // 梯度 + }; +} \ No newline at end of file diff --git a/include/types/dual.hpp b/include/types/dual.hpp new file mode 100644 index 0000000..630588c --- /dev/null +++ b/include/types/dual.hpp @@ -0,0 +1,12 @@ +/*对偶数的数据类型定义部分*/ +#pragma once + +namespace forwardad { +struct Dual { + double value; + double deriv; + + Dual(double v = 0.0, double d = 0.0) : value(v), deriv(d) {} + +}; +} // namespace forwardad \ No newline at end of file diff --git a/src/.gitkeep b/src/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/basic.cpp b/tests/basic.cpp new file mode 100644 index 0000000..c734a15 --- /dev/null +++ b/tests/basic.cpp @@ -0,0 +1,40 @@ +#include +#include "forwardad.hpp" +#include "types/dual.hpp" + +using namespace forwardad; + +// 一阶测试函数 +Dual f(Dual x) { + return x + x*x + pow(x,3); +} + +// 二阶测试函数 +Dual g(Dual x, Dual y) { + return exp(x) * log(y); +} + +// 高阶测试函数 +Dual h(Dual x, Dual y, Dual z) { + return pow(x, 3) + pow(y, 2) + z + cos(x * y * z); +} + +int main() { + std::cout << "Testing f(x) = x^2 + sin(x) at x=2.0\n"; + auto r = diff(f, 2.0); + std::cout << "value = " << r.value << "\n"; + std::cout << "grad = " << r.gradient[0] << "\n"; + + std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n"; + auto r2 = diff(g, 1.0, 2.0); + std::cout << "value = " << r2.value << "\n"; + std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n"; + + std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n"; + auto r3 = diff(h, 1.0, 2.0, 3.0); + std::cout << "value = " << r3.value << "\n"; + std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n"; + + + return 0; +} \ No newline at end of file