Skip to content

Commit

Permalink
Fix bug in FD cache handling for grid variables
Browse files Browse the repository at this point in the history
  • Loading branch information
bgroenks96 committed Dec 13, 2024
1 parent 7958014 commit 4b65f40
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
1 change: 1 addition & 0 deletions examples/heat_simple_autodiff_grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using CryoGrid
# Set up forcings and boundary conditions similarly to other examples:
forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044);
soilprofile, tempprofile = CryoGrid.SamoylovDefault
soilprofile = SoilProfile(0.0u"m" => SimpleSoil())
grid = CryoGrid.DefaultGrid_5cm
initT = initializer(:T, tempprofile)
tile = CryoGrid.SoilHeatTile(
Expand Down
6 changes: 3 additions & 3 deletions src/Numerics/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ Base.show(io::IO, mime::MIME{Symbol("text/plain")}, cache::DiffCache) = show(io,
_retrieve(cache_var::AbstractArray{T}, ::AbstractArray{T}) where {T} = cache_var
_retrieve(cache_var::AbstractArray{T}, u::AbstractArray{U}) where {T,U} = copyto!(similar(u, length(cache_var)), cache_var)
retrieve(dc::DiffCache) = dc.cache.du
retrieve(dc::DiffCache, u::AbstractArray{T}) where {T<:ForwardDiff.Dual} = Prealloc.get_tmp(dc.cache, u)
retrieve(dc::DiffCache, u::AbstractArray{T}) where {T<:ForwardDiff.Dual} = copyto!(Prealloc.get_tmp(dc.cache, u), dc.cache.du)
retrieve(dc::DiffCache, u::AbstractArray{T}) where {T} = _retrieve(dc.cache.du, u)
retrieve(dc::DiffCache, u::AbstractArray{T}, t) where {T} = retrieve(dc, u)
# these cover cases for Rosenbrock solvers where only t has differentiable type
retrieve(dc::DiffCache, u::AbstractArray, t::T) where {T<:ForwardDiff.Dual} = Prealloc.get_tmp(dc.cache, t)
retrieve(dc::DiffCache, u::AbstractArray{T}, t::T) where {T<:ForwardDiff.Dual} = Prealloc.get_tmp(dc.cache, u)
retrieve(dc::DiffCache, u::AbstractArray, t::T) where {T<:ForwardDiff.Dual} = copyto!(Prealloc.get_tmp(dc.cache, t), dc.cache.du)
retrieve(dc::DiffCache, u::AbstractArray{T}, t::T) where {T<:ForwardDiff.Dual} = copyto!(Prealloc.get_tmp(dc.cache, u), dc.cache.du)

"""
ArrayCache{T,TA} <: StateVarCache
Expand Down
10 changes: 5 additions & 5 deletions src/Numerics/grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ function updategrid!(grid::Grid{Edges}, z0, thick::AbstractVector)
return grid
end

function currentgrid(statevars::NamedTuple, initialgrid::Grid, u, t)
function currentgrid(state::NamedTuple, initialgrid::Grid, u, t)
# retrieve grid data from StateVars
midpoints = retrieve(statevars.midpoints, u, t)
edges = retrieve(statevars.edges, u, t)
cellthick = retrieve(statevars.cellthick, u, t)
celldist = retrieve(statevars.celldist, u, t)
midpoints = retrieve(state.midpoints, u, t)
edges = retrieve(state.edges, u, t)
cellthick = retrieve(state.cellthick, u, t)
celldist = retrieve(state.celldist, u, t)
return Grid(Edges, (edges=edges, cells=midpoints), (edges=cellthick, cells=celldist), initialgrid.geometry, initialgrid.bounds)
end

Expand Down
4 changes: 2 additions & 2 deletions src/Tiles/tile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,11 @@ Returns a tuple of all variables defined in the tile.
CryoGrid.variables(tile::Tile) = Tuple(unique(Flatten.flatten(tile.state.vars, Flatten.flattenable, Var)))

"""
parameters(tile::Tile; kwargs...)
parameters(tile::Tile; include_all=false, kwargs...)
Extracts all parameters from `tile`.
"""
parameters(tile::Tile; kwargs...) = CryoGridParams(tile; kwargs...)
parameters(tile::Tile; include_all=true, kwargs...) = CryoGridParams(include_all ? tile : stripparams(FixedParam, tile); kwargs...)

"""
withaxes(u::AbstractArray, ::Tile)
Expand Down

0 comments on commit 4b65f40

Please sign in to comment.