Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using OptaxMinimiser results in AttributeError #21

Closed
hkortier opened this issue Nov 1, 2023 · 5 comments
Closed

Using OptaxMinimiser results in AttributeError #21

hkortier opened this issue Nov 1, 2023 · 5 comments
Labels
question User queries

Comments

@hkortier
Copy link

hkortier commented Nov 1, 2023

Using the OptaxMinimiser as solver results in " '_Closure' object has no attribute 'init' " whereas the BFGS solver runs without errors.

The objective function uses a custom equinox pytree.

@patrick-kidger
Copy link
Owner

Can you provide a MWE?

@hkortier
Copy link
Author

sry very late but still relevant. A single shooting example:

import jax
import jax.numpy as jnp

import equinox as eqx  
import optimistix as optx
import optax
import diffrax as dfx  

from watermark import watermark

jax.config.update("jax_enable_x64", True)

def c_func(mach):
    return jnp.select([mach < 0.4, mach < 0.8, mach < 1.2], 
                      [0.1, .1 * (mach - 0.4) / 0.4 + 0.1, 0.25 * (mach - 0.8) / 0.4 + 0.25], default=.5)

class CannonODE(eqx.Module):
    c: float 
    g: float 

    def __call__(self, t, y, args):    
        v = y[1]
        T, = args
        speed = jnp.linalg.norm(v)

        mach = speed / 340.0
        c = c_func(mach)  
    
        dp = T * v
        dv = T * jnp.array([-c * v[0] * speed,
                        -c * v[1] * speed - self.g])

        return (dp, dv)
    
class CannonTrajectory(eqx.Module):
    ode: CannonODE

    def __init__(self, ode):
        self.ode = ode

    def __call__(self, parameter, saveat: dfx.SaveAt):
        QE, v0, T = parameter
        y0 = (jnp.array([0.0, 0.0]) , jnp.array([v0*jnp.cos(QE), v0*jnp.sin(QE)]))

        term = dfx.ODETerm(self.ode)
        stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
        solver = dfx.Tsit5()
        t0 = saveat.subs.ts[0]
        t1 = saveat.subs.ts[-1]
        dt0 = 0.01

        sol = dfx.diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        y0,
        args=(T,),
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        # support forward-mode autodiff, which is used by Levenberg--Marquardt
        adjoint=dfx.DirectAdjoint(),
        max_steps=1024
        )
        return sol

def residuals(parameter, args):
    traj, target = args
    saveat = dfx.SaveAt(ts=jnp.array([0., 1.]))
    pred_values = traj(parameter, saveat).ys[0][-1,:]
    return target - pred_values

def residuals_min(parameter, args):
    res = residuals(parameter, args)
    return jnp.sqrt(jnp.dot(res, res))

def main(target):
    v0 = 200.0
    QE0 = 0.01#jnp.pi/4
    T0 = 2.0

    ode = CannonODE(c=0.6, g=9.81)
    traj = CannonTrajectory(ode)

    init_parameter = jnp.array([QE0, v0, T0])

    solver = optx.OptaxMinimiser(optax.adabelief, rtol=1e-8, atol=1e-8)
    res = optx.minimise(residuals_min, solver, init_parameter, max_steps=128, throw=False, args=(traj, target))
    
    return res, traj, target

if __name__ == "__main__":
    print(watermark(packages="jax,jaxlib,optimistix,equinox,diffrax,optax"))
    target = jnp.array([100., 0.])
    res, traj, target = main(target)

output:

jax       : 0.4.20
jaxlib    : 0.4.14
optimistix: 0.0.5
equinox   : 0.11.2
diffrax   : 0.4.1
optax     : 0.1.7

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
.....
  File "/Users/hkortier/venvs/diffrax/lib/python3.10/site-packages/optimistix/_solver/optax.py", line 90, in init
    opt_state = self.optim.init(y)
AttributeError: '_Closure' object has no attribute 'init'

@patrick-kidger
Copy link
Owner

Ah! You want optax.adabelief(...), not just optax.adabelief.

@patrick-kidger patrick-kidger added the question User queries label Dec 18, 2023
@hkortier
Copy link
Author

ah thanks for you prompt reponse! I took this sentence from the https://docs.kidger.site/optimistix/how-to-choose/
optimistix.OptaxMinimiser(optax.adabelief, learning_rate=1e-3, rtol=1e-8, atol=1e-8)
However, lower in that text the correct syntax is listed.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Dec 19, 2023

Ah, thank you for pointing out the mistake! This should now be fixed in #29, so I'm closing this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants