Skip to content

Commit

Permalink
fixing solver instantiation logic for kinematic 1D env + some code cl…
Browse files Browse the repository at this point in the history
…eanups
  • Loading branch information
slayoo committed Aug 5, 2024
1 parent 6d5cc1f commit 0f910c0
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 36 deletions.
12 changes: 4 additions & 8 deletions PySDM/dynamics/eulerian_advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@ def register(self, builder):
self.particulator = builder.particulator

def __call__(self):
self.particulator.environment.get_predicted(
"water_vapour_mixing_ratio"
).download(
self.particulator.environment.get_water_vapour_mixing_ratio(), reshape=True
)
self.particulator.environment.get_predicted("thd").download(
self.particulator.environment.get_thd(), reshape=True
)
for field in ("water_vapour_mixing_ratio", "thd"):
self.particulator.environment.get_predicted(field).download(
getattr(self.particulator.environment, f"get_{field}")(), reshape=True
)
self.solvers()
14 changes: 7 additions & 7 deletions PySDM/environments/impl/moist.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,21 @@ def _allocate(self, variables):
result[var] = self.particulator.Storage.empty((self.mesh.n_cell,), float)
return result

def __getitem__(self, index):
"""returns a Storage representing the variable (field) at a given index or
otherwise a NaN-filled Storage if the index is not found (in order to simplify
def __getitem__(self, key: str):
"""returns a Storage representing the variable (field) at a given key or
otherwise a NaN-filled Storage if the key is not found (in order to simplify
generic code which uses optional variables, e.g. air viscosity, etc.)"""
if index in self._values["current"]:
return self._values["current"][index]
if key in self._values["current"]:
return self._values["current"][key]
return self._nan_field

def get_predicted(self, index):
def get_predicted(self, key: str):
if self._values["predicted"] is None:
raise AssertionError(
"It seems the AmbientThermodynamics dynamic was not added"
" when building particulator"
)
return self._values["predicted"][index]
return self._values["predicted"][key]

def _recalculate_temperature_pressure_relative_humidity(self, target):
self.particulator.backend.temperature_pressure_rh(
Expand Down
2 changes: 1 addition & 1 deletion PySDM/environments/kinematic_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def register(self, builder):
self._tmp["rhod"] = rhod

def get_water_vapour_mixing_ratio(self) -> np.ndarray:
return self.particulator.dynamics["EulerianAdvection"].solvers.advectee.get()
return self.particulator.dynamics["EulerianAdvection"].solvers.advectee

Check warning on line 29 in PySDM/environments/kinematic_1d.py

View check run for this annotation

Codecov / codecov/patch

PySDM/environments/kinematic_1d.py#L29

Added line #L29 was not covered by tests

def get_thd(self) -> np.ndarray:
return self.thd0
Expand Down
49 changes: 33 additions & 16 deletions examples/PySDM_examples/Shipway_and_Hill_2012/mpdata_1d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from functools import cached_property

import numpy as np

from PyMPDATA import Options, ScalarField, Solver, Stepper, VectorField
from PyMPDATA.boundary_conditions import Extrapolated

Expand All @@ -19,41 +22,55 @@ def __init__(
self.dt = dt
self.advector_of_t = advector_of_t

grid = (nz,)
options = Options(
self._grid = (nz,)
self._options = Options(
n_iters=mpdata_settings["n_iters"],
infinite_gauge=mpdata_settings["iga"],
nonoscillatory=mpdata_settings["fct"],
third_order_terms=mpdata_settings["tot"],
)
stepper = Stepper(options=options, grid=grid, non_unit_g_factor=True)
bcs = (Extrapolated(),)
zZ_scalar = arakawa_c.z_scalar_coord(grid) / nz
g_factor = ScalarField(
zZ_scalar = arakawa_c.z_scalar_coord(self._grid) / nz
self._g_factor = ScalarField(
data=g_factor_of_zZ(zZ_scalar),
halo=options.n_halo,
halo=self._options.n_halo,
boundary_conditions=bcs,
)
advector = VectorField(
self._advector = VectorField(
data=(np.full(nz + 1, advector_of_t(0)),),
halo=options.n_halo,
halo=self._options.n_halo,
boundary_conditions=bcs,
)
self.advectee = ScalarField(
self._advectee = ScalarField(
data=advectee_of_zZ_at_t0(zZ_scalar),
halo=options.n_halo,
halo=self._options.n_halo,
boundary_conditions=bcs,
)
self.solver = Solver(
stepper=stepper,
advectee=self.advectee,
advector=advector,
g_factor=g_factor,

@cached_property
def solver(self):
return Solver(
stepper=Stepper(
options=self._options, grid=self._grid, non_unit_g_factor=True
),
advectee=self._advectee,
advector=self._advector,
g_factor=self._g_factor,
)

@property
def solver_cached(self):
return "solver" in self.__dict__

@property
def advectee(self):
return (self.solver.advectee if self.solver_cached else self._advectee).get()

@property
def advector(self):
return self.solver.advector.get_component(0)
return (
self.solver.advector if self.solver_cached else self._advector
).get_component(0)

def update_advector_field(self):
self.__t += 0.5 * self.dt
Expand Down
9 changes: 5 additions & 4 deletions examples/PySDM_examples/Shipway_and_Hill_2012/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def zZ_to_z_above_reservoir(zZ):
z_above_reservoir = zZ * (settings.nz * settings.dz) + self.z0
return z_above_reservoir

self.mpdata = MPDATA_1D(
mpdata = MPDATA_1D(
nz=settings.nz,
dt=settings.dt,
mpdata_settings=settings.mpdata_settings,
Expand Down Expand Up @@ -79,7 +79,7 @@ def zZ_to_z_above_reservoir(zZ):
update_thd=settings.condensation_update_thd,
)
)
self.builder.add_dynamic(EulerianAdvection(self.mpdata))
self.builder.add_dynamic(EulerianAdvection(mpdata))

self.products = []
if settings.precip:
Expand Down Expand Up @@ -257,10 +257,11 @@ def run(self):

self.save(0)
for step in range(self.nt):
self.mpdata.update_advector_field()
mpdata = self.particulator.dynamics["EulerianAdvection"].solvers
mpdata.update_advector_field()
if "Displacement" in self.particulator.dynamics:
self.particulator.dynamics["Displacement"].upload_courant_field(
(self.mpdata.advector / self.g_factor_vec,)
(mpdata.advector / self.g_factor_vec,)
)
self.particulator.run(steps=1)
self.save(step + 1)
Expand Down

0 comments on commit 0f910c0

Please sign in to comment.