Skip to content

Commit

Permalink
remove pad to match functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Sep 12, 2024
1 parent a625fde commit 91be92f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 92 deletions.
7 changes: 3 additions & 4 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ InfFeedbackSample = namedtuple(
)
```

The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class:
The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_edges_to_match()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class:

```{python}
# | label: new-model-def
Expand Down Expand Up @@ -224,11 +224,10 @@ class InfFeedback(RandomVariable):
inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength)
inf_feedback_strength = au.pad_x_to_match_y(
inf_feedback_strength = au.pad_edges_to_match(
x=inf_feedback_strength,
y=Rt,
fill_value=inf_feedback_strength[0],
)
)[0]
# Sampling inf feedback and adjusting the shape
inf_feedback_pmf = self.infection_feedback_pmf(**kwargs)
Expand Down
88 changes: 0 additions & 88 deletions pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,94 +68,6 @@ def pad_edges_to_match(
return x, y


def pad_to_match(
x: ArrayLike,
y: ArrayLike,
fill_value: float = 0.0,
pad_direction: str = "end",
fix_y: bool = False,
) -> tuple[ArrayLike, ArrayLike]:
"""
Pad the shorter array at the start or end to match the length of the longer array.
Parameters
----------
x : ArrayLike
First array.
y : ArrayLike
Second array.
fill_value : float, optional
Value to use for padding, by default 0.0.
pad_direction : str, optional
Direction to pad the shorter array, either "start" or "end", by default "end".
fix_y : bool, optional
If True, raise an error when `y` is shorter than `x`, by default False.
Returns
-------
tuple[ArrayLike, ArrayLike]
Tuple of the two arrays with the same length.
"""
x = jnp.atleast_1d(x)
y = jnp.atleast_1d(y)
x_len = x.size
y_len = y.size
pad_size = abs(x_len - y_len)

pad_width = {"start": (pad_size, 0), "end": (0, pad_size)}.get(
pad_direction, None
)

if pad_width is None:
raise ValueError(
"pad_direction must be either 'start' or 'end'."
f" Got {pad_direction}."
)

if x_len > y_len:
if fix_y:
raise ValueError(
"Cannot fix y when x is longer than y."
f" x_len: {x_len}, y_len: {y_len}."
)
y = jnp.pad(y, pad_width, constant_values=fill_value)

elif y_len > x_len:
x = jnp.pad(x, pad_width, constant_values=fill_value)

return x, y


def pad_x_to_match_y(
x: ArrayLike,
y: ArrayLike,
fill_value: float = 0.0,
pad_direction: str = "end",
) -> ArrayLike:
"""
Pad the `x` array at the start or end to match the length of the `y` array.
Parameters
----------
x : ArrayLike
First array.
y : ArrayLike
Second array.
fill_value : float, optional
Value to use for padding, by default 0.0.
pad_direction : str, optional
Direction to pad the shorter array, either "start" or "end", by default "end".
Returns
-------
Array
Padded array.
"""
return pad_to_match(
x, y, fill_value=fill_value, pad_direction=pad_direction, fix_y=True
)[0]


class PeriodicProcessSample(NamedTuple):
"""
A container for holding the output from `process.PeriodicProcess()`.
Expand Down

0 comments on commit 91be92f

Please sign in to comment.