#include #include "forwardad.hpp" #include "backwardad.hpp" #include "types/dual.hpp" // 一阶测试函数 forwardad::Dual f1(forwardad::Dual x) { return x + x*x + pow(x,3); } // 二阶测试函数 forwardad::Dual g1(forwardad::Dual x, forwardad::Dual y) { return exp(x) * log(y); } // 高阶测试函数 forwardad::Dual h1(forwardad::Dual x, forwardad::Dual y, forwardad::Dual z) { return pow(x, 3) + pow(y, 2) + z + cos(x * y * z); } backwardad::Node f2(backwardad::Node x) { return x + x*x + pow(x,3); } backwardad::Node g2(backwardad::Node x, backwardad::Node y) { return exp(x) * log(y); } backwardad::Node h2(backwardad::Node x, backwardad::Node y, backwardad::Node z) { return pow(x, 3) + pow(y, 2) + z + cos(x * y * z); } int test_forward_ad() { std::cout << "Testing f(x) = x^2 + sin(x) at x=2.0\n"; auto r = forwardad::diff(f1, 2.0); std::cout << "value = " << r.value << "\n"; std::cout << "grad = " << r.gradient[0] << "\n"; std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n"; auto r2 = forwardad::diff(g1, 1.0, 2.0); std::cout << "value = " << r2.value << "\n"; std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n"; std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n"; auto r3 = forwardad::diff(h1, 1.0, 2.0, 3.0); std::cout << "value = " << r3.value << "\n"; std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n"; return 0; } int test_backward_ad() { std::cout << "Testing f(x) = x^2 + sin(x) at x=2.0\n"; auto r = backwardad::diff(f2, 2.0); std::cout << "value = " << r.value << "\n"; std::cout << "grad = " << r.gradient[0] << "\n"; std::cout << "\nTesting g(x,y) = exp(x)*log(y) at (x,y)=(1.0, 2.0)\n"; auto r2 = backwardad::diff(g2, 1.0, 2.0); std::cout << "value = " << r2.value << "\n"; std::cout << "grad = (" << r2.gradient[0] << ", " << r2.gradient[1] << ")\n"; std::cout << "\nTesting h(x,y,z) = x^3 + y^2 + z + cos(x*y*z) at (x,y,z)=(1.0, 2.0, 3.0)\n"; auto r3 = backwardad::diff(h2, 1.0, 2.0, 3.0); std::cout << "value = " << r3.value << "\n"; std::cout << "grad = (" << r3.gradient[0] << ", " << r3.gradient[1] << ", " << r3.gradient[2] << ")\n"; return 0; } int main() { std::cout << "=== Forward AD Tests ===\n"; test_forward_ad(); std::cout << "\n=== Backward AD Tests ===\n"; test_backward_ad(); return 0; }