Skip to content

Commit

Permalink
First pass of refactoring. All tests pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Apr 3, 2024
1 parent 5a06ab2 commit 82e20fe
Show file tree
Hide file tree
Showing 24 changed files with 196 additions and 380 deletions.
13 changes: 4 additions & 9 deletions model/docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ rt_proc = RtRandomWalkProcess()
latent_infections = Infections()

# (5) The observed infections process (with mean at the latent infections)
observed_infections = PoissonObservation(
rate_varname = 'latent',
counts_varname = 'observed_infections',
)
observed_infections = PoissonObservation()
```

With these five pieces, we can build the basic renewal model:
Expand Down Expand Up @@ -136,7 +133,7 @@ function of `RtInfectionsRenewalModel`:
``` python
np.random.seed(223)
with npro.handlers.seed(rng_seed = np.random.randint(1, 60)):
sim_data = model1.sample(constants = dict(n_timepoints=30))
sim_data = model1.sample(n_timepoints=30)

sim_data
```
Expand Down Expand Up @@ -187,13 +184,11 @@ To fit the model, we can use the `run()` method of the model
``` python
import jax

model_data = {'n_timepoints': len(sim_data[1])-1}

model1.run(
num_warmup=2000,
num_samples=1000,
random_variables=dict(observed_infections=sim_data.observed),
constants=model_data,
observed_infections=sim_data.observed,
n_timepoints = len(sim_data[1])-1,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down
13 changes: 4 additions & 9 deletions model/docs/getting-started.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ rt_proc = RtRandomWalkProcess()
latent_infections = Infections()
# (5) The observed infections process (with mean at the latent infections)
observed_infections = PoissonObservation(
rate_varname = 'latent',
counts_varname = 'observed_infections',
)
observed_infections = PoissonObservation()
```

With these five pieces, we can build the basic renewal model:
Expand Down Expand Up @@ -112,7 +109,7 @@ Using `numpyro`, we can simulate data using the `sample()` member function of `R
#| label: simulate
np.random.seed(223)
with npro.handlers.seed(rng_seed = np.random.randint(1, 60)):
sim_data = model1.sample(constants = dict(n_timepoints=30))
sim_data = model1.sample(n_timepoints=30)
sim_data
```
Expand Down Expand Up @@ -147,13 +144,11 @@ To fit the model, we can use the `run()` method of the model `RtInfectionsRenewa
#| label: model-fit
import jax
model_data = {'n_timepoints': len(sim_data[1])-1}
model1.run(
num_warmup=2000,
num_samples=1000,
random_variables=dict(observed_infections=sim_data.observed),
constants=model_data,
observed_infections=sim_data.observed,
n_timepoints = len(sim_data[1])-1,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down
16 changes: 4 additions & 12 deletions model/docs/pyrenew_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import numpyro.distributions as dist
from pyrenew.process import SimpleRandomWalkProcess
```

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

``` python
np.random.seed(3312)
q = SimpleRandomWalkProcess(dist.Normal(0, 0.001))
Expand Down Expand Up @@ -71,10 +69,7 @@ latent_hospitalizations = HospitalAdmissions(
)

# And observation process for the hospitalizations
observed_hospitalizations = PoissonObservation(
rate_varname='latent',
counts_varname='observed_hospitalizations',
)
observed_hospitalizations = PoissonObservation()

# And a random walk process (it could be deterministic using
# pyrenew.process.DeterministicProcess())
Expand All @@ -93,7 +88,7 @@ hospmodel = HospitalizationsModel(

``` python
with seed(rng_seed=np.random.randint(1, 60)):
x = hospmodel.sample(constants=dict(n_timepoints=30))
x = hospmodel.sample(n_timepoints=30)
x
```

Expand Down Expand Up @@ -130,15 +125,12 @@ for axis in ax[:-1]:
![](pyrenew_demo_files/figure-commonmark/fig-hosp-output-1.png)

``` python
sim_dat={"observed_hospitalizations": x.sampled}
constants = {"n_timepoints":len(x.sampled)-1}

# from numpyro.infer import MCMC, NUTS
hospmodel.run(
num_warmup=1000,
num_samples=1000,
random_variables=sim_dat,
constants=constants,
observed_hospitalizations=x.sampled,
n_timepoints = len(x.sampled)-1,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down
14 changes: 4 additions & 10 deletions model/docs/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@ latent_hospitalizations = HospitalAdmissions(
)
# And observation process for the hospitalizations
observed_hospitalizations = PoissonObservation(
rate_varname='latent',
counts_varname='observed_hospitalizations',
)
observed_hospitalizations = PoissonObservation()
# And a random walk process (it could be deterministic using
# pyrenew.process.DeterministicProcess())
Expand All @@ -96,7 +93,7 @@ hospmodel = HospitalizationsModel(

```{python}
with seed(rng_seed=np.random.randint(1, 60)):
x = hospmodel.sample(constants=dict(n_timepoints=30))
x = hospmodel.sample(n_timepoints=30)
x
```

Expand All @@ -113,15 +110,12 @@ for axis in ax[:-1]:
```

```{python}
sim_dat={"observed_hospitalizations": x.sampled}
constants = {"n_timepoints":len(x.sampled)-1}
# from numpyro.infer import MCMC, NUTS
hospmodel.run(
num_warmup=1000,
num_samples=1000,
random_variables=sim_dat,
constants=constants,
observed_hospitalizations=x.sampled,
n_timepoints = len(x.sampled)-1,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down
10 changes: 3 additions & 7 deletions model/src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,14 @@ def validate(vars: tuple) -> None:

def sample(
self,
random_variables: dict = None,
constants: dict = None,
**kwargs,
) -> tuple:
"""Retrieve the value of the deterministic Rv
Parameters
----------
random_variables : dict
Ignored. Default None.
constants : dict
Ignored. Default None.
kwargs : dict
Ignored.
Returns
-------
Expand Down
15 changes: 4 additions & 11 deletions model/src/pyrenew/deterministic/deterministicpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,19 @@ def validate(vars: tuple) -> None:

def sample(
self,
random_variables: dict = None,
constants: dict = None,
**kwargs,
) -> tuple:
"""Retrieves the deterministic PMF
Parameters
----------
random_variables : dict
Ignored. Default None.
constants : dict
Ignored. Default None.
kwargs : dict
Arguments to pass to the sample method.
Returns
-------
tuple
Containing the stored values during construction.
"""

return self.basevar.sample(
random_variables=random_variables,
constants=constants,
)
return self.basevar.sample(**kwargs)
62 changes: 15 additions & 47 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
import numpyro as npro
import numpyro.distributions as dist
from numpy.typing import ArrayLike
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable

Expand Down Expand Up @@ -57,16 +58,11 @@ def __init__(
def validate(distr: dist.Distribution) -> None:
assert isinstance(distr, dist.Distribution)

def sample(
self,
random_variables: dict = None,
constants: dict = None,
) -> InfectHospRateSample:
def sample(self, **kwargs) -> InfectHospRateSample:
return InfectHospRateSample(
npro.sample(
"IHR",
self.dist,
obs=random_variables.get(self.varname, None),
)
)

Expand Down Expand Up @@ -106,7 +102,6 @@ def __init__(
self,
infection_to_admission_interval: RandomVariable,
infect_hosp_rate_dist: RandomVariable,
infections_varname: str = "infections",
hospitalizations_predicted_varname: str = "predicted_hospitalizations",
weekday_effect_dist: RandomVariable = DeterministicVariable((1,)),
hosp_report_prob_dist: RandomVariable = DeterministicVariable((1,)),
Expand All @@ -120,21 +115,14 @@ def __init__(
pyrenew.observations.Deterministic).
infect_hosp_rate_dist : RandomVariable
Infection to hospitalization rate distribution.
infections_varname : str
Name of the entry in random_variables that holds the vector of
infections.
infect_hosp_rate_varname : str
Name of the entry in random_variables that holds the observed
infection-hospitalization rate (IHR).
(if available).
hospitalizations_predicted_varname : str
Name to assign to the deterministic component in numpyro of
predicted hospitalizations.
weekday_effect_dist : RandomVariable, optional
Weekday effect.
hosp_report_prob_dist : RandomVariable, optional
Distribution or fixed value for the hospital admission reporting probability. Defaults to 1 (full
reporting).
Distribution or fixed value for the hospital admission reporting
probability. Defaults to 1 (full reporting).
Returns
-------
Expand All @@ -146,7 +134,6 @@ def __init__(
hosp_report_prob_dist,
)

self.infections_varname = infections_varname
self.hospitalizations_predicted_varname = (
hospitalizations_predicted_varname
)
Expand All @@ -170,44 +157,31 @@ def validate(

def sample(
self,
random_variables: dict = None,
constants: dict = None,
latent: ArrayLike,
**kwargs,
) -> HospAdmissionsSample:
"""Samples from the observation process
Parameters
----------
random_variables : dict
A dictionary `self.infections_varname` with the observed
infections. Optionally, with IHR passed to obs in npyro.sample().
constants : dict, optional
Ignored.
latent : ArrayLike
Latent infections.
kwargs : dict
Keyword arguments passed to the sampling methods.
Returns
-------
HospAdmissionsSample
"""

if random_variables is None:
random_variables = dict()

if constants is None:
constants = dict()
IHR, *_ = self.infect_hosp_rate_dist.sample(**kwargs)

IHR, *_ = self.infect_hosp_rate_dist.sample(
random_variables=random_variables,
constants=constants,
)

IHR_t = IHR * random_variables.get(self.infections_varname)
IHR_t = IHR * latent

(
infection_to_admission_interval,
*_,
) = self.infection_to_admission_interval.sample(
random_variables=random_variables,
constants=constants,
)
) = self.infection_to_admission_interval.sample(**kwargs)

predicted_hospitalizations = jnp.convolve(
IHR_t, infection_to_admission_interval, mode="full"
Expand All @@ -216,19 +190,13 @@ def sample(
# Applying weekday effect
predicted_hospitalizations = (
predicted_hospitalizations
* self.weekday_effect_dist.sample(
random_variables=random_variables,
constants=constants,
)[0]
* self.weekday_effect_dist.sample(**kwargs)[0]
)

# Applying probability of hospitalization effect
predicted_hospitalizations = (
predicted_hospitalizations
* self.hosp_report_prob_dist.sample(
random_variables=random_variables,
constants=constants,
)[0]
* self.hosp_report_prob_dist.sample(**kwargs)[0]
)

npro.deterministic(
Expand Down
9 changes: 3 additions & 6 deletions model/src/pyrenew/latent/i0.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,14 @@ def validate(i0_dist):

def sample(
self,
random_variables: dict,
constants: dict,
**kwargs,
) -> tuple:
"""Sample the initial infections.
Parameters
----------
random_variables : dict
Dictionary of random variables.
constants : dict
Dictionary of constants.
kwargs : dict
Ignored
Returns
-------
Expand Down
Loading

0 comments on commit 82e20fe

Please sign in to comment.