Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add dispatches of TruncatedNormal distribution for forward sampling #7489

Open
lucianopaz opened this issue Sep 3, 2024 · 4 comments · May be fixed by #7506
Open

ENH: Add dispatches of TruncatedNormal distribution for forward sampling #7489

lucianopaz opened this issue Sep 3, 2024 · 4 comments · May be fixed by #7506

Comments

@lucianopaz
Copy link
Contributor

Before

with pm.Model() as m:
    a = pm.TruncatedNormal("a", 0, 1, lower=-1, upper=1)

pm.draw(a, mode="JAX") # Fails
pm.draw(a, mode="NUMBA") # Fails

After

with pm.Model() as m:
    a = pm.TruncatedNormal("a", 0, 1, lower=-1, upper=1)

pm.draw(a, mode="JAX") # Works
pm.draw(a, mode="NUMBA") # Works

Context for the issue:

The TruncatedNormal distribution creates a TruncatedNormalRV Op that doesn't have dispatch rules for either JAX or NUMBA. The Truncated class itself seems to work though. In general, it would be nice to know where to write dispatches for these special pymc Ops.

@lucianopaz
Copy link
Contributor Author

This might be related to #7348

@ricardoV94
Copy link
Member

it would be nice to know where to write dispatches for these special pymc Ops

distributions/dispatch/jax
distributions/dispatch/numba

?

@HarshvirSandhu
Copy link
Contributor

I can work on this. Also, would the dispatch for jax go to pymc.sampling.jax?

@ricardoV94
Copy link
Member

We may need a dispatch/numba.py and dispatch/jax.py? sampling/jax is too specific

It will need to be based on try except import of jax/numba as those are optional dependecies

@HarshvirSandhu HarshvirSandhu linked a pull request Sep 17, 2024 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants