Skip to content

Commit

Permalink
Add more higher-order function parameter names
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul committed Jul 20, 2024
1 parent 2b1b98e commit 1a439cd
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 25 deletions.
2 changes: 1 addition & 1 deletion include/sleipnir/autodiff/VariableBlock.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class VariableBlock {
* @param unaryOp The unary operator to use for the transform operation.
*/
std::remove_cv_t<Mat> CwiseTransform(
function_ref<Variable(const Variable&)> unaryOp) const {
function_ref<Variable(const Variable& x)> unaryOp) const {
std::remove_cv_t<Mat> result{Rows(), Cols()};

for (int row = 0; row < Rows(); ++row) {
Expand Down
4 changes: 2 additions & 2 deletions include/sleipnir/autodiff/VariableMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
* @param unaryOp The unary operator to use for the transform operation.
*/
VariableMatrix CwiseTransform(
function_ref<Variable(const Variable&)> unaryOp) const {
function_ref<Variable(const Variable& x)> unaryOp) const {
VariableMatrix result{Rows(), Cols()};

for (int row = 0; row < Rows(); ++row) {
Expand Down Expand Up @@ -896,7 +896,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
*/
SLEIPNIR_DLLEXPORT inline VariableMatrix CwiseReduce(
const VariableMatrix& lhs, const VariableMatrix& rhs,
function_ref<Variable(const Variable&, const Variable&)> binaryOp) {
function_ref<Variable(const Variable& x, const Variable& y)> binaryOp) {
Assert(lhs.Rows() == rhs.Rows());
Assert(lhs.Rows() == rhs.Rows());

Expand Down
3 changes: 1 addition & 2 deletions jormungandr/cpp/Binders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ void BindExpressionType(nb::enum_<ExpressionType>& e);

void BindVariable(nb::module_& autodiff, nb::class_<Variable>& cls);
void BindVariableMatrix(nb::module_& autodiff, nb::class_<VariableMatrix>& cls);
void BindVariableBlock(nb::module_& autodiff,
nb::class_<VariableBlock<VariableMatrix>>& cls);
void BindVariableBlock(nb::class_<VariableBlock<VariableMatrix>>& cls);

void BindGradient(nb::class_<Gradient>& cls);
void BindHessian(nb::class_<Hessian>& cls);
Expand Down
2 changes: 1 addition & 1 deletion jormungandr/cpp/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ NB_MODULE(_jormungandr, m) {

BindVariable(autodiff, variable);
BindVariableMatrix(autodiff, variable_matrix);
BindVariableBlock(autodiff, variable_block);
BindVariableBlock(variable_block);

// Implicit conversions
variable.def(nb::init_implicit<VariableMatrix>());
Expand Down
16 changes: 3 additions & 13 deletions jormungandr/cpp/autodiff/BindVariableBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ namespace nb = nanobind;

namespace sleipnir {

void BindVariableBlock(nb::module_& autodiff,
nb::class_<VariableBlock<VariableMatrix>>& cls) {
void BindVariableBlock(nb::class_<VariableBlock<VariableMatrix>>& cls) {
using namespace nb::literals;

// VariableBlock-VariableMatrix overloads
Expand Down Expand Up @@ -472,8 +471,8 @@ void BindVariableBlock(nb::module_& autodiff,
cls.def(
"cwise_transform",
[](const VariableBlock<VariableMatrix>& self,
const std::function<Variable(const Variable&)>& func) {
return self.CwiseTransform(func);
const std::function<Variable(const Variable& x)>& unaryOp) {
return self.CwiseTransform(unaryOp);
},
"func"_a, DOC(sleipnir, VariableBlock, CwiseTransform));
cls.def(nb::self == nb::self, "rhs"_a, DOC(sleipnir, operator, eq));
Expand Down Expand Up @@ -547,15 +546,6 @@ void BindVariableBlock(nb::module_& autodiff,
"value_iterator", self.begin(), self.end());
},
nb::keep_alive<0, 1>());

autodiff.def(
"cwise_reduce",
[](const VariableBlock<VariableMatrix>& lhs,
const VariableBlock<VariableMatrix>& rhs,
const std::function<Variable(const Variable&, const Variable&)> func) {
return CwiseReduce(lhs, rhs, func);
},
"lhs"_a, "rhs"_a, "func"_a, DOC(sleipnir, CwiseReduce));
} // NOLINT(readability/fn_size)

} // namespace sleipnir
9 changes: 4 additions & 5 deletions jormungandr/cpp/autodiff/BindVariableMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ void BindVariableMatrix(nb::module_& autodiff,
cls.def(
"cwise_transform",
[](const VariableMatrix& self,
const std::function<Variable(const Variable&)>& func) {
return self.CwiseTransform(func);
const std::function<Variable(const Variable& x)>& unaryOp) {
return self.CwiseTransform(unaryOp);
},
"func"_a, DOC(sleipnir, VariableMatrix, CwiseTransform));
cls.def_static("zero", &VariableMatrix::Zero, "rows"_a, "cols"_a,
Expand Down Expand Up @@ -653,9 +653,8 @@ void BindVariableMatrix(nb::module_& autodiff,
autodiff.def(
"cwise_reduce",
[](const VariableMatrix& lhs, const VariableMatrix& rhs,
const std::function<Variable(const Variable&, const Variable&)> func) {
return CwiseReduce(lhs, rhs, func);
},
const std::function<Variable(const Variable& x, const Variable& y)>&
binaryOp) { return CwiseReduce(lhs, rhs, binaryOp); },
"lhs"_a, "rhs"_a, "func"_a, DOC(sleipnir, CwiseReduce));

autodiff.def("block",
Expand Down
2 changes: 1 addition & 1 deletion jormungandr/cpp/optimization/BindOptimizationProblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Parameter ``spy``:
cls.def(
"callback",
[](OptimizationProblem& self,
std::function<bool(const SolverIterationInfo&)> callback) {
std::function<bool(const SolverIterationInfo& info)> callback) {
self.Callback(std::move(callback));
},
"callback"_a, DOC(sleipnir, OptimizationProblem, Callback, 2));
Expand Down

0 comments on commit 1a439cd

Please sign in to comment.