-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert some index case to embedding #682
Conversation
indices = get_arg(node, 1, "indices") | ||
if len(input_shape) == 2 and len(indices) == 1 and indices[0] is not None: | ||
index_shape = get_shape(gm, indices[0]) | ||
# magic number, 38000 can pass, 39000 can pass, but 38809 will hang |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe file a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I've open an issue
gm = rewrite_graph(gm, decompose_aten_to_aten_ops) | ||
max_try = 10 | ||
cnt = 0 | ||
while cnt < max_try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking if we should raise an exception when running out of tries. Because that means the graph still can't not converge and probably in a bad state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I added
return g.call_function(torch.ops.aten.embedding.default, args=(args[0], indices[0])) | ||
return None | ||
|
||
if node.target == torch.ops.aten.index_select.default: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have unit tests for index_select and unsafe_index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I've add
Ticket
#535
#544
#639
Problem description
Some input variation of index like below can be replaced as embedding
pytorch2.0_ttnn/tests/autogen_op/ALL/test_ALL_aten_index_Tensor.py
Lines 45 to 59 in 2174709
So add the conversion of it
What's changed
aten.index_select.Tensor
toaten.index.Tensor
aten.index_select.Tensor
andaten.index.Tensor
to const fold passcredit: @jerrysky3