-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert some index case to embedding #682
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
from tests.utils import assert_with_pcc | ||
|
||
|
||
class IndexModule(torch.nn.Module): | ||
def __init__(self, op): | ||
super().__init__() | ||
self.op = op | ||
|
||
def forward(self, input, indices): | ||
return self.op(input, indices) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes, indices, converted", | ||
[ | ||
((3, 4), [[[0, 1], [1, 2]]], "embedding"), | ||
((3, 4, 5), [[0, 1], [2, 1], [2, 4]], False), | ||
((3, 4, 5), [[0, 1], [2, 1]], False), | ||
((3, 4, 5), [[[0, 1]], [[2, 1]]], False), | ||
((3, 4, 5), [[[0, 1, 1]], [[2, 1, 2]]], False), | ||
((3, 4, 5), [[[0, 1, 1], [1, 1, 0]], [[2, 1, 2]]], False), # broadcast | ||
], | ||
) | ||
@pytest.mark.parametrize("op", [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor]) | ||
def test_index(device, op, input_shapes, indices, converted): | ||
m = IndexModule(op) | ||
inputs = torch.rand(input_shapes, dtype=torch.bfloat16) | ||
indices = [torch.tensor(index) for index in indices] | ||
result_before = m.forward(inputs, indices) | ||
|
||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
# option.gen_graphviz = True | ||
|
||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
|
||
result_after = m.forward(inputs, indices) | ||
# option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
if converted == "embedding": | ||
nodes = [node.target for node in option._out_fx_graphs[0].nodes] | ||
assert op not in nodes | ||
assert ttnn.embedding in nodes | ||
|
||
# Check inference result | ||
assert_with_pcc(result_before, result_after, pcc=0.99) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
from tests.utils import assert_with_pcc | ||
|
||
|
||
class IndexSelectModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input, dim, index): | ||
return torch.ops.aten.index_select.default(input, dim, index) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes, dim, index, converted", | ||
[ | ||
((3, 4), 0, [0, 1, 1, 0], "embedding"), # => index => embedding | ||
((3, 4), 1, [0, 1], "index"), | ||
((3, 4, 5), 0, [1, 0], "index"), | ||
((3, 4, 5), 1, [2, 3], "index"), | ||
], | ||
) | ||
def test_index_select(device, input_shapes, dim, index, converted): | ||
m = IndexSelectModule() | ||
inputs = torch.rand(input_shapes, dtype=torch.bfloat16) | ||
index = torch.tensor(index) | ||
result_before = m.forward(inputs, dim, index) | ||
|
||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
# option.gen_graphviz = True | ||
|
||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
|
||
result_after = m.forward(inputs, dim, index) | ||
# option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
if converted == "embedding": | ||
nodes = [node.target for node in option._out_fx_graphs[0].nodes] | ||
assert torch.ops.aten.index_select.default not in nodes | ||
assert ttnn.embedding in nodes | ||
if converted == "index": | ||
nodes = [node.target for node in option._out_fx_graphs[0].nodes] | ||
assert torch.ops.aten.index_select.default not in nodes | ||
assert torch.ops.aten.index.Tensor in nodes | ||
|
||
# Check inference result | ||
assert_with_pcc(result_before, result_after, pcc=0.99) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1244,26 +1244,46 @@ def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node): | |
return g.call_function(torch.ops.aten.squeeze.default, args=(args[0],)) | ||
return None | ||
|
||
if node.target in [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor]: | ||
input_shape = get_shape(gm, args[0]) | ||
indices = get_arg(node, 1, "indices") | ||
if len(input_shape) == 2 and len(indices) == 1 and indices[0] is not None: | ||
index_shape = get_shape(gm, indices[0]) | ||
# magic number, 38000 can pass, 39000 can pass, but 38809 will hang | ||
# and if device is just ttnn.open_device(device=0), then it can pass | ||
# see issue #685 | ||
if index_shape == torch.Size([38809]): | ||
return None | ||
return g.call_function(torch.ops.aten.embedding.default, args=(args[0], indices[0])) | ||
return None | ||
|
||
if node.target == torch.ops.aten.index_select.default: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have unit tests for index_select and unsafe_index? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I've add |
||
dim = get_arg(node, 1, "dim") | ||
indices = get_arg(node, 2, "indices") | ||
new_indices = [None] * dim + [indices] | ||
return g.call_function(torch.ops.aten.index.Tensor, args=(args[0], new_indices)) | ||
return None | ||
|
||
|
||
# TODO(jerrysky3): Refactor ReplaceMoreTtManually with rewrite_graph | ||
def rewrite_graph(gm: torch.fx.GraphModule, rewrite_node_fn) -> torch.fx.GraphModule: | ||
nodes = list(gm.graph.nodes) | ||
modified = False | ||
for node in nodes: | ||
if not can_lowering_to_ttnn(node): | ||
continue | ||
g = GraphWrapper(node) | ||
with g.inserting_before(node): | ||
new_node = rewrite_node_fn(gm, g, node) | ||
if new_node is not None: | ||
modified = True | ||
node.replace_all_uses_with( | ||
new_node, | ||
delete_user_cb=lambda node: node != new_node, | ||
) | ||
|
||
gm = GraphCleanup(gm) | ||
return gm | ||
return gm, modified | ||
|
||
|
||
class ToTtPass(PassBase): | ||
|
@@ -1273,18 +1293,28 @@ def __init__(self, device, use_less_ttnn_op_types): | |
|
||
def call(self, gm: torch.fx.GraphModule): | ||
# Decompose some aten ops to simpler aten ops | ||
gm = rewrite_graph(gm, decompose_aten_to_aten_ops) | ||
max_try = 10 | ||
cnt = 0 | ||
while True: | ||
cnt += 1 | ||
gm, modified = rewrite_graph(gm, decompose_aten_to_aten_ops) | ||
if not modified: | ||
break | ||
if cnt == max_try: | ||
raise RuntimeError("Failed to decompose aten ops to simpler aten ops") | ||
|
||
# Replace more patterns with torch.fx.Transformer | ||
gm = ReplaceMoreTt(gm, self.device, self.use_less_ttnn_op_types).transform() | ||
|
||
# Replace patterns manually | ||
max_try = 10 | ||
cnt = 0 | ||
while cnt < max_try: | ||
while True: | ||
cnt += 1 | ||
gm, modified = ReplaceMoreTtManually(gm, self.use_less_ttnn_op_types) | ||
if not modified: | ||
break | ||
if cnt == max_try: | ||
raise RuntimeError("Failed to decompose aten ops to simpler aten ops") | ||
|
||
return PassResult(gm, True) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe file a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I've open an issue