-
-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add hooks for
OverrideInit
(#517)
* feat: add `get_abstol` and `get_reltol` interface methods * feat: add `initialize_cache!` * feat: implement initialization for polyalg cache * feat: implement initialization for no-init cache * feat: implement initialization for first order cache * feat: implement initialization for `QuasiNewtonCache` * feat: implement initialization for `GeneralizedDFSaneCache` * fix: fix `SII.parameter_values` * feat: implement initialization for `SimpleNonlinearSolve` * fix: fix `InternalAPI.reinit_self!` for `GeneralizedDFSaneCache` * fix: fix `SII.state_values` for `NoInitCache` * feat: run initialiation on `solve!` * build: bump SciMLBase compat in NonlinearSolveBase
- Loading branch information
1 parent
873aae4
commit 8295933
Showing
13 changed files
with
222 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
struct NonlinearSolveDefaultInit <: SciMLBase.DAEInitializationAlgorithm end | ||
|
||
function run_initialization!(cache, initializealg = cache.initializealg, prob = cache.prob) | ||
_run_initialization!(cache, initializealg, prob, Val(SciMLBase.isinplace(cache))) | ||
end | ||
|
||
function _run_initialization!( | ||
cache, ::NonlinearSolveDefaultInit, prob, isinplace::Union{Val{true}, Val{false}}) | ||
if SciMLBase.has_initialization_data(prob.f) && | ||
prob.f.initialization_data isa SciMLBase.OverrideInitData | ||
return _run_initialization!(cache, SciMLBase.OverrideInit(), prob, isinplace) | ||
end | ||
return cache, true | ||
end | ||
|
||
function _run_initialization!(cache, initalg::SciMLBase.OverrideInit, prob, | ||
isinplace::Union{Val{true}, Val{false}}) | ||
if cache isa AbstractNonlinearSolveCache && isdefined(cache.alg, :autodiff) | ||
autodiff = cache.alg.autodiff | ||
else | ||
autodiff = ADTypes.AutoForwardDiff() | ||
end | ||
alg = initialization_alg(prob.f.initialization_data.initializeprob, autodiff) | ||
if alg === nothing && cache isa AbstractNonlinearSolveCache | ||
alg = cache.alg | ||
end | ||
u0, p, success = SciMLBase.get_initial_values( | ||
prob, cache, prob.f, initalg, isinplace; nlsolve_alg = alg, | ||
abstol = get_abstol(cache), reltol = get_reltol(cache)) | ||
cache = update_initial_values!(cache, u0, p) | ||
if cache isa AbstractNonlinearSolveCache && isdefined(cache, :retcode) && !success | ||
cache.retcode = ReturnCode.InitialFailure | ||
end | ||
|
||
return cache, success | ||
end | ||
|
||
function get_abstol(prob::AbstractNonlinearProblem) | ||
get_tolerance(get(prob.kwargs, :abstol, nothing), eltype(SII.state_values(prob))) | ||
end | ||
function get_reltol(prob::AbstractNonlinearProblem) | ||
get_tolerance(get(prob.kwargs, :reltol, nothing), eltype(SII.state_values(prob))) | ||
end | ||
|
||
initialization_alg(initprob, autodiff) = nothing | ||
|
||
function update_initial_values!(cache::AbstractNonlinearSolveCache, u0, p) | ||
InternalAPI.reinit!(cache; u0, p) | ||
cache.prob = SciMLBase.remake(cache.prob; u0, p) | ||
return cache | ||
end | ||
|
||
function update_initial_values!(prob::AbstractNonlinearProblem, u0, p) | ||
return SciMLBase.remake(prob; u0, p) | ||
end | ||
|
||
function _run_initialization!( | ||
cache::AbstractNonlinearSolveCache, ::SciMLBase.NoInit, prob, isinplace) | ||
return cache, true | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.