diff --git a/docs/source/conf.py b/docs/source/conf.py index 0d991949..ac07b219 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,6 +26,7 @@ "sphinx.ext.autodoc", "sphinx.ext.autosectionlabel", "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", "sphinx.ext.doctest", "sphinx.ext.napoleon", # numpydoc "sphinx.ext.duration", @@ -87,3 +88,22 @@ myst_fence_as_directive = ["mermaid"] myst_enable_extensions = ["amsmath", "dollarmath"] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "matplotlib": ("https://matplotlib.org/stable/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "numpyro": ("https://num.pyro.ai/en/latest/", None), + "jax": ("https://jax.readthedocs.io/en/latest/", None), + "polars": ("https://docs.pola.rs/api/python/stable/", None), +} + +napoleon_preprocess_types = True +autodoc_typehints = "description" +autodoc_typehints_format = "short" +autodoc_type_aliases = { + "ArrayLike": ":obj:`ArrayLike `", + "RandomVariable": ":class:`RandomVariable `", + "Any": ":obj:`Any `", +} +napoleon_type_aliases = autodoc_type_aliases diff --git a/pyrenew/convolve.py b/pyrenew/convolve.py index 8b99853c..5ab8ae44 100755 --- a/pyrenew/convolve.py +++ b/pyrenew/convolve.py @@ -8,8 +8,9 @@ :py:func:`jax.lax.scan`. Factories generate functions that can be passed to -:py:func:`jax.lax.scan` with an -appropriate array to scan along. +:func:`jax.lax.scan` or +:func:`numpyro.contrib.control_flow.scan` +with an appropriate array to scan along. """ from __future__ import annotations @@ -26,7 +27,8 @@ def new_convolve_scanner( ) -> Callable: r""" Factory function to create a "scanner" function - that can be used with :py:func:`jax.lax.scan` to + that can be used with :func:`jax.lax.scan` or + :func:`numpyro.contrib.control_flow.scan` to construct an array via backward-looking iterative convolution. @@ -44,7 +46,9 @@ def new_convolve_scanner( ------- Callable A scanner function that can be used with - :py:func:`jax.lax.scan` for convolution. + :func:`jax.lax.scan` or + :func:`numpyro.contrib.control_flow.scan` + for convolution. This function takes a history subset array and a scalar, computes the dot product of the supplied convolution array with the history diff --git a/pyrenew/latent/hospitaladmissions.py b/pyrenew/latent/hospitaladmissions.py index 99f46c13..81cedb5b 100644 --- a/pyrenew/latent/hospitaladmissions.py +++ b/pyrenew/latent/hospitaladmissions.py @@ -16,7 +16,8 @@ class HospitalAdmissionsSample(NamedTuple): """ - A container to hold the output of `latent.HospAdmissions()`. + A container to hold the output of + :meth:`HospitalAdmissions.sample`. Attributes ---------- @@ -192,10 +193,12 @@ def sample( Parameters ---------- latent_infections : ArrayLike - Latent infections. Possibly the output of the `latent.Infections()`. + Latent infections. **kwargs : dict, optional - Additional keyword arguments passed through to internal `sample()` - calls, should there be any. + Additional keyword arguments passed through to + internal :meth:`sample() + ` calls, + should there be any. Returns ------- @@ -217,7 +220,8 @@ def sample( # Applying the day of the week effect. For this we need to: # 1. Get the day of the week effect # 2. Identify the offset of the latent_infections - # 3. Apply the day of the week effect to the latent_hospital_admissions + # 3. Apply the day of the week effect to the + # latent_hospital_admissions dow_effect_sampled = self.day_of_week_effect_rv(**kwargs) if dow_effect_sampled.size != 7: diff --git a/pyrenew/latent/infection_functions.py b/pyrenew/latent/infection_functions.py index 5fdbff6b..b7d7e766 100755 --- a/pyrenew/latent/infection_functions.py +++ b/pyrenew/latent/infection_functions.py @@ -14,11 +14,11 @@ def compute_infections_from_rt( I0: ArrayLike, Rt: ArrayLike, reversed_generation_interval_pmf: ArrayLike, -) -> ArrayLike: +) -> jnp.ndarray: """ Generate infections according to a renewal process with a time-varying - reproduction number R(t) + reproduction number :math:`\\mathcal{R}(t)` Parameters ---------- @@ -27,7 +27,7 @@ def compute_infections_from_rt( same length as the generation interval pmf vector. Rt : ArrayLike - Timeseries of R(t) values + Timeseries of :math:`\\mathcal{R}(t)` values reversed_generation_interval_pmf : ArrayLike discrete probability mass vector representing the generation interval @@ -38,8 +38,8 @@ def compute_infections_from_rt( Returns ------- - ArrayLike - The timeseries of infections, as a JAX array + jnp.ndarray + The timeseries of infections. """ incidence_func = new_convolve_scanner( reversed_generation_interval_pmf, IdentityTransform() @@ -58,8 +58,9 @@ def logistic_susceptibility_adjustment( """ Apply the logistic susceptibility adjustment to a potential new - incidence I_unadjusted proposed in - equation 6 of Bhatt et al 2023 [1]_ + incidence ``I_raw_t`` proposed in + equation 6 of `Bhatt et al 2023 + `_. Parameters ---------- @@ -76,16 +77,7 @@ def logistic_susceptibility_adjustment( Returns ------- float - The adjusted value of I(t) - - References - ---------- - .. [1] Bhatt, Samir, et al. - "Semi-mechanistic Bayesian modelling of - COVID-19 with renewal processes." - Journal of the Royal Statistical Society - Series A: Statistics in Society 186.4 (2023): 601-615. - https://doi.org/10.1093/jrsssa/qnad030 + The adjusted value of :math:`I(t)`. """ approx_frac_infected = 1 - jnp.exp(-I_raw_t / n_population) return n_population * frac_susceptible * approx_frac_infected @@ -101,8 +93,8 @@ def compute_infections_from_rt_with_feedback( r""" Generate infections according to a renewal process with infection - feedback (generalizing Asher 2018: - https://doi.org/10.1016/j.epidem.2017.02.009) + feedback (generalizing `Asher 2018 + `_). Parameters ---------- @@ -111,7 +103,7 @@ def compute_infections_from_rt_with_feedback( same length as the generation interval pmf vector. Rt_raw : ArrayLike - Timeseries of raw R(t) values not + Timeseries of raw :math:`\mathcal{R}(t)` values not adjusted by infection feedback infection_feedback_strength : ArrayLike Strength of the infection feedback. @@ -139,21 +131,21 @@ def compute_infections_from_rt_with_feedback( Returns ------- tuple - A tuple `(infections, Rt_adjusted)`, - where `Rt_adjusted` is the infection-feedback-adjusted - timeseries of the reproduction number R(t) and - infections is the incident infection timeseries. + A tuple ``(infections, Rt_adjusted)``, + where ``Rt_adjusted`` is the infection-feedback-adjusted + timeseries of the reproduction number :math:`\mathcal{R}(t)` + and ``infections`` is the incident infection timeseries. Notes ----- This function implements the following renewal process: .. math:: - - I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) - + \begin{aligned} + I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\ \mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(\gamma(t)\ \sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right) + \end{aligned} where :math:`\mathcal{R}(t)` is the reproductive number, :math:`\gamma(t)` is the infection feedback strength, @@ -166,7 +158,7 @@ def compute_infections_from_rt_with_feedback( that recent incident infections reduce :math:`\mathcal{R}(t)` below its raw value in the absence of feedback, while positive :math:`\gamma` implies that recent incident infections - _increase_ :math:`\mathcal{R}(t)` above its raw value, and + *increase* :math:`\mathcal{R}(t)` above its raw value, and :math:`\gamma(t)=0` implies no feedback. In general, negative :math:`\gamma` is the more common modeling diff --git a/pyrenew/latent/infection_initialization_method.py b/pyrenew/latent/infection_initialization_method.py index 116785f0..e482d5e6 100644 --- a/pyrenew/latent/infection_initialization_method.py +++ b/pyrenew/latent/infection_initialization_method.py @@ -11,12 +11,14 @@ class InfectionInitializationMethod(metaclass=ABCMeta): """Method for initializing infections in a renewal process.""" def __init__(self, n_timepoints: int): - """Default constructor for the ``InfectionInitializationMethod`` class. + """Default constructor for + :class:`InfectionInitializationMethod`. Parameters ---------- n_timepoints : int - the number of time points to generate initial infections for + the number of time points for which to + generate initial infections Returns ------- @@ -27,7 +29,10 @@ def __init__(self, n_timepoints: int): @staticmethod def validate(n_timepoints: int) -> None: - """Validate inputs for the ``InfectionInitializationMethod`` class constructor + """ + Validate inputs to the + :class:`InfectionInitializationMethod` + constructor. Parameters ---------- @@ -99,7 +104,7 @@ def initialize_infections(self, I_pre_init: ArrayLike): class InitializeInfectionsFromVec(InfectionInitializationMethod): """Create initial infections from a vector of infections.""" - def initialize_infections(self, I_pre_init: ArrayLike): + def initialize_infections(self, I_pre_init: ArrayLike) -> ArrayLike: """Create initial infections from a vector of infections. Parameters @@ -112,7 +117,7 @@ def initialize_infections(self, I_pre_init: ArrayLike): ------- ArrayLike An array of length ``n_timepoints`` with the number of - initialized infections at each time point. + initialized infections at each time point. """ I_pre_init = jnp.array(I_pre_init) if I_pre_init.size != self.n_timepoints: diff --git a/pyrenew/latent/infection_initialization_process.py b/pyrenew/latent/infection_initialization_process.py index ef9525f4..4fcc4a35 100644 --- a/pyrenew/latent/infection_initialization_process.py +++ b/pyrenew/latent/infection_initialization_process.py @@ -1,4 +1,6 @@ # numpydoc ignore=GL08 +from __future__ import annotations + from jax.typing import ArrayLike from pyrenew.latent.infection_initialization_method import ( diff --git a/pyrenew/latent/infections.py b/pyrenew/latent/infections.py index 7a8b02a8..cb13e53e 100644 --- a/pyrenew/latent/infections.py +++ b/pyrenew/latent/infections.py @@ -13,24 +13,23 @@ class InfectionsSample(NamedTuple): """ - A container for holding the output from `latent.Infections()`. + A container for holding the output from + :meth:`Infections.sample() + `. Attributes ---------- - post_initialization_infections : SampledValue | None, optional - The estimated latent infections. Defaults to None. + post_initialization_infections: + The estimated latent infections. Default :obj:`None`. """ post_initialization_infections: ArrayLike | None = None - def __repr__(self): - return f"InfectionsSample(post_initialization_infections={self.post_initialization_infections})" - class Infections(RandomVariable): r"""Latent infections - This class samples infections given Rt, + This class samples infections given :math:`\mathcal{R}(t)`, initial infections, and generation interval. Notes @@ -57,9 +56,10 @@ def sample( gen_int: ArrayLike, **kwargs, ) -> InfectionsSample: - """ - Samples infections given Rt, initial infections, and generation - interval. + r""" + Sample infections given + :math:`\mathcal{R}(t)`, initial infections, + and generation interval. Parameters ---------- @@ -78,7 +78,8 @@ def sample( Returns ------- InfectionsSample - Named tuple with "infections". + A named tuple with a + ``post_initialization_infections`` field. """ if I0.shape[0] < gen_int.size: raise ValueError( @@ -90,7 +91,8 @@ def sample( if I0.shape[1:] != Rt.shape[1:]: raise ValueError( - "Initial infections and Rt must have the same batch shapes. " + "Initial infections and Rt must have the " + "same batch shapes. " f"Got initial infections of batch shape {I0.shape[1:]} " f"and Rt of batch shape {Rt.shape[1:]}." ) diff --git a/pyrenew/math.py b/pyrenew/math.py index a43a1b89..db942008 100755 --- a/pyrenew/math.py +++ b/pyrenew/math.py @@ -19,7 +19,8 @@ def get_leslie_matrix( Create the Leslie matrix corresponding to a basic renewal process with the - given R value and discrete + given :math:`\\mathcal{R}` + value and discrete generation interval pmf vector. @@ -79,7 +80,8 @@ def get_asymptotic_growth_rate_and_age_dist( Raises ------ ValueError - If an age distribution vector with non-zero imaginary part is produced. + If an age distribution vector with non-zero + imaginary part is produced. """ L = get_leslie_matrix(R, generation_interval_pmf) eigenvals, eigenvecs = jnp.linalg.eig(L) @@ -148,8 +150,8 @@ def get_asymptotic_growth_rate( """ Get the asymptotic per timestep growth rate for a renewal process with a given value of - R and a given discrete generation interval - probability mass vector. + :math:`\\mathcal{R}` and a given discrete + generation interval probability mass vector. This function computes that growth rate finding the dominant eigenvalue of the diff --git a/pyrenew/mcmcutils.py b/pyrenew/mcmcutils.py index f5f0c404..ee631871 100644 --- a/pyrenew/mcmcutils.py +++ b/pyrenew/mcmcutils.py @@ -24,16 +24,18 @@ def spread_draws( Parameters ---------- posteriors: dict - A dictionary of posteriors with variable names as keys and numpy - ndarrays as values (with the first axis corresponding to the posterior - draw number. + A dictionary of posteriors with variable names + as keys and numpy ndarrays as values (with the + first axis corresponding to the posterior + draw number). variables_names: list[str] | list[tuple] - list of strings or of tuples identifying which variables to retrieve. + list of strings or of tuples identifying which + variables to retrieve. Returns ------- pl.DataFrame - A dataframe of draw-indexed + A polars dataframe of draw-indexed posterior samples. """ for i_var, v in enumerate(variables_names): diff --git a/pyrenew/metaclass.py b/pyrenew/metaclass.py index 3bfdba81..26e68cb2 100644 --- a/pyrenew/metaclass.py +++ b/pyrenew/metaclass.py @@ -69,8 +69,8 @@ def sample( Parameters ---------- **kwargs : dict, optional - Additional keyword arguments passed through to internal `sample()` - calls, should there be any. + Additional keyword arguments passed through to internal + :meth:`sample` calls, should there be any. Returns ------- @@ -88,7 +88,7 @@ def validate(**kwargs) -> None: def __call__(self, **kwargs): """ - Alias for `sample()`. + Alias for :meth:`sample`. """ return self.sample(**kwargs) @@ -123,7 +123,7 @@ def sample( ---------- **kwargs : dict, optional Additional keyword arguments passed through to internal - `sample()` calls, should there be any. + :meth:`sample` calls, should there be any. Returns ------- @@ -138,8 +138,8 @@ def model(self, **kwargs) -> tuple: Parameters ---------- **kwargs : dict, optional - Additional keyword arguments passed through to internal `sample()` - calls, should there be any. + Additional keyword arguments passed through to + internal :meth:`sample` calls, should there be any. Returns ------- @@ -160,10 +160,13 @@ def _init_model( Parameters ---------- nuts_args : dict, optional - Dictionary of arguments passed to NUTS. Defaults to None. + Dictionary of arguments passed to the + :class:`numpyro.infer.hmc.NUTS` constructor. + Default None. mcmc_args : dict, optional - Dictionary of arguments passed to the MCMC sampler. Defaults to - None. + Dictionary of arguments passed to the + :class:`numpyro.infer.mcmc.MCMC` constructor. + Default None. Returns ------- @@ -211,12 +214,12 @@ def run( Parameters ---------- nuts_args : dict, optional - Dictionary of arguments passed to the - :class:`numpyro.infer.NUTS` kernel. + Dictionary of arguments passed to the kernel + (:class:`numpyro.infer.hmc.NUTS`) constructor. Defaults to None. mcmc_args : dict, optional - Dictionary of arguments passed to the - :class:`numpyro.infer.MCMC` constructor. + Dictionary of arguments passed to the MCMC runner + (:class:`numpyro.infer.mcmc.MCMC`) constructor. Defaults to None. Returns @@ -246,7 +249,8 @@ def print_summary( exclude_deterministic: bool = True, ) -> None: """ - A wrapper of :meth:`numpyro.infer.MCMC.print_summary` + A wrapper of :meth:`MCMC.print_summary() + `. Parameters ---------- @@ -264,7 +268,7 @@ def print_summary( def spread_draws(self, variables_names: list) -> pl.DataFrame: """ - A wrapper of mcmcutils.spread_draws + A wrapper of :func:`pyrenew.mcmcutils.spread_draws` Parameters ---------- @@ -309,7 +313,7 @@ def posterior_predictive( **kwargs, ) -> dict: """ - A wrapper for :class:`numpyro.infer.Predictive` to generate + A wrapper of :class:`numpyro.infer.util.Predictive` to generate posterior predictive samples. Parameters @@ -318,10 +322,11 @@ def posterior_predictive( Random key for the Predictive function call. Defaults to None. numpyro_predictive_args : dict, optional Dictionary of arguments to be passed to the - :class:`numpyro.infer.Predictive` constructor. + :class:`numpyro.infer.util.Predictive` constructor. **kwargs Additional named arguments passed to the - `__call__()` method of :class:`numpyro.infer.Predictive` + :meth:`__call__()` method of + :class:`numpyro.infer.util.Predictive`. Returns ------- @@ -353,16 +358,22 @@ def prior_predictive( **kwargs, ) -> dict: """ - A wrapper for numpyro.infer.Predictive to generate prior predictive samples. + A wrapper for :class:`numpyro.infer.util.Predictive` + to generate prior predictive samples. Parameters ---------- rng_key : ArrayLike, optional - Random key for the Predictive function call. Defaults to None. + Random key for the Predictive function call. + Default None. numpyro_predictive_args : dict, optional - Dictionary of arguments to be passed to the numpyro.infer.Predictive constructor. + Dictionary of arguments to be passed to + the :class:`numpyro.infer.util.Predictive` + constructor. Default None. **kwargs - Additional named arguments passed to the `__call__()` method of numpyro.infer.Predictive + Additional named arguments passed to the + :meth:`__call__()` method of + :class:`numpyro.infer.util.Predictive`. Returns ------- diff --git a/pyrenew/process/ar.py b/pyrenew/process/ar.py index ec8b88bb..67c784a4 100644 --- a/pyrenew/process/ar.py +++ b/pyrenew/process/ar.py @@ -37,13 +37,14 @@ def sample( noise_name: str A name for the sample site holding the Normal(`0`, `noise_sd`) noise for the AR process. - Passed to :func:`numpyro.sample`. + Passed to :func:`numpyro.sample() + `. n: int Length of the sequence. autoreg: ArrayLike Autoregressive coefficients. The length of the array's first - dimension determines the order :math`p` + dimension determines the order :math:`p` of the AR process. init_vals : ArrayLike Array of initial values. Must have the diff --git a/pyrenew/process/differencedprocess.py b/pyrenew/process/differencedprocess.py index 513e2524..94c7ad55 100644 --- a/pyrenew/process/differencedprocess.py +++ b/pyrenew/process/differencedprocess.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Any + import jax.numpy as jnp from jax.typing import ArrayLike @@ -65,19 +67,22 @@ def __init__( super().__init__(**kwargs) @staticmethod - def assert_valid_differencing_order(differencing_order: any): + def assert_valid_differencing_order(differencing_order: Any): """ To be valid, a differencing order must be an integer and must be strictly positive. This function raises a value error if its argument is not a valid differencing order. - Parameter - --------- - differcing_order : any + + Parameters + ---------- + differencing_order : Any Potential differencing order to validate. + Returns ------- - None or raises a ValueError + None + or raises a :class:`ValueError` """ if not isinstance(differencing_order, int): raise ValueError( @@ -105,7 +110,7 @@ def sample( *args, fundamental_process_init_vals: ArrayLike = None, **kwargs, - ) -> ArrayLike: + ) -> jnp.ndarray: """ Sample from the process @@ -114,8 +119,8 @@ def sample( init_vals : ArrayLike initial values for the :math:`0^{th}` through :math:`(n-1)^{st}` differences, passed as the - ``init_diff_vals`` argument to - :func:`integrate_discrete()` + :code:`init_diff_vals` argument to + :func:`~pyrenew.math.integrate_discrete()` n : int Number of values to sample. Will sample @@ -126,12 +131,13 @@ def sample( *args : Additional positional arguments passed to - :meth:`self.fundamental_process.sample()` + :meth:`self.fundamental_process.sample` - fundamental_process_init_vals : ArrayLike + fundamental_process_init_vals : ArrayLike, optional Initial values for the fundamental process. - Passed as the :arg:`init_vals` keyword argument - to :meth:`self.fundamental_process.sample()`. + Passed as the :code:`init_vals` keyword argument + to :meth:`self.fundamental_process.sample`. + Default :obj:`None`. **kwargs : dict, optional Keyword arguments passed to @@ -139,8 +145,8 @@ def sample( Returns ------- - ArrayLike - representing the undifferenced timeseries + jnp.ndarray + An array representing the undifferenced timeseries """ if not isinstance(n, int): raise ValueError("n must be an integer. " f"Got {type(n)}") diff --git a/pyrenew/process/iidrandomsequence.py b/pyrenew/process/iidrandomsequence.py index 34df097a..8b3b4d84 100644 --- a/pyrenew/process/iidrandomsequence.py +++ b/pyrenew/process/iidrandomsequence.py @@ -57,21 +57,23 @@ def sample( to self.element_rv.sample() vectorize: bool - Sample vectorized? If True, use - :meth:`RandomVariable.expand_by()`, - whenever available, and fall back on - :meth:`numpyro.contrib.control_flow.scan`. - If False, always use :meth:`scan()`. + Sample vectorized? If True, use the + :class:`~pyrenew.metaclass.RandomVariable`'s + :meth:`expand_by()` method, if available, + and fall back on :func:`numpyro.contrib.control_flow.scan` + otherwise. + If False, always use + :func:`~numpyro.contrib.control_flow.scan`. Default False. **kwargs: Additional keyword arguments passed to - self.element_rv.sample(). + :meth:`self.element_rv.sample`. Returns ------- ArrayLike - `n` samples from `self.distribution` + `n` samples from :code:`self.distribution`. """ if vectorize and hasattr(self.element_rv, "expand_by"): diff --git a/pyrenew/process/rtperiodicdiffar.py b/pyrenew/process/rtperiodicdiffar.py index 8a6fc6f0..c951450e 100644 --- a/pyrenew/process/rtperiodicdiffar.py +++ b/pyrenew/process/rtperiodicdiffar.py @@ -121,20 +121,21 @@ def sample( **kwargs, ) -> ArrayLike: """ - Samples the periodic Rt with autoregressive difference. + Samples the periodic :math:`\\mathcal{R}(t)` + with autoregressive first differences. Parameters ---------- duration : int Duration of the sequence. **kwargs : dict, optional - Additional keyword arguments passed through to internal sample() - calls, should there be any. + Additional keyword arguments passed through to + internal :meth:`sample` calls, should there be any. Returns ------- ArrayLike - Sampled Rt values. + Sampled :math:`\\mathcal{R}(t)` values. """ # Initial sample