diff --git a/examples/score_based_diffusion.ipynb b/examples/score_based_diffusion.ipynb index ced6f295..fe31847e 100644 --- a/examples/score_based_diffusion.ipynb +++ b/examples/score_based_diffusion.ipynb @@ -158,7 +158,7 @@ " self.t1 = t1\n", "\n", " def __call__(self, t, y):\n", - " t = t / self.t1\n", + " t = jnp.array(t / self.t1)\n", " _, height, width = y.shape\n", " t = einops.repeat(t, \"-> 1 h w\", h=height, w=width)\n", " y = jnp.concatenate([y, t])\n",