Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert some index case to embedding #682

Merged
merged 3 commits into from
Dec 27, 2024
Merged

Convert some index case to embedding #682

merged 3 commits into from
Dec 27, 2024

Conversation

swimdi
Copy link
Contributor

@swimdi swimdi commented Dec 26, 2024

Ticket

#535
#544
#639

Problem description

Some input variation of index like below can be replaced as embedding

["Tensor<[169, 8]> self = ?", "List[Optional[Tensor]] indices = [<[2401]>]"],
["Tensor<[169, 16]> self = ?", "List[Optional[Tensor]] indices = [<[2401]>]"],
["Tensor<[169, 32]> self = ?", "List[Optional[Tensor]] indices = [<[2401]>]"],
["Tensor<[169, 3]> self = ?", "List[Optional[Tensor]] indices = [<[2401]>]"],
["Tensor<[169, 6]> self = ?", "List[Optional[Tensor]] indices = [<[2401]>]"],
["Tensor<[169, 12]> self = ?", "List[Optional[Tensor]] indices = [<[2401]>]"],
["Tensor<[169, 24]> self = ?", "List[Optional[Tensor]] indices = [<[2401]>]"],
["Tensor<[225, 4]> self = ?", "List[Optional[Tensor]] indices = [<[4096]>]"],
["Tensor<[225, 8]> self = ?", "List[Optional[Tensor]] indices = [<[4096]>]"],
["Tensor<[225, 16]> self = ?", "List[Optional[Tensor]] indices = [<[4096]>]"],
["Tensor<[225, 32]> self = ?", "List[Optional[Tensor]] indices = [<[4096]>]"],
["Tensor<[225, 3]> self = ?", "List[Optional[Tensor]] indices = [<[4096]>]"],
["Tensor<[225, 6]> self = ?", "List[Optional[Tensor]] indices = [<[4096]>]"],
["Tensor<[225, 12]> self = ?", "List[Optional[Tensor]] indices = [<[4096]>]"],
["Tensor<[225, 24]> self = ?", "List[Optional[Tensor]] indices = [<[4096]>]"],

So add the conversion of it

What's changed

  • Convert some case from index to embedding
  • Convert aten.index_select.Tensor to aten.index.Tensor
  • Add aten.index_select.Tensor and aten.index.Tensor to const fold pass

credit: @jerrysky3

@swimdi swimdi self-assigned this Dec 26, 2024
@swimdi swimdi enabled auto-merge December 26, 2024 09:16
@swimdi swimdi changed the title Index to emb Convert some index case to embedding Dec 26, 2024
indices = get_arg(node, 1, "indices")
if len(input_shape) == 2 and len(indices) == 1 and indices[0] is not None:
index_shape = get_shape(gm, indices[0])
# magic number, 38000 can pass, 39000 can pass, but 38809 will hang
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe file a bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I've open an issue

gm = rewrite_graph(gm, decompose_aten_to_aten_ops)
max_try = 10
cnt = 0
while cnt < max_try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking if we should raise an exception when running out of tries. Because that means the graph still can't not converge and probably in a bad state

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I added

return g.call_function(torch.ops.aten.embedding.default, args=(args[0], indices[0]))
return None

if node.target == torch.ops.aten.index_select.default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have unit tests for index_select and unsafe_index?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I've add

@swimdi swimdi added this pull request to the merge queue Dec 27, 2024
@swimdi swimdi removed this pull request from the merge queue due to a manual request Dec 27, 2024
@swimdi swimdi added this pull request to the merge queue Dec 27, 2024
Merged via the queue into main with commit d0ccf81 Dec 27, 2024
1 check passed
@swimdi swimdi deleted the index-to-emb branch December 27, 2024 12:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants