Skip to content

export jax model to pytorch model #28054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
FanZhang91 opened this issue Apr 16, 2025 · 2 comments
Open

export jax model to pytorch model #28054

FanZhang91 opened this issue Apr 16, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@FanZhang91
Copy link

FanZhang91 commented Apr 16, 2025

Description

I am not familiar with the jax framework。Is there any tool to convert a jax model to a pytorch model?

System info (python version, jaxlib version, accelerator, etc.)

None

@FanZhang91 FanZhang91 added the bug Something isn't working label Apr 16, 2025
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 16, 2025

Check this discussion: #25618

@patrick-toulme
Copy link

Just convert the Jax tensors to Numpy. Save them to disk and then load back into Torch.
Jax->Numpy->Torch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants