From de9b5274a1ee8c6bcf5eae1288975e9ade1d9fd8 Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Sat, 21 Dec 2024 08:55:14 +0100 Subject: [PATCH] Add generated, performant snake_case for NodeID. (#1982) Fixes #1981 This precalculates all snakecases for NodeType and compiles it. @visr could you benchmark this? --------- Co-authored-by: Martijn Visser --- core/src/callback.jl | 9 +++++---- core/src/concentration.jl | 10 +++++----- core/src/parameter.jl | 12 ++++++++++++ core/src/util.jl | 6 +++--- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/core/src/callback.jl b/core/src/callback.jl index d332d7a00..90a492e83 100644 --- a/core/src/callback.jl +++ b/core/src/callback.jl @@ -190,15 +190,15 @@ function update_concentrations!(u, t, integrator)::Nothing # of the basins after processing inflows only cumulative_in .= 0.0 - mass .+= concentration[1, :, :] .* vertical_flux.drainage * dt + @views mass .+= concentration[1, :, :] .* vertical_flux.drainage * dt basin.concentration_data.cumulative_in .= vertical_flux.drainage * dt # Precipitation depends on fixed area for node_id in basin.node_id fixed_area = basin_areas(basin, node_id.idx)[end] added_precipitation = fixed_area * vertical_flux.precipitation[node_id.idx] * dt - - mass[node_id.idx, :] .+= concentration[2, node_id.idx, :] .* added_precipitation + @views mass[node_id.idx, :] .+= + concentration[2, node_id.idx, :] .* added_precipitation cumulative_in[node_id.idx] += added_precipitation end @@ -212,7 +212,8 @@ function update_concentrations!(u, t, integrator)::Nothing if active outflow_id = edge[1].edge[2] volume = integral(flow_rate, tprev, t) - mass[outflow_id.idx, :] .+= flow_boundary.concentration[id.idx, :] .* volume + @views mass[outflow_id.idx, :] .+= + flow_boundary.concentration[id.idx, :] .* volume cumulative_in[outflow_id.idx] += volume end end diff --git a/core/src/concentration.jl b/core/src/concentration.jl index 12fb4dd2b..2ab3043a7 100644 --- a/core/src/concentration.jl +++ b/core/src/concentration.jl @@ -6,8 +6,8 @@ function mass_updates_user_demand!(integrator::DEIntegrator)::Nothing (; basin, user_demand) = integrator.p (; concentration_state, mass) = basin.concentration_data - for (inflow_edge, outflow_edge) in - zip(user_demand.inflow_edge, user_demand.outflow_edge) + @views for (inflow_edge, outflow_edge) in + zip(user_demand.inflow_edge, user_demand.outflow_edge) from_node = inflow_edge.edge[1] to_node = outflow_edge.edge[2] userdemand_idx = outflow_edge.edge[1].idx @@ -41,7 +41,7 @@ function mass_inflows_basin!(integrator::DEIntegrator)::Nothing for (inflow_edge, outflow_edge) in zip(state_inflow_edge, state_outflow_edge) from_node = inflow_edge.edge[1] to_node = outflow_edge.edge[2] - if from_node.type == NodeType.Basin + @views if from_node.type == NodeType.Basin flow = flow_update_on_edge(integrator, inflow_edge.edge) if flow < 0 cumulative_in[from_node.idx] -= flow @@ -67,7 +67,7 @@ function mass_inflows_basin!(integrator::DEIntegrator)::Nothing flow = flow_update_on_edge(integrator, outflow_edge.edge) if flow > 0 cumulative_in[to_node.idx] += flow - if from_node.type == NodeType.Basin + @views if from_node.type == NodeType.Basin mass[to_node.idx, :] .+= concentration_state[from_node.idx, :] .* flow elseif from_node.type == NodeType.LevelBoundary mass[to_node.idx, :] .+= @@ -95,7 +95,7 @@ function mass_outflows_basin!(integrator::DEIntegrator)::Nothing (; state_inflow_edge, state_outflow_edge, basin) = integrator.p (; mass, concentration_state) = basin.concentration_data - for (inflow_edge, outflow_edge) in zip(state_inflow_edge, state_outflow_edge) + @views for (inflow_edge, outflow_edge) in zip(state_inflow_edge, state_outflow_edge) from_node = inflow_edge.edge[1] to_node = outflow_edge.edge[2] if from_node.type == NodeType.Basin diff --git a/core/src/parameter.jl b/core/src/parameter.jl index 95c8fd3f0..7dc0910d2 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -21,6 +21,17 @@ const SolverStats = @NamedTuple{ 5 Drainage = 6 Precipitation = 7 Base.to_index(id::Substance.T) = Int(id) # used to index into concentration matrices +@generated function config.snake_case(nt::NodeType.T) + ex = quote end + for (sym, _) in EnumX.symbol_map(NodeType.T) + sc = QuoteNode(config.snake_case(sym)) + t = NodeType.T(sym) + push!(ex.args, :(nt === $t && return $sc)) + end + push!(ex.args, :(return :nothing)) # type stability + ex +end + # Support creating a NodeType enum instance from a symbol or string function NodeType.T(s::Symbol)::NodeType.T symbol_map = EnumX.symbol_map(NodeType.T) @@ -86,6 +97,7 @@ Base.convert(::Type{Int32}, id::NodeID) = id.value Base.broadcastable(id::NodeID) = Ref(id) Base.:(==)(id_1::NodeID, id_2::NodeID) = id_1.type == id_2.type && id_1.value == id_2.value Base.show(io::IO, id::NodeID) = print(io, id.type, " #", id.value) +config.snake_case(id::NodeID) = config.snake_case(id.type) function Base.isless(id_1::NodeID, id_2::NodeID)::Bool if id_1.type != id_2.type diff --git a/core/src/util.jl b/core/src/util.jl index e9c2a2bcc..2ae1dbeff 100644 --- a/core/src/util.jl +++ b/core/src/util.jl @@ -657,7 +657,7 @@ function get_variable_ref( PreallocationRef(cache(1), flow_idx; from_du = true) end else - node = getfield(p, snake_case(Symbol(node_id.type))) + node = getfield(p, snake_case(node_id)) PreallocationRef(node.flow_rate, node_id.idx) end else @@ -814,7 +814,7 @@ function collect_control_mappings!(p)::Nothing for node_type in instances(NodeType.T) node_type == NodeType.Terminal && continue - node = getfield(p, Symbol(snake_case(string(node_type)))) + node = getfield(p, snake_case(node_type)) if hasfield(typeof(node), :control_mapping) control_mappings[node_type] = node.control_mapping end @@ -1096,7 +1096,7 @@ function get_state_index( component_name = if id.type == NodeType.UserDemand inflow ? :user_demand_inflow : :user_demand_outflow else - snake_case(Symbol(id.type)) + snake_case(id) end for (comp, range) in pairs(NT) if comp == component_name