From 91be92fe0bb12aad6d43148dc95ab03966a8102f Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 11:52:39 -0400 Subject: [PATCH] remove pad to match functions --- docs/source/tutorials/extending_pyrenew.qmd | 7 +- pyrenew/arrayutils.py | 88 --------------------- 2 files changed, 3 insertions(+), 92 deletions(-) diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 2cef0d2a..34e92a15 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -168,7 +168,7 @@ InfFeedbackSample = namedtuple( ) ``` -The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: +The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_edges_to_match()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: ```{python} # | label: new-model-def @@ -224,11 +224,10 @@ class InfFeedback(RandomVariable): inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength) - inf_feedback_strength = au.pad_x_to_match_y( + inf_feedback_strength = au.pad_edges_to_match( x=inf_feedback_strength, y=Rt, - fill_value=inf_feedback_strength[0], - ) + )[0] # Sampling inf feedback and adjusting the shape inf_feedback_pmf = self.infection_feedback_pmf(**kwargs) diff --git a/pyrenew/arrayutils.py b/pyrenew/arrayutils.py index 915c6816..f8ae8590 100644 --- a/pyrenew/arrayutils.py +++ b/pyrenew/arrayutils.py @@ -68,94 +68,6 @@ def pad_edges_to_match( return x, y -def pad_to_match( - x: ArrayLike, - y: ArrayLike, - fill_value: float = 0.0, - pad_direction: str = "end", - fix_y: bool = False, -) -> tuple[ArrayLike, ArrayLike]: - """ - Pad the shorter array at the start or end to match the length of the longer array. - - Parameters - ---------- - x : ArrayLike - First array. - y : ArrayLike - Second array. - fill_value : float, optional - Value to use for padding, by default 0.0. - pad_direction : str, optional - Direction to pad the shorter array, either "start" or "end", by default "end". - fix_y : bool, optional - If True, raise an error when `y` is shorter than `x`, by default False. - - Returns - ------- - tuple[ArrayLike, ArrayLike] - Tuple of the two arrays with the same length. - """ - x = jnp.atleast_1d(x) - y = jnp.atleast_1d(y) - x_len = x.size - y_len = y.size - pad_size = abs(x_len - y_len) - - pad_width = {"start": (pad_size, 0), "end": (0, pad_size)}.get( - pad_direction, None - ) - - if pad_width is None: - raise ValueError( - "pad_direction must be either 'start' or 'end'." - f" Got {pad_direction}." - ) - - if x_len > y_len: - if fix_y: - raise ValueError( - "Cannot fix y when x is longer than y." - f" x_len: {x_len}, y_len: {y_len}." - ) - y = jnp.pad(y, pad_width, constant_values=fill_value) - - elif y_len > x_len: - x = jnp.pad(x, pad_width, constant_values=fill_value) - - return x, y - - -def pad_x_to_match_y( - x: ArrayLike, - y: ArrayLike, - fill_value: float = 0.0, - pad_direction: str = "end", -) -> ArrayLike: - """ - Pad the `x` array at the start or end to match the length of the `y` array. - - Parameters - ---------- - x : ArrayLike - First array. - y : ArrayLike - Second array. - fill_value : float, optional - Value to use for padding, by default 0.0. - pad_direction : str, optional - Direction to pad the shorter array, either "start" or "end", by default "end". - - Returns - ------- - Array - Padded array. - """ - return pad_to_match( - x, y, fill_value=fill_value, pad_direction=pad_direction, fix_y=True - )[0] - - class PeriodicProcessSample(NamedTuple): """ A container for holding the output from `process.PeriodicProcess()`.