From 2185ae3f75a53b3beb5074781d935865eb589bbd Mon Sep 17 00:00:00 2001 From: Tyler Veness Date: Thu, 23 Jan 2025 11:23:04 -0800 Subject: [PATCH] Fix crash on free variables (#713) --- .../autodiff/AdjointExpressionGraph.hpp | 5 +++++ include/sleipnir/autodiff/Jacobian.hpp | 4 ++++ jormungandr/cpp/Docstrings.hpp | 5 +++++ .../test/optimization/linear_problem_test.py | 20 +++++++++++++++++++ test/src/optimization/LinearProblemTest.cpp | 20 +++++++++++++++++++ 5 files changed, 54 insertions(+) diff --git a/include/sleipnir/autodiff/AdjointExpressionGraph.hpp b/include/sleipnir/autodiff/AdjointExpressionGraph.hpp index 16c7b853..b3646d96 100644 --- a/include/sleipnir/autodiff/AdjointExpressionGraph.hpp +++ b/include/sleipnir/autodiff/AdjointExpressionGraph.hpp @@ -107,6 +107,11 @@ class AdjointExpressionGraph { /** * Returns the variable's gradient tree. * + * This function lazily allocates variables, so elements of the returned + * VariableMatrix will be empty if the corresponding element of wrt had no + * adjoint. Ensure Variable::expr isn't nullptr before calling member + * functions. + * * @param wrt Variables with respect to which to compute the gradient. */ VariableMatrix GenerateGradientTree(const VariableMatrix& wrt) const { diff --git a/include/sleipnir/autodiff/Jacobian.hpp b/include/sleipnir/autodiff/Jacobian.hpp index de459b0b..8c9febe8 100644 --- a/include/sleipnir/autodiff/Jacobian.hpp +++ b/include/sleipnir/autodiff/Jacobian.hpp @@ -45,6 +45,10 @@ class SLEIPNIR_DLLEXPORT Jacobian { } for (int row = 0; row < m_variables.Rows(); ++row) { + if (m_variables(row).expr == nullptr) { + continue; + } + if (m_variables(row).Type() == ExpressionType::kLinear) { // If the row is linear, compute its gradient once here and cache its // triplets. Constant rows are ignored because their gradients have no diff --git a/jormungandr/cpp/Docstrings.hpp b/jormungandr/cpp/Docstrings.hpp index 6f1ea212..eb6236e6 100644 --- a/jormungandr/cpp/Docstrings.hpp +++ b/jormungandr/cpp/Docstrings.hpp @@ -1989,6 +1989,11 @@ Parameter ``row``: static const char *__doc_sleipnir_detail_AdjointExpressionGraph_GenerateGradientTree = R"doc(Returns the variable's gradient tree. +This function lazily allocates variables, so elements of the returned +VariableMatrix will be empty if the corresponding element of wrt had +no adjoint. Ensure Variable::expr isn't nullptr before calling member +functions. + Parameter ``wrt``: Variables with respect to which to compute the gradient.)doc"; diff --git a/jormungandr/test/optimization/linear_problem_test.py b/jormungandr/test/optimization/linear_problem_test.py index 15873416..6424f6f1 100644 --- a/jormungandr/test/optimization/linear_problem_test.py +++ b/jormungandr/test/optimization/linear_problem_test.py @@ -30,3 +30,23 @@ def test_maximize(): assert x.value() == pytest.approx(375.0, abs=1e-6) assert y.value() == pytest.approx(250.0, abs=1e-6) + + +def test_free_variable(): + problem = OptimizationProblem() + + x = problem.decision_variable(2) + x[0].set_value(1.0) + x[1].set_value(2.0) + + problem.subject_to(x[0] == 0) + + status = problem.solve(diagnostics=True) + + assert status.cost_function_type == ExpressionType.NONE + assert status.equality_constraint_type == ExpressionType.LINEAR + assert status.inequality_constraint_type == ExpressionType.NONE + assert status.exit_condition == SolverExitCondition.SUCCESS + + assert x[0].value() == pytest.approx(0.0, abs=1e-6) + assert x[1].value() == pytest.approx(2.0, abs=1e-6) diff --git a/test/src/optimization/LinearProblemTest.cpp b/test/src/optimization/LinearProblemTest.cpp index 47869f46..6a55bd74 100644 --- a/test/src/optimization/LinearProblemTest.cpp +++ b/test/src/optimization/LinearProblemTest.cpp @@ -33,3 +33,23 @@ TEST_CASE("LinearProblem - Maximize", "[LinearProblem]") { CHECK(x.Value() == Catch::Approx(375.0).margin(1e-6)); CHECK(y.Value() == Catch::Approx(250.0).margin(1e-6)); } + +TEST_CASE("LinearProblem - Free variable", "[LinearProblem]") { + sleipnir::OptimizationProblem problem; + + auto x = problem.DecisionVariable(2); + x(0).SetValue(1.0); + x(1).SetValue(2.0); + + problem.SubjectTo(x(0) == 0); + + auto status = problem.Solve({.diagnostics = true}); + + CHECK(status.costFunctionType == sleipnir::ExpressionType::kNone); + CHECK(status.equalityConstraintType == sleipnir::ExpressionType::kLinear); + CHECK(status.inequalityConstraintType == sleipnir::ExpressionType::kNone); + CHECK(status.exitCondition == sleipnir::SolverExitCondition::kSuccess); + + CHECK(x(0).Value() == Catch::Approx(0.0).margin(1e-6)); + CHECK(x(1).Value() == Catch::Approx(2.0).margin(1e-6)); +}