Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support S32/U32 indices for BWD embedding & Neuron implicit downcast #8462

Merged
merged 3 commits into from
Dec 7, 2024

Conversation

rpsilva-aws
Copy link
Collaborator

In this PR, we extend embedding tensor operations to allow S32 indices. This follows suits with other operations, in order to add flexibility and potentially performance benefits for accelerator backends. Reference for embedding dense bwd: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp#L117

In addition, we also re-introduce the implicit downcasting for Neuron S64/U64 types, since the Neuron compiler does not support 64 bits.

There is an ongoing effort to further extend this requirement to other tensor operations involving indices: pytorch/pytorch#142160. Once this is resolved, we adapt it on XLA as well.

@rpsilva-aws rpsilva-aws changed the title Rpsilva downcast v2 Support S32/U32 indices for BWD embedding & Neuron implicit downcast Dec 6, 2024
@rpsilva-aws rpsilva-aws marked this pull request as ready for review December 6, 2024 00:28
@rpsilva-aws
Copy link
Collaborator Author

FYI, I split the previous PR: @miladm @ManfeiBai @tengyifei, this one is needed for 2.6. Unfortunately #8463 has a dependency on PT.

@tengyifei tengyifei added the tpuci label Dec 6, 2024
Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a test at all?

torch_xla/csrc/dtype.cpp Outdated Show resolved Hide resolved
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_downcast_v2 branch 2 times, most recently from c2fb7ef to 95d0f0c Compare December 6, 2024 02:26
@rpsilva-aws
Copy link
Collaborator Author

@tengyifei Ran yapf over the test file. PTAL, thanks!

@tengyifei tengyifei merged commit 00c0e96 into pytorch:master Dec 7, 2024
12 checks passed
@rpsilva-aws rpsilva-aws deleted the rpsilva_downcast_v2 branch December 9, 2024 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants