Skip to content

Commit

Permalink
feat: add hooks for OverrideInit (#517)
Browse files Browse the repository at this point in the history
* feat: add `get_abstol` and `get_reltol` interface methods

* feat: add `initialize_cache!`

* feat: implement initialization for polyalg cache

* feat: implement initialization for no-init cache

* feat: implement initialization for first order cache

* feat: implement initialization for `QuasiNewtonCache`

* feat: implement initialization for `GeneralizedDFSaneCache`

* fix: fix `SII.parameter_values`

* feat: implement initialization for `SimpleNonlinearSolve`

* fix: fix `InternalAPI.reinit_self!` for `GeneralizedDFSaneCache`

* fix: fix `SII.state_values` for `NoInitCache`

* feat: run initialiation on `solve!`

* build: bump SciMLBase compat in NonlinearSolveBase
  • Loading branch information
AayushSabharwal authored Dec 15, 2024
1 parent 873aae4 commit 8295933
Show file tree
Hide file tree
Showing 13 changed files with 222 additions and 18 deletions.
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ MaybeInplace = "0.1.4"
Preferences = "1.4"
Printf = "1.10"
RecursiveArrayTools = "3"
SciMLBase = "2.58"
SciMLBase = "2.68.1"
SciMLJacobianOperators = "0.1.1"
SciMLOperators = "0.3.10"
SparseArrays = "1.10"
Expand Down
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include("descent/damped_newton.jl")
include("descent/dogleg.jl")
include("descent/geodesic_acceleration.jl")

include("initialization.jl")
include("solve.jl")

include("forward_diff.jl")
Expand Down
11 changes: 10 additions & 1 deletion lib/NonlinearSolveBase/src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ Abstract Type for all NonlinearSolveBase Caches.
`u0` and any additional keyword arguments.
- `SciMLBase.isinplace(cache)`: whether or not the solver is inplace.
- `CommonSolve.step!(cache; kwargs...)`: See [`CommonSolve.step!`](@ref) for more details.
- `get_abstol(cache)`: get the `abstol` provided to the cache.
- `get_reltol(cache)`: get the `reltol` provided to the cache.
Additionally implements `SymbolicIndexingInterface` interface Functions.
Expand Down Expand Up @@ -304,9 +306,16 @@ end

SciMLBase.isinplace(cache::AbstractNonlinearSolveCache) = SciMLBase.isinplace(cache.prob)

function get_abstol(cache::AbstractNonlinearSolveCache)
get_abstol(cache.termination_cache)
end
function get_reltol(cache::AbstractNonlinearSolveCache)
get_reltol(cache.termination_cache)
end

## SII Interface
SII.symbolic_container(cache::AbstractNonlinearSolveCache) = cache.prob
SII.parameter_values(cache::AbstractNonlinearSolveCache) = SII.parameter_values(cache.prob)
SII.parameter_values(cache::AbstractNonlinearSolveCache) = cache.p
SII.state_values(cache::AbstractNonlinearSolveCache) = get_u(cache)

function Base.getproperty(cache::AbstractNonlinearSolveCache, sym::Symbol)
Expand Down
7 changes: 7 additions & 0 deletions lib/NonlinearSolveBase/src/forward_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@
values_p
partials_p
end

function NonlinearSolveBase.get_abstol(cache::NonlinearSolveForwardDiffCache)
NonlinearSolveBase.get_abstol(cache.cache)
end
function NonlinearSolveBase.get_reltol(cache::NonlinearSolveForwardDiffCache)
NonlinearSolveBase.get_reltol(cache.cache)
end
60 changes: 60 additions & 0 deletions lib/NonlinearSolveBase/src/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
struct NonlinearSolveDefaultInit <: SciMLBase.DAEInitializationAlgorithm end

function run_initialization!(cache, initializealg = cache.initializealg, prob = cache.prob)
_run_initialization!(cache, initializealg, prob, Val(SciMLBase.isinplace(cache)))
end

function _run_initialization!(
cache, ::NonlinearSolveDefaultInit, prob, isinplace::Union{Val{true}, Val{false}})
if SciMLBase.has_initialization_data(prob.f) &&
prob.f.initialization_data isa SciMLBase.OverrideInitData
return _run_initialization!(cache, SciMLBase.OverrideInit(), prob, isinplace)
end
return cache, true
end

function _run_initialization!(cache, initalg::SciMLBase.OverrideInit, prob,
isinplace::Union{Val{true}, Val{false}})
if cache isa AbstractNonlinearSolveCache && isdefined(cache.alg, :autodiff)
autodiff = cache.alg.autodiff
else
autodiff = ADTypes.AutoForwardDiff()
end
alg = initialization_alg(prob.f.initialization_data.initializeprob, autodiff)
if alg === nothing && cache isa AbstractNonlinearSolveCache
alg = cache.alg
end
u0, p, success = SciMLBase.get_initial_values(
prob, cache, prob.f, initalg, isinplace; nlsolve_alg = alg,
abstol = get_abstol(cache), reltol = get_reltol(cache))
cache = update_initial_values!(cache, u0, p)
if cache isa AbstractNonlinearSolveCache && isdefined(cache, :retcode) && !success
cache.retcode = ReturnCode.InitialFailure
end

return cache, success
end

function get_abstol(prob::AbstractNonlinearProblem)
get_tolerance(get(prob.kwargs, :abstol, nothing), eltype(SII.state_values(prob)))
end
function get_reltol(prob::AbstractNonlinearProblem)
get_tolerance(get(prob.kwargs, :reltol, nothing), eltype(SII.state_values(prob)))
end

initialization_alg(initprob, autodiff) = nothing

function update_initial_values!(cache::AbstractNonlinearSolveCache, u0, p)
InternalAPI.reinit!(cache; u0, p)
cache.prob = SciMLBase.remake(cache.prob; u0, p)
return cache
end

function update_initial_values!(prob::AbstractNonlinearProblem, u0, p)
return SciMLBase.remake(prob; u0, p)
end

function _run_initialization!(
cache::AbstractNonlinearSolveCache, ::SciMLBase.NoInit, prob, isinplace)
return cache, true
end
32 changes: 28 additions & 4 deletions lib/NonlinearSolveBase/src/polyalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ end
u0
u0_aliased
alias_u0::Bool

initializealg
end

function update_initial_values!(cache::NonlinearSolvePolyAlgorithmCache, u0, p)
foreach(cache.caches) do subcache
update_initial_values!(subcache, u0, p)
end
cache.prob = SciMLBase.remake(cache.prob; u0, p)
return cache
end

function NonlinearSolveBase.get_abstol(cache::NonlinearSolvePolyAlgorithmCache)
NonlinearSolveBase.get_abstol(cache.caches[cache.current])
end
function NonlinearSolveBase.get_reltol(cache::NonlinearSolvePolyAlgorithmCache)
NonlinearSolveBase.get_reltol(cache.caches[cache.current])
end

function SII.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache)
Expand All @@ -67,6 +84,9 @@ end
function SII.state_values(cache::NonlinearSolvePolyAlgorithmCache)
SII.state_values(SII.symbolic_container(cache))
end
function SII.parameter_values(cache::NonlinearSolvePolyAlgorithmCache)
SII.parameter_values(SII.symbolic_container(cache))
end

function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolvePolyAlgorithmCache)
println(io, "NonlinearSolvePolyAlgorithmCache with \
Expand Down Expand Up @@ -97,7 +117,8 @@ end
function SciMLBase.__init(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...;
stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000,
internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs...
internalnorm = L2_NORM, alias_u0 = false, verbose = true,
initializealg = NonlinearSolveDefaultInit(), kwargs...
)
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
Expand All @@ -109,18 +130,21 @@ function SciMLBase.__init(
u0_aliased = alias_u0 ? copy(u0) : u0
alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased))

return NonlinearSolvePolyAlgorithmCache(
cache = NonlinearSolvePolyAlgorithmCache(
alg.static_length, prob,
map(alg.algs) do solver
SciMLBase.__init(
prob, solver, args...;
stats, maxtime, internalnorm, alias_u0, verbose, kwargs...
stats, maxtime, internalnorm, alias_u0, verbose,
initializealg = SciMLBase.NoInit(), kwargs...
)
end,
alg, -1, alg.start_index, 0, stats, 0.0, maxtime,
ReturnCode.Default, false, maxiters, internalnorm,
u0, u0_aliased, alias_u0
u0, u0_aliased, alias_u0, initializealg
)
run_initialization!(cache)
return cache
end

@generated function InternalAPI.step!(
Expand Down
61 changes: 58 additions & 3 deletions lib/NonlinearSolveBase/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ function SciMLBase.__solve(
end

function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
if cache.retcode == ReturnCode.InitialFailure
return SciMLBase.build_solution(
cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, cache.stats, cache.trace
)
end

while not_terminated(cache)
CommonSolve.step!(cache)
end
Expand Down Expand Up @@ -40,6 +47,17 @@ end
sol_syms = [gensym("sol") for i in 1:N]
u_result_syms = [gensym("u_result") for i in 1:N]

push!(calls,
quote
if cache.retcode == ReturnCode.InitialFailure
u = $(SII.state_values)(cache)
return build_solution_less_specialize(
cache.prob, cache.alg, u, $(Utils.evaluate_f)(cache.prob, u);
retcode = cache.retcode
)
end
end)

for i in 1:N
push!(calls,
quote
Expand Down Expand Up @@ -111,7 +129,8 @@ end

@generated function __generated_polysolve(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...;
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs...
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true,
initializealg = NonlinearSolveDefaultInit(), kwargs...
) where {N}
sol_syms = [gensym("sol") for _ in 1:N]
prob_syms = [gensym("prob") for _ in 1:N]
Expand All @@ -123,9 +142,23 @@ end
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing
end
end]

push!(calls,
quote
prob, success = $(run_initialization!)(prob, initializealg, prob)
if !success
u = $(SII.state_values)(prob)
return build_solution_less_specialize(
prob, alg, u, $(Utils.evaluate_f)(prob, u);
retcode = $(ReturnCode.InitialFailure))
end
end)

push!(calls, quote
u0 = prob.u0
u0_aliased = alias_u0 ? zero(u0) : u0
end]
end)
for i in 1:N
cur_sol = sol_syms[i]
push!(calls,
Expand Down Expand Up @@ -246,8 +279,21 @@ end
alg
args
kwargs::Any
initializealg

retcode::ReturnCode.T
end

function get_abstol(cache::NonlinearSolveNoInitCache)
get(cache.kwargs, :abstol, get_tolerance(nothing, eltype(cache.prob.u0)))
end
function get_reltol(cache::NonlinearSolveNoInitCache)
get(cache.kwargs, :reltol, get_tolerance(nothing, eltype(cache.prob.u0)))
end

SII.parameter_values(cache::NonlinearSolveNoInitCache) = SII.parameter_values(cache.prob)
SII.state_values(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob)

get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob)

function SciMLBase.reinit!(
Expand All @@ -264,11 +310,20 @@ end

function SciMLBase.__init(
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...;
initializealg = NonlinearSolveDefaultInit(),
kwargs...
)
return NonlinearSolveNoInitCache(prob, alg, args, kwargs)
cache = NonlinearSolveNoInitCache(
prob, alg, args, kwargs, initializealg, ReturnCode.Default)
run_initialization!(cache)
return cache
end

function CommonSolve.solve!(cache::NonlinearSolveNoInitCache)
if cache.retcode == ReturnCode.InitialFailure
u = SII.state_values(cache)
return SciMLBase.build_solution(
cache.prob, cache.alg, u, Utils.evaluate_f(cache.prob, u); cache.retcode)
end
return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...)
end
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ const AbsNormModes = Union{
u_diff_cache::uType
end

get_abstol(cache::NonlinearTerminationModeCache) = cache.abstol
get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol

function update_u!!(cache::NonlinearTerminationModeCache, u)
cache.u === nothing && return
if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u)
Expand Down
12 changes: 9 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ end
retcode::ReturnCode.T
force_stop::Bool
kwargs

initializealg
end

function InternalAPI.reinit_self!(
Expand Down Expand Up @@ -121,7 +123,7 @@ function SciMLBase.__init(
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, maxtime = nothing,
termination_condition = nothing, internalnorm = L2_NORM,
linsolve_kwargs = (;), kwargs...
linsolve_kwargs = (;), initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), kwargs...
)
@set! alg.autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
provided_jvp_autodiff = alg.jvp_autodiff !== nothing
Expand Down Expand Up @@ -206,13 +208,17 @@ function SciMLBase.__init(
prob, alg, u, fu, J, du; kwargs...
)

return GeneralizedFirstOrderAlgorithmCache(
cache = GeneralizedFirstOrderAlgorithmCache(
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
jac_cache, descent_cache, linesearch_cache, trustregion_cache,
stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs,
initializealg
)
NonlinearSolveBase.run_initialization!(cache)
end

return cache
end

function InternalAPI.step!(
Expand Down
20 changes: 17 additions & 3 deletions lib/NonlinearSolveQuasiNewton/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ end
force_stop::Bool
force_reinit::Bool
kwargs

# Initialization
initializealg
end

function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache)
NonlinearSolveBase.get_abstol(cache.termination_cache)
end
function NonlinearSolveBase.get_reltol(cache::QuasiNewtonCache)
NonlinearSolveBase.get_reltol(cache.termination_cache)
end

function InternalAPI.reinit_self!(
Expand Down Expand Up @@ -130,7 +140,8 @@ function SciMLBase.__init(
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxtime = nothing,
maxiters = 1000, abstol = nothing, reltol = nothing,
linsolve_kwargs = (;), termination_condition = nothing,
internalnorm::F = L2_NORM, kwargs...
internalnorm::F = L2_NORM, initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(),
kwargs...
) where {F}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
Expand Down Expand Up @@ -204,15 +215,18 @@ function SciMLBase.__init(
uses_jacobian_inverse = inverted_jac, kwargs...
)

return QuasiNewtonCache(
cache = QuasiNewtonCache(
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
initialization_cache, descent_cache, linesearch_cache,
trustregion_cache, update_rule_cache, reinit_rule_cache,
inv_workspace, stats, 0, 0, alg.max_resets, maxiters, maxtime,
alg.max_shrink_times, 0, timer, 0.0, termination_cache, trace,
ReturnCode.Default, false, false, kwargs
ReturnCode.Default, false, false, kwargs, initializealg
)
NonlinearSolveBase.run_initialization!(cache)
end

return cache
end

function InternalAPI.step!(
Expand Down
Loading

0 comments on commit 8295933

Please sign in to comment.