diff --git a/docs/JAX_porting_PyTorch_model.ipynb b/docs/JAX_porting_PyTorch_model.ipynb index 92ae53f..feeae47 100644 --- a/docs/JAX_porting_PyTorch_model.ipynb +++ b/docs/JAX_porting_PyTorch_model.ipynb @@ -592,9 +592,12 @@ " self.p = p\n", " self.mode = mode\n", " self.deterministic = False\n", + " self.rngs = rngs\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", - " return stochastic_depth(x, self.p, self.mode, self.deterministic)" + " return stochastic_depth(\n", + " x, self.p, self.mode, self.deterministic, rngs=self.rngs\n", + " )" ] }, { diff --git a/docs/JAX_porting_PyTorch_model.md b/docs/JAX_porting_PyTorch_model.md index 4060296..1eedfdf 100644 --- a/docs/JAX_porting_PyTorch_model.md +++ b/docs/JAX_porting_PyTorch_model.md @@ -361,9 +361,12 @@ class StochasticDepth(nnx.Module): self.p = p self.mode = mode self.deterministic = False + self.rngs = rngs def __call__(self, x: jax.Array) -> jax.Array: - return stochastic_depth(x, self.p, self.mode, self.deterministic) + return stochastic_depth( + x, self.p, self.mode, self.deterministic, rngs=self.rngs + ) ``` ```{code-cell} ipython3