Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] minimization with complex arguments #71

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion optimistix/_solver/backtracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing_extensions import TypeAlias

import equinox as eqx
import jax
import jax.numpy as jnp
from equinox.internal import ω
from jaxtyping import Array, Bool, Scalar, ScalarLike
Expand Down Expand Up @@ -85,7 +86,7 @@ def step(
)

y_diff = (y_eval**ω - y**ω).ω
predicted_reduction = tree_dot(grad, y_diff)
predicted_reduction = tree_dot(jax.tree_map(jnp.conj, grad), y_diff).real
# Terminate when the Armijo condition is satisfied. That is, `fn(y_eval)`
# must do better than its linear approximation:
# `fn(y_eval) < fn(y) + grad•y_diff`
Expand Down
10 changes: 7 additions & 3 deletions optimistix/_solver/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ def _identity_pytree(pytree: PyTree[Array]) -> lx.PyTreeLinearOperator:
for i2, l2 in enumerate(leaves):
if i1 == i2:
eye_leaves.append(
jnp.eye(jnp.size(l1)).reshape(jnp.shape(l1) + jnp.shape(l2))
jnp.eye(jnp.size(l1), dtype=l1.dtype).reshape(
jnp.shape(l1) + jnp.shape(l2)
)
)
else:
eye_leaves.append(jnp.zeros(jnp.shape(l1) + jnp.shape(l2)))
eye_leaves.append(
jnp.zeros(jnp.shape(l1) + jnp.shape(l2), dtype=l1.dtype)
)

# This has a Lineax positive_semidefinite tag. This is okay because the BFGS update
# preserves positive-definiteness.
Expand Down Expand Up @@ -111,7 +115,7 @@ def no_update(hessian, hessian_inv):
# this we jump straight to the line search.
# Likewise we get inner <= eps on convergence, and so again we make no update
# to avoid a division by zero.
inner_nonzero = inner > jnp.finfo(inner.dtype).eps
inner_nonzero = jnp.abs(inner) > jnp.finfo(inner.dtype).eps
hessian, hessian_inv = filter_cond(
inner_nonzero, bfgs_update, no_update, hessian, hessian_inv
)
Expand Down
18 changes: 12 additions & 6 deletions optimistix/_solver/dogleg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, cast, Generic, Union

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import lineax as lx
Expand Down Expand Up @@ -73,22 +74,25 @@ def query(
state: _DoglegDescentState,
) -> _DoglegDescentState:
del state
conj_grad = jax.tree_map(jnp.conj, f_info.grad)
# Compute `denom = grad^T Hess grad.`
if isinstance(f_info, FunctionInfo.EvalGradHessian):
denom = tree_dot(f_info.grad, f_info.hessian.mv(f_info.grad))
elif isinstance(f_info, FunctionInfo.ResidualJac):
# Use Gauss--Newton approximation `Hess ~ J^T J`
denom = sum_squares(f_info.jac.mv(f_info.grad))
denom = sum_squares(f_info.jac.mv(conj_grad))
else:
raise ValueError(
"`DoglegDescent` can only be used with least-squares solvers, or "
"quasi-Newton minimisers which make approximations to the Hessian "
"(like `optx.BFGS(use_inverse=False)`)"
)
denom_nonzero = denom > jnp.finfo(denom.dtype).eps
denom_nonzero = jnp.abs(denom) > jnp.finfo(denom.dtype).eps
safe_denom = jnp.where(denom_nonzero, denom, 1)
# Compute `grad^T grad / (grad^T Hess grad)`
scaling = jnp.where(denom_nonzero, sum_squares(f_info.grad) / safe_denom, 0.0)

with jax.numpy_dtype_promotion("standard"):
scaling = jnp.where(denom_nonzero, sum_squares(conj_grad) / safe_denom, 0.0)
scaling = cast(Array, scaling)

# Downhill towards the bottom of the quadratic basin.
Expand All @@ -97,7 +101,8 @@ def query(
newton_norm = self.trust_region_norm(newton_sol)

# Downhill steepest descent.
cauchy = (-scaling * f_info.grad**ω).ω
with jax.numpy_dtype_promotion("standard"):
cauchy = (-scaling * conj_grad**ω).ω
cauchy_norm = self.trust_region_norm(cauchy)

return _DoglegDescentState(
Expand Down Expand Up @@ -139,7 +144,8 @@ def interpolate_cauchy_and_newton(cauchy, newton):
"""

def interpolate(t):
return (cauchy**ω + (t - 1) * (newton**ω - cauchy**ω)).ω
with jax.numpy_dtype_promotion("standard"):
return (cauchy**ω + (t - 1) * (newton**ω - cauchy**ω)).ω

# The vast majority of the time we expect users to use `two_norm`,
# ie. the classic, elliptical trust region radius. In this case, we
Expand All @@ -152,7 +158,7 @@ def interpolate(t):
# find the value which hits the trust region radius.
if self.trust_region_norm is two_norm:
a = sum_squares((newton**ω - cauchy**ω).ω)
inner_prod = tree_dot(cauchy, (newton**ω - cauchy**ω).ω)
inner_prod = tree_dot(cauchy, (newton**ω - cauchy**ω).ω).real
b = 2 * (inner_prod - a)
c = state.cauchy_norm**2 - 2 * inner_prod + a - scaled_step_size**2
quadratic_1 = jnp.clip(
Expand Down
5 changes: 3 additions & 2 deletions optimistix/_solver/gauss_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def newton_step(
value.)
"""
if isinstance(f_info, FunctionInfo.EvalGradHessianInv):
newton = f_info.hessian_inv.mv(f_info.grad)
conj_grad = jax.tree_map(jnp.conj, f_info.grad)
newton = f_info.hessian_inv.mv(conj_grad)
result = RESULTS.successful
else:
if isinstance(f_info, FunctionInfo.EvalGradHessian):
Expand All @@ -73,7 +74,7 @@ def newton_step(
"Cannot use a Newton descent with a solver that only evaluates the "
"gradient, or only the function itself."
)
out = lx.linear_solve(operator, vector, linear_solver)
out = lx.linear_solve(operator, jax.tree_map(jnp.conj, vector), linear_solver)
newton = out.value
result = RESULTS.promote(out.result)
return newton, result
Expand Down
2 changes: 1 addition & 1 deletion optimistix/_solver/gradient_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def query(
)
if self.norm is not None:
grad = (grad**ω / self.norm(grad)).ω
return _SteepestDescentState(grad)
return _SteepestDescentState(jax.tree_map(jnp.conj, grad))

def step(
self, step_size: Scalar, state: _SteepestDescentState
Expand Down
9 changes: 7 additions & 2 deletions optimistix/_solver/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lineax as lx
from equinox.internal import ω
from jaxtyping import Array, Float, PyTree, Scalar, ScalarLike
from lineax.internal import default_floating_dtype as default_floating_dtype

from .._custom_types import Aux, Out, Y
from .._misc import max_norm, tree_full_like, two_norm
Expand Down Expand Up @@ -57,7 +58,11 @@ def damped_newton_step(
lm_param = jnp.where(pred, 1 / safe_step_size, jnp.finfo(step_size).max)
lm_param = cast(Array, lm_param)
if isinstance(f_info, FunctionInfo.EvalGradHessian):
operator = f_info.hessian + lm_param * lx.IdentityLinearOperator(
leaves = jtu.tree_leaves(f_info.hessian.in_structure())
dtype = (
default_floating_dtype() if len(leaves) == 0 else jnp.result_type(*leaves)
)
operator = f_info.hessian + lm_param.astype(dtype) * lx.IdentityLinearOperator(
f_info.hessian.in_structure()
)
vector = f_info.grad
Expand All @@ -73,7 +78,7 @@ def damped_newton_step(
"provide (approximate) Hessian information."
)
linear_sol = lx.linear_solve(operator, vector, linear_solver, throw=False)
return linear_sol.value, RESULTS.promote(linear_sol.result)
return jax.tree_map(jnp.conj, linear_sol.value), RESULTS.promote(linear_sol.result)


class _DampedNewtonDescentState(eqx.Module, strict=True):
Expand Down
16 changes: 11 additions & 5 deletions optimistix/_solver/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, cast, Generic, Union

import equinox as eqx
import jax
import jax.numpy as jnp
from equinox.internal import ω
from jaxtyping import Array, PyTree, Scalar
Expand Down Expand Up @@ -31,7 +32,9 @@ def polak_ribiere(grad_vector: Y, grad_prev: Y, y_diff_prev: Y) -> Scalar:
# have a gradient. In either case we set β=0 to revert to just gradient descent.
pred = denominator > jnp.finfo(denominator.dtype).eps
safe_denom = jnp.where(pred, denominator, 1)
out = jnp.where(pred, jnp.clip(numerator / safe_denom, min=0), 0)

with jax.numpy_dtype_promotion("standard"):
out = jnp.where(pred, jnp.clip(numerator / safe_denom, min=0), 0)
return cast(Scalar, out)


Expand Down Expand Up @@ -67,7 +70,8 @@ def dai_yuan(grad: Y, grad_prev: Y, y_diff_prev: Y) -> Scalar:
# Triggers at initialisation and convergence, as above.
pred = jnp.abs(denominator) > jnp.finfo(denominator.dtype).eps
safe_denom = jnp.where(pred, denominator, 1)
return jnp.where(pred, numerator / safe_denom, 0)
with jax.numpy_dtype_promotion("standard"):
return jnp.where(pred, numerator / safe_denom, 0)


class _NonlinearCGDescentState(eqx.Module, Generic[Y], strict=True):
Expand Down Expand Up @@ -141,11 +145,13 @@ def query(
# `state.{grad, y_diff} = 0`, i.e. our previous step hit a local minima, then
# on this next step we'll again just use gradient descent, and stop.
beta = self.method(f_info.grad, state.grad, state.y_diff)
neg_grad = (-(f_info.grad**ω)).ω
nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω
conj_grad = jax.tree_map(jnp.conj, f_info.grad)
neg_grad = (-(conj_grad**ω)).ω
with jax.numpy_dtype_promotion("standard"):
nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω
# Check if this is a descent direction. Use gradient descent if it isn't.
y_diff = tree_where(
tree_dot(f_info.grad, nonlinear_cg_direction) < 0,
tree_dot(conj_grad, nonlinear_cg_direction).real < 0,
nonlinear_cg_direction,
neg_grad,
)
Expand Down
4 changes: 3 additions & 1 deletion optimistix/_solver/optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def step(
("loss" in self.verbose, "Loss", f),
("y" in self.verbose, "y", y),
)
updates, new_opt_state = self.optim.update(grads, state.opt_state)
updates, new_opt_state = self.optim.update(
jax.tree_map(jnp.conj, grads), state.opt_state
)
new_y = eqx.apply_updates(y, updates)
terminate = cauchy_termination(
self.rtol,
Expand Down
9 changes: 5 additions & 4 deletions optimistix/_solver/trust_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing_extensions import TypeAlias

import equinox as eqx
import jax
import jax.numpy as jnp
from equinox import AbstractVar
from equinox.internal import ω
Expand Down Expand Up @@ -166,9 +167,9 @@ def predict_reduction(
if isinstance(f_info, FunctionInfo.EvalGradHessian):
# Minimisation algorithm. Directly compute the quadratic approximation.
return tree_dot(
y_diff,
jax.tree_map(jnp.conj, f_info.grad),
(f_info.grad**ω + 0.5 * f_info.hessian.mv(y_diff) ** ω).ω,
)
).real
elif isinstance(f_info, FunctionInfo.ResidualJac):
# Least-squares algorithm. So instead of considering fn (which returns the
# residuals), instead consider `0.5*fn(y)^2`, and then apply the logic as
Expand All @@ -190,7 +191,7 @@ def predict_reduction(
jacobian_term = sum_squares(
(f_info.jac.mv(y_diff) ** ω + f_info.residual**ω).ω
)
return 0.5 * (jacobian_term - rtr)
return 0.5 * (jacobian_term - rtr).real
else:
raise ValueError(
"Cannot use `ClassicalTrustRegion` with this solver. This is because "
Expand Down Expand Up @@ -273,7 +274,7 @@ def predict_reduction(
FunctionInfo.ResidualJac,
),
):
return tree_dot(f_info.grad, y_diff)
return tree_dot(jax.tree_map(jnp.conj, f_info.grad), y_diff).real
else:
raise ValueError(
"Cannot use `LinearTrustRegion` with this solver. This is because "
Expand Down
28 changes: 25 additions & 3 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ def bowl(tree: PyTree[Array], args: Array):
# Trivial quadratic bowl smoke test for convergence.
(y, _) = jfu.ravel_pytree(tree)
matrix = args
return y.T @ matrix @ y
return y.T.conj() @ matrix @ y


def diagonal_quadratic_bowl(tree: PyTree[Array], args: PyTree[Array]):
# A diagonal quadratic bowl smoke test for convergence.
weight_vector = args
return (ω(tree).call(jnp.square) * (0.1 + weight_vector**ω)).ω
return (ω(tree).call(jnp.square) * (0.1 + weight_vector**ω)).call(jnp.abs).ω


def rosenbrock(tree: PyTree[Array], args: Scalar):
Expand Down Expand Up @@ -317,7 +317,7 @@ def loss(model, x, y):

def square_minus_one(x: Array, args: PyTree):
"""A simple ||x||^2 - 1 function."""
return jnp.sum(jnp.square(x)) - 1.0
return jnp.sum(jnp.square(jnp.abs(x))) - 1.0


#
Expand Down Expand Up @@ -383,6 +383,16 @@ def get_weights(model):
[jr.normal(key, leaf.shape, leaf.dtype) ** 2 for leaf in leaves]
)

diagonal_bowl_init_complex = (
{"a": (0.05 + 0.01j) * jnp.ones((2, 3, 3), dtype=jnp.complex128)},
((0.01 + 0.05j) * jnp.ones(2, dtype=jnp.complex128)),
)
leaves_complex, treedef_complex = jtu.tree_flatten(diagonal_bowl_init_complex)
key = jr.PRNGKey(17)
diagonal_bowl_args_complex = treedef.unflatten(
[jr.normal(key, leaf.shape, leaf.dtype) ** 2 for leaf in leaves_complex]
)

# neural net args
ffn_data = jnp.linspace(0, 1, 100)[..., None]
ffn_args = (ffn_static, ffn_data)
Expand All @@ -394,6 +404,12 @@ def get_weights(model):
diagonal_bowl_init,
diagonal_bowl_args,
),
(
diagonal_quadratic_bowl,
jnp.array(0.0),
diagonal_bowl_init_complex,
diagonal_bowl_args_complex,
),
(
rosenbrock,
jnp.array(0.0),
Expand Down Expand Up @@ -463,6 +479,12 @@ def get_weights(model):
),
# Problems with initial value of 0
(square_minus_one, jnp.array(-1.0), jnp.array(1.0), None),
(
square_minus_one,
jnp.array(-1.0),
jnp.array(1.0 + 1.0j, dtype=jnp.complex128),
None,
),
)

# ROOT FIND/FIXED POINT PROBLEMS
Expand Down
6 changes: 4 additions & 2 deletions tests/test_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def test_least_squares_jvp(getkey, solver, _fn, minimum, init, args):
fn = _fn

dynamic_args, static_args = eqx.partition(args, eqx.is_array)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init)
t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), init)
t_dynamic_args = jtu.tree_map(
lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), dynamic_args
)

def least_squares(x, dynamic_args, *, adjoint):
args = eqx.combine(dynamic_args, static_args)
Expand Down
23 changes: 14 additions & 9 deletions tests/test_minimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ def test_minimise_jvp(getkey, solver, _fn, minimum, init, args):
fn = _fn

dynamic_args, static_args = eqx.partition(args, eqx.is_array)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init)
t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), init)
t_dynamic_args = jtu.tree_map(
lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), dynamic_args
)

def minimise(x, dynamic_args, *, adjoint):
args = eqx.combine(dynamic_args, static_args)
Expand Down Expand Up @@ -138,18 +140,19 @@ def minimise(x, dynamic_args, *, adjoint):
# assert tree_allclose(t_out2, t_expected_out, atol=atol, rtol=rtol)


@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128])
@pytest.mark.parametrize(
"method",
[optx.polak_ribiere, optx.fletcher_reeves, optx.hestenes_stiefel, optx.dai_yuan],
)
def test_nonlinear_cg_methods(method):
def test_nonlinear_cg_methods(method, dtype):
solver = optx.NonlinearCG(rtol=1e-10, atol=1e-10, method=method)

def f(y, _):
A = jnp.array([[2.0, -1.0], [-1.0, 3.0]])
b = jnp.array([-100.0, 5.0])
c = jnp.array(100.0)
return jnp.einsum("ij,i,j", A, y, y) + jnp.dot(b, y) + c
A = jnp.array([[2.0, -1.0], [-1.0, 3.0]], dtype=dtype)
b = jnp.array([-100.0, 5.0], dtype=dtype)
c = jnp.array(100.0, dtype=dtype)
return (jnp.einsum("ij,i,j", A, y, y) + jnp.dot(b, y) + c).real

# Analytic minimum:
# 0 = df/dyk
Expand All @@ -158,9 +161,11 @@ def f(y, _):
# => y = -0.5 A^{-1} b
# = [[-0.3, 0.1], [0.1, 0.2]] [-100, 5]
# = [29.5, 9]
y0 = jnp.array([2.0, 3.0])
y0 = jnp.array([2.0, 3.0], dtype=dtype)
sol = optx.minimise(f, solver, y0, max_steps=500)
assert tree_allclose(sol.value, jnp.array([29.5, 9.0]), rtol=1e-5, atol=1e-5)
assert tree_allclose(
sol.value, jnp.array([29.5, 9.0], dtype=dtype), rtol=1e-5, atol=1e-5
)


def test_optax_recompilation():
Expand Down
Loading