Skip to content

Commit

Permalink
Add and configure intersphinx (#463)
Browse files Browse the repository at this point in the history
* Add and configure intersphinx

* Fix links, format math

* More typo fixes and formatting tweaks

* Remove custom __repr__ for infectionssample, fix escapes

* Shorten line length

* Escape mathcals
  • Loading branch information
dylanhmorris authored Oct 1, 2024
1 parent 907d1c1 commit aae9399
Show file tree
Hide file tree
Showing 14 changed files with 166 additions and 112 deletions.
20 changes: 20 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"sphinx.ext.autodoc",
"sphinx.ext.autosectionlabel",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.doctest",
"sphinx.ext.napoleon", # numpydoc
"sphinx.ext.duration",
Expand Down Expand Up @@ -87,3 +88,22 @@

myst_fence_as_directive = ["mermaid"]
myst_enable_extensions = ["amsmath", "dollarmath"]

intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"numpyro": ("https://num.pyro.ai/en/latest/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"polars": ("https://docs.pola.rs/api/python/stable/", None),
}

napoleon_preprocess_types = True
autodoc_typehints = "description"
autodoc_typehints_format = "short"
autodoc_type_aliases = {
"ArrayLike": ":obj:`ArrayLike <jax.typing.ArrayLike>`",
"RandomVariable": ":class:`RandomVariable <pyrenew.metaclass.52RandomVariable>`",
"Any": ":obj:`Any <typing.Any>`",
}
napoleon_type_aliases = autodoc_type_aliases
12 changes: 8 additions & 4 deletions pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
:py:func:`jax.lax.scan`.
Factories generate functions
that can be passed to
:py:func:`jax.lax.scan` with an
appropriate array to scan along.
:func:`jax.lax.scan` or
:func:`numpyro.contrib.control_flow.scan`
with an appropriate array to scan along.
"""

from __future__ import annotations
Expand All @@ -26,7 +27,8 @@ def new_convolve_scanner(
) -> Callable:
r"""
Factory function to create a "scanner" function
that can be used with :py:func:`jax.lax.scan` to
that can be used with :func:`jax.lax.scan` or
:func:`numpyro.contrib.control_flow.scan` to
construct an array via backward-looking iterative
convolution.
Expand All @@ -44,7 +46,9 @@ def new_convolve_scanner(
-------
Callable
A scanner function that can be used with
:py:func:`jax.lax.scan` for convolution.
:func:`jax.lax.scan` or
:func:`numpyro.contrib.control_flow.scan`
for convolution.
This function takes a history subset array and
a scalar, computes the dot product of
the supplied convolution array with the history
Expand Down
14 changes: 9 additions & 5 deletions pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

class HospitalAdmissionsSample(NamedTuple):
"""
A container to hold the output of `latent.HospAdmissions()`.
A container to hold the output of
:meth:`HospitalAdmissions.sample`.
Attributes
----------
Expand Down Expand Up @@ -192,10 +193,12 @@ def sample(
Parameters
----------
latent_infections : ArrayLike
Latent infections. Possibly the output of the `latent.Infections()`.
Latent infections.
**kwargs : dict, optional
Additional keyword arguments passed through to internal `sample()`
calls, should there be any.
Additional keyword arguments passed through to
internal :meth:`sample()
<pyrenew.metaclass.RandomVariable.sample>` calls,
should there be any.
Returns
-------
Expand All @@ -217,7 +220,8 @@ def sample(
# Applying the day of the week effect. For this we need to:
# 1. Get the day of the week effect
# 2. Identify the offset of the latent_infections
# 3. Apply the day of the week effect to the latent_hospital_admissions
# 3. Apply the day of the week effect to the
# latent_hospital_admissions
dow_effect_sampled = self.day_of_week_effect_rv(**kwargs)

if dow_effect_sampled.size != 7:
Expand Down
48 changes: 20 additions & 28 deletions pyrenew/latent/infection_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def compute_infections_from_rt(
I0: ArrayLike,
Rt: ArrayLike,
reversed_generation_interval_pmf: ArrayLike,
) -> ArrayLike:
) -> jnp.ndarray:
"""
Generate infections according to a
renewal process with a time-varying
reproduction number R(t)
reproduction number :math:`\\mathcal{R}(t)`
Parameters
----------
Expand All @@ -27,7 +27,7 @@ def compute_infections_from_rt(
same length as the generation interval
pmf vector.
Rt : ArrayLike
Timeseries of R(t) values
Timeseries of :math:`\\mathcal{R}(t)` values
reversed_generation_interval_pmf : ArrayLike
discrete probability mass vector
representing the generation interval
Expand All @@ -38,8 +38,8 @@ def compute_infections_from_rt(
Returns
-------
ArrayLike
The timeseries of infections, as a JAX array
jnp.ndarray
The timeseries of infections.
"""
incidence_func = new_convolve_scanner(
reversed_generation_interval_pmf, IdentityTransform()
Expand All @@ -58,8 +58,9 @@ def logistic_susceptibility_adjustment(
"""
Apply the logistic susceptibility
adjustment to a potential new
incidence I_unadjusted proposed in
equation 6 of Bhatt et al 2023 [1]_
incidence ``I_raw_t`` proposed in
equation 6 of `Bhatt et al 2023
<https://doi.org/10.1093/jrsssa/qnad030>`_.
Parameters
----------
Expand All @@ -76,16 +77,7 @@ def logistic_susceptibility_adjustment(
Returns
-------
float
The adjusted value of I(t)
References
----------
.. [1] Bhatt, Samir, et al.
"Semi-mechanistic Bayesian modelling of
COVID-19 with renewal processes."
Journal of the Royal Statistical Society
Series A: Statistics in Society 186.4 (2023): 601-615.
https://doi.org/10.1093/jrsssa/qnad030
The adjusted value of :math:`I(t)`.
"""
approx_frac_infected = 1 - jnp.exp(-I_raw_t / n_population)
return n_population * frac_susceptible * approx_frac_infected
Expand All @@ -101,8 +93,8 @@ def compute_infections_from_rt_with_feedback(
r"""
Generate infections according to
a renewal process with infection
feedback (generalizing Asher 2018:
https://doi.org/10.1016/j.epidem.2017.02.009)
feedback (generalizing `Asher 2018
<https://doi.org/10.1016/j.epidem.2017.02.009>`_).
Parameters
----------
Expand All @@ -111,7 +103,7 @@ def compute_infections_from_rt_with_feedback(
same length as the generation interval
pmf vector.
Rt_raw : ArrayLike
Timeseries of raw R(t) values not
Timeseries of raw :math:`\mathcal{R}(t)` values not
adjusted by infection feedback
infection_feedback_strength : ArrayLike
Strength of the infection feedback.
Expand Down Expand Up @@ -139,21 +131,21 @@ def compute_infections_from_rt_with_feedback(
Returns
-------
tuple
A tuple `(infections, Rt_adjusted)`,
where `Rt_adjusted` is the infection-feedback-adjusted
timeseries of the reproduction number R(t) and
infections is the incident infection timeseries.
A tuple ``(infections, Rt_adjusted)``,
where ``Rt_adjusted`` is the infection-feedback-adjusted
timeseries of the reproduction number :math:`\mathcal{R}(t)`
and ``infections`` is the incident infection timeseries.
Notes
-----
This function implements the following renewal process:
.. math::
I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau)
\begin{aligned}
I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\
\mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(\gamma(t)\
\sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right)
\end{aligned}
where :math:`\mathcal{R}(t)` is the reproductive number,
:math:`\gamma(t)` is the infection feedback strength,
Expand All @@ -166,7 +158,7 @@ def compute_infections_from_rt_with_feedback(
that recent incident infections reduce :math:`\mathcal{R}(t)`
below its raw value in the absence of feedback, while
positive :math:`\gamma` implies that recent incident infections
_increase_ :math:`\mathcal{R}(t)` above its raw value, and
*increase* :math:`\mathcal{R}(t)` above its raw value, and
:math:`\gamma(t)=0` implies no feedback.
In general, negative :math:`\gamma` is the more common modeling
Expand Down
15 changes: 10 additions & 5 deletions pyrenew/latent/infection_initialization_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ class InfectionInitializationMethod(metaclass=ABCMeta):
"""Method for initializing infections in a renewal process."""

def __init__(self, n_timepoints: int):
"""Default constructor for the ``InfectionInitializationMethod`` class.
"""Default constructor for
:class:`InfectionInitializationMethod`.
Parameters
----------
n_timepoints : int
the number of time points to generate initial infections for
the number of time points for which to
generate initial infections
Returns
-------
Expand All @@ -27,7 +29,10 @@ def __init__(self, n_timepoints: int):

@staticmethod
def validate(n_timepoints: int) -> None:
"""Validate inputs for the ``InfectionInitializationMethod`` class constructor
"""
Validate inputs to the
:class:`InfectionInitializationMethod`
constructor.
Parameters
----------
Expand Down Expand Up @@ -99,7 +104,7 @@ def initialize_infections(self, I_pre_init: ArrayLike):
class InitializeInfectionsFromVec(InfectionInitializationMethod):
"""Create initial infections from a vector of infections."""

def initialize_infections(self, I_pre_init: ArrayLike):
def initialize_infections(self, I_pre_init: ArrayLike) -> ArrayLike:
"""Create initial infections from a vector of infections.
Parameters
Expand All @@ -112,7 +117,7 @@ def initialize_infections(self, I_pre_init: ArrayLike):
-------
ArrayLike
An array of length ``n_timepoints`` with the number of
initialized infections at each time point.
initialized infections at each time point.
"""
I_pre_init = jnp.array(I_pre_init)
if I_pre_init.size != self.n_timepoints:
Expand Down
2 changes: 2 additions & 0 deletions pyrenew/latent/infection_initialization_process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# numpydoc ignore=GL08
from __future__ import annotations

from jax.typing import ArrayLike

from pyrenew.latent.infection_initialization_method import (
Expand Down
26 changes: 14 additions & 12 deletions pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,23 @@

class InfectionsSample(NamedTuple):
"""
A container for holding the output from `latent.Infections()`.
A container for holding the output from
:meth:`Infections.sample()
<pyrenew.latent.infections.Infections.sample>`.
Attributes
----------
post_initialization_infections : SampledValue | None, optional
The estimated latent infections. Defaults to None.
post_initialization_infections:
The estimated latent infections. Default :obj:`None`.
"""

post_initialization_infections: ArrayLike | None = None

def __repr__(self):
return f"InfectionsSample(post_initialization_infections={self.post_initialization_infections})"


class Infections(RandomVariable):
r"""Latent infections
This class samples infections given Rt,
This class samples infections given :math:`\mathcal{R}(t)`,
initial infections, and generation interval.
Notes
Expand All @@ -57,9 +56,10 @@ def sample(
gen_int: ArrayLike,
**kwargs,
) -> InfectionsSample:
"""
Samples infections given Rt, initial infections, and generation
interval.
r"""
Sample infections given
:math:`\mathcal{R}(t)`, initial infections,
and generation interval.
Parameters
----------
Expand All @@ -78,7 +78,8 @@ def sample(
Returns
-------
InfectionsSample
Named tuple with "infections".
A named tuple with a
``post_initialization_infections`` field.
"""
if I0.shape[0] < gen_int.size:
raise ValueError(
Expand All @@ -90,7 +91,8 @@ def sample(

if I0.shape[1:] != Rt.shape[1:]:
raise ValueError(
"Initial infections and Rt must have the same batch shapes. "
"Initial infections and Rt must have the "
"same batch shapes. "
f"Got initial infections of batch shape {I0.shape[1:]} "
f"and Rt of batch shape {Rt.shape[1:]}."
)
Expand Down
10 changes: 6 additions & 4 deletions pyrenew/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def get_leslie_matrix(
Create the Leslie matrix
corresponding to a basic
renewal process with the
given R value and discrete
given :math:`\\mathcal{R}`
value and discrete
generation interval pmf
vector.
Expand Down Expand Up @@ -79,7 +80,8 @@ def get_asymptotic_growth_rate_and_age_dist(
Raises
------
ValueError
If an age distribution vector with non-zero imaginary part is produced.
If an age distribution vector with non-zero
imaginary part is produced.
"""
L = get_leslie_matrix(R, generation_interval_pmf)
eigenvals, eigenvecs = jnp.linalg.eig(L)
Expand Down Expand Up @@ -148,8 +150,8 @@ def get_asymptotic_growth_rate(
"""
Get the asymptotic per timestep growth rate
for a renewal process with a given value of
R and a given discrete generation interval
probability mass vector.
:math:`\\mathcal{R}` and a given discrete
generation interval probability mass vector.
This function computes that growth rate
finding the dominant eigenvalue of the
Expand Down
12 changes: 7 additions & 5 deletions pyrenew/mcmcutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ def spread_draws(
Parameters
----------
posteriors: dict
A dictionary of posteriors with variable names as keys and numpy
ndarrays as values (with the first axis corresponding to the posterior
draw number.
A dictionary of posteriors with variable names
as keys and numpy ndarrays as values (with the
first axis corresponding to the posterior
draw number).
variables_names: list[str] | list[tuple]
list of strings or of tuples identifying which variables to retrieve.
list of strings or of tuples identifying which
variables to retrieve.
Returns
-------
pl.DataFrame
A dataframe of draw-indexed
A polars dataframe of draw-indexed posterior samples.
"""

for i_var, v in enumerate(variables_names):
Expand Down
Loading

0 comments on commit aae9399

Please sign in to comment.