From 59a8713c052a3c8ebbb685dfebc74a1e7f9589b3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 Nov 2023 19:47:01 -0500 Subject: [PATCH 1/8] Krylov Version for Trust Region --- Project.toml | 2 +- src/gaussnewton.jl | 9 ----- src/jacobian.jl | 95 ++++++++++++++++++++++++++++++++++++++++++---- src/trustRegion.jl | 28 ++++++++------ 4 files changed, 106 insertions(+), 28 deletions(-) diff --git a/Project.toml b/Project.toml index 2d8e4b661..2b7abac1b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "2.8.2" +version = "2.9.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 012767dcf..2062a3bfd 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -6,10 +6,6 @@ An advanced GaussNewton implementation with support for efficient handling of sp matrices via colored automatic differentiation and preconditioned linear solvers. Designed for large-scale and numerically-difficult nonlinear least squares problems. -!!! note - In most practical situations, users should prefer using `LevenbergMarquardt` instead! It - is a more general extension of `Gauss-Newton` Method. - ### Keyword Arguments - `autodiff`: determines the backend used for the Jacobian. Note that this argument is @@ -33,11 +29,6 @@ for large-scale and numerically-difficult nonlinear least squares problems. - `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref), which means that no line search is performed. Algorithms from `LineSearches.jl` can be used here directly, and they will be converted to the correct `LineSearch`. - -!!! warning - - Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian - construction. This will be fixed in the near future. """ @concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} ad::AD diff --git a/src/jacobian.jl b/src/jacobian.jl index ac824559b..0362b5734 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -54,7 +54,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val # NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) : (iip ? deepcopy(f.resid_prototype) : f.resid_prototype) - if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ) + if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) # || needsJᵀJ) sd = sparsity_detection_alg(f, alg.ad) ad = alg.ad jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) : @@ -92,9 +92,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val du = _mutable_zero(u) if needsJᵀJ - JᵀJ = __init_JᵀJ(J) - # FIXME: This needs to be handled better for JacVec Operator - Jᵀfu = J' * _vec(fu) + # TODO: Pass in `jac_transpose_autodiff` + JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; + jac_autodiff = __get_nonsparse_ad(alg.ad)) end if linsolve_init @@ -120,21 +120,68 @@ function __setup_linsolve(A, b, u, p, alg) nothing)..., weight) return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr) end +__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg) __get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff() __get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff() __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote() __get_nonsparse_ad(ad) = ad -__init_JᵀJ(J::Number) = zero(J) -__init_JᵀJ(J::AbstractArray) = J' * J -__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef) +__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J) +function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...) + JᵀJ = J' * J + Jᵀfu = J' * fu + return JᵀJ, Jᵀfu +end +function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...) + JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef) + return JᵀJ, J' * fu +end +function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; + jac_transpose_autodiff = nothing, jac_autodiff = nothing, kwargs...) + autodiff = __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf) + Jᵀ = VecJac(uf, u; autodiff) + JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u) + JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ) + Jᵀfu = Jᵀ * fu + return JᵀJ, Jᵀfu +end + +@concrete struct KrylovJᵀJ + JᵀJ + Jᵀ +end + +SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ) + +function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf) + if jac_transpose_autodiff === nothing + if isinplace(uf) + # VecJac can be only FiniteDiff + return AutoFiniteDiff() + else + # Short circuit if we see that FiniteDiff was used for J computation + jac_autodiff isa AutoFiniteDiff && return jac_autodiff + # Check if Zygote is loaded then use Zygote else use FiniteDiff + if haskey(Base.loaded_modules, + Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote")) + return AutoZygote() + else + return AutoFiniteDiff() + end + end + else + return __get_nonsparse_ad(jac_transpose_autodiff) + end +end __maybe_symmetric(x) = Symmetric(x) __maybe_symmetric(x::Number) = x # LinearSolve with `nothing` doesn't dispatch correctly here __maybe_symmetric(x::StaticArray) = x __maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x +__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x +__maybe_symmetric(x::KrylovJᵀJ) = x ## Special Handling for Scalars function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p, @@ -145,3 +192,37 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u return uf, nothing, u, nothing, nothing, u end + +function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J) + return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J) +end +__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J) +__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J) +__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H +__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H + +function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu) + return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu) +end +function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu) + return setproperty!(cache, sym1, J' * fu) +end +function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu) + return mul!(getproperty(cache, sym1), J', fu) +end +function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu) + return setproperty!(cache, sym1, H.Jᵀ * fu) +end +function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu) + return mul!(getproperty(cache, sym1), H.Jᵀ, fu) +end + +# Left-Right Multiplication +__lr_mul(::Val, H, g) = dot(g, H, g) +## TODO: Use a cache here to avoid allocations +__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g) +function __lr_mul(::Val{true}, H::KrylovJᵀJ, g) + c = similar(g) + mul!(c, H.JᵀJ, g) + return dot(g, c) +end diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 651591e84..47db955db 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -239,15 +239,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, fu_prev = zero(fu1) loss = get_loss(fu1) - uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); - linsolve_kwargs) + # uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); + # linsolve_kwargs) + uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip); + linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false)) + linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, u, p, alg) + u_tmp = zero(u) u_cauchy = zero(u) u_gauss_newton = _mutable_zero(u) loss_new = loss - H = zero(J' * J) - g = _mutable_zero(fu1) + # H = zero(J' * J) + # g = _mutable_zero(fu1) shrink_counter = 0 fu_new = zero(fu1) make_new_J = true @@ -346,8 +350,10 @@ function perform_step!(cache::TrustRegionCache{true}) @unpack make_new_J, J, fu, f, u, p, u_gauss_newton, alg, linsolve = cache if cache.make_new_J jacobian!!(J, cache) - mul!(cache.H, J', J) - mul!(_vec(cache.g), J', _vec(fu)) + __update_JᵀJ!(Val{true}(), cache, :H, J) + # mul!(cache.H, J', J) + __update_Jᵀf!(Val{true}(), cache, :g, :H, J, fu) + # mul!(_vec(cache.g), J', _vec(fu)) cache.stats.njacs += 1 # do not use A = cache.H, b = _vec(cache.g) since it is equivalent @@ -376,8 +382,8 @@ function perform_step!(cache::TrustRegionCache{false}) if make_new_J J = jacobian!!(cache.J, cache) - cache.H = J' * J - cache.g = _restructure(fu, J' * _vec(fu)) + __update_JᵀJ!(Val{false}(), cache, :H, J) + __update_Jᵀf!(Val{false}(), cache, :g, :H, J, fu) cache.stats.njacs += 1 if cache.linsolve === nothing @@ -431,7 +437,7 @@ function trust_region_step!(cache::TrustRegionCache) # Compute the ratio of the actual reduction to the predicted reduction. cache.r = -(loss - cache.loss_new) / - (dot(_vec(du), _vec(g)) + dot(_vec(du), H, _vec(du)) / 2) + (dot(_vec(du), _vec(g)) + __lr_mul(Val(isinplace(cache)), H, _vec(du)) / 2) @unpack r = cache if radius_update_scheme === RadiusUpdateSchemes.Simple @@ -594,7 +600,7 @@ function dogleg!(cache::TrustRegionCache{true}) # Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region l_grad = norm(cache.g) # length of the gradient - d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate + d_cauchy = l_grad^3 / __lr_mul(Val{true}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate if d_cauchy >= trust_r @. cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region return @@ -624,7 +630,7 @@ function dogleg!(cache::TrustRegionCache{false}) ## Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region l_grad = norm(cache.g) - d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate + d_cauchy = l_grad^3 / __lr_mul(Val{false}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate if d_cauchy > trust_r # cauchy point lies outside of trust region cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region return From 8f739fb7498c7bcd300b1f6709966debd0c885ec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 Nov 2023 20:28:34 -0500 Subject: [PATCH 2/8] Some progress on LM --- src/jacobian.jl | 11 +++++------ src/levenberg.jl | 3 ++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/jacobian.jl b/src/jacobian.jl index 0362b5734..e262f6964 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -1,3 +1,7 @@ +@concrete struct KrylovJᵀJ + JᵀJ + Jᵀ +end sparsity_detection_alg(_, _) = NoSparsityDetection() function sparsity_detection_alg(f, ad::AbstractSparseADType) if f.sparsity === nothing @@ -54,7 +58,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val # NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) : (iip ? deepcopy(f.resid_prototype) : f.resid_prototype) - if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) # || needsJᵀJ) + if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) sd = sparsity_detection_alg(f, alg.ad) ad = alg.ad jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) : @@ -147,11 +151,6 @@ function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; return JᵀJ, Jᵀfu end -@concrete struct KrylovJᵀJ - JᵀJ - Jᵀ -end - SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ) function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf) diff --git a/src/levenberg.jl b/src/levenberg.jl index fa3189332..a9b0bf89f 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -106,7 +106,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing, α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) - return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs, + _concrete_jac = ifelse(concrete_jac === nothing, true, concrete_jac) + return LevenbergMarquardt{_unwrap_val(_concrete_jac)}(ad, linsolve, precs, damping_initial, damping_increase_factor, damping_decrease_factor, finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D) end From 5052a6e90f44a0f7a6e13275aca3a2c35b5a229c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 Nov 2023 12:16:25 -0500 Subject: [PATCH 3/8] Fix StaticArray Case --- src/trustRegion.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 47db955db..601a4204a 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -243,7 +243,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, # linsolve_kwargs) uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false)) - linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, u, p, alg) + linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, du, p, alg) u_tmp = zero(u) u_cauchy = zero(u) From 77164c75484494f596ab0be7edc19ef76377d2eb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 Nov 2023 13:03:00 -0500 Subject: [PATCH 4/8] Fix matrix resizing --- src/gaussnewton.jl | 8 ++++---- src/jacobian.jl | 10 +++++----- src/trustRegion.jl | 15 +++++---------- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 2062a3bfd..261a4e1d0 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -113,8 +113,8 @@ function perform_step!(cache::GaussNewtonCache{true}) jacobian!!(J, cache) if JᵀJ !== nothing - __matmul!(JᵀJ, J', J) - __matmul!(Jᵀf, J', fu1) + __update_JᵀJ!(Val{true}(), cache, :JᵀJ, J) + __update_Jᵀf!(Val{true}(), cache, :Jᵀf, :JᵀJ, J, fu1) end # u = u - JᵀJ \ Jᵀfu @@ -151,8 +151,8 @@ function perform_step!(cache::GaussNewtonCache{false}) cache.J = jacobian!!(cache.J, cache) if cache.JᵀJ !== nothing - cache.JᵀJ = cache.J' * cache.J - cache.Jᵀf = cache.J' * fu1 + __update_JᵀJ!(Val{false}(), cache, :JᵀJ, cache.J) + __update_Jᵀf!(Val{false}(), cache, :Jᵀf, :JᵀJ, cache.J, fu1) end # u = u - J \ fu diff --git a/src/jacobian.jl b/src/jacobian.jl index e262f6964..8cd919246 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -180,7 +180,7 @@ __maybe_symmetric(x::Number) = x __maybe_symmetric(x::StaticArray) = x __maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x __maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x -__maybe_symmetric(x::KrylovJᵀJ) = x +__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ ## Special Handling for Scalars function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p, @@ -204,16 +204,16 @@ function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu) return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu) end function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu) - return setproperty!(cache, sym1, J' * fu) + return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu)) end function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu) - return mul!(getproperty(cache, sym1), J', fu) + return mul!(_vec(getproperty(cache, sym1)), J', fu) end function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu) - return setproperty!(cache, sym1, H.Jᵀ * fu) + return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu)) end function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu) - return mul!(getproperty(cache, sym1), H.Jᵀ, fu) + return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu) end # Left-Right Multiplication diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 601a4204a..f46451820 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -239,10 +239,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, fu_prev = zero(fu1) loss = get_loss(fu1) - # uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); - # linsolve_kwargs) uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false)) + g = _restructure(fu1, g) linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, du, p, alg) u_tmp = zero(u) @@ -250,8 +249,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, u_gauss_newton = _mutable_zero(u) loss_new = loss - # H = zero(J' * J) - # g = _mutable_zero(fu1) shrink_counter = 0 fu_new = zero(fu1) make_new_J = true @@ -351,9 +348,7 @@ function perform_step!(cache::TrustRegionCache{true}) if cache.make_new_J jacobian!!(J, cache) __update_JᵀJ!(Val{true}(), cache, :H, J) - # mul!(cache.H, J', J) - __update_Jᵀf!(Val{true}(), cache, :g, :H, J, fu) - # mul!(_vec(cache.g), J', _vec(fu)) + __update_Jᵀf!(Val{true}(), cache, :g, :H, J, _vec(fu)) cache.stats.njacs += 1 # do not use A = cache.H, b = _vec(cache.g) since it is equivalent @@ -383,7 +378,7 @@ function perform_step!(cache::TrustRegionCache{false}) if make_new_J J = jacobian!!(cache.J, cache) __update_JᵀJ!(Val{false}(), cache, :H, J) - __update_Jᵀf!(Val{false}(), cache, :g, :H, J, fu) + __update_Jᵀf!(Val{false}(), cache, :g, :H, J, _vec(fu)) cache.stats.njacs += 1 if cache.linsolve === nothing @@ -420,8 +415,8 @@ function retrospective_step!(cache::TrustRegionCache) cache.H = J' * J cache.g = J' * fu else - mul!(cache.H, J', J) - mul!(cache.g, J', fu) + __update_JᵀJ!(Val{isinplace(cache)}(), cache, :H, J) + __update_Jᵀf!(Val{isinplace(cache)}(), cache, :g, :H, J, fu) end cache.stats.njacs += 1 @unpack H, g, du = cache From 8f70a4e814bd4c2b8f5d777e5fb41cbc054deb9c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 Nov 2023 23:20:11 -0500 Subject: [PATCH 5/8] Use Zygote for LineSearch if loaded --- src/linesearch.jl | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/linesearch.jl b/src/linesearch.jl index 598934d03..30e057a2c 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -1,5 +1,5 @@ """ - LineSearch(method = Static(), autodiff = AutoFiniteDiff(), alpha = true) + LineSearch(method = nothing, autodiff = nothing, alpha = true) Wrapper over algorithms from [LineSeaches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl/). Allows automatic @@ -13,7 +13,7 @@ differentiation for fast Vector Jacobian Products. - `autodiff`: the automatic differentiation backend to use for the line search. Defaults to `AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP. `AutoZygote()` will be faster in most cases, but it requires `Zygote.jl` to be manually - installed and loaded + installed and loaded. - `alpha`: the initial step size to use. Defaults to `true` (which is equivalent to `1`). """ @concrete struct LineSearch @@ -22,7 +22,7 @@ differentiation for fast Vector Jacobian Products. α end -function LineSearch(; method = nothing, autodiff = AutoFiniteDiff(), alpha = true) +function LineSearch(; method = nothing, autodiff = nothing, alpha = true) return LineSearch(method, autodiff, alpha) end @@ -113,12 +113,21 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe g₀ = _mutable_zero(u) - autodiff = if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote) - @warn "Attempting to use Zygote.jl for linesearch on an in-place problem. Falling \ - back to finite differencing." - AutoFiniteDiff() + autodiff = if ls.autodiff === nothing + if !iip && haskey(Base.loaded_modules, + Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote")) + AutoZygote() + else + AutoFiniteDiff() + end else - ls.autodiff + if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote) + @warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \ + Falling back to finite differencing." + AutoFiniteDiff() + else + ls.autodiff + end end function g!(u, fu) From b79228d89b45317f64eb2cd01609ba3f370b2878 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Nov 2023 14:58:07 -0500 Subject: [PATCH 6/8] Add tests --- Project.toml | 4 +-- src/gaussnewton.jl | 13 +++++++--- src/jacobian.jl | 44 +++++++++++++++++++------------- src/linesearch.jl | 4 +++ src/trustRegion.jl | 17 ++++++++----- src/utils.jl | 6 +++++ test/basictests.jl | 39 ++++++++++++++++++---------- test/nonlinear_least_squares.jl | 45 ++++++++++++++++++++++++++++++--- 8 files changed, 126 insertions(+), 46 deletions(-) diff --git a/Project.toml b/Project.toml index 2b7abac1b..bab6ce5a6 100644 --- a/Project.toml +++ b/Project.toml @@ -50,7 +50,7 @@ FiniteDiff = "2" ForwardDiff = "0.10.3" LeastSquaresOptim = "0.8" LineSearches = "7" -LinearAlgebra = "1.9" +LinearAlgebra = "<0.0.1, 1" LinearSolve = "2.12" NonlinearProblemLibrary = "0.1" PrecompileTools = "1" @@ -59,7 +59,7 @@ Reexport = "0.2, 1" SciMLBase = "2.8.2" SciMLOperators = "0.3" SimpleNonlinearSolve = "0.1.23" -SparseArrays = "1.9" +SparseArrays = "<0.0.1, 1" SparseDiffTools = "2.12" StaticArraysCore = "1.4" UnPack = "1.0" diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 261a4e1d0..07b1adaec 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -29,23 +29,30 @@ for large-scale and numerically-difficult nonlinear least squares problems. - `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref), which means that no line search is performed. Algorithms from `LineSearches.jl` can be used here directly, and they will be converted to the correct `LineSearch`. + - `vjp_autodiff`: Automatic Differentiation Backend used for vector-jacobian products. + This is applicable if the linear solver doesn't require a concrete jacobian, for eg., + Krylov Methods. Defaults to `nothing`, which means if the problem is out of place and + `Zygote` is loaded then, we use `AutoZygote`. In all other, cases `FiniteDiff` is used. """ @concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} ad::AD linsolve precs linesearch + vjp_autodiff end function set_ad(alg::GaussNewton{CJ}, ad) where {CJ} - return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch) + return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch, alg.vjp_autodiff) end function GaussNewton(; concrete_jac = nothing, linsolve = nothing, - linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) + linesearch = LineSearch(), precs = DEFAULT_PRECS, vjp_autodiff = nothing, + adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch) - return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch) + return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch, + vjp_autodiff) end @concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip} diff --git a/src/jacobian.jl b/src/jacobian.jl index 8cd919246..6fef600af 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -2,6 +2,9 @@ JᵀJ Jᵀ end + +SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ) + sparsity_detection_alg(_, _) = NoSparsityDetection() function sparsity_detection_alg(f, ad::AbstractSparseADType) if f.sparsity === nothing @@ -67,12 +70,10 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val jac_cache = nothing end - # FIXME: To properly support needsJᵀJ without Jacobian, we need to implement - # a reverse diff operation with the seed being `Jx`, this is not yet implemented - J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ) + J = if !(linsolve_needs_jac || alg_wants_jac) if f.jvp === nothing # We don't need to construct the Jacobian - JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad)) + JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad)) else if iip jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du) @@ -96,9 +97,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val du = _mutable_zero(u) if needsJᵀJ - # TODO: Pass in `jac_transpose_autodiff` - JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; - jac_autodiff = __get_nonsparse_ad(alg.ad)) + JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f, + vjp_autodiff = __get_nonsparse_ad(_getproperty(alg, Val(:vjp_autodiff))), + jvp_autodiff = __get_nonsparse_ad(alg.ad)) end if linsolve_init @@ -141,26 +142,29 @@ function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...) JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef) return JᵀJ, J' * fu end -function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; - jac_transpose_autodiff = nothing, jac_autodiff = nothing, kwargs...) - autodiff = __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf) - Jᵀ = VecJac(uf, u; autodiff) +function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; f = nothing, + vjp_autodiff = nothing, jvp_autodiff = nothing, kwargs...) + # FIXME: Proper fix to this requires the FunctionOperator patch + if f !== nothing && f.vjp !== nothing + @warn "Currently we don't make use of user provided `jvp`. This is planned to be \ + fixed in the near future." + end + autodiff = __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) + Jᵀ = VecJac(uf, u; fu, autodiff) JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u) JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ) Jᵀfu = Jᵀ * fu return JᵀJ, Jᵀfu end -SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ) - -function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf) - if jac_transpose_autodiff === nothing +function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) + if vjp_autodiff === nothing if isinplace(uf) # VecJac can be only FiniteDiff return AutoFiniteDiff() else # Short circuit if we see that FiniteDiff was used for J computation - jac_autodiff isa AutoFiniteDiff && return jac_autodiff + jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff # Check if Zygote is loaded then use Zygote else use FiniteDiff if haskey(Base.loaded_modules, Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote")) @@ -170,7 +174,13 @@ function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, end end else - return __get_nonsparse_ad(jac_transpose_autodiff) + ad = __get_nonsparse_ad(vjp_autodiff) + if isinplace(uf) && ad isa AutoZygote + @warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \ + Falling back to finite differencing." + return AutoFiniteDiff() + end + return ad end end diff --git a/src/linesearch.jl b/src/linesearch.jl index 30e057a2c..c9e87a4cb 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -131,6 +131,10 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe end function g!(u, fu) + if f.jvp !== nothing + @warn "Currently we don't make use of user provided `jvp` in linesearch. This \ + is planned to be fixed in the near future." maxlog=1 + end op = VecJac(SciMLBase.JacobianWrapper(f, p), u; fu = fu1, autodiff) if iip mul!(g₀, op, fu) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index f46451820..d0149c31e 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -141,9 +141,12 @@ for large-scale and numerically-difficult nonlinear systems. `expand_threshold < r` (with `r` defined in `shrink_threshold`). Defaults to `2.0`. - `max_shrink_times`: the maximum number of times to shrink the trust region radius in a row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`. + - `vjp_autodiff`: Automatic Differentiation Backend used for vector-jacobian products. + This is applicable if the linear solver doesn't require a concrete jacobian, for eg., + Krylov Methods. Defaults to `nothing`, which means if the problem is out of place and + `Zygote` is loaded then, we use `AutoZygote`. In all other, cases `FiniteDiff` is used. """ -@concrete struct TrustRegion{CJ, AD, MTR} <: - AbstractNewtonAlgorithm{CJ, AD} +@concrete struct TrustRegion{CJ, AD, MTR} <: AbstractNewtonAlgorithm{CJ, AD} ad::AD linsolve precs @@ -156,13 +159,14 @@ for large-scale and numerically-difficult nonlinear systems. shrink_factor::MTR expand_factor::MTR max_shrink_times::Int + vjp_autodiff end function set_ad(alg::TrustRegion{CJ}, ad) where {CJ} return TrustRegion{CJ}(ad, alg.linsolve, alg.precs, alg.radius_update_scheme, alg.max_trust_radius, alg.initial_trust_radius, alg.step_threshold, alg.shrink_threshold, alg.expand_threshold, alg.shrink_factor, alg.expand_factor, - alg.max_shrink_times) + alg.max_shrink_times, alg.vjp_autodiff) end function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS, @@ -170,11 +174,12 @@ function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAU max_trust_radius::Real = 0 // 1, initial_trust_radius::Real = 0 // 1, step_threshold::Real = 1 // 10000, shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4, shrink_factor::Real = 1 // 4, - expand_factor::Real = 2 // 1, max_shrink_times::Int = 32, adkwargs...) + expand_factor::Real = 2 // 1, max_shrink_times::Int = 32, vjp_autodiff = nothing, + adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) return TrustRegion{_unwrap_val(concrete_jac)}(ad, linsolve, precs, radius_update_scheme, max_trust_radius, initial_trust_radius, step_threshold, shrink_threshold, - expand_threshold, shrink_factor, expand_factor, max_shrink_times) + expand_threshold, shrink_factor, expand_factor, max_shrink_times, vjp_autodiff) end @concrete mutable struct TrustRegionCache{iip, trustType, floatType} <: @@ -422,7 +427,7 @@ function retrospective_step!(cache::TrustRegionCache) @unpack H, g, du = cache return -(get_loss(fu_prev) - get_loss(fu)) / - (dot(du, g) + dot(du, H, du) / 2) + (dot(_vec(du), _vec(g)) + __lr_mul(Val(isinplace(cache)), H, _vec(du)) / 2) end function trust_region_step!(cache::TrustRegionCache) diff --git a/src/utils.jl b/src/utils.jl index 6a43acc80..f0660831c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -305,6 +305,12 @@ end __issingular(x::AbstractMatrix{T}) where {T} = cond(x) > inv(sqrt(eps(real(T)))) __issingular(x) = false ## If SciMLOperator and such +# Safe getproperty +@generated function _getproperty(s::S, ::Val{X}) where {S, X} + hasfield(S, X) && return :(s.$X) + return :(nothing) +end + # If factorization is LU then perform that and update the linsolve cache # else check if the matrix is singular function _try_factorize_and_check_singular!(linsolve, X) diff --git a/test/basictests.jl b/test/basictests.jl index 2ab059502..3cd9fb9f4 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -142,41 +142,52 @@ end # --- TrustRegion tests --- @testset "TrustRegion" begin - function benchmark_nlsolve_oop(f, u0, p = 2.0; radius_update_scheme, kwargs...) + function benchmark_nlsolve_oop(f, u0, p = 2.0; radius_update_scheme, linsolve = nothing, + vjp_autodiff = nothing, kwargs...) prob = NonlinearProblem{false}(f, u0, p) - return solve(prob, TrustRegion(; radius_update_scheme); abstol = 1e-9, kwargs...) + return solve(prob, TrustRegion(; radius_update_scheme, linsolve, vjp_autodiff); + abstol = 1e-9, kwargs...) end - function benchmark_nlsolve_iip(f, u0, p = 2.0; radius_update_scheme, kwargs...) + function benchmark_nlsolve_iip(f, u0, p = 2.0; radius_update_scheme, linsolve = nothing, + vjp_autodiff = nothing, kwargs...) prob = NonlinearProblem{true}(f, u0, p) - return solve(prob, TrustRegion(; radius_update_scheme); abstol = 1e-9, kwargs...) + return solve(prob, TrustRegion(; radius_update_scheme, linsolve, vjp_autodiff); + abstol = 1e-9, kwargs...) end radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.NocedalWright, RadiusUpdateSchemes.NLsolve, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan, RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin] u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) + linear_solvers = [nothing, LUFactorization(), KrylovJL_GMRES()] - @testset "[OOP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme)" for u0 in u0s, - radius_update_scheme in radius_update_schemes + @testset "[OOP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme) linear_solver: $(linsolve)" for u0 in u0s, + radius_update_scheme in radius_update_schemes, linsolve in linear_solvers + + !(u0 isa Array) && linsolve !== nothing && continue + + abstol = ifelse(linsolve isa KrylovJL, 1e-6, 1e-9) - sol = benchmark_nlsolve_oop(quadratic_f, u0; radius_update_scheme) + sol = benchmark_nlsolve_oop(quadratic_f, u0; radius_update_scheme, linsolve, abstol) @test SciMLBase.successful_retcode(sol) - @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + @test all(abs.(sol.u .* sol.u .- 2) .< abstol) cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), - TrustRegion(; radius_update_scheme); abstol = 1e-9) + TrustRegion(; radius_update_scheme, linsolve); abstol) @test (@ballocated solve!($cache)) < 200 end - @testset "[IIP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme)" for u0 in ([ - 1.0, 1.0],), radius_update_scheme in radius_update_schemes - sol = benchmark_nlsolve_iip(quadratic_f!, u0; radius_update_scheme) + @testset "[IIP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme) linear_solver: $(linsolve)" for u0 in ([ + 1.0, 1.0],), radius_update_scheme in radius_update_schemes, linsolve in linear_solvers + abstol = ifelse(linsolve isa KrylovJL, 1e-6, 1e-9) + sol = benchmark_nlsolve_iip(quadratic_f!, u0; radius_update_scheme, linsolve, + abstol) @test SciMLBase.successful_retcode(sol) - @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + @test all(abs.(sol.u .* sol.u .- 2) .< abstol) cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0), - TrustRegion(; radius_update_scheme); abstol = 1e-9) + TrustRegion(; radius_update_scheme); abstol) @test (@ballocated solve!($cache)) ≤ 64 end diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index 7b8354a9d..ddfdfc03f 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -1,4 +1,4 @@ -using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random, ForwardDiff +using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random, ForwardDiff, Zygote import FastLevenbergMarquardt, LeastSquaresOptim true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]) @@ -27,9 +27,16 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x) nlls_problems = [prob_oop, prob_iip] -solvers = vec(Any[GaussNewton(; linsolve, linesearch) - for linsolve in [nothing, LUFactorization()], -linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()]]) +solvers = [] +for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES()] + vjp_autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoFiniteDiff()] : + [nothing] + for linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()], + vjp_autodiff in vjp_autodiffs + + push!(solvers, GaussNewton(; linsolve, linesearch, vjp_autodiff)) + end +end append!(solvers, [ LevenbergMarquardt(), @@ -45,6 +52,36 @@ for prob in nlls_problems, solver in solvers @test norm(sol.resid) < 1e-6 end +# This is just for testing that we can use vjp provided by the user +function vjp(v, θ, p) + resid = zeros(length(p)) + J = ForwardDiff.jacobian((resid, θ) -> loss_function(resid, θ, p), resid, θ) + return vec(v' * J) +end + +function vjp!(Jv, v, θ, p) + resid = zeros(length(p)) + J = ForwardDiff.jacobian((resid, θ) -> loss_function(resid, θ, p), resid, θ) + mul!(vec(Jv), v', J) + return nothing +end + +probs = [ + NonlinearLeastSquaresProblem(NonlinearFunction{true}(loss_function; + resid_prototype = zero(y_target), vjp = vjp!), θ_init, x), + NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function; + resid_prototype = zero(y_target), vjp = vjp), θ_init, x), +] + +for prob in probs, solver in solvers + !(solver isa GaussNewton) && continue + !(solver.linsolve isa KrylovJL) && continue + @test_warn "Currently we don't make use of user provided `jvp`. This is planned to be \ + fixed in the near future." sol=solve(prob, solver; maxiters = 10000, abstol = 1e-8) + sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8) + @test norm(sol.resid) < 1e-6 +end + function jac!(J, θ, p) resid = zeros(length(p)) ForwardDiff.jacobian!(J, (resid, θ) -> loss_function(resid, θ, p), resid, θ) From 6c52956b3ed401fdb17f117d5ba27f04e56de5db Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Nov 2023 15:34:32 -0500 Subject: [PATCH 7/8] Fix aliasing issue --- src/levenberg.jl | 9 ++++++--- src/utils.jl | 20 ++++++++++++++++---- test/basictests.jl | 4 ++-- test/infeasible.jl | 4 ++-- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/levenberg.jl b/src/levenberg.jl index a9b0bf89f..83729a634 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -366,9 +366,10 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas if linsolve === nothing cache.v = -cache.mat_tmp \ (J' * fu1) else - linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp), + linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp), b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol) cache.linsolve = linres.cache + cache.v .*= -1 end end @@ -384,9 +385,11 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas if linsolve === nothing cache.a = -cache.mat_tmp \ _vec(J' * rhs_term) else - linres = dolinsolve(alg.precs, linsolve; b = _mutable(_vec(J' * rhs_term)), - linu = _vec(cache.a), p, reltol = cache.abstol) + linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp), + b = _mutable(_vec(J' * rhs_term)), linu = _vec(cache.a), p, + reltol = cache.abstol, reuse_A_if_factorization = true) cache.linsolve = linres.cache + cache.a .*= -1 end end cache.stats.nsolve += 1 diff --git a/src/utils.jl b/src/utils.jl index f0660831c..9d96e7b75 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -82,16 +82,28 @@ end DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing, - du = nothing, u = nothing, p = nothing, t = nothing, weight = nothing, - cachedata = nothing, reltol = nothing) where {P} - A !== nothing && (linsolve.A = A) + du = nothing, p = nothing, weight = nothing, cachedata = nothing, reltol = nothing, + reuse_A_if_factorization = false) where {P} + # Some Algorithms would reuse factorization but it causes the cache to not reset in + # certain cases + if A !== nothing + alg = linsolve.alg + if (alg isa LinearSolve.AbstractFactorization) || + (alg isa LinearSolve.DefaultLinearSolver && !(alg == + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES))) + # Factorization Algorithm + !reuse_A_if_factorization && (linsolve.A = A) + else + linsolve.A = A + end + end b !== nothing && (linsolve.b = b) linu !== nothing && (linsolve.u = linu) Plprev = linsolve.Pl isa ComposePreconditioner ? linsolve.Pl.outer : linsolve.Pl Prprev = linsolve.Pr isa ComposePreconditioner ? linsolve.Pr.outer : linsolve.Pr - _Pl, _Pr = precs(linsolve.A, du, u, p, nothing, A !== nothing, Plprev, Prprev, + _Pl, _Pr = precs(linsolve.A, du, linu, p, nothing, A !== nothing, Plprev, Prprev, cachedata) if (_Pl !== nothing || _Pr !== nothing) _weight = weight === nothing ? diff --git a/test/basictests.jl b/test/basictests.jl index 3cd9fb9f4..127b11681 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1013,12 +1013,12 @@ end u0 = rand(100) prob = NonlinearProblem(NonlinearFunction{false}(F; jvp = JVP), u0, u0) - sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES())) + sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()); abstol = 1e-13) @test norm(F(sol.u, u0)) ≤ 1e-8 prob = NonlinearProblem(NonlinearFunction{true}(F!; jvp = JVP!), u0, u0) - sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES())) + sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()); abstol = 1e-13) @test norm(F(sol.u, u0)) ≤ 1e-8 end diff --git a/test/infeasible.jl b/test/infeasible.jl index 001ce1f6e..db5d31f1b 100644 --- a/test/infeasible.jl +++ b/test/infeasible.jl @@ -29,8 +29,8 @@ function f1(u, p) v_x = 8.550491684548064e-12 + u[1] v_y = 6631.60076191005 + u[2] v_z = 3600.665431405663 + u[3] - r = @SVector [x, y, z] - v = @SVector [v_x, v_y, v_z] + r = [x, y, z] + v = [v_x, v_y, v_z] h = cross(r, v) ev = cross(v, h) / μ - r / norm(r) i = acos(h[3] / norm(h)) From bcfcc16727ad84cc79cac56aa7a2683c43f6f572 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Nov 2023 11:55:56 -0500 Subject: [PATCH 8/8] Avoid Runtime Checks for Zygote Being loaded --- .JuliaFormatter.toml | 1 + Project.toml | 2 ++ ext/NonlinearSolveZygoteExt.jl | 7 +++++++ src/NonlinearSolve.jl | 3 +++ src/extension_algs.jl | 21 +++++++++++++-------- src/jacobian.jl | 8 ++------ src/linesearch.jl | 3 +-- src/pseudotransient.jl | 1 + src/utils.jl | 1 + test/basictests.jl | 4 ++-- 10 files changed, 33 insertions(+), 18 deletions(-) create mode 100644 ext/NonlinearSolveZygoteExt.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 320e0c073..1768a1a7f 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,3 +1,4 @@ style = "sciml" format_markdown = true annotate_untyped_fields_with_any = false +format_docstrings = true diff --git a/Project.toml b/Project.toml index bab6ce5a6..37513ae88 100644 --- a/Project.toml +++ b/Project.toml @@ -30,11 +30,13 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] NonlinearSolveBandedMatricesExt = "BandedMatrices" NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt" NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim" +NonlinearSolveZygoteExt = "Zygote" [compat] ADTypes = "0.2" diff --git a/ext/NonlinearSolveZygoteExt.jl b/ext/NonlinearSolveZygoteExt.jl new file mode 100644 index 000000000..d58faabbd --- /dev/null +++ b/ext/NonlinearSolveZygoteExt.jl @@ -0,0 +1,7 @@ +module NonlinearSolveZygoteExt + +import NonlinearSolve, Zygote + +NonlinearSolve.is_extension_loaded(::Val{:Zygote}) = true + +end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 58d10b290..369de3669 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -38,6 +38,9 @@ import DiffEqBase: AbstractNonlinearTerminationMode, const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences, ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode} +# Type-Inference Friendly Check for Extension Loading +is_extension_loaded(::Val) = false + abstract type AbstractNonlinearSolveLineSearchAlgorithm end abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end diff --git a/src/extension_algs.jl b/src/extension_algs.jl index c4e56b6a5..8f9ed4400 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -8,13 +8,15 @@ for solving `NonlinearLeastSquaresProblem`. ## Arguments: -- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`. -- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If - `nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based - on the Jacobian structure. -- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`. + - `alg`: Algorithm to use. Can be `:lm` or `:dogleg`. + - `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If `nothing`, + then `LeastSquaresOptim.jl` will choose the best linear solver based on the Jacobian + structure. + - `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or + `:forward`. !!! note + This algorithm is only available if `LeastSquaresOptim.jl` is installed. """ struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm @@ -36,21 +38,24 @@ end """ FastLevenbergMarquardtJL(linsolve = :cholesky) -Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving -`NonlinearLeastSquaresProblem`. +Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) +for solving `NonlinearLeastSquaresProblem`. !!! warning + This is not really the fastest solver. It is called that since the original package is called "Fast". `LevenbergMarquardt()` is almost always a better choice. !!! warning + This algorithm requires the jacobian function to be provided! ## Arguments: -- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`. + - `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`. !!! note + This algorithm is only available if `FastLevenbergMarquardt.jl` is installed. """ @concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm diff --git a/src/jacobian.jl b/src/jacobian.jl index 6fef600af..41c7319a1 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -166,12 +166,8 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) # Short circuit if we see that FiniteDiff was used for J computation jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff # Check if Zygote is loaded then use Zygote else use FiniteDiff - if haskey(Base.loaded_modules, - Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote")) - return AutoZygote() - else - return AutoFiniteDiff() - end + is_extension_loaded(Val{:Zygote}()) && return AutoZygote() + return AutoFiniteDiff() end else ad = __get_nonsparse_ad(vjp_autodiff) diff --git a/src/linesearch.jl b/src/linesearch.jl index c9e87a4cb..d67ac978c 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -114,8 +114,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe g₀ = _mutable_zero(u) autodiff = if ls.autodiff === nothing - if !iip && haskey(Base.loaded_modules, - Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote")) + if !iip && is_extension_loaded(Val{:Zygote}()) AutoZygote() else AutoFiniteDiff() diff --git a/src/pseudotransient.jl b/src/pseudotransient.jl index 5da1375d6..b343138de 100644 --- a/src/pseudotransient.jl +++ b/src/pseudotransient.jl @@ -12,6 +12,7 @@ the time-stepping and algorithm, please see the paper: SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S106482750241044X) ### Keyword Arguments + - `autodiff`: determines the backend used for the Jacobian. Note that this argument is ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to `nothing` which means that a default is selected according to the problem specification! diff --git a/src/utils.jl b/src/utils.jl index 9d96e7b75..c5161df7c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -23,6 +23,7 @@ Construct the AD type from the arguments. This is mostly needed for compatibilit code. !!! warning + `chunk_size`, `standardtag`, `diff_type`, and `autodiff::Union{Val, Bool}` are deprecated and will be removed in v3. Update your code to directly specify `autodiff=`. diff --git a/test/basictests.jl b/test/basictests.jl index 127b11681..e3928a7bc 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1015,10 +1015,10 @@ end prob = NonlinearProblem(NonlinearFunction{false}(F; jvp = JVP), u0, u0) sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()); abstol = 1e-13) - @test norm(F(sol.u, u0)) ≤ 1e-8 + @test norm(F(sol.u, u0)) ≤ 1e-6 prob = NonlinearProblem(NonlinearFunction{true}(F!; jvp = JVP!), u0, u0) sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()); abstol = 1e-13) - @test norm(F(sol.u, u0)) ≤ 1e-8 + @test norm(F(sol.u, u0)) ≤ 1e-6 end