Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concrete return type of BMI.get_value_ptr #2018

Merged
merged 2 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions core/src/bmi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,31 @@ function BMI.update_until(model::Model, time::Float64)::Nothing
return nothing
end

function BMI.get_value_ptr(model::Model, name::AbstractString)::AbstractVector{Float64}
"""
BMI.get_value_ptr(model::Model, name::String)::Vector{Float64}

This uses a typeassert to ensure that the return type annotation doesn't create a copy.
"""
function BMI.get_value_ptr(model::Model, name::String)::Vector{Float64}
(; u, p) = model.integrator
if name == "basin.storage"
p.basin.current_properties.current_storage[parent(u)]
p.basin.current_properties.current_storage[parent(u)]::Vector{Float64}
elseif name == "basin.level"
p.basin.current_properties.current_level[parent(u)]
p.basin.current_properties.current_level[parent(u)]::Vector{Float64}
elseif name == "basin.infiltration"
p.basin.vertical_flux.infiltration
p.basin.vertical_flux.infiltration::Vector{Float64}
elseif name == "basin.drainage"
p.basin.vertical_flux.drainage
p.basin.vertical_flux.drainage::Vector{Float64}
elseif name == "basin.cumulative_infiltration"
u.infiltration
unsafe_array(u.infiltration)::Vector{Float64}
elseif name == "basin.cumulative_drainage"
p.basin.cumulative_drainage
p.basin.cumulative_drainage::Vector{Float64}
elseif name == "basin.subgrid_level"
p.subgrid.level
p.subgrid.level::Vector{Float64}
elseif name == "user_demand.demand"
vec(p.user_demand.demand)
vec(p.user_demand.demand)::Vector{Float64}
elseif name == "user_demand.cumulative_inflow"
u.user_demand_inflow
unsafe_array(u.user_demand_inflow)::Vector{Float64}
else
error("Unknown variable $name")
end
Expand Down
30 changes: 19 additions & 11 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,28 +391,36 @@ the length of each Vector is the number of Basins.
infiltration::Vector{ScalarConstantInterpolation} = ScalarConstantInterpolation[]
end

"""Current values of the vertical fluxes in a Basin, per node ID.
evetion marked this conversation as resolved.
Show resolved Hide resolved

Current forcing is stored as separate array for BMI access.
These are updated from BasinForcing at runtime.
"""
@kwdef struct VerticalFlux
precipitation::Vector{Float64}
potential_evaporation::Vector{Float64}
drainage::Vector{Float64}
infiltration::Vector{Float64}
end

VerticalFlux(n::Int) = VerticalFlux(zeros(n), zeros(n), zeros(n), zeros(n))

"""
Requirements:

* Must be positive: precipitation, evaporation, infiltration, drainage
* Index points to a Basin
* volume, area, level must all be positive and monotonic increasing.

Type parameter C indicates the content backing the StructVector, which can be a NamedTuple
Type parameter D indicates the content backing the StructVector, which can be a NamedTuple
of vectors or Arrow Tables, and is added to avoid type instabilities.

if autodiff
T = DiffCache{Vector{Float64}}
else
T = Vector{Float64}
end
"""
@kwdef struct Basin{V, CD, D} <: AbstractParameterNode
@kwdef struct Basin{CD, D} <: AbstractParameterNode
node_id::Vector{NodeID}
inflow_ids::Vector{Vector{NodeID}} = [NodeID[]]
outflow_ids::Vector{Vector{NodeID}} = [NodeID[]]
# Vertical fluxes
vertical_flux::V = zeros(length(node_id))
vertical_flux::VerticalFlux = VerticalFlux(length(node_id))
# Initial_storage
storage0::Vector{Float64} = zeros(length(node_id))
# Storage at previous saveat without storage0
Expand Down Expand Up @@ -946,11 +954,11 @@ const ModelGraph = MetaGraph{
Float64,
}

@kwdef mutable struct Parameters{C1, C3, C4, C6, C7, C8, C9, C10, C11}
@kwdef mutable struct Parameters{C3, C4, C6, C7, C8, C9, C10, C11}
const starttime::DateTime
const graph::ModelGraph
const allocation::Allocation
const basin::Basin{C1, C3, C4}
const basin::Basin{C3, C4}
const linear_resistance::LinearResistance
const manning_resistance::ManningResistance
const tabulated_rating_curve::TabulatedRatingCurve
Expand Down
12 changes: 0 additions & 12 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -745,16 +745,6 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
parsed_parameters.infiltration,
)

# Current forcing is stored as separate array for BMI access
# These are updated from the interpolation objects at runtime
n = length(node_id)
vertical_flux = ComponentVector(;
precipitation = zeros(n),
potential_evaporation = zeros(n),
drainage = zeros(n),
infiltration = zeros(n),
)

# Profiles
area, level = create_storage_tables(db, config)

Expand All @@ -776,7 +766,6 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
node_id,
inflow_ids = [collect(inflow_ids(graph, id)) for id in node_id],
outflow_ids = [collect(outflow_ids(graph, id)) for id in node_id],
vertical_flux,
storage_to_level,
level_to_area,
forcing,
Expand All @@ -788,7 +777,6 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
update_basin!(basin, 0.0)

storage0 = get_storages_from_levels(basin, state.level)
@assert length(storage0) == n "Basin / state length differs from number of Basins"
basin.storage0 .= storage0
basin.storage_prev .= storage0
basin.concentration_data.mass .*= storage0 # was initialized by concentration_state, resulting in mass
Expand Down
13 changes: 13 additions & 0 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1120,3 +1120,16 @@ end

source_edges_subnetwork(p::Parameters, subnetwork_id::Int32) =
keys(mean_input_flows_subnetwork(p, subnetwork_id))

"""
Wrap the data of a SubArray into a Vector.

This function is labeled unsafe because it will crash if pointer is not a valid memory
address to data of the requested length, and it will not prevent the input array A from
being freed.
"""
function unsafe_array(
A::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true},
)::Vector{Float64}
GC.@preserve A unsafe_wrap(Array, pointer(A), length(A))
end
12 changes: 12 additions & 0 deletions core/test/utils_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,15 @@ end
)
end
end

@testitem "unsafe_array" begin
using ComponentArrays: ComponentVector
x = ComponentVector(; a = [1.0, 2.0, 3.0], b = [4.0, 5.0, 6.0])
y = Ribasim.unsafe_array(x.b)
@test x.b isa SubArray
@test y isa Vector{Float64}
@test y == x.b
# changing the input changes the output; no data copy is made
x.b[2] = 10.0
@test y[2] === 10.0
end
Loading