Skip to content

Commit

Permalink
Merge branch 'main' into 332-tests-for-_assert_sample_and_rtype
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer authored Jul 31, 2024
2 parents b0f9461 + 64e0ed8 commit 7001ee4
Show file tree
Hide file tree
Showing 17 changed files with 167 additions and 85 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/deptry.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: model

on:
pull_request:
push:
branches: [main]

jobs:
dependency-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: cache poetry
uses: actions/cache@v4
with:
path: ~/.local
key: ${{ runner.os }}-poetry

- name: install poetry
run: pip install poetry

- name: install package
run: poetry install --with dev

- name: run deptry
run: |
poetry run deptry . --per-rule-ignores "DEP001=pyrenew,DEP003=pytest"
2 changes: 1 addition & 1 deletion docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class MyRt(RandomVariable):
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
name="init_log_Rt_rv",
name="init_log_rt",
dist=dist.Normal(jnp.log(1), jnp.log(1.2)),
),
),
Expand Down
4 changes: 1 addition & 3 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ rt = TransformedRandomVariable(
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)),
),
transforms=t.ExpTransform(),
)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class MyRt(metaclass.RandomVariable):
name="rw_step_rv", dist=dist.Normal(0, sd_rt.value)
),
init_rv=metaclass.DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=transformation.ExpTransform(),
Expand Down
8 changes: 4 additions & 4 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ from pyrenew import process, deterministic
rt_proc = process.RtWeeklyDiffProcess(
name="rt_weekly_diff",
offset=0,
log_rt_prior=deterministic.DeterministicVariable(
name="log_rt_prior", value=jnp.array([0.1, 0.2])
log_rt_rv=deterministic.DeterministicVariable(
name="log_rt", value=jnp.array([0.1, 0.2])
),
autoreg=deterministic.DeterministicVariable(
autoreg_rv=deterministic.DeterministicVariable(
name="autoreg", value=jnp.array([0.7])
),
periodic_diff_sd=deterministic.DeterministicVariable(
periodic_diff_sd_rv=deterministic.DeterministicVariable(
name="periodic_diff_sd", value=jnp.array([0.1])
),
)
Expand Down
12 changes: 8 additions & 4 deletions model/src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def __init__(
None
"""
self.name = name
self.value = jnp.atleast_1d(value)
self.validate(value)
self.value = jnp.atleast_1d(value)
self.set_timeseries(t_start, t_unit)

return None

@staticmethod
def validate(value: ArrayLike) -> None:
"""
Validates input to DeterministicPMF
Validates input to DeterministicVariable
Parameters
----------
Expand All @@ -63,10 +63,14 @@ def validate(value: ArrayLike) -> None:
Raises
------
Exception
If the input value object is not a ArrayLike.
If the input value object is not an ArrayLike object.
"""
if not isinstance(value, ArrayLike):
raise Exception("value is not a ArrayLike")
raise ValueError(
f"value {value} passed to a DeterministicVariable "
f"is of type {type(value).__name__}, expected "
"an ArrayLike object"
)

return None

Expand Down
72 changes: 36 additions & 36 deletions model/src/pyrenew/process/rtperiodicdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def __init__(
name: str,
offset: int,
period_size: int,
log_rt_prior: RandomVariable,
autoreg: RandomVariable,
periodic_diff_sd: RandomVariable,
log_rt_rv: RandomVariable,
autoreg_rv: RandomVariable,
periodic_diff_sd_rv: RandomVariable,
) -> None:
"""
Default constructor for RtPeriodicDiffProcess class.
Expand All @@ -66,11 +66,11 @@ def __init__(
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
log_rt_prior : RandomVariable
log_rt_rv : RandomVariable
Log Rt prior for the first two observations.
autoreg : RandomVariable
autoreg_rv : RandomVariable
Autoregressive parameter.
periodic_diff_sd : RandomVariable
periodic_diff_sd_rv : RandomVariable
Standard deviation of the noise.
Returns
Expand All @@ -85,45 +85,45 @@ def __init__(
)

self.validate(
log_rt_prior=log_rt_prior,
autoreg=autoreg,
periodic_diff_sd=periodic_diff_sd,
log_rt_rv=log_rt_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)

self.period_size = period_size
self.offset = offset
self.log_rt_prior = log_rt_prior
self.autoreg = autoreg
self.periodic_diff_sd = periodic_diff_sd
self.log_rt_rv = log_rt_rv
self.autoreg_rv = autoreg_rv
self.periodic_diff_sd_rv = periodic_diff_sd_rv

return None

@staticmethod
def validate(
log_rt_prior: any,
autoreg: any,
periodic_diff_sd: any,
log_rt_rv: any,
autoreg_rv: any,
periodic_diff_sd_rv: any,
) -> None:
"""
Validate the input parameters.
Parameters
----------
log_rt_prior : any
log_rt_rv : any
Log Rt prior for the first two observations.
autoreg : any
autoreg_rv : any
Autoregressive parameter.
periodic_diff_sd : any
periodic_diff_sd_rv : any
Standard deviation of the noise.
Returns
-------
None
"""

_assert_sample_and_rtype(log_rt_prior)
_assert_sample_and_rtype(autoreg)
_assert_sample_and_rtype(periodic_diff_sd)
_assert_sample_and_rtype(log_rt_rv)
_assert_sample_and_rtype(autoreg_rv)
_assert_sample_and_rtype(periodic_diff_sd_rv)

return None

Expand Down Expand Up @@ -175,19 +175,19 @@ def sample(
"""

# Initial sample
log_rt_prior = self.log_rt_prior.sample(**kwargs)[0].value
b = self.autoreg.sample(**kwargs)[0].value
s_r = self.periodic_diff_sd.sample(**kwargs)[0].value
log_rt_rv = self.log_rt_rv.sample(**kwargs)[0].value
b = self.autoreg_rv.sample(**kwargs)[0].value
s_r = self.periodic_diff_sd_rv.sample(**kwargs)[0].value

# How many periods to sample?
n_periods = int(jnp.ceil(duration / self.period_size))

# Running the process
ar_diff = FirstDifferenceARProcess("trend_rw", autoreg=b, noise_sd=s_r)
ar_diff = FirstDifferenceARProcess(self.name, autoreg=b, noise_sd=s_r)
log_rt = ar_diff.sample(
duration=n_periods,
init_val=log_rt_prior[1],
init_rate_of_change=log_rt_prior[1] - log_rt_prior[0],
init_val=log_rt_rv[1],
init_rate_of_change=log_rt_rv[1] - log_rt_rv[0],
)[0]

return RtPeriodicDiffProcessSample(
Expand All @@ -208,9 +208,9 @@ def __init__(
self,
name: str,
offset: int,
log_rt_prior: RandomVariable,
autoreg: RandomVariable,
periodic_diff_sd: RandomVariable,
log_rt_rv: RandomVariable,
autoreg_rv: RandomVariable,
periodic_diff_sd_rv: RandomVariable,
) -> None:
"""
Default constructor for RtWeeklyDiffProcess class.
Expand All @@ -221,11 +221,11 @@ def __init__(
Name of the site.
offset : int
Relative point at which data starts, must be between 0 and 6.
log_rt_prior : RandomVariable
log_rt_rv : RandomVariable
Log Rt prior for the first two observations.
autoreg : RandomVariable
autoreg_rv : RandomVariable
Autoregressive parameter.
periodic_diff_sd : RandomVariable
periodic_diff_sd_rv : RandomVariable
Standard deviation of the noise.
Returns
Expand All @@ -237,9 +237,9 @@ def __init__(
name=name,
offset=offset,
period_size=7,
log_rt_prior=log_rt_prior,
autoreg=autoreg,
periodic_diff_sd=periodic_diff_sd,
log_rt_rv=log_rt_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)

return None
50 changes: 50 additions & 0 deletions model/src/test/test_deterministic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# numpydoc ignore=GL08

import re

import jax.numpy as jnp
import numpy as np
import numpy.testing as testing
import pytest
from pyrenew.deterministic import (
DeterministicPMF,
DeterministicProcess,
Expand Down Expand Up @@ -62,3 +66,49 @@ def test_deterministic():

testing.assert_equal(var4()[0].value, None)
testing.assert_equal(var5(duration=1)[0].value, None)


def test_deterministic_validation():
"""
Check that validation methods for DeterministicVariable
work as expected.
"""
# validation should fail on construction
some_non_array_likes = [
{"a": jnp.array([1, 2.5, 3])},
# a valid pytree, but not an arraylike
"a string",
]
some_array_likes = [
5,
-3.023523,
np.array([1, 3.32, 5]),
jnp.array([-32, 23]),
jnp.array(-32),
np.array(5),
]

for non_arraylike_val in some_non_array_likes:
matchval = re.escape(
f"value {non_arraylike_val} passed to a "
"DeterministicVariable is of type "
f"{type(non_arraylike_val).__name__}, expected "
"an ArrayLike object"
)

with pytest.raises(ValueError, match=matchval):
# the class's validation function itself
# should raise an error when passed a
# non arraylike value
DeterministicVariable.validate(non_arraylike_val)

with pytest.raises(ValueError, match=matchval):
# validation should fail on constructor call
DeterministicVariable(
value=non_arraylike_val, name="invalid_variable"
)

# validation should succeed with ArrayLike
for arraylike_val in some_array_likes:
DeterministicVariable.validate(arraylike_val)
DeterministicVariable(value=arraylike_val, name="valid_variable")
2 changes: 1 addition & 1 deletion model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_forecast():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_latent_admissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_admissions_sample():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_latent_infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_infections_as_deterministic():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_default_rt():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_model_hosp_admissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_default_rt():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
Expand Down
4 changes: 1 addition & 3 deletions model/src/test/test_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)),
),
transforms=t.ExpTransform(),
)
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_random_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def create_test_model(): # numpydoc ignore=GL08
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
Expand Down
Loading

0 comments on commit 7001ee4

Please sign in to comment.