Skip to content

Commit

Permalink
Document broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Sep 12, 2024
1 parent 461e024 commit d72add5
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions pyrenew/process/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def sample(
----------
noise_name: str
A name for the sample site holding the
Normal(0, noise_sd) noise for the AR process.
Normal(`0`, `noise_sd`) noise for the AR process.
Passed to :func:`numpyro.sample`.
n: int
Length of the sequence.
Expand All @@ -56,7 +56,31 @@ def sample(
Returns
-------
ArrayLike
of shape (n,) + init_vals.shape[1:].
with first dimension of length `n`
and additional dimensions as inferred
from the shapes of `autoreg`,
`init_vals`, and `noise_sd`.
Notes
-----
The first dimension of the return value
with be of length `n` and represents time.
Trailing dimensions follow standard numpy
broadcasting rules and are determined from
the second through `n`th dimensions, if any,
of `autoreg` and `init_vals`, as well as the
all dimensions of `noise_sd` (i.e.
:code:`jax.numpy.shape(autoreg)[1:]`,
:code:`jax.numpy.shape(init_vals)[1:]`
and :code:`jax.numpy.shape(noise_sd)`
Those shapes must be
broadcastable together via
:func:`jax.lax.broadcast_shapes`. This can
be used to produce multiple AR processes of the
same order but with either shared or different initial
values, AR coefficient vectors, and/or
and noise standard deviation values.
"""
autoreg = jnp.atleast_1d(autoreg)
init_vals = jnp.atleast_1d(init_vals)
Expand All @@ -82,7 +106,8 @@ def sample(
"valid shape for the AR process noise "
"from the shapes of the init_vals, "
"autoreg, and noise_sd arrays. "
"See ARProcess documentation for details."
"See ARProcess.sample() documentation "
"for details."
) from e

if not n_inits == order:
Expand Down

0 comments on commit d72add5

Please sign in to comment.