Skip to content

Commit

Permalink
Fix crash on free variables
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul committed Jan 23, 2025
1 parent a363e14 commit 201c8e3
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/sleipnir/autodiff/AdjointExpressionGraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class AdjointExpressionGraph {
/**
* Returns the variable's gradient tree.
*
* This function lazily allocates variables, so elements of the returned
* VariableMatrix will be empty if the corresponding element of wrt had no
* adjoint. Ensure Variable::expr isn't nullptr before calling member
* functions.
*
* @param wrt Variables with respect to which to compute the gradient.
*/
VariableMatrix GenerateGradientTree(const VariableMatrix& wrt) const {
Expand Down
4 changes: 4 additions & 0 deletions include/sleipnir/autodiff/Jacobian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class SLEIPNIR_DLLEXPORT Jacobian {
}

for (int row = 0; row < m_variables.Rows(); ++row) {
if (m_variables(row).expr == nullptr) {
continue;
}

if (m_variables(row).Type() == ExpressionType::kLinear) {
// If the row is linear, compute its gradient once here and cache its
// triplets. Constant rows are ignored because their gradients have no
Expand Down
5 changes: 5 additions & 0 deletions jormungandr/cpp/Docstrings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,11 @@ Parameter ``row``:
static const char *__doc_sleipnir_detail_AdjointExpressionGraph_GenerateGradientTree =
R"doc(Returns the variable's gradient tree.
This function lazily allocates variables, so elements of the returned
VariableMatrix will be empty if the corresponding element of wrt had
no adjoint. Ensure Variable::expr isn't nullptr before calling member
functions.
Parameter ``wrt``:
Variables with respect to which to compute the gradient.)doc";

Expand Down
20 changes: 20 additions & 0 deletions jormungandr/test/optimization/linear_problem_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,23 @@ def test_maximize():

assert x.value() == pytest.approx(375.0, abs=1e-6)
assert y.value() == pytest.approx(250.0, abs=1e-6)


def test_free_variable():
problem = OptimizationProblem()

x = problem.decision_variable(2)
x[0].set_value(1.0)
x[1].set_value(2.0)

problem.subject_to(x[0] == 0)

status = problem.solve(diagnostics=True)

assert status.cost_function_type == ExpressionType.NONE
assert status.equality_constraint_type == ExpressionType.LINEAR
assert status.inequality_constraint_type == ExpressionType.NONE
assert status.exit_condition == SolverExitCondition.SUCCESS

assert x[0].value() == pytest.approx(0.0, abs=1e-6)
assert x[1].value() == pytest.approx(2.0, abs=1e-6)
20 changes: 20 additions & 0 deletions test/src/optimization/LinearProblemTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,23 @@ TEST_CASE("LinearProblem - Maximize", "[LinearProblem]") {
CHECK(x.Value() == Catch::Approx(375.0).margin(1e-6));
CHECK(y.Value() == Catch::Approx(250.0).margin(1e-6));
}

TEST_CASE("LinearProblem - Free variable", "[LinearProblem]") {
sleipnir::OptimizationProblem problem;

auto x = problem.DecisionVariable(2);
x(0).SetValue(1.0);
x(1).SetValue(2.0);

problem.SubjectTo(x(0) == 0);

auto status = problem.Solve({.diagnostics = true});

CHECK(status.costFunctionType == sleipnir::ExpressionType::kNone);
CHECK(status.equalityConstraintType == sleipnir::ExpressionType::kLinear);
CHECK(status.inequalityConstraintType == sleipnir::ExpressionType::kNone);
CHECK(status.exitCondition == sleipnir::SolverExitCondition::kSuccess);

CHECK(x(0).Value() == Catch::Approx(0.0).margin(1e-6));
CHECK(x(1).Value() == Catch::Approx(2.0).margin(1e-6));
}

0 comments on commit 201c8e3

Please sign in to comment.