From a77bc123ee16d171ff62019b9112a681ff23524f Mon Sep 17 00:00:00 2001 From: Raffi Sanna Date: Tue, 2 Jul 2024 13:53:56 -0400 Subject: [PATCH] Make miniPyro's JIT elbo work with our impl --- effectful/handlers/minipyro.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/effectful/handlers/minipyro.py b/effectful/handlers/minipyro.py index e2f4d17d..e9c17890 100644 --- a/effectful/handlers/minipyro.py +++ b/effectful/handlers/minipyro.py @@ -393,7 +393,7 @@ def __call__(self, model, guide, *args): # On first call, initialize params and save their names. if self._param_trace is None: with block(), trace() as tr, block( - hide_fn=lambda op, *_, **__: op == param + hide_fn=lambda op, *_, **__: op != param ): elbo(model, guide, *args) self._param_trace = tr @@ -415,8 +415,9 @@ def compiled(*params_and_args): ): constrained_param = param(name) # assume param has been initialized assert constrained_param.unconstrained() is unconstrained_param - self._param_trace[name]["value"] = constrained_param - return replay(elbo, guide_trace=self._param_trace)(model, guide, *args) + self._param_trace[name].value = constrained_param + with replay(self._param_trace): + return elbo(model, guide, *args) with validation_enabled(False), warnings.catch_warnings(): if self.ignore_jit_warnings: