From de6eb96517d6d584fc3d4e5f4024f1c33f714a0a Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 4 Dec 2024 00:12:35 +0800 Subject: [PATCH] Put DualAbstractNonlinearProblem solving in subpackages --- .../ext/NonlinearSolveBaseForwardDiffExt.jl | 19 ++------ .../src/NonlinearSolveBase.jl | 2 + lib/NonlinearSolveBase/src/common_defaults.jl | 9 ++++ .../src/NonlinearSolveFirstOrder.jl | 4 +- .../src/forward_diff.jl | 34 ++++++++++++++ lib/NonlinearSolveQuasiNewton/Project.toml | 6 +++ ...NonlinearSolveQuasiNewtonForwardDiffExt.jl | 47 +++++++++++++++++++ .../Project.toml | 6 +++ ...inearSolveSpectralMethodsForwardDiffExt.jl | 47 +++++++++++++++++++ src/NonlinearSolve.jl | 11 +---- src/forward_diff.jl | 44 +++++++++++++++++ 11 files changed, 205 insertions(+), 24 deletions(-) create mode 100644 lib/NonlinearSolveFirstOrder/src/forward_diff.jl create mode 100644 lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl create mode 100644 lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl create mode 100644 src/forward_diff.jl diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 6357549ec..203d06f14 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -12,12 +12,12 @@ using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, AbstractNonlinearSolveAlgorithm, Utils, InternalAPI, - AbstractNonlinearSolveCache + AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm const DI = DifferentiationInterface -const ALL_SOLVER_TYPES = [ - Nothing, AbstractNonlinearSolveAlgorithm +const GENERAL_SOLVER_TYPES = [ + Nothing, AbstractNonlinearSolveAlgorithm, NonlinearSolvePolyAlgorithm ] const DualNonlinearProblem = NonlinearProblem{ @@ -121,7 +121,7 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution( return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials))) end -for algType in ALL_SOLVER_TYPES +for algType in GENERAL_SOLVER_TYPES @eval function SciMLBase.__solve( prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... ) @@ -157,7 +157,7 @@ function InternalAPI.reinit!( return cache end -for algType in ALL_SOLVER_TYPES +for algType in GENERAL_SOLVER_TYPES @eval function SciMLBase.__init( prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... ) @@ -200,13 +200,4 @@ nodual_value(x) = x nodual_value(x::Dual) = ForwardDiff.value(x) nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) -""" - pickchunksize(x) = pickchunksize(length(x)) - pickchunksize(x::Int) - -Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. -""" -@inline pickchunksize(x) = pickchunksize(length(x)) -@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) - end diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 649ac79d2..8fd4b1947 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -84,4 +84,6 @@ export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogle export NonlinearSolvePolyAlgorithm +export pickchunksize + end diff --git a/lib/NonlinearSolveBase/src/common_defaults.jl b/lib/NonlinearSolveBase/src/common_defaults.jl index 4518063a5..5a5433ee3 100644 --- a/lib/NonlinearSolveBase/src/common_defaults.jl +++ b/lib/NonlinearSolveBase/src/common_defaults.jl @@ -45,3 +45,12 @@ function get_tolerance(::Union{StaticArray, Number}, ::Nothing, ::Type{T}) where # Rational numbers can throw an error if used inside GPU Kernels return T(real(oneunit(T)) * (eps(real(one(T)))^(real(T)(0.8)))) end + +""" + pickchunksize(x) = pickchunksize(length(x)) + pickchunksize(x::Int) + +Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. +""" +@inline pickchunksize(x) = pickchunksize(length(x)) +@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 1f480fb4b..15b99c5d1 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -29,7 +29,7 @@ using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode, using SciMLJacobianOperators: VecJacOperator, JacVecOperator, StatefulJacobianOperator using FiniteDiff: FiniteDiff # Default Finite Difference Method -using ForwardDiff: ForwardDiff # Default Forward Mode AD +using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD include("raphson.jl") include("gauss_newton.jl") @@ -41,6 +41,8 @@ include("poly_algs.jl") include("solve.jl") +include("forward_diff.jl") + @setup_workload begin nonlinear_functions = ( (NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1), diff --git a/lib/NonlinearSolveFirstOrder/src/forward_diff.jl b/lib/NonlinearSolveFirstOrder/src/forward_diff.jl new file mode 100644 index 000000000..86f4b072a --- /dev/null +++ b/lib/NonlinearSolveFirstOrder/src/forward_diff.jl @@ -0,0 +1,34 @@ +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs... +) + p = NonlinearSolveBase.nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) +end + +function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs... +) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end diff --git a/lib/NonlinearSolveQuasiNewton/Project.toml b/lib/NonlinearSolveQuasiNewton/Project.toml index 2f00863d8..4912e9070 100644 --- a/lib/NonlinearSolveQuasiNewton/Project.toml +++ b/lib/NonlinearSolveQuasiNewton/Project.toml @@ -18,6 +18,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +NonlinearSolveQuasiNewtonForwardDiffExt = "ForwardDiff" + [compat] ADTypes = "1.9.0" Aqua = "0.8" diff --git a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl new file mode 100644 index 000000000..afba60d43 --- /dev/null +++ b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl @@ -0,0 +1,47 @@ +module NonlinearSolveQuasiNewtonForwardDiffExt + +using CommonSolve: CommonSolve, solve +using ForwardDiff: ForwardDiff, Dual +using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + NonlinearProblem, NonlinearLeastSquaresProblem, remake + +using NonlinearSolveBase: NonlinearSolveBase + +using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs... +) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end + +function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs... +) + p = nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) +end + +end diff --git a/lib/NonlinearSolveSpectralMethods/Project.toml b/lib/NonlinearSolveSpectralMethods/Project.toml index bb9367554..7175c5ea9 100644 --- a/lib/NonlinearSolveSpectralMethods/Project.toml +++ b/lib/NonlinearSolveSpectralMethods/Project.toml @@ -14,6 +14,12 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +NonlinearSolveSpectralMethodsForwardDiffExt = "ForwardDiff" + [compat] Aqua = "0.8" BenchmarkTools = "1.5.0" diff --git a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl new file mode 100644 index 000000000..86604d7e2 --- /dev/null +++ b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl @@ -0,0 +1,47 @@ +module NonlinearSolveSpectralMethodsForwardDiffExt + +using CommonSolve: CommonSolve, solve +using ForwardDiff: ForwardDiff, Dual +using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + NonlinearProblem, NonlinearLeastSquaresProblem, remake + +using NonlinearSolveBase: NonlinearSolveBase + +using NonlinearSolveSpectralMethods: GeneralizedDFSane + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs... +) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end + +function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs... +) + p = nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) +end + +end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index c6fcc1f12..a1b759011 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -14,7 +14,7 @@ using LineSearch: BackTracking using NonlinearSolveBase: NonlinearSolveBase, InternalAPI, AbstractNonlinearSolveAlgorithm, AbstractNonlinearSolveCache, Utils, L2_NORM, enable_timer_outputs, disable_timer_outputs, - NonlinearSolvePolyAlgorithm + NonlinearSolvePolyAlgorithm, pickchunksize using Preferences: set_preferences! using SciMLBase: SciMLBase, NLStats, ReturnCode, AbstractNonlinearProblem, @@ -53,14 +53,7 @@ include("extension_algs.jl") include("default.jl") -const ALL_SOLVER_TYPES = [ - Nothing, AbstractNonlinearSolveAlgorithm, - GeneralizedDFSane, GeneralizedFirstOrderAlgorithm, QuasiNewtonAlgorithm, - LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL, - SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL, - CMINPACK, PETScSNES, - NonlinearSolvePolyAlgorithm -] +include("forward_diff.jl") @setup_workload begin nonlinear_functions = ( diff --git a/src/forward_diff.jl b/src/forward_diff.jl new file mode 100644 index 000000000..76fdf6f52 --- /dev/null +++ b/src/forward_diff.jl @@ -0,0 +1,44 @@ +const EXTENSION_SOLVER_TYPES = [ + LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL, + SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL, + CMINPACK, PETScSNES +] + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +for algType in EXTENSION_SOLVER_TYPES + @eval function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + p = NonlinearSolveBase.nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) + end +end + +for algType in EXTENSION_SOLVER_TYPES + @eval function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) + end +end