Skip to content

Commit

Permalink
prepare new version
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Feb 18, 2023
1 parent ebef854 commit 605a9c6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
autoclass_content = "both"
autodoc_member_order = "groupwise"
autodoc_member_order = "bysource"
latex_use_parts = False


Expand Down
8 changes: 4 additions & 4 deletions grgrjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import jax
import jax.numpy as jnp

from .helpers import jvp_vmap, vjp_vmap, val_and_jacfwd, val_and_jacrev as val_and_jacrev
from .newton import callback_func, newton_jax, newton_jax_jit
from .helpers import *
from .newton import *

__version__ = '0.1.1'
__all__ = ["newton_jax", "newton_jax_jit", "callback_func", "jax_print",
__all__ = ["newton_jax_jit", "newton_jax", "callback_func", "jax_print",
"jvp_vmap", "vjp_vmap", "val_and_jacfwd", "val_and_jacrev", "amax"]


def jax_print(w):
"""Print in jax compiled functions. Wrapper around `jax.experimental.host_callback.id_print`
"""Print in jax compiled functions. Wrapper around `jax.experimental.host_callback.id_print`.
"""
return jax.experimental.host_callback.id_print(w)

Expand Down
11 changes: 5 additions & 6 deletions grgrjax/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def callback_func(cnt, err, dampening=None, ltime=None, verbose=True):


@jax.jit
def newton_jax_jit(func, x_init, maxit=30, tol=1e-8, verbose=True):
def newton_jax_jit(func, init, maxit=30, tol=1e-8, verbose=True):
"""Newton method for root finding using automatic differentiation with jax and running in jitted jax.
...
Parameters
----------
func : callable
Function returning (y, jac) where f(x)=y=0 should be found and jac is the jacobian. Must be jittable with jax. Could e.g. be the output of jacfwd_and_val. The function must be a jax.
x_init : array
init : array
Initial values of x
maxit : int, optional
Maximum number of iterations
Expand All @@ -59,11 +59,11 @@ def newton_jax_jit(func, x_init, maxit=30, tol=1e-8, verbose=True):
res: (xopt, (fopt, jacopt), niter, success)
"""
(xi, eps, cnt), _ = jax.lax.while_loop(_newton_cond_func,
_newton_body_func, ((x_init, 1., 0), (func, verbose, maxit, tol)))
_newton_body_func, ((init, 1., 0), (func, verbose, maxit, tol)))
return xi, func(xi), cnt, eps > tol


def perform_checks_newton(res, eps, cnt, jac_is_nan, tol, rtol, maxit):
def _perform_checks_newton(res, eps, cnt, jac_is_nan, tol, rtol, maxit):

if jac_is_nan.any():
res['success'] = False
Expand All @@ -90,7 +90,6 @@ def perform_checks_newton(res, eps, cnt, jac_is_nan, tol, rtol, maxit):

def newton_jax(func, init, maxit=30, tol=1e-8, rtol=None, solver=None, verbose=True, verbose_jac=False):
"""Newton method for root finding using automatic differenciation with jax. The argument `func` must be jittable with jax.
...
Parameters
Expand Down Expand Up @@ -143,7 +142,7 @@ def newton_jax(func, init, maxit=30, tol=1e-8, rtol=None, solver=None, verbose=T
jac_is_nan = jnp.isnan(jacval.data).any() if isinstance(
jacval, ssp._arrays.csr_array) else jnp.isnan(jacval).any()
eps = jnp.abs(fval).max()
if perform_checks_newton(res, eps, cnt, jac_is_nan, tol, rtol, maxit):
if _perform_checks_newton(res, eps, cnt, jac_is_nan, tol, rtol, maxit):
break

# be informative
Expand Down

0 comments on commit 605a9c6

Please sign in to comment.