Skip to content

Commit

Permalink
Minor fix in JAX_porting_PyTorch_model tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Nov 8, 2024
1 parent 3cc817a commit 369bed6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion docs/JAX_porting_PyTorch_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,12 @@
" self.p = p\n",
" self.mode = mode\n",
" self.deterministic = False\n",
" self.rngs = rngs\n",
"\n",
" def __call__(self, x: jax.Array) -> jax.Array:\n",
" return stochastic_depth(x, self.p, self.mode, self.deterministic)"
" return stochastic_depth(\n",
" x, self.p, self.mode, self.deterministic, rngs=self.rngs\n",
" )"
]
},
{
Expand Down
5 changes: 4 additions & 1 deletion docs/JAX_porting_PyTorch_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,12 @@ class StochasticDepth(nnx.Module):
self.p = p
self.mode = mode
self.deterministic = False
self.rngs = rngs
def __call__(self, x: jax.Array) -> jax.Array:
return stochastic_depth(x, self.p, self.mode, self.deterministic)
return stochastic_depth(
x, self.p, self.mode, self.deterministic, rngs=self.rngs
)
```

```{code-cell} ipython3
Expand Down

0 comments on commit 369bed6

Please sign in to comment.