Skip to content

Commit

Permalink
Pass triplets to AppendAdjointTriplets instead of using lambda (Sleip…
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul authored Jan 20, 2025
1 parent c88b3fa commit 06bcae0
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 94 deletions.
10 changes: 5 additions & 5 deletions cart-pole-scalability-results-casadi.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Samples,Setup time (ms),Solve time (ms)
100,38.637,1826.95
150,63.45,3635.56
200,94.95,4734.67
250,131.586,7043.58
300,178.675,9238.07
100,46.984,1813.21
150,63.382,3544.96
200,94.538,4703.68
250,131.917,6964.4
300,176.799,9236
10 changes: 5 additions & 5 deletions cart-pole-scalability-results-sleipnir.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Samples,Setup time (ms),Solve time (ms)
100,1.515,411.42
150,0.999,1235.13
200,1.346,1967.54
250,1.713,2342.6
300,2.139,3613.15
100,1.679,433.736
150,1.008,995.212
200,1.384,1742.71
250,1.773,2327.23
300,2.302,2830.75
Binary file modified cart-pole-scalability-results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 14 additions & 14 deletions flywheel-scalability-results-casadi.csv
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
Samples,Setup time (ms),Solve time (ms)
100,2.057,21.727
200,3.646,36.476
300,5.402,59.129
400,7.351,78.306
500,8.968,102.043
600,10.933,128.042
700,12.609,146.622
800,14.693,169.145
900,16.269,189.67
1000,19.966,215.395
2000,36.407,553.81
3000,56.344,1018.4
4000,75.164,1546.43
5000,93.18,2225.34
100,2.079,46.592
200,3.617,40.134
300,5.434,56.029
400,7.271,76.859
500,8.855,98.327
600,10.575,128.926
700,12.636,146.844
800,16.671,160.866
900,18.494,188.714
1000,20.053,214.457
2000,38.314,561.117
3000,56.819,1011.85
4000,77.04,1519.76
5000,92.532,2185.81
28 changes: 14 additions & 14 deletions flywheel-scalability-results-sleipnir.csv
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
Samples,Setup time (ms),Solve time (ms)
100,0.224,1.927
200,0.082,6.263
300,0.129,5.932
400,0.168,8.273
500,0.217,9.847
600,0.256,14.863
700,0.305,17.475
800,0.404,18.325
900,0.43,21.534
1000,0.439,23.686
2000,0.919,50.931
3000,1.375,85.562
4000,1.814,116.722
5000,2.329,157.785
100,0.222,1.931
200,0.08,3.873
300,0.122,5.863
400,0.166,8.01
500,0.209,9.938
600,0.248,16.683
700,0.311,14.822
800,0.348,16.789
900,0.389,19.368
1000,0.435,23.399
2000,0.993,50.993
3000,1.326,83.74
4000,1.772,119.669
5000,2.474,158.64
Binary file modified flywheel-scalability-results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 2 additions & 3 deletions include/sleipnir/autodiff/Expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ struct Expression {
/// an expression tree.
uint32_t duplications = 0;

/// This expression's row in wrt for autodiff gradient, Jacobian, or Hessian.
/// This is -1 if the expression isn't in wrt.
int32_t row = -1;
/// This expression's column in a Jacobian, or -1 otherwise.
int32_t col = -1;

/// The adjoint of the expression node used during gradient expression tree
/// generation.
Expand Down
31 changes: 16 additions & 15 deletions include/sleipnir/autodiff/ExpressionGraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
#include <ranges>
#include <utility>

#include <Eigen/SparseCore>

#include "sleipnir/autodiff/Variable.hpp"
#include "sleipnir/autodiff/VariableMatrix.hpp"
#include "sleipnir/util/FunctionRef.hpp"
#include "sleipnir/util/small_vector.hpp"

namespace sleipnir::detail {
Expand Down Expand Up @@ -38,8 +39,8 @@ class ExpressionGraph {

small_vector<Expression*> stack;

// Initialize the number of instances of each node in the tree
// (Expression::duplications)
// Assign each node's number of instances in the tree to
// Expression::duplications
stack.emplace_back(root.expr.Get());
while (!stack.empty()) {
auto node = stack.back();
Expand All @@ -63,7 +64,7 @@ class ExpressionGraph {
auto node = stack.back();
stack.pop_back();

m_rowList.emplace_back(node->row);
m_colList.emplace_back(node->col);
m_adjointList.emplace_back(node);
if (node->args[0] != nullptr) {
// Constants (expressions with no arguments) are skipped because they
Expand Down Expand Up @@ -158,12 +159,13 @@ class ExpressionGraph {

/**
* Updates the adjoints in the expression graph (computes the gradient) then
* appends the adjoints of wrt to the sparse matrix triplets via a callback.
* appends the adjoints of wrt to the sparse matrix triplets.
*
* @param func A function that takes two arguments: an int for the gradient
* row, and a double for the adjoint (gradient value).
* @param triplets The sparse matrix triplets.
* @param row The row of wrt.
*/
void AppendAdjointTriplets(function_ref<void(int row, double adjoint)> func) {
void AppendAdjointTriplets(small_vector<Eigen::Triplet<double>>& triplets,
int row) const {
// Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
// for background on reverse accumulation automatic differentiation.

Expand All @@ -178,8 +180,8 @@ class ExpressionGraph {
// multiplied by dy/dx. If there are multiple "paths" from the root node to
// variable; the variable's adjoint is the sum of each path's adjoint
// contribution.
for (size_t col = 0; col < m_adjointList.size(); ++col) {
auto& node = m_adjointList[col];
for (size_t i = 0; i < m_adjointList.size(); ++i) {
auto& node = m_adjointList[i];
auto& lhs = node->args[0];
auto& rhs = node->args[1];

Expand All @@ -196,9 +198,8 @@ class ExpressionGraph {
}

// Append adjoints of wrt to sparse matrix triplets
int row = m_rowList[col];
if (row != -1) {
func(row, node->adjoint);
if (const int& col = m_colList[i]; col != -1 && node->adjoint != 0.0) {
triplets.emplace_back(row, col, node->adjoint);
}
}

Expand All @@ -209,8 +210,8 @@ class ExpressionGraph {
}

private:
// List that maps nodes to their respective row.
small_vector<int> m_rowList;
// List that maps nodes to their respective column
small_vector<int> m_colList;

// List for updating adjoints
small_vector<Expression*> m_adjointList;
Expand Down
18 changes: 8 additions & 10 deletions include/sleipnir/autodiff/Jacobian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ class SLEIPNIR_DLLEXPORT Jacobian {
: m_variables{std::move(variables)}, m_wrt{std::move(wrt)} {
m_profiler.StartSetup();

for (int row = 0; row < m_wrt.Rows(); ++row) {
m_wrt(row).expr->row = row;
// Initialize column each expression's adjoint occupies in the Jacobian
for (size_t col = 0; col < m_wrt.size(); ++col) {
m_wrt(col).expr->col = col;
}

for (auto& variable : m_variables) {
Expand All @@ -48,18 +49,17 @@ class SLEIPNIR_DLLEXPORT Jacobian {
// If the row is linear, compute its gradient once here and cache its
// triplets. Constant rows are ignored because their gradients have no
// nonzero triplets.
m_graphs[row].AppendAdjointTriplets([&](int col, double adjoint) {
m_cachedTriplets.emplace_back(row, col, adjoint);
});
m_graphs[row].AppendAdjointTriplets(m_cachedTriplets, row);
} else if (m_variables(row).Type() > ExpressionType::kLinear) {
// If the row is quadratic or nonlinear, add it to the list of nonlinear
// rows to be recomputed in Value().
m_nonlinearRows.emplace_back(row);
}
}

for (int row = 0; row < m_wrt.Rows(); ++row) {
m_wrt(row).expr->row = -1;
// Reset col to -1
for (auto& node : m_wrt) {
node.expr->col = -1;
}

if (m_nonlinearRows.empty()) {
Expand Down Expand Up @@ -114,9 +114,7 @@ class SLEIPNIR_DLLEXPORT Jacobian {

// Compute each nonlinear row of the Jacobian
for (int row : m_nonlinearRows) {
m_graphs[row].AppendAdjointTriplets([&](int col, double adjoint) {
triplets.emplace_back(row, col, adjoint);
});
m_graphs[row].AppendAdjointTriplets(triplets, row);
}

if (triplets.size() > 0) {
Expand Down
19 changes: 9 additions & 10 deletions jormungandr/cpp/Docstrings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2186,12 +2186,13 @@ static const char *__doc_sleipnir_detail_ExpressionGraph_2 = R"doc()doc";

static const char *__doc_sleipnir_detail_ExpressionGraph_AppendAdjointTriplets =
R"doc(Updates the adjoints in the expression graph (computes the gradient)
then appends the adjoints of wrt to the sparse matrix triplets via a
callback.
then appends the adjoints of wrt to the sparse matrix triplets.
Parameter ``func``:
A function that takes two arguments: an int for the gradient row,
and a double for the adjoint (gradient value).)doc";
Parameter ``triplets``:
The sparse matrix triplets.
Parameter ``row``:
The row of wrt.)doc";

static const char *__doc_sleipnir_detail_ExpressionGraph_ExpressionGraph =
R"doc(Generates the deduplicated computational graph for the given
Expand All @@ -2212,7 +2213,7 @@ values of their dependent nodes.)doc";

static const char *__doc_sleipnir_detail_ExpressionGraph_m_adjointList = R"doc()doc";

static const char *__doc_sleipnir_detail_ExpressionGraph_m_rowList = R"doc()doc";
static const char *__doc_sleipnir_detail_ExpressionGraph_m_colList = R"doc()doc";

static const char *__doc_sleipnir_detail_ExpressionGraph_m_valueList = R"doc()doc";

Expand Down Expand Up @@ -2316,16 +2317,14 @@ tree generation.)doc";

static const char *__doc_sleipnir_detail_Expression_args = R"doc(Expression arguments.)doc";

static const char *__doc_sleipnir_detail_Expression_col = R"doc(This expression's column in a Jacobian, or -1 otherwise.)doc";

static const char *__doc_sleipnir_detail_Expression_duplications =
R"doc(Tracks the number of instances of this expression yet to be
encountered in an expression tree.)doc";

static const char *__doc_sleipnir_detail_Expression_refCount = R"doc(Reference count for intrusive shared pointer.)doc";

static const char *__doc_sleipnir_detail_Expression_row =
R"doc(This expression's row in wrt for autodiff gradient, Jacobian, or
Hessian. This is -1 if the expression isn't in wrt.)doc";

static const char *__doc_sleipnir_detail_Expression_value = R"doc(The value of the expression node.)doc";

static const char *__doc_sleipnir_detail_HypotExpression = R"doc()doc";
Expand Down
14 changes: 3 additions & 11 deletions jormungandr/test/control/ocp_solver_cart_pole_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
import platform

import numpy as np
import pytest
Expand Down Expand Up @@ -77,16 +76,9 @@ def each(x: VariableMatrix, u: VariableMatrix):
assert status.equality_constraint_type == ExpressionType.NONLINEAR
assert status.inequality_constraint_type == ExpressionType.LINEAR

if platform.system() == "Darwin" and platform.machine() == "arm64":
# FIXME: Fails on macOS arm64 with "feasibility restoration failed"
assert (
status.exit_condition == SolverExitCondition.FEASIBILITY_RESTORATION_FAILED
)
return
else:
# FIXME: Fails on other platforms with "locally infeasible"
assert status.exit_condition == SolverExitCondition.LOCALLY_INFEASIBLE
return
# FIXME: Fails on other platforms with "locally infeasible"
assert status.exit_condition == SolverExitCondition.LOCALLY_INFEASIBLE
return

# Verify initial state
assert X.value(0, 0) == pytest.approx(x_initial[0, 0], abs=1e-8)
Expand Down
7 changes: 0 additions & 7 deletions test/src/control/OCPSolverTest_CartPole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,10 @@ TEST_CASE("OCPSolver - Cart-pole", "[OCPSolver]") {
CHECK(status.equalityConstraintType == sleipnir::ExpressionType::kNonlinear);
CHECK(status.inequalityConstraintType == sleipnir::ExpressionType::kLinear);

#if defined(__APPLE__) && defined(__aarch64__)
// FIXME: Fails on macOS arm64 with "feasibility restoration failed"
CHECK(status.exitCondition ==
sleipnir::SolverExitCondition::kFeasibilityRestorationFailed);
SKIP("Fails on macOS arm64 with \"feasibility restoration failed\"");
#else
// FIXME: Fails on other platforms with "locally infeasible"
CHECK(status.exitCondition ==
sleipnir::SolverExitCondition::kLocallyInfeasible);
SKIP("Fails with \"locally infeasible\"");
#endif

// Verify initial state
CHECK(X.Value(0, 0) == Catch::Approx(x_initial(0)).margin(1e-8));
Expand Down

0 comments on commit 06bcae0

Please sign in to comment.