Skip to content

Commit

Permalink
Add callback overload that returns void
Browse files Browse the repository at this point in the history
Also add a missing SolverExitCondition test whose inputs were discovered
while modifying the "user requested stop" test.
  • Loading branch information
calcmogul committed Dec 14, 2023
1 parent 627340b commit e449957
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 14 deletions.
29 changes: 28 additions & 1 deletion include/sleipnir/optimization/OptimizationProblem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

#pragma once

#include <concepts>
#include <fstream>
#include <functional>
#include <optional>
#include <type_traits>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -331,10 +333,35 @@ class SLEIPNIR_DLLEXPORT OptimizationProblem {
/**
* Sets a callback to be called at each solver iteration.
*
* The callback for this overload should return void.
*
* @param callback The callback.
*/
template <typename F>
requires std::invocable<F, const SolverIterationInfo&> &&
std::same_as<std::result_of_t<F(const SolverIterationInfo&)>, void>
void Callback(F&& callback) {
m_callback = [=, callback = std::forward<F>(callback)](
const SolverIterationInfo& info) {
callback(info);
return false;
};
}

/**
* Sets a callback to be called at each solver iteration.
*
* The callback for this overload should return bool.
*
* @param callback The callback. Returning true from the callback causes the
* solver to exit early with the solution it has so far.
*/
void Callback(std::function<bool(const SolverIterationInfo&)> callback);
template <typename F>
requires std::invocable<F, const SolverIterationInfo&> &&
std::same_as<std::result_of_t<F(const SolverIterationInfo&)>, bool>
void Callback(F&& callback) {
m_callback = std::forward<F>(callback);
}

private:
// GCC incorrectly applies C++14 rules for const static data members, so an
Expand Down
6 changes: 5 additions & 1 deletion jormungandr/cpp/optimization/BindOptimizationProblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ void BindOptimizationProblem(py::module_& optimization) {

return self.Solve(config);
});
cls.def("callback", &OptimizationProblem::Callback);
cls.def("callback",
[](OptimizationProblem& self,
std::function<bool(const SolverIterationInfo&)> callback) {
self.Callback(std::move(callback));
});
}

} // namespace sleipnir
35 changes: 32 additions & 3 deletions jormungandr/test/optimization/solver_exit_condition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,28 @@ def test_callback_requested_stop():
problem = OptimizationProblem()

x = problem.decision_variable()
problem.minimize(x)
problem.minimize(x * x)

problem.callback(lambda info: True)
problem.callback(lambda info: None)
status = problem.solve(diagnostics=True)

assert status.cost_function_type == ExpressionType.QUADRATIC
assert status.equality_constraint_type == ExpressionType.NONE
assert status.inequality_constraint_type == ExpressionType.NONE
assert status.exit_condition == SolverExitCondition.SUCCESS

problem.callback(lambda info: False)
status = problem.solve(diagnostics=True)

assert status.cost_function_type == ExpressionType.LINEAR
assert status.cost_function_type == ExpressionType.QUADRATIC
assert status.equality_constraint_type == ExpressionType.NONE
assert status.inequality_constraint_type == ExpressionType.NONE
assert status.exit_condition == SolverExitCondition.SUCCESS

problem.callback(lambda info: True)
status = problem.solve(diagnostics=True)

assert status.cost_function_type == ExpressionType.QUADRATIC
assert status.equality_constraint_type == ExpressionType.NONE
assert status.inequality_constraint_type == ExpressionType.NONE
assert status.exit_condition == SolverExitCondition.CALLBACK_REQUESTED_STOP
Expand Down Expand Up @@ -78,6 +93,20 @@ def test_locally_infeasible_inequality_constraints():
assert status.exit_condition == SolverExitCondition.LOCALLY_INFEASIBLE


def test_diverging_iterates():
problem = OptimizationProblem()

x = problem.decision_variable()
problem.minimize(x)

status = problem.solve(diagnostics=True)

