diff --git a/effectful/handlers/minipyro.py b/effectful/handlers/minipyro.py index e2f4d17d..e9c17890 100644 --- a/effectful/handlers/minipyro.py +++ b/effectful/handlers/minipyro.py @@ -393,7 +393,7 @@ def __call__(self, model, guide, *args): # On first call, initialize params and save their names. if self._param_trace is None: with block(), trace() as tr, block( - hide_fn=lambda op, *_, **__: op == param + hide_fn=lambda op, *_, **__: op != param ): elbo(model, guide, *args) self._param_trace = tr @@ -415,8 +415,9 @@ def compiled(*params_and_args): ): constrained_param = param(name) # assume param has been initialized assert constrained_param.unconstrained() is unconstrained_param - self._param_trace[name]["value"] = constrained_param - return replay(elbo, guide_trace=self._param_trace)(model, guide, *args) + self._param_trace[name].value = constrained_param + with replay(self._param_trace): + return elbo(model, guide, *args) with validation_enabled(False), warnings.catch_warnings(): if self.ignore_jit_warnings: