From f77579ea11d1493468a449c38a5b1bf380687e8b Mon Sep 17 00:00:00 2001 From: Damon Bayer <xum8@cdc.gov> Date: Mon, 29 Jul 2024 19:57:25 -0500 Subject: [PATCH 1/4] Rename random variables in rtperiodicdiff (#339) --- docs/source/tutorials/basic_renewal_model.qmd | 2 +- docs/source/tutorials/extending_pyrenew.qmd | 4 +- .../tutorials/hospital_admissions_model.qmd | 2 +- docs/source/tutorials/periodic_effects.qmd | 8 +-- model/src/pyrenew/process/rtperiodicdiff.py | 70 +++++++++---------- model/src/test/test_forecast.py | 2 +- model/src/test/test_latent_admissions.py | 2 +- model/src/test/test_latent_infections.py | 2 +- model/src/test/test_model_basic_renewal.py | 2 +- model/src/test/test_model_hosp_admissions.py | 2 +- model/src/test/test_predictive.py | 4 +- model/src/test/test_random_key.py | 2 +- model/src/test/test_rtperiodicdiff.py | 52 +++++++------- 13 files changed, 75 insertions(+), 79 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index eec79ec2..3ef776fe 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -152,7 +152,7 @@ class MyRt(RandomVariable): reparam=LocScaleReparam(0), ), init_rv=DistributionalRV( - name="init_log_Rt_rv", + name="init_log_rt", dist=dist.Normal(jnp.log(1), jnp.log(1.2)), ), ), diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 664cd3bd..bcfc2749 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -69,9 +69,7 @@ rt = TransformedRandomVariable( step_rv=DistributionalRV( name="rw_step_rv", dist=dist.Normal(0, 0.025) ), - init_rv=DistributionalRV( - name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) - ), + init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)), ), transforms=t.ExpTransform(), ) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 32bab32d..8d491cd9 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -202,7 +202,7 @@ class MyRt(metaclass.RandomVariable): name="rw_step_rv", dist=dist.Normal(0, sd_rt.value) ), init_rv=metaclass.DistributionalRV( - name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + name="init_log_rt", dist=dist.Normal(0, 0.2) ), ), transforms=transformation.ExpTransform(), diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 2cd1db8a..6eb0f876 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -27,13 +27,13 @@ from pyrenew import process, deterministic rt_proc = process.RtWeeklyDiffProcess( name="rt_weekly_diff", offset=0, - log_rt_prior=deterministic.DeterministicVariable( - name="log_rt_prior", value=jnp.array([0.1, 0.2]) + log_rt_rv=deterministic.DeterministicVariable( + name="log_rt", value=jnp.array([0.1, 0.2]) ), - autoreg=deterministic.DeterministicVariable( + autoreg_rv=deterministic.DeterministicVariable( name="autoreg", value=jnp.array([0.7]) ), - periodic_diff_sd=deterministic.DeterministicVariable( + periodic_diff_sd_rv=deterministic.DeterministicVariable( name="periodic_diff_sd", value=jnp.array([0.1]) ), ) diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index 1fd5da86..710d5d22 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -52,9 +52,9 @@ def __init__( name: str, offset: int, period_size: int, - log_rt_prior: RandomVariable, - autoreg: RandomVariable, - periodic_diff_sd: RandomVariable, + log_rt_rv: RandomVariable, + autoreg_rv: RandomVariable, + periodic_diff_sd_rv: RandomVariable, ) -> None: """ Default constructor for RtPeriodicDiffProcess class. @@ -66,11 +66,11 @@ def __init__( offset : int Relative point at which data starts, must be between 0 and period_size - 1. - log_rt_prior : RandomVariable + log_rt_rv : RandomVariable Log Rt prior for the first two observations. - autoreg : RandomVariable + autoreg_rv : RandomVariable Autoregressive parameter. - periodic_diff_sd : RandomVariable + periodic_diff_sd_rv : RandomVariable Standard deviation of the noise. Returns @@ -85,35 +85,35 @@ def __init__( ) self.validate( - log_rt_prior=log_rt_prior, - autoreg=autoreg, - periodic_diff_sd=periodic_diff_sd, + log_rt_rv=log_rt_rv, + autoreg_rv=autoreg_rv, + periodic_diff_sd_rv=periodic_diff_sd_rv, ) self.period_size = period_size self.offset = offset - self.log_rt_prior = log_rt_prior - self.autoreg = autoreg - self.periodic_diff_sd = periodic_diff_sd + self.log_rt_rv = log_rt_rv + self.autoreg_rv = autoreg_rv + self.periodic_diff_sd_rv = periodic_diff_sd_rv return None @staticmethod def validate( - log_rt_prior: any, - autoreg: any, - periodic_diff_sd: any, + log_rt_rv: any, + autoreg_rv: any, + periodic_diff_sd_rv: any, ) -> None: """ Validate the input parameters. Parameters ---------- - log_rt_prior : any + log_rt_rv : any Log Rt prior for the first two observations. - autoreg : any + autoreg_rv : any Autoregressive parameter. - periodic_diff_sd : any + periodic_diff_sd_rv : any Standard deviation of the noise. Returns @@ -121,9 +121,9 @@ def validate( None """ - _assert_sample_and_rtype(log_rt_prior) - _assert_sample_and_rtype(autoreg) - _assert_sample_and_rtype(periodic_diff_sd) + _assert_sample_and_rtype(log_rt_rv) + _assert_sample_and_rtype(autoreg_rv) + _assert_sample_and_rtype(periodic_diff_sd_rv) return None @@ -175,9 +175,9 @@ def sample( """ # Initial sample - log_rt_prior = self.log_rt_prior.sample(**kwargs)[0].value - b = self.autoreg.sample(**kwargs)[0].value - s_r = self.periodic_diff_sd.sample(**kwargs)[0].value + log_rt_rv = self.log_rt_rv.sample(**kwargs)[0].value + b = self.autoreg_rv.sample(**kwargs)[0].value + s_r = self.periodic_diff_sd_rv.sample(**kwargs)[0].value # How many periods to sample? n_periods = int(jnp.ceil(duration / self.period_size)) @@ -186,8 +186,8 @@ def sample( ar_diff = FirstDifferenceARProcess("trend_rw", autoreg=b, noise_sd=s_r) log_rt = ar_diff.sample( duration=n_periods, - init_val=log_rt_prior[1], - init_rate_of_change=log_rt_prior[1] - log_rt_prior[0], + init_val=log_rt_rv[1], + init_rate_of_change=log_rt_rv[1] - log_rt_rv[0], )[0] return RtPeriodicDiffProcessSample( @@ -208,9 +208,9 @@ def __init__( self, name: str, offset: int, - log_rt_prior: RandomVariable, - autoreg: RandomVariable, - periodic_diff_sd: RandomVariable, + log_rt_rv: RandomVariable, + autoreg_rv: RandomVariable, + periodic_diff_sd_rv: RandomVariable, ) -> None: """ Default constructor for RtWeeklyDiffProcess class. @@ -221,11 +221,11 @@ def __init__( Name of the site. offset : int Relative point at which data starts, must be between 0 and 6. - log_rt_prior : RandomVariable + log_rt_rv : RandomVariable Log Rt prior for the first two observations. - autoreg : RandomVariable + autoreg_rv : RandomVariable Autoregressive parameter. - periodic_diff_sd : RandomVariable + periodic_diff_sd_rv : RandomVariable Standard deviation of the noise. Returns @@ -237,9 +237,9 @@ def __init__( name=name, offset=offset, period_size=7, - log_rt_prior=log_rt_prior, - autoreg=autoreg, - periodic_diff_sd=periodic_diff_sd, + log_rt_rv=log_rt_rv, + autoreg_rv=autoreg_rv, + periodic_diff_sd_rv=periodic_diff_sd_rv, ) return None diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index f6b736a0..1e293b47 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -38,7 +38,7 @@ def test_forecast(): name="rw_step_rv", dist=dist.Normal(0, 0.025) ), init_rv=DistributionalRV( - name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + name="init_log_rt", dist=dist.Normal(0, 0.2) ), ), transforms=t.ExpTransform(), diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index cf841a8d..73f41c17 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -28,7 +28,7 @@ def test_admissions_sample(): name="rw_step_rv", dist=dist.Normal(0, 0.025) ), init_rv=DistributionalRV( - name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + name="init_log_rt", dist=dist.Normal(0, 0.2) ), ), transforms=t.ExpTransform(), diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index d55c7dff..fcfd3f99 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -26,7 +26,7 @@ def test_infections_as_deterministic(): name="rw_step_rv", dist=dist.Normal(0, 0.025) ), init_rv=DistributionalRV( - name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + name="init_log_rt", dist=dist.Normal(0, 0.2) ), ), transforms=t.ExpTransform(), diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index c44737c9..34fce28b 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -41,7 +41,7 @@ def get_default_rt(): name="rw_step_rv", dist=dist.Normal(0, 0.025) ), init_rv=DistributionalRV( - name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + name="init_log_rt", dist=dist.Normal(0, 0.2) ), ), transforms=t.ExpTransform(), diff --git a/model/src/test/test_model_hosp_admissions.py b/model/src/test/test_model_hosp_admissions.py index a573ec06..4766d9e6 100644 --- a/model/src/test/test_model_hosp_admissions.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -51,7 +51,7 @@ def get_default_rt(): name="rw_step_rv", dist=dist.Normal(0, 0.025) ), init_rv=DistributionalRV( - name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + name="init_log_rt", dist=dist.Normal(0, 0.2) ), ), transforms=t.ExpTransform(), diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py index 9c53521d..1a15a3c2 100644 --- a/model/src/test/test_predictive.py +++ b/model/src/test/test_predictive.py @@ -36,9 +36,7 @@ step_rv=DistributionalRV( name="rw_step_rv", dist=dist.Normal(0, 0.025) ), - init_rv=DistributionalRV( - name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) - ), + init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)), ), transforms=t.ExpTransform(), ) diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index d032bb93..4bcc644f 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -42,7 +42,7 @@ def create_test_model(): # numpydoc ignore=GL08 name="rw_step_rv", dist=dist.Normal(0, 0.025) ), init_rv=DistributionalRV( - name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + name="init_log_rt", dist=dist.Normal(0, 0.2) ), ), transforms=t.ExpTransform(), diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py index 4fb2cbfb..d29681ae 100644 --- a/model/src/test/test_rtperiodicdiff.py +++ b/model/src/test/test_rtperiodicdiff.py @@ -52,14 +52,14 @@ def test_rtweeklydiff() -> None: params = { "name": "test", "offset": 0, - "log_rt_prior": DeterministicVariable( - name="log_rt_prior", value=jnp.array([0.1, 0.2]) + "log_rt_rv": DeterministicVariable( + name="log_rt", value=jnp.array([0.1, 0.2]) ), - "autoreg": DeterministicVariable( - name="autoreg", value=jnp.array([0.7]) + "autoreg_rv": DeterministicVariable( + name="autoreg_rv", value=jnp.array([0.7]) ), - "periodic_diff_sd": DeterministicVariable( - name="periodic_diff_sd", value=jnp.array([0.1]) + "periodic_diff_sd_rv": DeterministicVariable( + name="periodic_diff_sd_rv", value=jnp.array([0.1]) ), } duration = 30 @@ -100,15 +100,15 @@ def test_rtweeklydiff_no_autoregressive() -> None: params = { "name": "test", "offset": 0, - "log_rt_prior": DeterministicVariable( - name="log_rt_prior", value=jnp.array([0.0, 0.0]) + "log_rt_rv": DeterministicVariable( + name="log_rt", value=jnp.array([0.0, 0.0]) ), # No autoregression! - "autoreg": DeterministicVariable( - name="autoreg", value=jnp.array([0.0]) + "autoreg_rv": DeterministicVariable( + name="autoreg_rv", value=jnp.array([0.0]) ), - "periodic_diff_sd": DeterministicVariable( - name="periodic_diff_sd", + "periodic_diff_sd_rv": DeterministicVariable( + name="periodic_diff_sd_rv", value=jnp.array([0.1]), ), } @@ -141,15 +141,15 @@ def test_rtweeklydiff_manual_reconstruction() -> None: params = { "name": "test", "offset": 0, - "log_rt_prior": DeterministicVariable( - name="log_rt_prior", + "log_rt_rv": DeterministicVariable( + name="log_rt", value=jnp.array([0.1, 0.2]), ), - "autoreg": DeterministicVariable( - name="autoreg", value=jnp.array([0.7]) + "autoreg_rv": DeterministicVariable( + name="autoreg_rv", value=jnp.array([0.7]) ), - "periodic_diff_sd": DeterministicVariable( - name="periodic_diff_sd", + "periodic_diff_sd_rv": DeterministicVariable( + name="periodic_diff_sd_rv", value=jnp.array([0.1]), ), } @@ -161,12 +161,12 @@ def test_rtweeklydiff_manual_reconstruction() -> None: _, ans0 = lax.scan( f=rtwd.autoreg_process, - init=np.hstack([params["log_rt_prior"]()[0].value, b]), + init=np.hstack([params["log_rt_rv"]()[0].value, b]), xs=noise, ) ans1 = _manual_rt_weekly_diff( - log_seed=params["log_rt_prior"]()[0].value, sd=noise, b=b + log_seed=params["log_rt_rv"]()[0].value, sd=noise, b=b ) assert_array_almost_equal(ans0, ans1) @@ -180,15 +180,15 @@ def test_rtperiodicdiff_smallsample(): params = { "name": "test", "offset": 0, - "log_rt_prior": DeterministicVariable( - name="log_rt_prior", + "log_rt_rv": DeterministicVariable( + name="log_rt", value=jnp.array([0.1, 0.2]), ), - "autoreg": DeterministicVariable( - name="autoreg", value=jnp.array([0.7]) + "autoreg_rv": DeterministicVariable( + name="autoreg_rv", value=jnp.array([0.7]) ), - "periodic_diff_sd": DeterministicVariable( - name="periodic_diff_sd", + "periodic_diff_sd_rv": DeterministicVariable( + name="periodic_diff_sd_rv", value=jnp.array([0.1]), ), } From fd4af556b4e821e9c9bfca0e564cccec4a087d73 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" <dylanhmorris@users.noreply.github.com> Date: Tue, 30 Jul 2024 13:29:28 -0400 Subject: [PATCH 2/4] Fix some docstrings and add validation testing for `deterministic.py` (#342) * Add new tests for deterministicvariable * Fix docs, error classes, and order of validation in deterministic.py * More meaningful error message and test for it --- .../pyrenew/deterministic/deterministic.py | 12 +++-- model/src/test/test_deterministic.py | 50 +++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 2bb03333..2b626ef2 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -40,8 +40,8 @@ def __init__( None """ self.name = name - self.value = jnp.atleast_1d(value) self.validate(value) + self.value = jnp.atleast_1d(value) self.set_timeseries(t_start, t_unit) return None @@ -49,7 +49,7 @@ def __init__( @staticmethod def validate(value: ArrayLike) -> None: """ - Validates input to DeterministicPMF + Validates input to DeterministicVariable Parameters ---------- @@ -63,10 +63,14 @@ def validate(value: ArrayLike) -> None: Raises ------ Exception - If the input value object is not a ArrayLike. + If the input value object is not an ArrayLike object. """ if not isinstance(value, ArrayLike): - raise Exception("value is not a ArrayLike") + raise ValueError( + f"value {value} passed to a DeterministicVariable " + f"is of type {type(value).__name__}, expected " + "an ArrayLike object" + ) return None diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py index 6e72f8cb..39e08e65 100644 --- a/model/src/test/test_deterministic.py +++ b/model/src/test/test_deterministic.py @@ -1,7 +1,11 @@ # numpydoc ignore=GL08 +import re + import jax.numpy as jnp +import numpy as np import numpy.testing as testing +import pytest from pyrenew.deterministic import ( DeterministicPMF, DeterministicProcess, @@ -62,3 +66,49 @@ def test_deterministic(): testing.assert_equal(var4()[0].value, None) testing.assert_equal(var5(duration=1)[0].value, None) + + +def test_deterministic_validation(): + """ + Check that validation methods for DeterministicVariable + work as expected. + """ + # validation should fail on construction + some_non_array_likes = [ + {"a": jnp.array([1, 2.5, 3])}, + # a valid pytree, but not an arraylike + "a string", + ] + some_array_likes = [ + 5, + -3.023523, + np.array([1, 3.32, 5]), + jnp.array([-32, 23]), + jnp.array(-32), + np.array(5), + ] + + for non_arraylike_val in some_non_array_likes: + matchval = re.escape( + f"value {non_arraylike_val} passed to a " + "DeterministicVariable is of type " + f"{type(non_arraylike_val).__name__}, expected " + "an ArrayLike object" + ) + + with pytest.raises(ValueError, match=matchval): + # the class's validation function itself + # should raise an error when passed a + # non arraylike value + DeterministicVariable.validate(non_arraylike_val) + + with pytest.raises(ValueError, match=matchval): + # validation should fail on constructor call + DeterministicVariable( + value=non_arraylike_val, name="invalid_variable" + ) + + # validation should succeed with ArrayLike + for arraylike_val in some_array_likes: + DeterministicVariable.validate(arraylike_val) + DeterministicVariable(value=arraylike_val, name="valid_variable") From dd3e0229949c2968e41a9503c7a35dd45e5033cd Mon Sep 17 00:00:00 2001 From: Damon Bayer <xum8@cdc.gov> Date: Tue, 30 Jul 2024 13:44:56 -0500 Subject: [PATCH 3/4] Fix name in `rtperiodicdiff.py` (#343) Update rtperiodicdiff.py --- model/src/pyrenew/process/rtperiodicdiff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index 710d5d22..658a5339 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -183,7 +183,7 @@ def sample( n_periods = int(jnp.ceil(duration / self.period_size)) # Running the process - ar_diff = FirstDifferenceARProcess("trend_rw", autoreg=b, noise_sd=s_r) + ar_diff = FirstDifferenceARProcess(self.name, autoreg=b, noise_sd=s_r) log_rt = ar_diff.sample( duration=n_periods, init_val=log_rt_rv[1], From 64e0ed832166d6c5ec2d0ca817c3cfdd1eeb0e0d Mon Sep 17 00:00:00 2001 From: Subekshya Bidari <37636707+sbidari@users.noreply.github.com> Date: Tue, 30 Jul 2024 17:12:58 -0400 Subject: [PATCH 4/4] add dependency check workflow (#346) --- .github/workflows/deptry.yml | 32 ++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/deptry.yml diff --git a/.github/workflows/deptry.yml b/.github/workflows/deptry.yml new file mode 100644 index 00000000..0607c736 --- /dev/null +++ b/.github/workflows/deptry.yml @@ -0,0 +1,32 @@ +name: model + +on: + pull_request: + push: + branches: [main] + +jobs: + dependency-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: cache poetry + uses: actions/cache@v4 + with: + path: ~/.local + key: ${{ runner.os }}-poetry + + - name: install poetry + run: pip install poetry + + - name: install package + run: poetry install --with dev + + - name: run deptry + run: | + poetry run deptry . --per-rule-ignores "DEP001=pyrenew,DEP003=pytest" diff --git a/pyproject.toml b/pyproject.toml index d7df60ca..3e0fe1fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,6 @@ exclude = [{path = "datasets/*.rds"}] python = "^3.12" numpyro = ">=0.15.1" jax = ">=0.4.30" -jaxlib = ">=0.4.30" numpy = "^1.26.4" polars = "^1.2.1" matplotlib = "^3.8.3" @@ -39,6 +38,7 @@ sphinx-book-theme = "^1.1.2" pillow = "^10.3.0" nbconvert = "^7.16.4" ipywidgets = "^8.1.3" +deptry = "^0.17.0" [tool.numpydoc_validation] checks = [