Skip to content

Commit

Permalink
Fix promote_batch_shape logic to take batch shapes of all parameters (#…
Browse files Browse the repository at this point in the history
…1973)

* fix promote batch shape logic

* lint
  • Loading branch information
fehiepsi authored Feb 6, 2025
1 parent fa3f731 commit 56f63eb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
16 changes: 10 additions & 6 deletions numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,12 +504,16 @@ def promote_batch_shape(d: Distribution):

@promote_batch_shape.register
def _default_promote_batch_shape(d: Distribution):
attr_name = list(d.arg_constraints.keys())[0]
attr_event_dim = d.arg_constraints[attr_name].event_dim
attr = getattr(d, attr_name)
resolved_batch_shape = attr.shape[
: max(0, attr.ndim - d.event_dim - attr_event_dim)
]
attr_batch_shapes = [d.batch_shape]
for attr_name, constraint in d.arg_constraints.items():
try:
attr_event_dim = constraint.event_dim
except NotImplementedError:
continue
attr = getattr(d, attr_name)
attr_batch_ndim = max(0, jnp.ndim(attr) - attr_event_dim)
attr_batch_shapes.append(jnp.shape(attr)[:attr_batch_ndim])
resolved_batch_shape = jnp.broadcast_shapes(*attr_batch_shapes)
new_self = copy.deepcopy(d)
new_self._batch_shape = resolved_batch_shape
return new_self
Expand Down
16 changes: 16 additions & 0 deletions test/contrib/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,19 @@ def transition(x_prev, y_curr):

xhat = results.params["x_auto_loc"]
assert_allclose(xhat, tr["x"]["value"], rtol=0.1, atol=0.2)


def test_scan_mvn():
def model():
def transition(c, a):
with numpyro.plate("foo", 5):
c2 = numpyro.sample(
"val", dist.MultivariateNormal(c + a, scale_tril=jnp.eye(2))
)
return c2, c2

scan(transition, jnp.zeros((5, 2)), jnp.ones((4, 5, 2)))

with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr:
model()
assert tr["val"]["fn"].batch_shape == (4, 5)

0 comments on commit 56f63eb

Please sign in to comment.