Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add and configure intersphinx #463

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"sphinx.ext.autodoc",
"sphinx.ext.autosectionlabel",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.doctest",
"sphinx.ext.napoleon", # numpydoc
"sphinx.ext.duration",
Expand Down Expand Up @@ -87,3 +88,22 @@

myst_fence_as_directive = ["mermaid"]
myst_enable_extensions = ["amsmath", "dollarmath"]

intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"numpyro": ("https://num.pyro.ai/en/latest/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"polars": ("https://docs.pola.rs/api/python/stable/", None),
}

napoleon_preprocess_types = True
autodoc_typehints = "description"
autodoc_typehints_format = "short"
autodoc_type_aliases = {
"ArrayLike": ":obj:`ArrayLike <jax.typing.ArrayLike>`",
"RandomVariable": ":class:`RandomVariable <pyrenew.metaclass.52RandomVariable>`",
"Any": ":obj:`Any <typing.Any>`",
}
napoleon_type_aliases = autodoc_type_aliases
12 changes: 8 additions & 4 deletions pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
:py:func:`jax.lax.scan`.
Factories generate functions
that can be passed to
:py:func:`jax.lax.scan` with an
appropriate array to scan along.
:func:`jax.lax.scan` or
:func:`numpyro.contrib.control_flow.scan`
with an appropriate array to scan along.
"""

from __future__ import annotations
Expand All @@ -26,7 +27,8 @@ def new_convolve_scanner(
) -> Callable:
r"""
Factory function to create a "scanner" function
that can be used with :py:func:`jax.lax.scan` to
that can be used with :func:`jax.lax.scan` or
:func:`numpyro.contrib.control_flow.scan` to
construct an array via backward-looking iterative
convolution.