assert status.cost_function_type == ExpressionType.LINEAR
assert status.equality_constraint_type == ExpressionType.NONE
assert status.inequality_constraint_type == ExpressionType.NONE
assert status.exit_condition == SolverExitCondition.DIVERGING_ITERATES


def test_max_iterations_exceeded():
problem = OptimizationProblem()

Expand Down
5 changes: 0 additions & 5 deletions src/optimization/OptimizationProblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,6 @@ SolverStatus OptimizationProblem::Solve(const SolverConfig& config) {
return status;
}

void OptimizationProblem::Callback(
std::function<bool(const SolverIterationInfo&)> callback) {
m_callback = callback;
}

Eigen::VectorXd OptimizationProblem::InteriorPoint(
const Eigen::Ref<const Eigen::VectorXd>& initialGuess,
SolverStatus* status) {
Expand Down
41 changes: 37 additions & 4 deletions test/src/optimization/SolverExitConditionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,31 @@ TEST(SolverExitConditionTest, CallbackRequestedStop) {
sleipnir::OptimizationProblem problem;

auto x = problem.DecisionVariable();
problem.Minimize(x);

problem.Callback([](const sleipnir::SolverIterationInfo&) { return true; });
problem.Minimize(x * x);

problem.Callback([](const sleipnir::SolverIterationInfo&) {});
auto status =
problem.Solve({.diagnostics = CmdlineArgPresent(kEnableDiagnostics)});

EXPECT_EQ(sleipnir::ExpressionType::kLinear, status.costFunctionType);
EXPECT_EQ(sleipnir::ExpressionType::kQuadratic, status.costFunctionType);
EXPECT_EQ(sleipnir::ExpressionType::kNone, status.equalityConstraintType);
EXPECT_EQ(sleipnir::ExpressionType::kNone, status.inequalityConstraintType);
EXPECT_EQ(sleipnir::SolverExitCondition::kSuccess, status.exitCondition);

problem.Callback([](const sleipnir::SolverIterationInfo&) { return false; });
status =
problem.Solve({.diagnostics = CmdlineArgPresent(kEnableDiagnostics)});

EXPECT_EQ(sleipnir::ExpressionType::kQuadratic, status.costFunctionType);
EXPECT_EQ(sleipnir::ExpressionType::kNone, status.equalityConstraintType);
EXPECT_EQ(sleipnir::ExpressionType::kNone, status.inequalityConstraintType);
EXPECT_EQ(sleipnir::SolverExitCondition::kSuccess, status.exitCondition);

problem.Callback([](const sleipnir::SolverIterationInfo&) { return true; });
status =
problem.Solve({.diagnostics = CmdlineArgPresent(kEnableDiagnostics)});

EXPECT_EQ(sleipnir::ExpressionType::kQuadratic, status.costFunctionType);
EXPECT_EQ(sleipnir::ExpressionType::kNone, status.equalityConstraintType);
EXPECT_EQ(sleipnir::ExpressionType::kNone, status.inequalityConstraintType);
EXPECT_EQ(sleipnir::SolverExitCondition::kCallbackRequestedStop,
Expand Down Expand Up @@ -93,6 +110,22 @@ TEST(SolverExitConditionTest, LocallyInfeasible) {
}
}

TEST(SolverExitConditionTest, DivergingIterates) {
sleipnir::OptimizationProblem problem;

auto x = problem.DecisionVariable();
problem.Minimize(x);

auto status =
problem.Solve({.diagnostics = CmdlineArgPresent(kEnableDiagnostics)});

EXPECT_EQ(sleipnir::ExpressionType::kLinear, status.costFunctionType);
EXPECT_EQ(sleipnir::ExpressionType::kNone, status.equalityConstraintType);
EXPECT_EQ(sleipnir::ExpressionType::kNone, status.inequalityConstraintType);
EXPECT_EQ(sleipnir::SolverExitCondition::kDivergingIterates,
status.exitCondition);
}

TEST(SolverExitConditionTest, MaxIterationsExceeded) {
sleipnir::OptimizationProblem problem;

Expand Down

0 comments on commit e449957

Please sign in to comment.