Skip to content

Commit

Permalink
refactor infectionswithfeedback.py to allow shared infection feedback…
Browse files Browse the repository at this point in the history
… strength across sites (#470)
  • Loading branch information
sbidari authored Oct 3, 2024
1 parent 9a996a3 commit 32e80fd
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax.numpy as jnp
from numpy.typing import ArrayLike

import pyrenew.arrayutils as au
import pyrenew.latent.infection_functions as inf
from pyrenew.metaclass import RandomVariable

Expand Down Expand Up @@ -168,23 +167,17 @@ def sample(
)
)

if inf_feedback_strength.ndim == Rt.ndim - 1:
inf_feedback_strength = inf_feedback_strength[jnp.newaxis]

# Making sure inf_feedback_strength spans the Rt length
if inf_feedback_strength.shape[0] == 1:
inf_feedback_strength = au.pad_edges_to_match(
x=inf_feedback_strength,
y=Rt,
axis=0,
)[0]
if inf_feedback_strength.shape != Rt.shape:
raise ValueError(
"Infection feedback strength must be of length 1 "
"or the same length as the reproduction number array. "
f"Got {inf_feedback_strength.shape} "
f"and {Rt.shape} respectively."
try:
inf_feedback_strength = jnp.broadcast_to(
inf_feedback_strength, Rt.shape
)
except Exception as e:
raise ValueError(
"Could not broadcast inf_feedback_strength "
f"(shape {inf_feedback_strength.shape}) "
"to the shape of Rt"
f"{Rt.shape}"
) from e

# Sampling inf feedback pmf
inf_feedback_pmf = self.infection_feedback_pmf(**kwargs)
Expand Down

0 comments on commit 32e80fd

Please sign in to comment.