From 5b2369a5b19541bde48397f1d8ffe14815aeb12f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 7 Feb 2025 10:24:32 +0100 Subject: [PATCH 1/2] Improve `cumulative_integral` --- src/interpolation_utils.jl | 14 ++++++-------- test/integral_tests.jl | 12 ++++++++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 37ca265c..062187a0 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -189,14 +189,12 @@ function get_idx(A::AbstractInterpolation, t, iguess::Union{<:Integer, Guesser}; end end -function cumulative_integral(A, cache_parameters) - if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number, Number}) - integral_values = _integral.( - Ref(A), 1:(length(A.t) - 1), A.t[1:(end - 1)], A.t[2:end]) - cumsum(integral_values) - else - promote_type(eltype(A.u), eltype(A.t))[] - end +cumulative_integral(::AbstractInterpolation, ::Bool) = nothing +function cumulative_integral(A::AbstractInterpolation{<:Number}, cache_parameters::Bool) + Base.require_one_based_indexing(A.u) + idxs = cache_parameters ? (1:(length(A.t) - 1)) : (1:0) + return cumsum(_integral(A, idx, t1, t2) + for (idx, t1, t2) in zip(idxs, @view(A.t[begin:(end - 1)]), @view(A.t[(begin + 1):end]))) end function get_parameters(A::LinearInterpolation, idx) diff --git a/test/integral_tests.jl b/test/integral_tests.jl index ab1ca892..f077e0b8 100644 --- a/test/integral_tests.jl +++ b/test/integral_tests.jl @@ -213,3 +213,15 @@ end @test_throws DataInterpolations.IntegralNotFoundError integral(A, 1.0, 100.0) @test_throws DataInterpolations.IntegralNotFoundError integral(A, 50.0) end + +@testset "cumulative_integral" begin + A = ConstantInterpolation(["A", "B", "C"], [0.0, 0.25, 0.75]) + for cache_parameter in (true, false) + @test @inferred(DataInterpolations.cumulative_integral(A, cache_parameter)) === + nothing + end + + A = ConstantInterpolation([3.1, 2.5, 4.7], [0.0, 0.25, 0.75]) + @test @inferred(DataInterpolations.cumulative_integral(A, false)) == Float64[] + @test @inferred(DataInterpolations.cumulative_integral(A, true)) == [0.775, 2.025] +end From 66f8b1767bce43e10362aaea01e80b22398b9f24 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 7 Feb 2025 12:12:10 +0100 Subject: [PATCH 2/2] Add test for #385 --- test/integral_tests.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/integral_tests.jl b/test/integral_tests.jl index f077e0b8..f9a99d48 100644 --- a/test/integral_tests.jl +++ b/test/integral_tests.jl @@ -4,6 +4,7 @@ using DataInterpolations: integral using Optim, ForwardDiff using RegularizationTools using StableRNGs +using Unitful function test_integral(method; args = [], kwargs = [], name::String) func = method(args...; kwargs..., extrapolation_left = ExtrapolationType.Extension, @@ -214,6 +215,13 @@ end @test_throws DataInterpolations.IntegralNotFoundError integral(A, 50.0) end +# issue #385 +@testset "Integrals with unitful numbers" begin + u = rand(5)u"m" + A = ConstantInterpolation(u, (1:5)u"s") + @test @inferred(integral(A, 4u"s")) ≈ sum(u[1:3]) * u"s" +end + @testset "cumulative_integral" begin A = ConstantInterpolation(["A", "B", "C"], [0.0, 0.25, 0.75]) for cache_parameter in (true, false)