Skip to content

Commit

Permalink
Merge pull request #393 from SciML/dw/cumulative_integral
Browse files Browse the repository at this point in the history
Improve `cumulative_integral`
  • Loading branch information
ChrisRackauckas authored Feb 7, 2025
2 parents b983363 + 66f8b17 commit 5391a3c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,12 @@ function get_idx(A::AbstractInterpolation, t, iguess::Union{<:Integer, Guesser};
end
end

function cumulative_integral(A, cache_parameters)
if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number, Number})
integral_values = _integral.(
Ref(A), 1:(length(A.t) - 1), A.t[1:(end - 1)], A.t[2:end])
cumsum(integral_values)
else
promote_type(eltype(A.u), eltype(A.t))[]
end
cumulative_integral(::AbstractInterpolation, ::Bool) = nothing
function cumulative_integral(A::AbstractInterpolation{<:Number}, cache_parameters::Bool)
Base.require_one_based_indexing(A.u)
idxs = cache_parameters ? (1:(length(A.t) - 1)) : (1:0)
return cumsum(_integral(A, idx, t1, t2)
for (idx, t1, t2) in zip(idxs, @view(A.t[begin:(end - 1)]), @view(A.t[(begin + 1):end])))
end

function get_parameters(A::LinearInterpolation, idx)
Expand Down
20 changes: 20 additions & 0 deletions test/integral_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using DataInterpolations: integral
using Optim, ForwardDiff
using RegularizationTools
using StableRNGs
using Unitful

function test_integral(method; args = [], kwargs = [], name::String)
func = method(args...; kwargs..., extrapolation_left = ExtrapolationType.Extension,
Expand Down Expand Up @@ -213,3 +214,22 @@ end
@test_throws DataInterpolations.IntegralNotFoundError integral(A, 1.0, 100.0)
@test_throws DataInterpolations.IntegralNotFoundError integral(A, 50.0)
end

# issue #385
@testset "Integrals with unitful numbers" begin
u = rand(5)u"m"
A = ConstantInterpolation(u, (1:5)u"s")
@test @inferred(integral(A, 4u"s")) sum(u[1:3]) * u"s"
end

@testset "cumulative_integral" begin
A = ConstantInterpolation(["A", "B", "C"], [0.0, 0.25, 0.75])
for cache_parameter in (true, false)
@test @inferred(DataInterpolations.cumulative_integral(A, cache_parameter)) ===
nothing
end

A = ConstantInterpolation([3.1, 2.5, 4.7], [0.0, 0.25, 0.75])
@test @inferred(DataInterpolations.cumulative_integral(A, false)) == Float64[]
@test @inferred(DataInterpolations.cumulative_integral(A, true)) == [0.775, 2.025]
end

0 comments on commit 5391a3c

Please sign in to comment.