Skip to content

Commit

Permalink
Sorting things out (#2003)
Browse files Browse the repository at this point in the history
Fixes #1976.
Fixes most of #601.
  • Loading branch information
visr authored Jan 13, 2025
1 parent 1ad402c commit f074e0c
Show file tree
Hide file tree
Showing 27 changed files with 373 additions and 363 deletions.
34 changes: 0 additions & 34 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,6 @@ function create_callbacks(
end
end

# Update TabulatedRatingCurve Q(h) relationships
tstops = get_tstops(tabulated_rating_curve.time.time, starttime)
tabulated_rating_curve_cb = PresetTimeCallback(
tstops,
update_tabulated_rating_curve!;
save_positions = (false, false),
)
push!(callbacks, tabulated_rating_curve_cb)

# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat
Expand Down Expand Up @@ -834,31 +825,6 @@ function update_allocation!(integrator)::Nothing
end
end

"Load updates from 'TabulatedRatingCurve / time' into the parameters"
function update_tabulated_rating_curve!(integrator)::Nothing
(; node_id, table, time) = integrator.p.tabulated_rating_curve
t = datetime_since(integrator.t, integrator.p.starttime)

# get groups of consecutive node_id for the current timestamp
rows = searchsorted(time.time, t)
timeblock = view(time, rows)

for group in IterTools.groupby(row -> row.node_id, timeblock)
# update the existing LinearInterpolation
id = first(group).node_id
level = [row.level for row in group]
flow_rate = [row.flow_rate for row in group]
i = searchsortedfirst(node_id, NodeID(NodeType.TabulatedRatingCurve, id, 0))
table[i] = LinearInterpolation(
flow_rate,
level;
extrapolate = true,
cache_parameters = true,
)
end
return nothing
end

function update_subgrid_level(model::Model)::Model
update_subgrid_level!(model.integrator)
return model
Expand Down
48 changes: 21 additions & 27 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,18 @@ end
Base.Int32(id::NodeID) = id.value
Base.convert(::Type{Int32}, id::NodeID) = id.value
Base.broadcastable(id::NodeID) = Ref(id)
Base.:(==)(id_1::NodeID, id_2::NodeID) = id_1.type == id_2.type && id_1.value == id_2.value
Base.show(io::IO, id::NodeID) = print(io, id.type, " #", id.value)
config.snake_case(id::NodeID) = config.snake_case(id.type)
Base.to_index(id::NodeID) = Int(id.value)

function Base.isless(id_1::NodeID, id_2::NodeID)::Bool
if id_1.type != id_2.type
error("Cannot compare NodeIDs of different types")
end
return id_1.value < id_2.value
end
# Compare only by value for working with a mix of integers from tables and processed NodeIDs
Base.:(==)(id_1::NodeID, id_2::NodeID) = id_1.value == id_2.value
Base.:(==)(id_1::Integer, id_2::NodeID) = id_1 == id_2.value
Base.:(==)(id_1::NodeID, id_2::Integer) = id_1.value == id_2

Base.to_index(id::NodeID) = Int(id.value)
Base.isless(id_1::NodeID, id_2::NodeID)::Bool = id_1.value < id_2.value
Base.isless(id_1::Integer, id_2::NodeID)::Bool = id_1 < id_2.value
Base.isless(id_1::NodeID, id_2::Integer)::Bool = id_1.value < id_2

"LinearInterpolation from a Float64 to a Float64"
const ScalarInterpolation = LinearInterpolation{
Expand Down Expand Up @@ -275,7 +275,7 @@ end
Base.length(::EdgeMetadata) = 1

"""
The update of an parameter given by a value and a reference to the target
The update of a parameter given by a value and a reference to the target
location of the variable in memory
"""
struct ParameterUpdate{T}
Expand All @@ -289,13 +289,12 @@ function ParameterUpdate(name::Symbol, value::T)::ParameterUpdate{T} where {T}
end

"""
The parameter update associated with a certain control state
for discrete control
The parameter update associated with a certain control state for discrete control
"""
@kwdef struct ControlStateUpdate
@kwdef struct ControlStateUpdate{T <: AbstractInterpolation}
active::ParameterUpdate{Bool}
scalar_update::Vector{ParameterUpdate{Float64}} = []
itp_update::Vector{ParameterUpdate{ScalarInterpolation}} = []
itp_update::Vector{ParameterUpdate{T}} = ParameterUpdate{ScalarInterpolation}[]
end

"""
Expand Down Expand Up @@ -434,15 +433,10 @@ end
end

"""
struct TabulatedRatingCurve{C}
struct TabulatedRatingCurve
Rating curve from level to flow rate. The rating curve is a lookup table with linear
interpolation in between. Relation can be updated in time, which is done by moving data from
the `time` field into the `tables`, which is done in the `update_tabulated_rating_curve`
callback.
Type parameter C indicates the content backing the StructVector, which can be a NamedTuple
of Vectors or Arrow Primitives, and is added to avoid type instabilities.
interpolation in between. Relations can be updated in time.
node_id: node ID of the TabulatedRatingCurve node
inflow_edge: incoming flow edge metadata
Expand All @@ -451,18 +445,18 @@ outflow_edge: outgoing flow edge metadata
The ID of the source node is always the ID of the TabulatedRatingCurve node
active: whether this node is active and thus contributes flows
max_downstream_level: The downstream level above which the TabulatedRatingCurve flow goes to zero
table: The current Q(h) relationships
time: The time table used for updating the tables
interpolations: All Q(h) relationships for the nodes over time
current_interpolation_index: Per node 1 lookup from t to an index in `interpolations`
control_mapping: dictionary from (node_id, control_state) to Q(h) and/or active state
"""
@kwdef struct TabulatedRatingCurve{C} <: AbstractParameterNode
@kwdef struct TabulatedRatingCurve <: AbstractParameterNode
node_id::Vector{NodeID}
inflow_edge::Vector{EdgeMetadata}
outflow_edge::Vector{EdgeMetadata}
active::Vector{Bool}
max_downstream_level::Vector{Float64} = fill(Inf, length(node_id))
table::Vector{ScalarInterpolation}
time::StructVector{TabulatedRatingCurveTimeV1, C, Int}
interpolations::Vector{ScalarInterpolation}
current_interpolation_index::Vector{IndexLookup}
control_mapping::Dict{Tuple{NodeID, String}, ControlStateUpdate}
end

Expand Down Expand Up @@ -935,14 +929,14 @@ const ModelGraph = MetaGraph{
Float64,
}

@kwdef mutable struct Parameters{C1, C2, C3, C4, C5, C6, C7, C8, C9, C10, C11}
@kwdef mutable struct Parameters{C1, C2, C3, C4, C6, C7, C8, C9, C10, C11}
const starttime::DateTime
const graph::ModelGraph
const allocation::Allocation
const basin::Basin{C1, C2, C3, C4}
const linear_resistance::LinearResistance
const manning_resistance::ManningResistance
const tabulated_rating_curve::TabulatedRatingCurve{C5}
const tabulated_rating_curve::TabulatedRatingCurve
const level_boundary::LevelBoundary{C6}
const flow_boundary::FlowBoundary{C7}
const pump::Pump
Expand Down
143 changes: 89 additions & 54 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,14 @@ function parse_static_and_time(
)
end
elseif node_id in time_node_ids
# TODO replace (time, node_id) order by (node_id, time)
# this fits our access pattern better, so we can use views
idx = findall(==(node_id), time_node_id_vec)
time_subset = time[idx]

time_first_idx = searchsortedfirst(time_node_id_vec[idx], node_id)

time_first_idx = searchsortedfirst(time.node_id, node_id)
for parameter_name in parameter_names
# If the parameter is interpolatable, create an interpolation object
if parameter_name in time_interpolatables
val, is_valid = get_scalar_interpolation(
config.starttime,
t_end,
time_subset,
time,
node_id,
parameter_name;
default_value = hasproperty(defaults, parameter_name) ?
Expand All @@ -174,7 +168,7 @@ function parse_static_and_time(
val = true
else
# If the parameter is not interpolatable, get the instance in the first row
val = getfield(time_subset[time_first_idx], parameter_name)
val = getfield(time[time_first_idx], parameter_name)
end
end
getfield(out, parameter_name)[node_id.idx] = val
Expand Down Expand Up @@ -300,75 +294,98 @@ function TabulatedRatingCurve(
static_node_ids, time_node_ids, node_ids, valid =
static_and_time_node_ids(db, static, time, NodeType.TabulatedRatingCurve)

if !valid
error(
"Problems encountered when parsing TabulatedRatingcurve static and time node IDs.",
)
end
valid || error(
"Problems encountered when parsing TabulatedRatingcurve static and time node IDs.",
)

interpolations = ScalarInterpolation[]
current_interpolation_index = IndexLookup[]
interpolation_index = 0
control_mapping = Dict{Tuple{NodeID, String}, ControlStateUpdate}()
active = Bool[]
max_downstream_level = Float64[]
errors = false

local is_active, interpolation, max_level

qh_iterator = IterTools.groupby(row -> (row.node_id, row.time), time)
state = nothing # initial iterator state

for node_id in node_ids
if node_id in static_node_ids
# Loop over all static rating curves (groups) with this node_id.
# If it has a control_state add it to control_mapping.
# The last rating curve forms the initial condition and activity.
source = "static"
# For the static case the interpolation index does not depend on time,
# but it can be changed by DiscreteControl. For simplicity we do create an
# index lookup that doesn't change with time just like the dynamic case.
# DiscreteControl will then change this lookup object.
rows = searchsorted(
NodeID.(NodeType.TabulatedRatingCurve, static.node_id, node_id.idx),
node_id,
)
static_id = view(static, rows)
local is_active, interpolation
# coalesce control_state to nothing to avoid boolean groupby logic on missing
for group in
for qh_group in
IterTools.groupby(row -> coalesce(row.control_state, nothing), static_id)
control_state = first(group).control_state
is_active = coalesce(first(group).active, true)
max_level = coalesce(first(group).max_downstream_level, Inf)
table = StructVector(group)
rowrange =
findlastgroup(node_id, NodeID.(node_id.type, table.node_id, Ref(0)))
if !valid_tabulated_rating_curve(node_id, table, rowrange)
errors = true
end
interpolation = try
qh_interpolation(table, rowrange)
catch
LinearInterpolation(Float64[], Float64[])
end
interpolation_index += 1
first_row = first(qh_group)
control_state = first_row.control_state
is_active = coalesce(first_row.active, true)
max_level = coalesce(first_row.max_downstream_level, Inf)
qh_table = StructVector(qh_group)
interpolation =
qh_interpolation(node_id, qh_table.level, qh_table.flow_rate)
if !ismissing(control_state)
control_mapping[(
NodeID(NodeType.TabulatedRatingCurve, node_id, node_id.idx),
control_state,
)] = ControlStateUpdate(
# let control swap out the static lookup object
index_lookup = static_lookup(interpolation_index)
control_mapping[(node_id, control_state)] = ControlStateUpdate(
ParameterUpdate(:active, is_active),
ParameterUpdate{Float64}[],
[ParameterUpdate(:table, interpolation)],
[ParameterUpdate(:current_interpolation_index, index_lookup)],
)
end
push!(interpolations, interpolation)
end
push!(interpolations, interpolation)
push_lookup!(current_interpolation_index, interpolation_index)
push!(active, is_active)
push!(max_downstream_level, max_level)
elseif node_id in time_node_ids
source = "time"
# get the timestamp that applies to the model starttime
idx_starttime = searchsortedlast(time.time, config.starttime)
pre_table = view(time, 1:idx_starttime)
rowrange =
findlastgroup(node_id, NodeID.(node_id.type, pre_table.node_id, Ref(0)))

if !valid_tabulated_rating_curve(node_id, pre_table, rowrange)
errors = true
lookup_time = Float64[]
lookup_index = Int[]
while true
val_state = iterate(qh_iterator, state)
if val_state === nothing
# end of table
break
end
qh_group, new_state = val_state

first_row = first(qh_group)
group_node_id = first_row.node_id
# max_level just document that it doesn't work and use the first or last
max_level = coalesce(first_row.max_downstream_level, Inf)
t = seconds_since(first_row.time, config.starttime)

qh_table = StructVector(qh_group)
if group_node_id == node_id
# continue iterator
state = new_state

interpolation =
qh_interpolation(node_id, qh_table.level, qh_table.flow_rate)

interpolation_index += 1
push!(interpolations, interpolation)
push!(lookup_index, interpolation_index)
push!(lookup_time, t)
else
# end of group, new timeseries for different node has started,
# don't accept the new state
break
end
end
interpolation = qh_interpolation(pre_table, rowrange)
max_level = coalesce(pre_table.max_downstream_level[rowrange][begin], Inf)
push!(interpolations, interpolation)
push_lookup!(current_interpolation_index, lookup_index, lookup_time)
push!(active, true)
push!(max_downstream_level, max_level)
else
Expand All @@ -377,17 +394,16 @@ function TabulatedRatingCurve(
end
end

if errors
error("Errors occurred when parsing TabulatedRatingCurve data.")
end
errors && error("Errors occurred when parsing TabulatedRatingCurve data.")

return TabulatedRatingCurve(;
node_id = node_ids,
inflow_edge = inflow_edge.(Ref(graph), node_ids),
outflow_edge = outflow_edge.(Ref(graph), node_ids),
active,
max_downstream_level,
table = interpolations,
time,
interpolations,
current_interpolation_index,
control_mapping,
)
end
Expand Down Expand Up @@ -1238,6 +1254,7 @@ function FlowDemand(db::DB, config::Config)::FlowDemand
)
end

"Create and push a ConstantInterpolation to the current_interpolation_index."
function push_lookup!(
current_interpolation_index::Vector{IndexLookup},
lookup_index::Vector{Int},
Expand All @@ -1252,6 +1269,24 @@ function push_lookup!(
push!(current_interpolation_index, index_lookup)
end

"Create and push a static ConstantInterpolation to the current_interpolation_index."
function push_lookup!(current_interpolation_index::Vector{IndexLookup}, lookup_index::Int)
index_lookup = static_lookup(lookup_index)
push!(current_interpolation_index, index_lookup)
end

"Create an interpolation object that always returns `lookup_index`."
function static_lookup(lookup_index::Int)::IndexLookup
# TODO if https://github.com/SciML/DataInterpolations.jl/issues/373 is fixed,
# make these size 1 vectors, and remove `unique` from `valid_tabulated_curve_level`
return ConstantInterpolation(
[lookup_index, lookup_index],
[0.0, 0.0];
extrapolate = true,
cache_parameters = true,
)
end

function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
time = load_structvector(db, config, BasinSubgridTimeV1)
static = load_structvector(db, config, BasinSubgridV1)
Expand Down
Loading

0 comments on commit f074e0c

Please sign in to comment.