Skip to content

Commit

Permalink
Make miniPyro's JIT elbo work with our impl
Browse files Browse the repository at this point in the history
  • Loading branch information
rvs314 committed Jul 2, 2024
1 parent b9cb09f commit a77bc12
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions effectful/handlers/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __call__(self, model, guide, *args):
# On first call, initialize params and save their names.
if self._param_trace is None:
with block(), trace() as tr, block(
hide_fn=lambda op, *_, **__: op == param
hide_fn=lambda op, *_, **__: op != param
):
elbo(model, guide, *args)
self._param_trace = tr
Expand All @@ -415,8 +415,9 @@ def compiled(*params_and_args):
):
constrained_param = param(name) # assume param has been initialized
assert constrained_param.unconstrained() is unconstrained_param
self._param_trace[name]["value"] = constrained_param
return replay(elbo, guide_trace=self._param_trace)(model, guide, *args)
self._param_trace[name].value = constrained_param
with replay(self._param_trace):
return elbo(model, guide, *args)

with validation_enabled(False), warnings.catch_warnings():
if self.ignore_jit_warnings:
Expand Down

0 comments on commit a77bc12

Please sign in to comment.