Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in PyroShim #223

Merged
merged 24 commits into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 63 additions & 30 deletions effectful/handlers/pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,23 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None:
):
return

# PyroShim turns each call to pyro.sample into two calls. The first
# dispatches to pyro_sample and the effectful stack. The effectful stack
# eventually calls pyro.sample again. We use state in PyroShim to
# recognize that we've been called twice, and we dispatch to the pyro
# stack.
#
# This branch handles the second call, so it massages the message to be
# compatible with Pyro. In particular, it removes all named dimensions
# and stores naming information in the message. Names are replaced by
# _pyro_post_sample.
if getattr(self, "_current_site", None) == msg["name"]:
if "_index_naming" in msg:
return

# We need to identify this pyro shim during post-sample.
msg["_pyro_shim_id"] = id(self) # type: ignore[typeddict-unknown-key]

if "_markov_scope" in msg["infer"] and self._current_site:
msg["infer"]["_markov_scope"].pop(self._current_site, None)

Expand All @@ -169,6 +185,9 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None:
else:
mask = msg["mask"]

assert set(sizesof(mask).keys()) <= (
set(indices.keys()) | set(sizesof(obs).keys())
)
pos_mask, _ = PyroShim._broadcast_to_named(mask, dist.batch_shape, indices)

pos_obs: Optional[torch.Tensor] = None
Expand All @@ -177,52 +196,66 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None:
obs, dist.shape(), indices
)

# Each of the batch dimensions on the distribution gets a
# cond_indep_stack frame.
for var, dim in naming.name_to_dim.items():
frame = pyro.poutine.indep_messenger.CondIndepStackFrame(
name=str(var), dim=dim, size=-1, counter=0
)
msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
# There can be additional batch dimensions on the observation
# that do not get frames, so only consider dimensions on the
# distribution.
if var in indices:
frame = pyro.poutine.indep_messenger.CondIndepStackFrame(
name=str(var),
# dims are indexed from the right of the batch shape
dim=dim + len(pdist.event_shape),
size=indices[var],
counter=0,
)
msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]

msg["fn"] = pdist
msg["value"] = pos_obs
msg["mask"] = pos_mask
msg["infer"]["_index_naming"] = naming # type: ignore
msg["_index_naming"] = naming # type: ignore

assert sizesof(msg["value"]) == {}
assert sizesof(msg["mask"]) == {}

return

try:
self._current_site = msg["name"]
msg["value"] = pyro_sample(
msg["name"],
msg["fn"],
obs=msg["value"] if msg["is_observed"] else None,
infer=msg["infer"].copy(),
)
finally:
self._current_site = None
# This branch handles the first call to pyro.sample by calling pyro_sample.
else:
try:
self._current_site = msg["name"]
msg["value"] = pyro_sample(
msg["name"],
msg["fn"],
obs=msg["value"] if msg["is_observed"] else None,
infer=msg["infer"].copy(),
)
finally:
self._current_site = None

# flags to guarantee commutativity of condition, intervene, trace
msg["stop"] = True
msg["done"] = True
msg["mask"] = False
msg["is_observed"] = True
msg["infer"]["is_auxiliary"] = True
msg["infer"]["_do_not_trace"] = True
# flags to guarantee commutativity of condition, intervene, trace
msg["stop"] = True
msg["done"] = True
msg["mask"] = False
msg["is_observed"] = True
msg["infer"]["is_auxiliary"] = True
msg["infer"]["_do_not_trace"] = True

def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None:
infer = msg.get("infer")
if infer is None or "_index_naming" not in infer:
assert msg["value"] is not None

# If this message has been handled already by a different pyro shim, ignore.
if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): # type: ignore[typeddict-item]
return

# note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key
naming = infer["_index_naming"] # type: ignore
if getattr(self, "_current_site", None) == msg["name"]:
assert "_index_naming" in msg

# note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key
naming = msg["_index_naming"] # type: ignore

value = msg["value"]
value = msg["value"]

if value is not None:
# note: is it safe to assume that msg['fn'] is a distribution?
dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore
if len(value.shape) < len(dist_shape):
Expand Down