From f77579ea11d1493468a449c38a5b1bf380687e8b Mon Sep 17 00:00:00 2001
From: Damon Bayer <xum8@cdc.gov>
Date: Mon, 29 Jul 2024 19:57:25 -0500
Subject: [PATCH 1/4] Rename random variables in rtperiodicdiff (#339)

---
 docs/source/tutorials/basic_renewal_model.qmd |  2 +-
 docs/source/tutorials/extending_pyrenew.qmd   |  4 +-
 .../tutorials/hospital_admissions_model.qmd   |  2 +-
 docs/source/tutorials/periodic_effects.qmd    |  8 +--
 model/src/pyrenew/process/rtperiodicdiff.py   | 70 +++++++++----------
 model/src/test/test_forecast.py               |  2 +-
 model/src/test/test_latent_admissions.py      |  2 +-
 model/src/test/test_latent_infections.py      |  2 +-
 model/src/test/test_model_basic_renewal.py    |  2 +-
 model/src/test/test_model_hosp_admissions.py  |  2 +-
 model/src/test/test_predictive.py             |  4 +-
 model/src/test/test_random_key.py             |  2 +-
 model/src/test/test_rtperiodicdiff.py         | 52 +++++++-------
 13 files changed, 75 insertions(+), 79 deletions(-)

diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd
index eec79ec2..3ef776fe 100644
--- a/docs/source/tutorials/basic_renewal_model.qmd
+++ b/docs/source/tutorials/basic_renewal_model.qmd
@@ -152,7 +152,7 @@ class MyRt(RandomVariable):
                     reparam=LocScaleReparam(0),
                 ),
                 init_rv=DistributionalRV(
-                    name="init_log_Rt_rv",
+                    name="init_log_rt",
                     dist=dist.Normal(jnp.log(1), jnp.log(1.2)),
                 ),
             ),
diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd
index 664cd3bd..bcfc2749 100644
--- a/docs/source/tutorials/extending_pyrenew.qmd
+++ b/docs/source/tutorials/extending_pyrenew.qmd
@@ -69,9 +69,7 @@ rt = TransformedRandomVariable(
         step_rv=DistributionalRV(
             name="rw_step_rv", dist=dist.Normal(0, 0.025)
         ),
-        init_rv=DistributionalRV(
-            name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
-        ),
+        init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)),
     ),
     transforms=t.ExpTransform(),
 )
diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd
index 32bab32d..8d491cd9 100644
--- a/docs/source/tutorials/hospital_admissions_model.qmd
+++ b/docs/source/tutorials/hospital_admissions_model.qmd
@@ -202,7 +202,7 @@ class MyRt(metaclass.RandomVariable):
                     name="rw_step_rv", dist=dist.Normal(0, sd_rt.value)
                 ),
                 init_rv=metaclass.DistributionalRV(
-                    name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
+                    name="init_log_rt", dist=dist.Normal(0, 0.2)
                 ),
             ),
             transforms=transformation.ExpTransform(),
diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd
index 2cd1db8a..6eb0f876 100644
--- a/docs/source/tutorials/periodic_effects.qmd
+++ b/docs/source/tutorials/periodic_effects.qmd
@@ -27,13 +27,13 @@ from pyrenew import process, deterministic
 rt_proc = process.RtWeeklyDiffProcess(
     name="rt_weekly_diff",
     offset=0,
-    log_rt_prior=deterministic.DeterministicVariable(
-        name="log_rt_prior", value=jnp.array([0.1, 0.2])
+    log_rt_rv=deterministic.DeterministicVariable(
+        name="log_rt", value=jnp.array([0.1, 0.2])
     ),
-    autoreg=deterministic.DeterministicVariable(
+    autoreg_rv=deterministic.DeterministicVariable(
         name="autoreg", value=jnp.array([0.7])
     ),
-    periodic_diff_sd=deterministic.DeterministicVariable(
+    periodic_diff_sd_rv=deterministic.DeterministicVariable(
         name="periodic_diff_sd", value=jnp.array([0.1])
     ),
 )
diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py
index 1fd5da86..710d5d22 100644
--- a/model/src/pyrenew/process/rtperiodicdiff.py
+++ b/model/src/pyrenew/process/rtperiodicdiff.py
@@ -52,9 +52,9 @@ def __init__(
         name: str,
         offset: int,
         period_size: int,
-        log_rt_prior: RandomVariable,
-        autoreg: RandomVariable,
-        periodic_diff_sd: RandomVariable,
+        log_rt_rv: RandomVariable,
+        autoreg_rv: RandomVariable,
+        periodic_diff_sd_rv: RandomVariable,
     ) -> None:
         """
         Default constructor for RtPeriodicDiffProcess class.
@@ -66,11 +66,11 @@ def __init__(
         offset : int
             Relative point at which data starts, must be between 0 and
             period_size - 1.
-        log_rt_prior : RandomVariable
+        log_rt_rv : RandomVariable
             Log Rt prior for the first two observations.
-        autoreg : RandomVariable
+        autoreg_rv : RandomVariable
             Autoregressive parameter.
-        periodic_diff_sd : RandomVariable
+        periodic_diff_sd_rv : RandomVariable
             Standard deviation of the noise.
 
         Returns
@@ -85,35 +85,35 @@ def __init__(
         )
 
         self.validate(
-            log_rt_prior=log_rt_prior,
-            autoreg=autoreg,
-            periodic_diff_sd=periodic_diff_sd,
+            log_rt_rv=log_rt_rv,
+            autoreg_rv=autoreg_rv,
+            periodic_diff_sd_rv=periodic_diff_sd_rv,
         )
 
         self.period_size = period_size
         self.offset = offset
-        self.log_rt_prior = log_rt_prior
-        self.autoreg = autoreg
-        self.periodic_diff_sd = periodic_diff_sd
+        self.log_rt_rv = log_rt_rv
+        self.autoreg_rv = autoreg_rv
+        self.periodic_diff_sd_rv = periodic_diff_sd_rv
 
         return None
 
     @staticmethod
     def validate(
-        log_rt_prior: any,
-        autoreg: any,
-        periodic_diff_sd: any,
+        log_rt_rv: any,
+        autoreg_rv: any,
+        periodic_diff_sd_rv: any,
     ) -> None:
         """
         Validate the input parameters.
 
         Parameters
         ----------
-        log_rt_prior : any
+        log_rt_rv : any
             Log Rt prior for the first two observations.
-        autoreg : any
+        autoreg_rv : any
             Autoregressive parameter.
-        periodic_diff_sd : any
+        periodic_diff_sd_rv : any
             Standard deviation of the noise.
 
         Returns
@@ -121,9 +121,9 @@ def validate(
         None
         """
 
-        _assert_sample_and_rtype(log_rt_prior)
-        _assert_sample_and_rtype(autoreg)
-        _assert_sample_and_rtype(periodic_diff_sd)
+        _assert_sample_and_rtype(log_rt_rv)
+        _assert_sample_and_rtype(autoreg_rv)
+        _assert_sample_and_rtype(periodic_diff_sd_rv)
 
         return None
 
@@ -175,9 +175,9 @@ def sample(
         """
 
         # Initial sample
-        log_rt_prior = self.log_rt_prior.sample(**kwargs)[0].value
-        b = self.autoreg.sample(**kwargs)[0].value
-        s_r = self.periodic_diff_sd.sample(**kwargs)[0].value
+        log_rt_rv = self.log_rt_rv.sample(**kwargs)[0].value
+        b = self.autoreg_rv.sample(**kwargs)[0].value
+        s_r = self.periodic_diff_sd_rv.sample(**kwargs)[0].value
 
         # How many periods to sample?
         n_periods = int(jnp.ceil(duration / self.period_size))
@@ -186,8 +186,8 @@ def sample(
         ar_diff = FirstDifferenceARProcess("trend_rw", autoreg=b, noise_sd=s_r)
         log_rt = ar_diff.sample(
             duration=n_periods,
-            init_val=log_rt_prior[1],
-            init_rate_of_change=log_rt_prior[1] - log_rt_prior[0],
+            init_val=log_rt_rv[1],
+            init_rate_of_change=log_rt_rv[1] - log_rt_rv[0],
         )[0]
 
         return RtPeriodicDiffProcessSample(
@@ -208,9 +208,9 @@ def __init__(
         self,
         name: str,
         offset: int,
-        log_rt_prior: RandomVariable,
-        autoreg: RandomVariable,
-        periodic_diff_sd: RandomVariable,
+        log_rt_rv: RandomVariable,
+        autoreg_rv: RandomVariable,
+        periodic_diff_sd_rv: RandomVariable,
     ) -> None:
         """
         Default constructor for RtWeeklyDiffProcess class.
@@ -221,11 +221,11 @@ def __init__(
             Name of the site.
         offset : int
             Relative point at which data starts, must be between 0 and 6.
-        log_rt_prior : RandomVariable
+        log_rt_rv : RandomVariable
             Log Rt prior for the first two observations.
-        autoreg : RandomVariable
+        autoreg_rv : RandomVariable
             Autoregressive parameter.
-        periodic_diff_sd : RandomVariable
+        periodic_diff_sd_rv : RandomVariable
             Standard deviation of the noise.
 
         Returns
@@ -237,9 +237,9 @@ def __init__(
             name=name,
             offset=offset,
             period_size=7,
-            log_rt_prior=log_rt_prior,
-            autoreg=autoreg,
-            periodic_diff_sd=periodic_diff_sd,
+            log_rt_rv=log_rt_rv,
+            autoreg_rv=autoreg_rv,
+            periodic_diff_sd_rv=periodic_diff_sd_rv,
         )
 
         return None
diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py
index f6b736a0..1e293b47 100644
--- a/model/src/test/test_forecast.py
+++ b/model/src/test/test_forecast.py
@@ -38,7 +38,7 @@ def test_forecast():
                 name="rw_step_rv", dist=dist.Normal(0, 0.025)
             ),
             init_rv=DistributionalRV(
-                name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
+                name="init_log_rt", dist=dist.Normal(0, 0.2)
             ),
         ),
         transforms=t.ExpTransform(),
diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py
index cf841a8d..73f41c17 100644
--- a/model/src/test/test_latent_admissions.py
+++ b/model/src/test/test_latent_admissions.py
@@ -28,7 +28,7 @@ def test_admissions_sample():
                 name="rw_step_rv", dist=dist.Normal(0, 0.025)
             ),
             init_rv=DistributionalRV(
-                name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
+                name="init_log_rt", dist=dist.Normal(0, 0.2)
             ),
         ),
         transforms=t.ExpTransform(),
diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py
index d55c7dff..fcfd3f99 100755
--- a/model/src/test/test_latent_infections.py
+++ b/model/src/test/test_latent_infections.py
@@ -26,7 +26,7 @@ def test_infections_as_deterministic():
                 name="rw_step_rv", dist=dist.Normal(0, 0.025)
             ),
             init_rv=DistributionalRV(
-                name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
+                name="init_log_rt", dist=dist.Normal(0, 0.2)
             ),
         ),
         transforms=t.ExpTransform(),
diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py
index c44737c9..34fce28b 100644
--- a/model/src/test/test_model_basic_renewal.py
+++ b/model/src/test/test_model_basic_renewal.py
@@ -41,7 +41,7 @@ def get_default_rt():
                 name="rw_step_rv", dist=dist.Normal(0, 0.025)
             ),
             init_rv=DistributionalRV(
-                name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
+                name="init_log_rt", dist=dist.Normal(0, 0.2)
             ),
         ),
         transforms=t.ExpTransform(),
diff --git a/model/src/test/test_model_hosp_admissions.py b/model/src/test/test_model_hosp_admissions.py
index a573ec06..4766d9e6 100644
--- a/model/src/test/test_model_hosp_admissions.py
+++ b/model/src/test/test_model_hosp_admissions.py
@@ -51,7 +51,7 @@ def get_default_rt():
                 name="rw_step_rv", dist=dist.Normal(0, 0.025)
             ),
             init_rv=DistributionalRV(
-                name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
+                name="init_log_rt", dist=dist.Normal(0, 0.2)
             ),
         ),
         transforms=t.ExpTransform(),
diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py
index 9c53521d..1a15a3c2 100644
--- a/model/src/test/test_predictive.py
+++ b/model/src/test/test_predictive.py
@@ -36,9 +36,7 @@
         step_rv=DistributionalRV(
             name="rw_step_rv", dist=dist.Normal(0, 0.025)
         ),
-        init_rv=DistributionalRV(
-            name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
-        ),
+        init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)),
     ),
     transforms=t.ExpTransform(),
 )
diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py
index d032bb93..4bcc644f 100644
--- a/model/src/test/test_random_key.py
+++ b/model/src/test/test_random_key.py
@@ -42,7 +42,7 @@ def create_test_model():  # numpydoc ignore=GL08
                 name="rw_step_rv", dist=dist.Normal(0, 0.025)
             ),
             init_rv=DistributionalRV(
-                name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
+                name="init_log_rt", dist=dist.Normal(0, 0.2)
             ),
         ),
         transforms=t.ExpTransform(),
diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py
index 4fb2cbfb..d29681ae 100644
--- a/model/src/test/test_rtperiodicdiff.py
+++ b/model/src/test/test_rtperiodicdiff.py
@@ -52,14 +52,14 @@ def test_rtweeklydiff() -> None:
     params = {
         "name": "test",
         "offset": 0,
-        "log_rt_prior": DeterministicVariable(
-            name="log_rt_prior", value=jnp.array([0.1, 0.2])
+        "log_rt_rv": DeterministicVariable(
+            name="log_rt", value=jnp.array([0.1, 0.2])
         ),
-        "autoreg": DeterministicVariable(
-            name="autoreg", value=jnp.array([0.7])
+        "autoreg_rv": DeterministicVariable(
+            name="autoreg_rv", value=jnp.array([0.7])
         ),
-        "periodic_diff_sd": DeterministicVariable(
-            name="periodic_diff_sd", value=jnp.array([0.1])
+        "periodic_diff_sd_rv": DeterministicVariable(
+            name="periodic_diff_sd_rv", value=jnp.array([0.1])
         ),
     }
     duration = 30
@@ -100,15 +100,15 @@ def test_rtweeklydiff_no_autoregressive() -> None:
     params = {
         "name": "test",
         "offset": 0,
-        "log_rt_prior": DeterministicVariable(
-            name="log_rt_prior", value=jnp.array([0.0, 0.0])
+        "log_rt_rv": DeterministicVariable(
+            name="log_rt", value=jnp.array([0.0, 0.0])
         ),
         # No autoregression!
-        "autoreg": DeterministicVariable(
-            name="autoreg", value=jnp.array([0.0])
+        "autoreg_rv": DeterministicVariable(
+            name="autoreg_rv", value=jnp.array([0.0])
         ),
-        "periodic_diff_sd": DeterministicVariable(
-            name="periodic_diff_sd",
+        "periodic_diff_sd_rv": DeterministicVariable(
+            name="periodic_diff_sd_rv",
             value=jnp.array([0.1]),
         ),
     }
@@ -141,15 +141,15 @@ def test_rtweeklydiff_manual_reconstruction() -> None:
     params = {
         "name": "test",
         "offset": 0,
-        "log_rt_prior": DeterministicVariable(
-            name="log_rt_prior",
+        "log_rt_rv": DeterministicVariable(
+            name="log_rt",
             value=jnp.array([0.1, 0.2]),
         ),
-        "autoreg": DeterministicVariable(
-            name="autoreg", value=jnp.array([0.7])
+        "autoreg_rv": DeterministicVariable(
+            name="autoreg_rv", value=jnp.array([0.7])
         ),
-        "periodic_diff_sd": DeterministicVariable(
-            name="periodic_diff_sd",
+        "periodic_diff_sd_rv": DeterministicVariable(
+            name="periodic_diff_sd_rv",
             value=jnp.array([0.1]),
         ),
     }
@@ -161,12 +161,12 @@ def test_rtweeklydiff_manual_reconstruction() -> None:
 
     _, ans0 = lax.scan(
         f=rtwd.autoreg_process,
-        init=np.hstack([params["log_rt_prior"]()[0].value, b]),
+        init=np.hstack([params["log_rt_rv"]()[0].value, b]),
         xs=noise,
     )
 
     ans1 = _manual_rt_weekly_diff(
-        log_seed=params["log_rt_prior"]()[0].value, sd=noise, b=b
+        log_seed=params["log_rt_rv"]()[0].value, sd=noise, b=b
     )
 
     assert_array_almost_equal(ans0, ans1)
@@ -180,15 +180,15 @@ def test_rtperiodicdiff_smallsample():
     params = {
         "name": "test",
         "offset": 0,
-        "log_rt_prior": DeterministicVariable(
-            name="log_rt_prior",
+        "log_rt_rv": DeterministicVariable(
+            name="log_rt",
             value=jnp.array([0.1, 0.2]),
         ),
-        "autoreg": DeterministicVariable(
-            name="autoreg", value=jnp.array([0.7])
+        "autoreg_rv": DeterministicVariable(
+            name="autoreg_rv", value=jnp.array([0.7])
         ),
-        "periodic_diff_sd": DeterministicVariable(
-            name="periodic_diff_sd",
+        "periodic_diff_sd_rv": DeterministicVariable(
+            name="periodic_diff_sd_rv",
             value=jnp.array([0.1]),
         ),
     }

From fd4af556b4e821e9c9bfca0e564cccec4a087d73 Mon Sep 17 00:00:00 2001
From: "Dylan H. Morris" <dylanhmorris@users.noreply.github.com>
Date: Tue, 30 Jul 2024 13:29:28 -0400
Subject: [PATCH 2/4] Fix some docstrings and add validation testing for
 `deterministic.py` (#342)

* Add new tests for deterministicvariable

* Fix docs, error classes, and order of validation in deterministic.py

* More meaningful error message and test for it
---
 .../pyrenew/deterministic/deterministic.py    | 12 +++--
 model/src/test/test_deterministic.py          | 50 +++++++++++++++++++
 2 files changed, 58 insertions(+), 4 deletions(-)

diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py
index 2bb03333..2b626ef2 100644
--- a/model/src/pyrenew/deterministic/deterministic.py
+++ b/model/src/pyrenew/deterministic/deterministic.py
@@ -40,8 +40,8 @@ def __init__(
         None
         """
         self.name = name
-        self.value = jnp.atleast_1d(value)
         self.validate(value)
+        self.value = jnp.atleast_1d(value)
         self.set_timeseries(t_start, t_unit)
 
         return None
@@ -49,7 +49,7 @@ def __init__(
     @staticmethod
     def validate(value: ArrayLike) -> None:
         """
-        Validates input to DeterministicPMF
+        Validates input to DeterministicVariable
 
         Parameters
         ----------
@@ -63,10 +63,14 @@ def validate(value: ArrayLike) -> None:
         Raises
         ------
         Exception
-            If the input value object is not a ArrayLike.
+            If the input value object is not an ArrayLike object.
         """
         if not isinstance(value, ArrayLike):
-            raise Exception("value is not a ArrayLike")
+            raise ValueError(
+                f"value {value} passed to a DeterministicVariable "
+                f"is of type {type(value).__name__}, expected "
+                "an ArrayLike object"
+            )
 
         return None
 
diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py
index 6e72f8cb..39e08e65 100644
--- a/model/src/test/test_deterministic.py
+++ b/model/src/test/test_deterministic.py
@@ -1,7 +1,11 @@
 # numpydoc ignore=GL08
 
+import re
+
 import jax.numpy as jnp
+import numpy as np
 import numpy.testing as testing
+import pytest
 from pyrenew.deterministic import (
     DeterministicPMF,
     DeterministicProcess,
@@ -62,3 +66,49 @@ def test_deterministic():
 
     testing.assert_equal(var4()[0].value, None)
     testing.assert_equal(var5(duration=1)[0].value, None)
+
+
+def test_deterministic_validation():
+    """
+    Check that validation methods for DeterministicVariable
+    work as expected.
+    """
+    # validation should fail on construction
+    some_non_array_likes = [
+        {"a": jnp.array([1, 2.5, 3])},
+        # a valid pytree, but not an arraylike
+        "a string",
+    ]
+    some_array_likes = [
+        5,
+        -3.023523,
+        np.array([1, 3.32, 5]),
+        jnp.array([-32, 23]),
+        jnp.array(-32),
+        np.array(5),
+    ]
+
+    for non_arraylike_val in some_non_array_likes:
+        matchval = re.escape(
+            f"value {non_arraylike_val} passed to a "
+            "DeterministicVariable is of type "
+            f"{type(non_arraylike_val).__name__}, expected "
+            "an ArrayLike object"
+        )
+
+        with pytest.raises(ValueError, match=matchval):
+            # the class's validation function itself
+            # should raise an error when passed a
+            # non arraylike value
+            DeterministicVariable.validate(non_arraylike_val)
+
+        with pytest.raises(ValueError, match=matchval):
+            # validation should fail on constructor call
+            DeterministicVariable(
+                value=non_arraylike_val, name="invalid_variable"
+            )
+
+    # validation should succeed with ArrayLike
+    for arraylike_val in some_array_likes:
+        DeterministicVariable.validate(arraylike_val)
+        DeterministicVariable(value=arraylike_val, name="valid_variable")

From dd3e0229949c2968e41a9503c7a35dd45e5033cd Mon Sep 17 00:00:00 2001
From: Damon Bayer <xum8@cdc.gov>
Date: Tue, 30 Jul 2024 13:44:56 -0500
Subject: [PATCH 3/4] Fix name in `rtperiodicdiff.py` (#343)

Update rtperiodicdiff.py
---
 model/src/pyrenew/process/rtperiodicdiff.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py
index 710d5d22..658a5339 100644
--- a/model/src/pyrenew/process/rtperiodicdiff.py
+++ b/model/src/pyrenew/process/rtperiodicdiff.py
@@ -183,7 +183,7 @@ def sample(
         n_periods = int(jnp.ceil(duration / self.period_size))
 
         # Running the process
-        ar_diff = FirstDifferenceARProcess("trend_rw", autoreg=b, noise_sd=s_r)
+        ar_diff = FirstDifferenceARProcess(self.name, autoreg=b, noise_sd=s_r)
         log_rt = ar_diff.sample(
             duration=n_periods,
             init_val=log_rt_rv[1],

From 64e0ed832166d6c5ec2d0ca817c3cfdd1eeb0e0d Mon Sep 17 00:00:00 2001
From: Subekshya Bidari <37636707+sbidari@users.noreply.github.com>
Date: Tue, 30 Jul 2024 17:12:58 -0400
Subject: [PATCH 4/4] add dependency check workflow (#346)

---
 .github/workflows/deptry.yml | 32 ++++++++++++++++++++++++++++++++
 pyproject.toml               |  2 +-
 2 files changed, 33 insertions(+), 1 deletion(-)
 create mode 100644 .github/workflows/deptry.yml

diff --git a/.github/workflows/deptry.yml b/.github/workflows/deptry.yml
new file mode 100644
index 00000000..0607c736
--- /dev/null
+++ b/.github/workflows/deptry.yml
@@ -0,0 +1,32 @@
+name: model
+
+on:
+  pull_request:
+  push:
+    branches: [main]
+
+jobs:
+  dependency-check:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v4
+
+      - uses: actions/setup-python@v5
+        with:
+          python-version: "3.12"
+
+      - name: cache poetry
+        uses: actions/cache@v4
+        with:
+          path: ~/.local
+          key: ${{ runner.os }}-poetry
+
+      - name: install poetry
+        run: pip install poetry
+
+      - name: install package
+        run: poetry install --with dev
+
+      - name: run deptry
+        run: |
+          poetry run deptry . --per-rule-ignores "DEP001=pyrenew,DEP003=pytest"
diff --git a/pyproject.toml b/pyproject.toml
index d7df60ca..3e0fe1fe 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,7 +15,6 @@ exclude = [{path = "datasets/*.rds"}]
 python = "^3.12"
 numpyro = ">=0.15.1"
 jax = ">=0.4.30"
-jaxlib = ">=0.4.30"
 numpy = "^1.26.4"
 polars = "^1.2.1"
 matplotlib = "^3.8.3"
@@ -39,6 +38,7 @@ sphinx-book-theme = "^1.1.2"
 pillow = "^10.3.0"
 nbconvert = "^7.16.4"
 ipywidgets = "^8.1.3"
+deptry = "^0.17.0"
 
 [tool.numpydoc_validation]
 checks = [