Skip to content

Commit

Permalink
refactor infectionswithfeedback.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Oct 3, 2024
1 parent 9bc3dae commit e4a28ab
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax.numpy as jnp
from numpy.typing import ArrayLike

import pyrenew.arrayutils as au
import pyrenew.latent.infection_functions as inf
from pyrenew.metaclass import RandomVariable

Expand Down Expand Up @@ -168,23 +167,17 @@ def sample(
)
)

if inf_feedback_strength.ndim == Rt.ndim - 1:
inf_feedback_strength = inf_feedback_strength[jnp.newaxis]

# Making sure inf_feedback_strength spans the Rt length
if inf_feedback_strength.shape[0] == 1:
inf_feedback_strength = au.pad_edges_to_match(
x=inf_feedback_strength,
y=Rt,
axis=0,
)[0]
if inf_feedback_strength.shape != Rt.shape:
raise ValueError(
"Infection feedback strength must be of length 1 "
"or the same length as the reproduction number array. "
f"Got {inf_feedback_strength.shape} "
f"and {Rt.shape} respectively."
try:
inf_feedback_strength = jnp.broadcast_to(
inf_feedback_strength, Rt.shape
)
except Exception as e:
raise ValueError(

Check warning on line 175 in pyrenew/latent/infectionswithfeedback.py

View check run for this annotation

Codecov / codecov/patch

pyrenew/latent/infectionswithfeedback.py#L174-L175

Added lines #L174 - L175 were not covered by tests
"Could not broadcast inf_feedback_strength "
f"(shape {inf_feedback_strength.shape}) "
"to the shape of Rt"
f"{Rt.shape}"
) from e

# Sampling inf feedback pmf
inf_feedback_pmf = self.infection_feedback_pmf(**kwargs)
Expand Down

0 comments on commit e4a28ab

Please sign in to comment.