Skip to content

Commit

Permalink
feat: EnsembleStatistics
Browse files Browse the repository at this point in the history
  • Loading branch information
oameye committed Jan 1, 2025
1 parent bc2572f commit 027a46e
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 47 deletions.
1 change: 1 addition & 0 deletions src/CriticalTransitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ include("extension_functions.jl")
include("utils.jl")
include("sde_utils.jl")

include("trajectories/TransitionEnsemble.jl")
include("trajectories/simulation.jl")
include("trajectories/transition.jl")

Expand Down
55 changes: 55 additions & 0 deletions src/trajectories/TransitionEnsemble.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
struct TransitionStatistics{T}
success_rate::T
residence_time::T
transition_time::T
rareness::T

function TransitionStatistics(sim, success_rate)
mean_res_time = mean([sol.t[1] for sol in sim])
mean_trans_time = mean([(sol.t[end] - sol.t[1]) for sol in sim])

return new{typeof(mean_res_time)}(
success_rate, mean_res_time, mean_trans_time, mean_res_time / mean_trans_time
)
end
end;


"""
$(TYPEDEF)
Ensemble of transition paths between two points in a state space.
# Fields
$(TYPEDFIELDS)
# Constructors
$(METHODLIST)
"""
struct TransitionEnsemble{SSS, T, ES}
paths::Vector{SSS}
times::Vector{Vector{T}}
stats::TransitionStatistics{T}
sciml_ensemble::ES

function TransitionEnsemble(sim, success_rate)
stats = TransitionStatistics(sim, success_rate)

samples = [StateSpaceSet(sol.u) for sol in sim]
times = [sol.t for sol in sim]

return new{eltype(samples),eltype(eltype(times)),typeof(sim)}(samples, times, stats, sim)
end
end;

function prettyprint(te::TransitionEnsemble)
ts = te.stats
return "Transition path ensemble of $(length(te.times)) samples
- sampling success rate: $(round(ts.success_rate, digits=3))
- mean residence time: $(round(ts.residence_time, digits=3))
- mean transition time: $(round(ts.transition_time, digits=3))
- rareness: $(round(ts.rareness, digits=1))"
end

Base.show(io::IO, te::TransitionEnsemble) = print(io, prettyprint(te))
44 changes: 3 additions & 41 deletions src/trajectories/transition.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,3 @@
"""
$(TYPEDEF)
Ensemble of transition paths between two points in a state space.
# Fields
$(TYPEDFIELDS)
# Constructors
$(METHODLIST)
"""
struct TransitionEnsemble{SSS,T,Tstat,ES}
paths::Vector{SSS}
times::Vector{T}
success_rate::Tstat
residence_time::Tstat
transition_time::Tstat
sciml_ensemble::ES
end;

function prettyprint(te::TransitionEnsemble)
return "Transition path ensemble of $(length(te.times)) samples
- sampling success rate: $(round(te.success_rate, digits=3))
- mean residence time: $(round(te.residence_time, digits=3))
- mean transition time: $(round(te.transition_time, digits=3))
- normalized transition rate: $(round(te.residence_time/te.transition_time, digits=1))"
end

Base.show(io::IO, te::TransitionEnsemble) = print(io, prettyprint(te))

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -159,23 +128,16 @@ function transitions(
end
return (sol, rerun)
end

seed = sys.integ.sol.prob.seed
function prob_func(prob, i, repeat)
return remake(prob; seed=rand(Random.MersenneTwister(seed + i + repeat), UInt32))
end

ensemble = EnsembleProblem(prob; output_func=output_func, prob_func=prob_func)
sim = solve(
ensemble, solver(sys), EnsembleAlg; callback=cb_ball, trajectories=N, kwargs...
)

success_rate = success / tries
mean_res_time = mean([sol.t[1] for sol in sim])
mean_trans_time = mean([(sol.t[end] - sol.t[1]) for sol in sim])

samples = [StateSpaceSet(sol.u) for sol in sim]
times = [sol.t for sol in sim]

return TransitionEnsemble(
samples, times, success_rate, mean_res_time, mean_trans_time, sim
)
return TransitionEnsemble(sim, success / tries)
end;
13 changes: 7 additions & 6 deletions test/trajectories/transition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
@test norm(tr[end, :] - fp2) < 0.1

ensemble = transitions(sys, fp1, fp2, 10)
@test isapprox(ensemble.success_rate, 0.833; atol=1e-2) ||
isapprox(ensemble.success_rate, 0.909; atol=1e-2)
@test isapprox(ensemble.transition_time, 5.213; atol=1e-2) ||
isapprox(ensemble.transition_time, 5.6512; atol=1e-2)
stats = ensemble.stats
@test isapprox(stats.success_rate, 0.833; atol=1e-2) ||
isapprox(stats.success_rate, 0.909; atol=1e-2)
@test isapprox(stats.transition_time, 5.213; atol=1e-2) ||
isapprox(stats.transition_time, 5.6512; atol=1e-2)
# SEED is different on github
# SEED doesn;t work on github
@test length(ensemble.times) == 10
@test isapprox(ensemble.residence_time, 346.5424; atol=1e-2) ||
isapprox(ensemble.residence_time, 177.70; atol=1e-2)
@test isapprox(stats.residence_time, 346.5424; atol=1e-2) ||
isapprox(stats.residence_time, 177.70; atol=1e-2)
end

0 comments on commit 027a46e

Please sign in to comment.