Skip to content

Commit

Permalink
DHM updates
Browse files Browse the repository at this point in the history
  • Loading branch information
AFg6K7h4fhy2 committed Aug 1, 2024
1 parent 47dc45c commit cf34382
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 27 deletions.
3 changes: 2 additions & 1 deletion model/src/pyrenew/transformation/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Built-in pyrenew transformations created using `numpyro.distributions.transforms`.
"""


import numpyro.distributions.transforms as nt


Expand All @@ -27,5 +28,5 @@ def ScaledLogitTransform(
- numpyro.distributions.transforms.SigmoidTransform().inv
"""
return nt.ComposeTransform(
[nt.AffineTransform(0.0, 1.0 / x_max), nt.SigmoidTransform().inv]
[nt.SigmoidTransform().inv, nt.AffineTransform(0.0, 1.0 / x_max)]
)
84 changes: 58 additions & 26 deletions scratch/tut_epim_port_msr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import datetime as dt
import inspect
import os
from pprint import pprint

Expand Down Expand Up @@ -77,14 +78,15 @@ def load_influenza_hosp_data(
"""
data = pl.read_csv(data_path, separator="\t", infer_schema_length=10000)
# verification: columns
print(data.columns)

if print_first_n_rows and (5 <= n_row_count <= 50):
pl.Config.set_tbl_hide_dataframe_shape(True)
pl.Config.set_tbl_formatting("ASCII_MARKDOWN")
pl.Config.set_tbl_hide_column_data_types(True)
with pl.Config(tbl_rows=n_row_count, tbl_cols=6):
# verification: rows and columns of data
print(data)
print(f"DATASET:\n{data}\n")
print(f"Dataset Columns:\n{data.columns}\n\n")
return data


Expand Down Expand Up @@ -333,6 +335,9 @@ def __init__(
self.predictors = predictors
self.alpha_intercept_prior_mode = alpha_intercept_prior_mode
self.alpha_intercept_prior_scale = alpha_intercept_prior_scale

# update: this should not depend on explicit covariates, it should
# receive all_coefficient_priors? (works fine)
self.day_of_week_effect_prior_modes = day_of_week_effect_prior_modes
self.day_of_week_effect_prior_scales = day_of_week_effect_prior_scales
self.holiday_eff_prior_mode = holiday_eff_prior_mode
Expand All @@ -352,6 +357,7 @@ def _init_alpha_t(self): # numpydoc ignore=GL08
alpha_intercept_prior = dist.Normal(
self.alpha_intercept_prior_mode, self.alpha_intercept_prior_scale
)
# update: fine for
all_coefficient_priors = dist.Normal(
loc=jnp.array(
self.day_of_week_effect_prior_modes
Expand All @@ -375,7 +381,7 @@ def _init_alpha_t(self): # numpydoc ignore=GL08
fixed_predictor_values=predictor_values,
intercept_prior=alpha_intercept_prior,
coefficient_priors=all_coefficient_priors,
transform=t.ScaledLogitTransform(x_max=self.max_rt),
transform=t.ScaledLogitTransform(x_max=self.max_rt).inv,
)

def _init_negative_binomial(self): # numpydoc ignore=GL08
Expand All @@ -398,6 +404,7 @@ def sample(
**kwargs,
) -> tuple: # numpydoc ignore=GL08
alpha_samples = self.alpha_process.sample()["prediction"]
alpha_samples = alpha_samples[: infections.shape[0]]
expected_hosp = (
alpha_samples
* jnp.convolve(infections, delay_distribution, mode="full")[
Expand All @@ -407,9 +414,6 @@ def sample(
nb_samples = self.nb_observation.sample(mu=expected_hosp, **kwargs)
return nb_samples

# update: ensure alpha_samples is the correct shape
# update: this is starting to look good, verify though


class CFAEPIM_Rt(RandomVariable): # numpydoc ignore=GL08
def __init__(
Expand Down Expand Up @@ -442,10 +446,12 @@ def sample(self, n_steps: int, **kwargs) -> tuple: # numpydoc ignore=GL08
dist=self.intercept_RW_prior,
),
)
wt_samples = wt_rv.sample(n_steps=n_steps, **kwargs)
print(f"Non-Transformed Samples: {wt_samples}")
transformed_rt_samples = TransformedRandomVariable(
name="transformed_rt_rw",
base_rv=wt_rv,
transforms=t.ScaledLogitTransform(x_max=self.max_rt),
transforms=t.ScaledLogitTransform(x_max=self.max_rt).inv,
).sample(n_steps=n_steps, **kwargs)
return transformed_rt_samples

Expand Down Expand Up @@ -494,24 +500,25 @@ def __init__(
self.rt_intercept_prior_mode, self.rt_intercept_prior_scale
)

# confirm: not needed
# transmission: prior for gamma term
self.gamma_RW_prior_scale = dist.HalfNormal(
self.weekly_rw_prior_scale
).sample(jax.random.PRNGKey(self.seed))
self.gamma_RW_prior = dist.Normal(0, self.gamma_RW_prior_scale)
# self.gamma_RW_prior_scale = dist.HalfNormal(
# self.weekly_rw_prior_scale
# ).sample(jax.random.PRNGKey(self.seed))
# self.gamma_RW_prior = dist.Normal(0, self.gamma_RW_prior_scale)

# transmission: Rt process
self.Rt_process = CFAEPIM_Rt(
intercept_RW_prior=self.intercept_RW_prior,
max_rt=self.max_rt,
gamma_RW_prior_scale=self.gamma_RW_prior_scale,
gamma_RW_prior_scale=self.weekly_rw_prior_scale,
)

# infections: get value rate for infection seeding (initialization)
first_week_hosp = (
self.dataset.filter((pl.col("location") == state))
.select(["first_week_hosp"])
.to_numpy()[0]
.to_numpy()[0][0]
)
self.mean_inf_val = (
self.inf_model_prior_infections_per_capita * self.population
Expand All @@ -522,7 +529,9 @@ def __init__(
name="I0_initialization",
I_pre_init_rv=DistributionalRV(
name="I0",
dist=dist.Exponential(rate=1 / self.mean_inf_val),
dist=dist.Exponential(rate=1 / self.mean_inf_val).expand(
[self.inf_model_seed_days]
),
),
infection_init_method=InitializeInfectionsFromVec(
n_timepoints=self.inf_model_seed_days
Expand Down Expand Up @@ -593,13 +602,13 @@ def predict(self, rng_key, **kwargs): # numpydoc ignore=GL08

def verify_cfaepim_MSR(cfaepim_MSR_model) -> None: # numpydoc ignore=GL08
# verification: population
print(f"Population Value: {cfaepim_MSR_model.population}")
print(f"Population Value:\n{cfaepim_MSR_model.population}\n\n")
# verification: predictors
print(f"Predictors:\n{cfaepim_MSR_model.predictors}")
print(f"Predictors:\n{cfaepim_MSR_model.predictors}\n\n")
# verification: (transmission) generation interval deterministic PMF
cfaepim_MSR_model.gen_int.validate(cfaepim_MSR_model.pmf_array)
sampled_gen_int = cfaepim_MSR_model.gen_int.sample()
print(f"SAMPLED GENERATION INTERVAL:\n{sampled_gen_int}")
print(f"CFAEPIM GENERATION INTERVAL:\n{sampled_gen_int}\n\n")
base_object_plot(
y=sampled_gen_int[0].value,
X=np.arange(0, len(sampled_gen_int[0].value)),
Expand All @@ -612,24 +621,47 @@ def verify_cfaepim_MSR(cfaepim_MSR_model) -> None: # numpydoc ignore=GL08
print(
f"CFAEPIM RT PROCESS:\n{cfaepim_MSR_model.Rt_process}\n{dir(cfaepim_MSR_model.Rt_process)}"
)
with numpyro.handlers.seed(
rng_seed=jax.random.key(cfaepim_MSR_model.seed)
):
print(
f"(Sample Method For Rt Process):\n{inspect.signature(cfaepim_MSR_model.Rt_process.sample)}"
)
with numpyro.handlers.seed(rng_seed=cfaepim_MSR_model.seed):
sampled_Rt = cfaepim_MSR_model.Rt_process.sample(n_steps=100)
print(sampled_Rt)
print(f"TRANSFORMED Samples:\n{sampled_Rt}\n\n")
# verification: (infections) first week hosp
print(cfaepim_MSR_model.mean_inf_val)
print(f"First Week Mean Infections:\n{cfaepim_MSR_model.mean_inf_val}\n\n")
# verification: (infections) initial infections
print(f"CFAEPIM I0:\n{cfaepim_MSR_model.I0}\n{dir(cfaepim_MSR_model.I0)}")
print(
f"(Sample Method For I0):\n{inspect.signature(cfaepim_MSR_model.I0.sample)}"
)
with numpyro.handlers.seed(
rng_seed=jax.random.key(cfaepim_MSR_model.seed)
):
sampled_I0 = cfaepim_MSR_model.I0.sample()
print(sampled_I0)
print(f"Samples:\n{sampled_I0}\n\n")
# verification: observation process
print(
f"CFAEPIM OBSERVATION PROCESS:\n{cfaepim_MSR_model.obs_process}\n{dir(cfaepim_MSR_model.obs_process)}"
)
print(
f"(Sample Method For Obs. Process):\n{inspect.signature(cfaepim_MSR_model.obs_process.sample)}\n\n"
)
with numpyro.handlers.seed(
rng_seed=jax.random.key(cfaepim_MSR_model.seed)
):
sampled_alpha = cfaepim_MSR_model.obs_process.alpha_process.sample()[
"prediction"
]
print(f"CFAEPIM ALPHA PROCESS:\n{sampled_alpha}\n\n")
random_infs = jnp.array(np.random.randint(low=1000, high=5000, size=20))
delay_dist = jnp.array(cfaepim_MSR_model.inf_to_hosp_dist)
with numpyro.handlers.seed(
rng_seed=jax.random.key(cfaepim_MSR_model.seed)
):
sampled_obs = cfaepim_MSR_model.obs_process(
infections=random_infs, delay_distribution=delay_dist
)
print(f"Samples:\n{sampled_obs}\n\n")


def main(): # numpydoc ignore=GL08
Expand All @@ -640,7 +672,7 @@ def main(): # numpydoc ignore=GL08
# load parameters config (2024-01-20)
config = toml.load("./config/params_2024-01-20.toml")
# verification: config file
pprint(config)
pprint(f"CONFIG:\n{config}\n\n")

# load NHSN data w/ population counts (2024-01-20)
data_path_01 = "./data/2024-01-20/2024-01-20_clean_data.tsv"
Expand All @@ -656,7 +688,7 @@ def main(): # numpydoc ignore=GL08
upper_date="2024-03-10",
use_log=False,
use_legend=True,
save_as_img=True,
save_as_img=False,
save_to_pdf=False,
display=False,
)
Expand All @@ -669,7 +701,7 @@ def main(): # numpydoc ignore=GL08
).select(["date", "hosp"])["hosp"]
)
# verification: data_observed_hosp_admissions
print(data_observed_hosp_admissions)
print(f"HOSPITALIZATIONS:\n{data_observed_hosp_admissions}\n\n")

# instantiate cfaepim-MSR
cfaepim_MSR = CFAEPIM_Model(
Expand Down

0 comments on commit cf34382

Please sign in to comment.