Skip to content

Commit

Permalink
Use function_ref where possible (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul authored Apr 13, 2024
1 parent a546475 commit f5eb0da
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 20 deletions.
9 changes: 5 additions & 4 deletions benchmarks/scalability/Util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
#include <chrono>
#include <concepts>
#include <fstream>
#include <functional>
#include <span>
#include <string>
#include <string_view>

#include <casadi/casadi.hpp>
#include <fmt/core.h>
#include <sleipnir/optimization/OptimizationProblem.hpp>
#include <sleipnir/util/FunctionRef.hpp>

/**
* Converts std::chrono::duration to a number of milliseconds rounded to three
Expand All @@ -38,8 +38,9 @@ constexpr double ToMilliseconds(
* solves it.
*/
template <typename Problem>
void RunBenchmark(std::ofstream& results, std::function<Problem()> setup,
std::function<void(Problem&)> solve) {
void RunBenchmark(std::ofstream& results,
sleipnir::function_ref<Problem()> setup,
sleipnir::function_ref<void(Problem&)> solve) {
// Record setup time
auto setupStartTime = std::chrono::system_clock::now();
auto problem = setup();
Expand Down Expand Up @@ -85,7 +86,7 @@ template <typename Problem>
int RunBenchmarksAndLog(
std::string_view filename, bool diagnostics,
std::chrono::duration<double> T, std::span<int> sampleSizesToTest,
std::function<Problem(std::chrono::duration<double>, int)> setup) {
sleipnir::function_ref<Problem(std::chrono::duration<double>, int)> setup) {
std::ofstream results{std::string{filename}};
if (!results.is_open()) {
return 1;
Expand Down
10 changes: 5 additions & 5 deletions include/sleipnir/control/OCPSolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
#include <stdint.h>

#include <chrono>
#include <functional>
#include <utility>

#include "sleipnir/autodiff/VariableMatrix.hpp"
#include "sleipnir/optimization/OptimizationProblem.hpp"
#include "sleipnir/util/Assert.hpp"
#include "sleipnir/util/Concepts.hpp"
#include "sleipnir/util/FunctionRef.hpp"
#include "sleipnir/util/SymbolExports.hpp"

namespace sleipnir {
Expand All @@ -25,8 +25,8 @@ namespace sleipnir {
* - State transition: xₖ₊₁ = f(t, xₖ, uₖ, dt)
*/
using DynamicsFunction =
std::function<VariableMatrix(const Variable&, const VariableMatrix&,
const VariableMatrix&, const Variable&)>;
function_ref<VariableMatrix(const Variable&, const VariableMatrix&,
const VariableMatrix&, const Variable&)>;

/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(t, x, u) for dt.
Expand Down Expand Up @@ -212,8 +212,8 @@ class SLEIPNIR_DLLEXPORT OCPSolver : public OptimizationProblem {
* vector, u is the input vector, and dt is the timestep duration.
*/
void ForEachStep(
const std::function<void(const Variable&, const VariableMatrix&,
const VariableMatrix&, const Variable&)>&
const function_ref<void(const Variable&, const VariableMatrix&,
const VariableMatrix&, const Variable&)>
callback) {
Variable time = 0.0;

Expand Down
4 changes: 2 additions & 2 deletions include/sleipnir/optimization/solver/InteriorPoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#pragma once

#include <functional>
#include <span>

#include <Eigen/Core>
Expand All @@ -11,6 +10,7 @@
#include "sleipnir/optimization/SolverConfig.hpp"
#include "sleipnir/optimization/SolverIterationInfo.hpp"
#include "sleipnir/optimization/SolverStatus.hpp"
#include "sleipnir/util/FunctionRef.hpp"
#include "sleipnir/util/SymbolExports.hpp"

namespace sleipnir {
Expand Down Expand Up @@ -48,7 +48,7 @@ SLEIPNIR_DLLEXPORT void InteriorPoint(
std::span<Variable> decisionVariables,
std::span<Variable> equalityConstraints,
std::span<Variable> inequalityConstraints, Variable& f,
const std::function<bool(const SolverIterationInfo&)>& callback,
function_ref<bool(const SolverIterationInfo&)> callback,
const SolverConfig& config, bool feasibilityRestoration, Eigen::VectorXd& x,
Eigen::VectorXd& s, SolverStatus* status);

Expand Down
14 changes: 7 additions & 7 deletions src/optimization/solver/InteriorPoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@

namespace sleipnir {

void InteriorPoint(
std::span<Variable> decisionVariables,
std::span<Variable> equalityConstraints,
std::span<Variable> inequalityConstraints, Variable& f,
const std::function<bool(const SolverIterationInfo&)>& callback,
const SolverConfig& config, bool feasibilityRestoration, Eigen::VectorXd& x,
Eigen::VectorXd& s, SolverStatus* status) {
void InteriorPoint(std::span<Variable> decisionVariables,
std::span<Variable> equalityConstraints,
std::span<Variable> inequalityConstraints, Variable& f,
function_ref<bool(const SolverIterationInfo&)> callback,
const SolverConfig& config, bool feasibilityRestoration,
Eigen::VectorXd& x, Eigen::VectorXd& s,
SolverStatus* status) {
const auto solveStartTime = std::chrono::system_clock::now();

// Map decision variables and constraints to VariableMatrices for Lagrangian
Expand Down
4 changes: 2 additions & 2 deletions src/optimization/solver/util/FeasibilityRestoration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <algorithm>
#include <cmath>
#include <functional>
#include <iterator>
#include <span>
#include <vector>
Expand All @@ -17,6 +16,7 @@
#include "sleipnir/optimization/SolverIterationInfo.hpp"
#include "sleipnir/optimization/SolverStatus.hpp"
#include "sleipnir/optimization/solver/InteriorPoint.hpp"
#include "sleipnir/util/FunctionRef.hpp"

namespace sleipnir {

Expand All @@ -41,7 +41,7 @@ inline void FeasibilityRestoration(
std::span<Variable> decisionVariables,
std::span<Variable> equalityConstraints,
std::span<Variable> inequalityConstraints, Variable& f, double μ,
const std::function<bool(const SolverIterationInfo&)>& callback,
function_ref<bool(const SolverIterationInfo&)> callback,
const SolverConfig& config, Eigen::VectorXd& x, Eigen::VectorXd& s,
SolverStatus* status) {
// Feasibility restoration
Expand Down

0 comments on commit f5eb0da

Please sign in to comment.