Expand All @@ -44,7 +46,9 @@ def new_convolve_scanner(
-------
Callable
A scanner function that can be used with
:py:func:`jax.lax.scan` for convolution.
:func:`jax.lax.scan` or
:func:`numpyro.contrib.control_flow.scan`
for convolution.
This function takes a history subset array and
a scalar, computes the dot product of
the supplied convolution array with the history
Expand Down
14 changes: 9 additions & 5 deletions pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

class HospitalAdmissionsSample(NamedTuple):
"""
A container to hold the output of `latent.HospAdmissions()`.
A container to hold the output of
:meth:`HospitalAdmissions.sample`.

Attributes
----------
Expand Down Expand Up @@ -192,10 +193,12 @@ def sample(
Parameters
----------
latent_infections : ArrayLike
Latent infections. Possibly the output of the `latent.Infections()`.
Latent infections.
**kwargs : dict, optional
Additional keyword arguments passed through to internal `sample()`
calls, should there be any.
Additional keyword arguments passed through to
internal :meth:`sample()
<pyrenew.metaclass.RandomVariable.sample>` calls,
should there be any.

Returns
-------
Expand All @@ -217,7 +220,8 @@ def sample(
# Applying the day of the week effect. For this we need to:
# 1. Get the day of the week effect
# 2. Identify the offset of the latent_infections
# 3. Apply the day of the week effect to the latent_hospital_admissions
# 3. Apply the day of the week effect to the
# latent_hospital_admissions
dow_effect_sampled = self.day_of_week_effect_rv(**kwargs)

if dow_effect_sampled.size != 7:
Expand Down
48 changes: 20 additions & 28 deletions pyrenew/latent/infection_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def compute_infections_from_rt(
I0: ArrayLike,
Rt: ArrayLike,
reversed_generation_interval_pmf: ArrayLike,
) -> ArrayLike:
) -> jnp.ndarray:
"""
Generate infections according to a
renewal process with a time-varying
reproduction number R(t)
reproduction number :math:`\\mathcal{R}(t)`

Parameters
----------
Expand All @@ -27,7 +27,7 @@ def compute_infections_from_rt(
same length as the generation interval
pmf vector.
Rt : ArrayLike
Timeseries of R(t) values
Timeseries of :math:`\\mathcal{R}(t)` values
reversed_generation_interval_pmf : ArrayLike
discrete probability mass vector
representing the generation interval
Expand All @@ -38,8 +38,8 @@ def compute_infections_from_rt(

Returns
-------
ArrayLike
The timeseries of infections, as a JAX array
jnp.ndarray
The timeseries of infections.
"""
incidence_func = new_convolve_scanner(
reversed_generation_interval_pmf, IdentityTransform()
Expand All @@ -58,8 +58,9 @@ def logistic_susceptibility_adjustment(
"""
Apply the logistic susceptibility
adjustment to a potential new
incidence I_unadjusted proposed in
equation 6 of Bhatt et al 2023 [1]_
incidence ``I_raw_t`` proposed in
equation 6 of `Bhatt et al 2023
<https://doi.org/10.1093/jrsssa/qnad030>`_.

Parameters
----------
Expand All @@ -76,16 +77,7 @@ def logistic_susceptibility_adjustment(
Returns
-------
float
The adjusted value of I(t)

References
----------
.. [1] Bhatt, Samir, et al.
"Semi-mechanistic Bayesian modelling of
COVID-19 with renewal processes."
Journal of the Royal Statistical Society
Series A: Statistics in Society 186.4 (2023): 601-615.
https://doi.org/10.1093/jrsssa/qnad030
The adjusted value of :math:`I(t)`.
"""
approx_frac_infected = 1 - jnp.exp(-I_raw_t / n_population)
return n_population * frac_susceptible * approx_frac_infected
Expand All @@ -101,8 +93,8 @@ def compute_infections_from_rt_with_feedback(
r"""
Generate infections according to
a renewal process with infection
feedback (generalizing Asher 2018:
https://doi.org/10.1016/j.epidem.2017.02.009)
feedback (generalizing `Asher 2018
<https://doi.org/10.1016/j.epidem.2017.02.009>`_).

Parameters
----------
Expand All @@ -111,7 +103,7 @@ def compute_infections_from_rt_with_feedback(
same length as the generation interval
pmf vector.
Rt_raw : ArrayLike
Timeseries of raw R(t) values not
Timeseries of raw :math:`\mathcal{R}(t)` values not
adjusted by infection feedback
infection_feedback_strength : ArrayLike
Strength of the infection feedback.
Expand Down Expand Up @@ -139,21 +131,21 @@ def compute_infections_from_rt_with_feedback(
Returns
-------
tuple
A tuple `(infections, Rt_adjusted)`,
where `Rt_adjusted` is the infection-feedback-adjusted
timeseries of the reproduction number R(t) and
infections is the incident infection timeseries.
A tuple ``(infections, Rt_adjusted)``,
where ``Rt_adjusted`` is the infection-feedback-adjusted
timeseries of the reproduction number :math:`\mathcal{R}(t)`
and ``infections`` is the incident infection timeseries.

Notes
-----
This function implements the following renewal process:

.. math::

I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau)

\begin{aligned}
I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\
\mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(\gamma(t)\
\sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right)
\end{aligned}

where :math:`\mathcal{R}(t)` is the reproductive number,
:math:`\gamma(t)` is the infection feedback strength,
Expand All @@ -166,7 +158,7 @@ def compute_infections_from_rt_with_feedback(
that recent incident infections reduce :math:`\mathcal{R}(t)`
below its raw value in the absence of feedback, while
positive :math:`\gamma` implies that recent incident infections
_increase_ :math:`\mathcal{R}(t)` above its raw value, and
*increase* :math:`\mathcal{R}(t)` above its raw value, and
:math:`\gamma(t)=0` implies no feedback.

In general, negative :math:`\gamma` is the more common modeling
Expand Down
15 changes: 10 additions & 5 deletions pyrenew/latent/infection_initialization_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ class InfectionInitializationMethod(metaclass=ABCMeta):
"""Method for initializing infections in a renewal process."""

def __init__(self, n_timepoints: int):
"""Default constructor for the ``InfectionInitializationMethod`` class.
"""Default constructor for
:class:`InfectionInitializationMethod`.

Parameters
----------
n_timepoints : int
the number of time points to generate initial infections for
the number of time points for which to
generate initial infections

Returns
-------
Expand All @@ -27,7 +29,10 @@ def __init__(self, n_timepoints: int):

@staticmethod
def validate(n_timepoints: int) -> None:
"""Validate inputs for the ``InfectionInitializationMethod`` class constructor
"""
Validate inputs to the
:class:`InfectionInitializationMethod`
constructor.

Parameters
----------
Expand Down Expand Up @@ -99,7 +104,7 @@ def initialize_infections(self, I_pre_init: ArrayLike):
class InitializeInfectionsFromVec(InfectionInitializationMethod):
"""Create initial infections from a vector of infections."""

def initialize_infections(self, I_pre_init: ArrayLike):
def initialize_infections(self, I_pre_init: ArrayLike) -> ArrayLike:
"""Create initial infections from a vector of infections.

Parameters
Expand All @@ -112,7 +117,7 @@ def initialize_infections(self, I_pre_init: ArrayLike):
-------
ArrayLike
An array of length ``n_timepoints`` with the number of
initialized infections at each time point.
initialized infections at each time point.
"""
I_pre_init = jnp.array(I_pre_init)
if I_pre_init.size != self.n_timepoints:
Expand Down
2 changes: 2 additions & 0 deletions pyrenew/latent/infection_initialization_process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# numpydoc ignore=GL08
from __future__ import annotations

from jax.typing import ArrayLike

from pyrenew.latent.infection_initialization_method import (
Expand Down
26 changes: 14 additions & 12 deletions pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,23 @@

class InfectionsSample(NamedTuple):
"""
A container for holding the output from `latent.Infections()`.
A container for holding the output from
:meth:`Infections.sample()
<pyrenew.latent.infections.Infections.sample>`.

Attributes
----------
post_initialization_infections : SampledValue | None, optional
The estimated latent infections. Defaults to None.
post_initialization_infections:
The estimated latent infections. Default :obj:`None`.
"""

post_initialization_infections: ArrayLike | None = None

def __repr__(self):
return f"InfectionsSample(post_initialization_infections={self.post_initialization_infections})"


class Infections(RandomVariable):
r"""Latent infections

This class samples infections given Rt,
This class samples infections given :math:`\mathcal{R}(t)`,
initial infections, and generation interval.

Notes
Expand All @@ -57,9 +56,10 @@ def sample(
gen_int: ArrayLike,
**kwargs,
) -> InfectionsSample:
"""
Samples infections given Rt, initial infections, and generation
interval.
r"""
Sample infections given
:math:`\mathcal{R}(t)`, initial infections,
and generation interval.

Parameters
----------
Expand All @@ -78,7 +78,8 @@ def sample(
Returns
-------
InfectionsSample
Named tuple with "infections".
A named tuple with a
``post_initialization_infections`` field.
"""
if I0.shape[0] < gen_int.size:
raise ValueError(
Expand All @@ -90,7 +91,8 @@ def sample(

if I0.shape[1:] != Rt.shape[1:]:
raise ValueError(
"Initial infections and Rt must have the same batch shapes. "
"Initial infections and Rt must have the "
"same batch shapes. "
f"Got initial infections of batch shape {I0.shape[1:]} "
f"and Rt of batch shape {Rt.shape[1:]}."
)
Expand Down
10 changes: 6 additions & 4 deletions pyrenew/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def get_leslie_matrix(
Create the Leslie matrix
corresponding to a basic
renewal process with the
given R value and discrete
given :math:`\\mathcal{R}`
value and discrete
generation interval pmf
vector.

Expand Down Expand Up @@ -79,7 +80,8 @@ def get_asymptotic_growth_rate_and_age_dist(
Raises
------
ValueError
If an age distribution vector with non-zero imaginary part is produced.
If an age distribution vector with non-zero
imaginary part is produced.
"""
L = get_leslie_matrix(R, generation_interval_pmf)
eigenvals, eigenvecs = jnp.linalg.eig(L)
Expand Down Expand Up @@ -148,8 +150,8 @@ def get_asymptotic_growth_rate(
"""
Get the asymptotic per timestep growth rate
for a renewal process with a given value of
R and a given discrete generation interval
probability mass vector.
:math:`\\mathcal{R}` and a given discrete
generation interval probability mass vector.

This function computes that growth rate
finding the dominant eigenvalue of the
Expand Down
12 changes: 7 additions & 5 deletions pyrenew/mcmcutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ def spread_draws(
Parameters
----------
posteriors: dict
A dictionary of posteriors with variable names as keys and numpy
ndarrays as values (with the first axis corresponding to the posterior
draw number.
A dictionary of posteriors with variable names
as keys and numpy ndarrays as values (with the
first axis corresponding to the posterior
draw number).
variables_names: list[str] | list[tuple]
list of strings or of tuples identifying which variables to retrieve.
list of strings or of tuples identifying which
variables to retrieve.

Returns
-------
pl.DataFrame
A dataframe of draw-indexed
A polars dataframe of draw-indexed posterior samples.
"""

for i_var, v in enumerate(variables_names):
Expand Down
Loading