Skip to content

Commit

Permalink
Adding __init__ arguments to documentation for descents (#109)
Browse files Browse the repository at this point in the history
* fix typo

* fix typo

* fix typo

* fix typo

* document fix for rewrite functions, such that closure can be checked in implicit_jvp

* add __init__ method to documentation for descents

* correct a typo and add a sentence

* fix typos

* adding a wrapper such that converted fields now inherit the correct type

---------

Co-authored-by: Johanna Haffner <johanna.haffner@bsse.ethz.ch>
  • Loading branch information
johannahaffner and Johanna Haffner authored Feb 9, 2025
1 parent 61d68b1 commit 0fcce03
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 17 deletions.
12 changes: 6 additions & 6 deletions docs/api/searches/descents.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,39 @@
::: optimistix.SteepestDescent
selection:
members:
false
- __init__

---

::: optimistix.NonlinearCGDescent
selection:
members:
false
- __init__

---

::: optimistix.NewtonDescent
selection:
members:
false
- __init__

---

::: optimistix.DampedNewtonDescent
selection:
members:
false
- __init__

---

::: optimistix.IndirectDampedNewtonDescent
selection:
members:
false
- __init__

---

::: optimistix.DoglegDescent
selection:
members:
false
- __init__
2 changes: 1 addition & 1 deletion optimistix/_fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def fixed_point(
an error. If `False` then the returned solution object will have a `result`
field indicating whether any failures occured. (See [`optimistix.Solution`][].)
Keyword only argument.
- `tags`: Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing the
- `tags`: Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing
any structure of the Jacobian of `y -> fn(y, args) - y` with respect to y. (That
is, the structure of the matrix `dfn/dy - I`.) Used with
[`optimistix.ImplicitAdjoint`][] to implement the implicit function theorem as
Expand Down
2 changes: 1 addition & 1 deletion optimistix/_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def least_squares(
an error. If `False` then the returned solution object will have a `result`
field indicating whether any failures occured. (See [`optimistix.Solution`][].)
Keyword only argument.
- `tags`: Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing the
- `tags`: Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing
any structure of the Hessian of `y -> sum(fn(y, args)**2)` with respect to y.
Used with [`optimistix.ImplicitAdjoint`][] to implement the implicit function
theorem as efficiently as possible. Keyword only argument.
Expand Down
5 changes: 3 additions & 2 deletions optimistix/_minimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def min_no_aux(x):
return jax.grad(min_no_aux)(minimum)


# Keep `optx.implicit_jvp` is happy.
# Keep `optx.implicit_jvp` happy.
# https://github.com/patrick-kidger/optimistix/issues/102#event-15786001854
if _rewrite_fn.__globals__["__name__"].startswith("jaxtyping"):
_rewrite_fn = _rewrite_fn.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]

Expand Down Expand Up @@ -75,7 +76,7 @@ def minimise(
an error. If `False` then the returned solution object will have a `result`
field indicating whether any failures occured. (See [`optimistix.Solution`][].)
Keyword only argument.
- `tags`: Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing the
- `tags`: Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing
any structure of the Hessian of `fn` with respect to `y`. Used with
[`optimistix.ImplicitAdjoint`][] to implement the implicit function theorem as
efficiently as possible. Keyword only argument.
Expand Down
2 changes: 1 addition & 1 deletion optimistix/_root_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def root_find(
an error. If `False` then the returned solution object will have a `result`
field indicating whether any failures occured. (See [`optimistix.Solution`][].)
Keyword only argument.
- `tags`: Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing the
- `tags`: Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing
any structure of the Jacobian of `fn` with respect to `y`. Used with some
solvers (e.g. [`optimistix.Newton`][]), and with some adjoint methods (e.g.
[`optimistix.ImplicitAdjoint`][]) to improve the efficiency of linear solves.
Expand Down
6 changes: 5 additions & 1 deletion optimistix/_solver/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
from .._solution import RESULTS


def _typed_asarray(x: ScalarLike) -> Array:
return jnp.asarray(x)


class LearningRate(AbstractSearch[Y, FunctionInfo, FunctionInfo, None], strict=True):
"""Move downhill by taking a step of the fixed size `learning_rate`."""

learning_rate: ScalarLike = eqx.field(converter=jnp.asarray)
learning_rate: ScalarLike = eqx.field(converter=_typed_asarray)

def init(self, y: Y, f_info_struct: FunctionInfo) -> None:
return None
Expand Down
11 changes: 6 additions & 5 deletions optimistix/_solver/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,12 @@ def step(
NonlinearCGDescent.__init__.__doc__ = """**Arguments:**
- `method`: A callable `method(vector, vector_prev, diff_prev)` describing how to
calculate the beta parameter of nonlinear CG. Each of these inputs has the meaning
described above. The "beta parameter" is the sake as can be described as e.g. the
β_n value
[on Wikipedia](https://en.wikipedia.org/wiki/Nonlinear_conjugate_gradient_method).
In practice Optimistix includes four built-in methods:
calculate the beta parameter of nonlinear CG. Nonlinear CG uses the previous search
direction, scaled by beta, and subtracts the gradient to find the next search
direction. This parameter, in the nonlinear case, is the same as the parameter β_n
described e.g. [on Wikipedia](https://en.wikipedia.org/wiki/Nonlinear_conjugate_gradient_method)
for the linear case.
Defaults to `polak_ribiere`. Optimistix includes four built-in methods:
[`optimistix.polak_ribiere`][], [`optimistix.fletcher_reeves`][],
[`optimistix.hestenes_stiefel`][], and [`optimistix.dai_yuan`][].
"""
Expand Down

0 comments on commit 0fcce03

Please sign in to comment.