Skip to content

Commit

Permalink
Merge branch 'main' into 135-UPX3-instantiate-cfas-epidemia-for-flu-m…
Browse files Browse the repository at this point in the history
…odel-via-msr
  • Loading branch information
AFg6K7h4fhy2 committed Jul 29, 2024
2 parents e48ce28 + ff33e31 commit 8817f69
Show file tree
Hide file tree
Showing 44 changed files with 685 additions and 392 deletions.
22 changes: 13 additions & 9 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ from pyrenew.metaclass import (
)
import pyrenew.transformation as t
from numpyro.infer.reparam import LocScaleReparam
```
By default, XLA (which is used by JAX for compilation) considers all CPU cores as one device. Depending on your system's configuration, we recommend using numpyro's [set_host_device_count()](https://num.pyro.ai/en/stable/utilities.html#set-host-device-count) function to set the number of devices available for parallel computing. Here, we set the device count to be 2.
```{python}
# | label: set-device-count
numpyro.set_host_device_count(2)
```

Expand Down Expand Up @@ -118,12 +121,12 @@ To initialize these five components within the renewal modeling framework, we es
# | label: creating-elements
# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1])
gen_int = DeterministicPMF(pmf_array, name="gen_int")
gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(2.5, 1), name="I0"),
DistributionalRV(name="I0", dist=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)
Expand All @@ -144,12 +147,13 @@ class MyRt(RandomVariable):
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
dist.Normal(0, sd_rt),
"rw_step_rv",
name="rw_step_rv",
dist=dist.Normal(0, sd_rt),
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
dist.Normal(jnp.log(1), jnp.log(1.2)), "init_log_Rt_rv"
name="init_log_Rt_rv",
dist=dist.Normal(jnp.log(1), jnp.log(1.2)),
),
),
transforms=t.ExpTransform(),
Expand Down Expand Up @@ -220,11 +224,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(sim_data.Rt)
axs[0].plot(sim_data.Rt.value)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.observed_infections)
axs[1].plot(sim_data.observed_infections.value)
axs[1].set_ylabel("Infections")
fig.suptitle("Basic renewal model")
Expand All @@ -242,7 +246,7 @@ import jax
model1.run(
num_warmup=2000,
num_samples=1000,
data_observed_infections=sim_data.observed_infections,
data_observed_infections=sim_data.observed_infections.value,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
Expand Down
36 changes: 22 additions & 14 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ The following code-chunk defines the model components. Notice that for both the
```{python}
# | label: model-components
gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1])
gen_int = DeterministicPMF(gen_int_array, name="gen_int")
feedback_strength = DeterministicVariable(0.05, name="feedback_strength")
gen_int = DeterministicPMF(name="gen_int", value=gen_int_array)
feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(0.5, name="rate"),
DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)
Expand All @@ -64,8 +66,12 @@ rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"),
init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"),
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
Expand Down Expand Up @@ -99,7 +105,7 @@ with numpyro.handlers.seed(rng_seed=223):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(model0_samp.latent_infections)
ax.plot(model0_samp.latent_infections.value)
ax.set_xlabel("Time")
ax.set_ylabel("Infections")
plt.show()
Expand Down Expand Up @@ -156,7 +162,7 @@ The next step is to create the actual class. The bulk of its implementation lies
# | label: new-model-def
# | code-line-numbers: true
# Creating the class
from pyrenew.metaclass import RandomVariable
from pyrenew.metaclass import RandomVariable, SampledValue
from pyrenew.latent import compute_infections_from_rt_with_feedback
from pyrenew import arrayutils as au
from jax.typing import ArrayLike
Expand Down Expand Up @@ -204,12 +210,14 @@ class InfFeedback(RandomVariable):
**kwargs,
)
inf_feedback_strength = au.pad_x_to_match_y(
x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0]
x=inf_feedback_strength.value,
y=Rt,
fill_value=inf_feedback_strength.value[0],
)
# Sampling inf feedback and adjusting the shape
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value)
# Generating the infections with feedback
all_infections, Rt_adj = compute_infections_from_rt_with_feedback(
Expand All @@ -226,8 +234,8 @@ class InfFeedback(RandomVariable):
# Preparing theoutput
return InfFeedbackSample(
infections=all_infections,
rt=Rt_adj,
infections=SampledValue(all_infections),
rt=SampledValue(Rt_adj),
)
```

Expand Down Expand Up @@ -269,8 +277,8 @@ Comparing `model0` with `model1`, these two should match:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(ncols=2)
ax[0].plot(model0_samp.latent_infections)
ax[1].plot(model1_samp.latent_infections)
ax[0].plot(model0_samp.latent_infections.value)
ax[1].plot(model1_samp.latent_infections.value)
ax[0].set_xlabel("Time (model 0)")
ax[1].set_xlabel("Time (model 1)")
ax[0].set_ylabel("Infections")
Expand Down
34 changes: 18 additions & 16 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ format: gfm
engine: jupyter
---

This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data.

We begin by loading `numpyro` and configuring the device count to 2 to enable running MCMC chains in parallel. By default, XLA (which is used by JAX for compilation) considers all CPU cores as one device. Depending on your system's configuration, we recommend using numpyro's [set_host_device_count()](https://num.pyro.ai/en/stable/utilities.html#set-host-device-count) function to set the number of devices available for parallel computing.

```{python}
# | label: numpyro setup
# | echo: false
import numpyro
numpyro.set_host_device_count(2)
```

This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data.

## Model definition

In this section, we provide the formal definition of the model. The hospital admissions model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions:
Expand Down Expand Up @@ -142,12 +142,11 @@ import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
inf_hosp_int, name="inf_hosp_int"
name="inf_hosp_int", value=inf_hosp_int
)
hosp_rate = metaclass.DistributionalRV(
dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)),
name="IHR",
name="IHR", dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)
latent_hosp = latent.HospitalAdmissions(
Expand All @@ -172,17 +171,17 @@ latent_inf = latent.Infections()
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), name="I0"
name="I0", dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75))
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
deterministic.DeterministicVariable(0.05, name="rate"),
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)
# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int")
gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int)
class MyRt(metaclass.RandomVariable):
Expand All @@ -200,10 +199,10 @@ class MyRt(metaclass.RandomVariable):
base_rv=process.SimpleRandomWalkProcess(
name="log_rt",
step_rv=metaclass.DistributionalRV(
dist.Normal(0, sd_rt), "rw_step_rv"
name="rw_step_rv", dist=dist.Normal(0, sd_rt.value)
),
init_rv=metaclass.DistributionalRV(
dist.Normal(0, 0.2), "init_log_Rt_rv"
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
),
transforms=transformation.ExpTransform(),
Expand All @@ -213,7 +212,9 @@ class MyRt(metaclass.RandomVariable):
rtproc = MyRt(
metaclass.DistributionalRV(dist.HalfNormal(0.025), "Rt_random_walk_sd")
metaclass.DistributionalRV(
name="Rt_random_walk_sd", dist=dist.HalfNormal(0.025)
)
)
# The observation model
Expand All @@ -223,7 +224,8 @@ rtproc = MyRt(
nb_conc_rv = metaclass.TransformedRandomVariable(
"concentration",
metaclass.DistributionalRV(
dist.TruncatedNormal(loc=0, scale=1, low=0.01), "concentration_raw"
name="concentration_raw",
dist=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
transformation.PowerTransform(-2),
)
Expand Down Expand Up @@ -270,11 +272,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(simulated_data.Rt)
axs[0].plot(simulated_data.Rt.value)
axs[0].set_ylabel("Simulated Rt")
# Admissions plot
axs[1].plot(simulated_data.observed_hosp_admissions, "-o")
axs[1].plot(simulated_data.observed_hosp_admissions.value, "-o")
axs[1].set_ylabel("Simulated Admissions")
fig.suptitle("Basic renewal model")
Expand Down
16 changes: 10 additions & 6 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ rt_proc = process.RtWeeklyDiffProcess(
name="rt_weekly_diff",
offset=0,
log_rt_prior=deterministic.DeterministicVariable(
jnp.array([0.1, 0.2]), name="log_rt_prior"
name="log_rt_prior", value=jnp.array([0.1, 0.2])
),
autoreg=deterministic.DeterministicVariable(
jnp.array([0.7]), name="autoreg"
name="autoreg", value=jnp.array([0.7])
),
periodic_diff_sd=deterministic.DeterministicVariable(
jnp.array([0.1]), name="periodic_diff_sd"
name="periodic_diff_sd", value=jnp.array([0.1])
),
)
```
Expand All @@ -46,7 +46,7 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the Rt values
import matplotlib.pyplot as plt
plt.step(np.arange(len(sim_data.rt)), sim_data.rt, where="post")
plt.step(np.arange(len(sim_data.rt.value)), sim_data.rt.value, where="post")
plt.xlabel("Time")
plt.ylabel("Rt")
plt.title("Simulated Rt values")
Expand Down Expand Up @@ -76,7 +76,9 @@ mysimplex = dist.TransformedDistribution(
# Constructing the day of week effect
dayofweek = process.DayOfWeekEffect(
offset=0,
quantity_to_broadcast=metaclass.DistributionalRV(mysimplex, "simp"),
quantity_to_broadcast=metaclass.DistributionalRV(
name="simp", dist=mysimplex
),
t_start=0,
)
```
Expand All @@ -90,7 +92,9 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the effect values
import matplotlib.pyplot as plt
plt.step(np.arange(len(sim_data.value)), sim_data.value, where="post")
plt.step(
np.arange(len(sim_data.value.value)), sim_data.value.value, where="post"
)
plt.xlabel("Time")
plt.ylabel("Effect size")
plt.title("Simulated Day of Week Effect values")
Expand Down
19 changes: 13 additions & 6 deletions docs/source/tutorials/time.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@ The fundamental time unit should represent a period of fixed (or approximately f

For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit.

`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units.
`pyrenew` deals with time by having `RandomVariable`s carry information about

The tuple `(t_unit, t_start)` can encode different types of time series data. For example:
1. their own time unit expressed relative to the fundamental unit (`t_unit`) and
2. the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units.

Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. `SampledValue` objects can have `t_start` and `t_unit` attributes.

By default, `SampledValue` objects carry the `t_start` and `t_unit` of the `RandomVariable` from which they are `sample()`-d. One might override this default to allow a `RandomVariable.sample()` call to produce multiple `SampledValue`s with different time-units, or with different start-points relative to the `RandomVariable`'s own `t_start`.

The `t_unit, t_start` pair can encode different types of time series data. For example:

| Description | `t_unit` | `t_start` |
|:-----------------|----------------:|-----------------:|
Expand All @@ -31,14 +38,14 @@ The `PeriodicBroadcaster()` class provides a way of tiling and repeating data ac

The following section describes some preliminary design principles that may be included in future versions of `pyrenew`.

### Validation

With random variables possibly spanning different time scales, *e.g.*, weekly, daily, hourly, the metaclass `Model` should ensure random variables within the model share the same time unit.

### Array alignment

Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the initialization process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:

```python
Rt_aligned, infections_aligned = align([Rt, infections])
```

### Retrieving time information from sites

Future versions of `pyrenew` could include a way to retrieve the time information for sites keyed by site name the model.
Loading

0 comments on commit 8817f69

Please sign in to comment.