diff --git a/cart-pole-scalability-results-sleipnir.csv b/cart-pole-scalability-results-sleipnir.csv index b6495ca8..5ea7ee80 100644 --- a/cart-pole-scalability-results-sleipnir.csv +++ b/cart-pole-scalability-results-sleipnir.csv @@ -1,6 +1,6 @@ Samples,Setup time (ms),Solve time (ms) -100,1.69,469.773 -150,1.297,1432.54 -200,1.813,2161.16 -250,2.29,2545.94 -300,2.815,3945 +100,1.515,411.42 +150,0.999,1235.13 +200,1.346,1967.54 +250,1.713,2342.6 +300,2.139,3613.15 diff --git a/cart-pole-scalability-results.png b/cart-pole-scalability-results.png index 1390ae47..255f8985 100644 Binary files a/cart-pole-scalability-results.png and b/cart-pole-scalability-results.png differ diff --git a/flywheel-scalability-results-sleipnir.csv b/flywheel-scalability-results-sleipnir.csv index d55a2669..5722f2f9 100644 --- a/flywheel-scalability-results-sleipnir.csv +++ b/flywheel-scalability-results-sleipnir.csv @@ -1,15 +1,15 @@ Samples,Setup time (ms),Solve time (ms) -100,0.353,1.989 -200,1.093,5.297 -300,0.178,6.186 -400,0.237,8.388 -500,0.296,10.557 -600,0.355,15.161 -700,0.423,15.865 -800,0.533,20.869 -900,0.504,19.665 -1000,0.545,23.423 -2000,1.096,50.603 -3000,1.644,87.56 -4000,2.279,113.649 -5000,2.838,156.224 +100,0.224,1.927 +200,0.082,6.263 +300,0.129,5.932 +400,0.168,8.273 +500,0.217,9.847 +600,0.256,14.863 +700,0.305,17.475 +800,0.404,18.325 +900,0.43,21.534 +1000,0.439,23.686 +2000,0.919,50.931 +3000,1.375,85.562 +4000,1.814,116.722 +5000,2.329,157.785 diff --git a/flywheel-scalability-results.png b/flywheel-scalability-results.png index 0e68ef02..6619ed1a 100644 Binary files a/flywheel-scalability-results.png and b/flywheel-scalability-results.png differ diff --git a/include/sleipnir/autodiff/ExpressionGraph.hpp b/include/sleipnir/autodiff/ExpressionGraph.hpp index 51d325ad..06763665 100644 --- a/include/sleipnir/autodiff/ExpressionGraph.hpp +++ b/include/sleipnir/autodiff/ExpressionGraph.hpp @@ -36,21 +36,17 @@ class ExpressionGraph { // // https://en.wikipedia.org/wiki/Breadth-first_search - // BFS list sorted from parent to child. small_vector stack; - stack.emplace_back(root.expr.Get()); - // Initialize the number of instances of each node in the tree // (Expression::duplications) + stack.emplace_back(root.expr.Get()); while (!stack.empty()) { auto node = stack.back(); stack.pop_back(); for (auto& arg : node->args) { - // Only continue if the node is not a constant and hasn't already been - // explored. - if (arg != nullptr && arg->Type() != ExpressionType::kConstant) { + if (arg != nullptr) { // If this is the first instance of the node encountered (it hasn't // been explored yet), add it to stack so it's recursed upon if (arg->duplications == 0) { @@ -61,13 +57,12 @@ class ExpressionGraph { } } + // Generate BFS lists sorted from parent to child stack.emplace_back(root.expr.Get()); - while (!stack.empty()) { auto node = stack.back(); stack.pop_back(); - // BFS lists sorted from parent to child. m_rowList.emplace_back(node->row); m_adjointList.emplace_back(node); if (node->args[0] != nullptr) { @@ -77,9 +72,7 @@ class ExpressionGraph { } for (auto& arg : node->args) { - // Only add node if it's not a constant and doesn't already exist in the - // tape. - if (arg != nullptr && arg->Type() != ExpressionType::kConstant) { + if (arg != nullptr) { // Once the number of node visitations equals the number of // duplications (the counter hits zero), add it to the stack. Note // that this means the node is only enqueued once. @@ -122,11 +115,13 @@ class ExpressionGraph { // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation // for background on reverse accumulation automatic differentiation. - // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1. - if (m_adjointList.size() > 0) { - m_adjointList[0]->adjointExpr = MakeExpressionPtr(1.0); + if (m_adjointList.empty()) { + return VariableMatrix(wrt.size(), 1); } + // Set root node's adjoint to 1 since df/df is 1 + m_adjointList[0]->adjointExpr = MakeExpressionPtr(1.0); + // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y // multiplied by dy/dx. If there are multiple "paths" from the root node to // variable; the variable's adjoint is the sum of each path's adjoint @@ -145,6 +140,7 @@ class ExpressionGraph { } } + // Move gradient tree to return value VariableMatrix grad(VariableMatrix::empty, wrt.size(), 1); for (int row = 0; row < grad.Rows(); ++row) { grad(row) = Variable{std::move(wrt(row).expr->adjointExpr)}; @@ -154,11 +150,6 @@ class ExpressionGraph { // parent expressions. This ensures all expressions are returned to the free // list. for (auto& node : m_adjointList) { - for (auto& arg : node->args) { - if (arg != nullptr) { - arg->adjointExpr = nullptr; - } - } node->adjointExpr = nullptr; } @@ -166,19 +157,23 @@ class ExpressionGraph { } /** - * Updates the adjoints in the expression graph, effectively computing the - * gradient. + * Updates the adjoints in the expression graph (computes the gradient) then + * appends the adjoints of wrt to the sparse matrix triplets via a callback. * * @param func A function that takes two arguments: an int for the gradient * row, and a double for the adjoint (gradient value). */ - void ComputeAdjoints(function_ref func) { - // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1. - m_adjointList[0]->adjoint = 1.0; - for (auto& node : m_adjointList | std::views::drop(1)) { - node->adjoint = 0.0; + void AppendAdjointTriplets(function_ref func) { + // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation + // for background on reverse accumulation automatic differentiation. + + if (m_adjointList.empty()) { + return; } + // Set root node's adjoint to 1 since df/df is 1 + m_adjointList[0]->adjoint = 1.0; + // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y // multiplied by dy/dx. If there are multiple "paths" from the root node to // variable; the variable's adjoint is the sum of each path's adjoint @@ -200,12 +195,17 @@ class ExpressionGraph { } } - // If variable is a leaf node, assign its adjoint to the gradient. + // Append adjoints of wrt to sparse matrix triplets int row = m_rowList[col]; if (row != -1) { func(row, node->adjoint); } } + + // Zero adjoints for next run + for (auto& node : m_adjointList) { + node->adjoint = 0.0; + } } private: diff --git a/include/sleipnir/autodiff/Hessian.hpp b/include/sleipnir/autodiff/Hessian.hpp index 0077d1f5..47791fec 100644 --- a/include/sleipnir/autodiff/Hessian.hpp +++ b/include/sleipnir/autodiff/Hessian.hpp @@ -2,7 +2,6 @@ #pragma once -#include #include #include "sleipnir/autodiff/ExpressionGraph.hpp" diff --git a/include/sleipnir/autodiff/Jacobian.hpp b/include/sleipnir/autodiff/Jacobian.hpp index 1c7034bd..17407584 100644 --- a/include/sleipnir/autodiff/Jacobian.hpp +++ b/include/sleipnir/autodiff/Jacobian.hpp @@ -48,7 +48,7 @@ class SLEIPNIR_DLLEXPORT Jacobian { // If the row is linear, compute its gradient once here and cache its // triplets. Constant rows are ignored because their gradients have no // nonzero triplets. - m_graphs[row].ComputeAdjoints([&](int col, double adjoint) { + m_graphs[row].AppendAdjointTriplets([&](int col, double adjoint) { m_cachedTriplets.emplace_back(row, col, adjoint); }); } else if (m_variables(row).Type() > ExpressionType::kLinear) { @@ -76,17 +76,16 @@ class SLEIPNIR_DLLEXPORT Jacobian { * them. */ VariableMatrix Get() const { - VariableMatrix result{m_variables.Rows(), m_wrt.Rows()}; + VariableMatrix result{VariableMatrix::empty, m_variables.Rows(), + m_wrt.Rows()}; for (int row = 0; row < m_variables.Rows(); ++row) { - for (auto& node : m_wrt) { - node.expr->adjointExpr = nullptr; - } - auto grad = m_graphs[row].GenerateGradientTree(m_wrt); for (int col = 0; col < m_wrt.Rows(); ++col) { if (grad(col).expr != nullptr) { result(row, col) = std::move(grad(col)); + } else { + result(row, col) = Variable{0.0}; } } } @@ -115,12 +114,18 @@ class SLEIPNIR_DLLEXPORT Jacobian { // Compute each nonlinear row of the Jacobian for (int row : m_nonlinearRows) { - m_graphs[row].ComputeAdjoints([&](int col, double adjoint) { + m_graphs[row].AppendAdjointTriplets([&](int col, double adjoint) { triplets.emplace_back(row, col, adjoint); }); } - m_J.setFromTriplets(triplets.begin(), triplets.end()); + if (triplets.size() > 0) { + m_J.setFromTriplets(triplets.begin(), triplets.end()); + } else { + // setFromTriplets() is a no-op on empty triplets, so explicitly zero out + // the storage + m_J.setZero(); + } m_profiler.StopSolve(); diff --git a/jormungandr/cpp/Docstrings.hpp b/jormungandr/cpp/Docstrings.hpp index b0f2af20..58df89d5 100644 --- a/jormungandr/cpp/Docstrings.hpp +++ b/jormungandr/cpp/Docstrings.hpp @@ -2184,9 +2184,10 @@ expression's computational graph in a way that skips duplicates.)doc"; static const char *__doc_sleipnir_detail_ExpressionGraph_2 = R"doc()doc"; -static const char *__doc_sleipnir_detail_ExpressionGraph_ComputeAdjoints = -R"doc(Updates the adjoints in the expression graph, effectively computing -the gradient. +static const char *__doc_sleipnir_detail_ExpressionGraph_AppendAdjointTriplets = +R"doc(Updates the adjoints in the expression graph (computes the gradient) +then appends the adjoints of wrt to the sparse matrix triplets via a +callback. Parameter ``func``: A function that takes two arguments: an int for the gradient row,