From 755fec85bc0c5401884352359ad36c65c6146db1 Mon Sep 17 00:00:00 2001 From: huiyuxie Date: Thu, 30 Mar 2023 17:12:18 -0400 Subject: [PATCH] more tests --- lib/SimpleNonlinearSolve/src/alefeld.jl | 8 +++++-- lib/SimpleNonlinearSolve/test/basictests.jl | 25 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/lib/SimpleNonlinearSolve/src/alefeld.jl b/lib/SimpleNonlinearSolve/src/alefeld.jl index 47b01d9ce..f468add9f 100644 --- a/lib/SimpleNonlinearSolve/src/alefeld.jl +++ b/lib/SimpleNonlinearSolve/src/alefeld.jl @@ -18,12 +18,16 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, c = a - (b - a) / (f(b) - f(a)) * f(a) fc = f(c) - if iszero(fc) + (a == c || b == c) && + return SciMLBase.build_solution(prob, alg, c, fc; + retcode = ReturnCode.FloatingPointLimit, + left = a, + right = b) + iszero(fc) && return SciMLBase.build_solution(prob, alg, c, fc; retcode = ReturnCode.Success, left = a, right = b) - end a, b, d = _bracket(f, a, b, c) e = zero(a) # Set e as 0 before iteration to avoid a non-value f(e) diff --git a/lib/SimpleNonlinearSolve/test/basictests.jl b/lib/SimpleNonlinearSolve/test/basictests.jl index 5b76ce15f..310ca0638 100644 --- a/lib/SimpleNonlinearSolve/test/basictests.jl +++ b/lib/SimpleNonlinearSolve/test/basictests.jl @@ -222,6 +222,18 @@ for p in 1.1:0.1:100.0 @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) end +f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0) +t = (p) -> [sqrt(p[2] / p[1])] +p = [0.9, 50.0] +g = function (p) + probN = IntervalNonlinearProblem{false}(f, tspan, p) + sol = solve(probN, Alefeld()) + return [sol.u] +end + +@test g(p) ≈ [sqrt(p[2] / p[1])] +@test ForwardDiff.jacobian(g, p) ≈ ForwardDiff.jacobian(t, p) + f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0) t = (p) -> [sqrt(p[2] / p[1])] p = [0.9, 50.0] @@ -288,6 +300,7 @@ probB = IntervalNonlinearProblem(f, tspan) sol = solve(probB, Falsi()) @test sol.left ≈ sqrt(2.0) +# Bisection sol = solve(probB, Bisection()) @test sol.left ≈ sqrt(2.0) @@ -315,6 +328,18 @@ probB = IntervalNonlinearProblem(f, tspan) sol = solve(probB, Brent()) @test sol.left ≈ sqrt(2.0) +# Alefeld +sol = solve(probB, Alefeld()) +@test sol.u ≈ sqrt(2.0) +tspan = (sqrt(2.0), 10.0) +probB = IntervalNonlinearProblem(f, tspan) +sol = solve(probB, Alefeld()) +@test sol.u ≈ sqrt(2.0) +tspan = (0.0, sqrt(2.0)) +probB = IntervalNonlinearProblem(f, tspan) +sol = solve(probB, Alefeld()) +@test sol.u ≈ sqrt(2.0) + # Garuntee Tests for Bisection f = function (u, p) if u < 2.0