Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert some index case to embedding #682

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/lowering/misc/test_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch_ttnn
import pytest
import ttnn

from tests.utils import assert_with_pcc


class IndexModule(torch.nn.Module):
def __init__(self, op):
super().__init__()
self.op = op

def forward(self, input, indices):
return self.op(input, indices)


@pytest.mark.parametrize(
"input_shapes, indices, converted",
[
((3, 4), [[[0, 1], [1, 2]]], "embedding"),
((3, 4, 5), [[0, 1], [2, 1], [2, 4]], False),
((3, 4, 5), [[0, 1], [2, 1]], False),
((3, 4, 5), [[[0, 1]], [[2, 1]]], False),
((3, 4, 5), [[[0, 1, 1]], [[2, 1, 2]]], False),
((3, 4, 5), [[[0, 1, 1], [1, 1, 0]], [[2, 1, 2]]], False), # broadcast
],
)
@pytest.mark.parametrize("op", [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor])
def test_index(device, op, input_shapes, indices, converted):
m = IndexModule(op)
inputs = torch.rand(input_shapes, dtype=torch.bfloat16)
indices = [torch.tensor(index) for index in indices]
result_before = m.forward(inputs, indices)

option = torch_ttnn.TorchTtnnOption(device=device)
# option.gen_graphviz = True

# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)

result_after = m.forward(inputs, indices)
# option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
if converted == "embedding":
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert op not in nodes
assert ttnn.embedding in nodes

# Check inference result
assert_with_pcc(result_before, result_after, pcc=0.99)
52 changes: 52 additions & 0 deletions tests/lowering/misc/test_index_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch_ttnn
import pytest
import ttnn

from tests.utils import assert_with_pcc


class IndexSelectModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, dim, index):
return torch.ops.aten.index_select.default(input, dim, index)


@pytest.mark.parametrize(
"input_shapes, dim, index, converted",
[
((3, 4), 0, [0, 1, 1, 0], "embedding"), # => index => embedding
((3, 4), 1, [0, 1], "index"),
((3, 4, 5), 0, [1, 0], "index"),
((3, 4, 5), 1, [2, 3], "index"),
],
)
def test_index_select(device, input_shapes, dim, index, converted):
m = IndexSelectModule()
inputs = torch.rand(input_shapes, dtype=torch.bfloat16)
index = torch.tensor(index)
result_before = m.forward(inputs, dim, index)

option = torch_ttnn.TorchTtnnOption(device=device)
# option.gen_graphviz = True

# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)

result_after = m.forward(inputs, dim, index)
# option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
if converted == "embedding":
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert torch.ops.aten.index_select.default not in nodes
assert ttnn.embedding in nodes
if converted == "index":
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert torch.ops.aten.index_select.default not in nodes
assert torch.ops.aten.index.Tensor in nodes

# Check inference result
assert_with_pcc(result_before, result_after, pcc=0.99)
11 changes: 5 additions & 6 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def __init__(self, op_name: str, input_strings: List[str]):
"aten.index.Tensor": self._adjust_index_tensor,
"aten.index_put.default": self._adjust_index_tensor,
"aten._native_batch_norm_legit_no_training.default": self._adjust__native_batch_norm_legit_no_training_default,
# "aten._unsafe_index.Tensor": self._adjust_index_tensor,
"aten._unsafe_index.Tensor": self._adjust_index_tensor,
}

def _adjust_bitwise_not_default(self, input_vals):
Expand Down Expand Up @@ -508,11 +508,10 @@ def _adjust_index_tensor(self, input_vals):
new_indices = []
for i in range(len(indices)):
indice = indices[i]
new_indice = []
for j in range(len(indice)):
new_indice.append(torch.randint(0, self_shape[i], [1]))
new_indice = torch.tensor(new_indice)
new_indices.append(new_indice)
if indice is None:
new_indices.append(None)
else:
new_indices.append(torch.randint(0, self_shape[i], indice.shape))
input_val["val"] = new_indices
break
return input_vals
Expand Down
2 changes: 2 additions & 0 deletions torch_ttnn/passes/constant_folding_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(self):
torch.ops.aten.ones.default,
torch.ops.aten.cumsum.default,
torch.ops.aten._unsafe_index.Tensor,
torch.ops.aten.index.Tensor,
torch.ops.aten.index_select.default,
torch.ops.aten.ne.Scalar,
torch.ops.aten.select.int,
torch.ops.aten.bitwise_not.default,
Expand Down
36 changes: 33 additions & 3 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,26 +1244,46 @@ def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node):
return g.call_function(torch.ops.aten.squeeze.default, args=(args[0],))
return None

if node.target in [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor]:
input_shape = get_shape(gm, args[0])
indices = get_arg(node, 1, "indices")
if len(input_shape) == 2 and len(indices) == 1 and indices[0] is not None:
index_shape = get_shape(gm, indices[0])
# magic number, 38000 can pass, 39000 can pass, but 38809 will hang
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe file a bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I've open an issue

# and if device is just ttnn.open_device(device=0), then it can pass
# see issue #685
if index_shape == torch.Size([38809]):
return None
return g.call_function(torch.ops.aten.embedding.default, args=(args[0], indices[0]))
return None

if node.target == torch.ops.aten.index_select.default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have unit tests for index_select and unsafe_index?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I've add

dim = get_arg(node, 1, "dim")
indices = get_arg(node, 2, "indices")
new_indices = [None] * dim + [indices]
return g.call_function(torch.ops.aten.index.Tensor, args=(args[0], new_indices))
return None


# TODO(jerrysky3): Refactor ReplaceMoreTtManually with rewrite_graph
def rewrite_graph(gm: torch.fx.GraphModule, rewrite_node_fn) -> torch.fx.GraphModule:
nodes = list(gm.graph.nodes)
modified = False
for node in nodes:
if not can_lowering_to_ttnn(node):
continue
g = GraphWrapper(node)
with g.inserting_before(node):
new_node = rewrite_node_fn(gm, g, node)
if new_node is not None:
modified = True
node.replace_all_uses_with(
new_node,
delete_user_cb=lambda node: node != new_node,
)

gm = GraphCleanup(gm)
return gm
return gm, modified


class ToTtPass(PassBase):
Expand All @@ -1273,18 +1293,28 @@ def __init__(self, device, use_less_ttnn_op_types):

def call(self, gm: torch.fx.GraphModule):
# Decompose some aten ops to simpler aten ops
gm = rewrite_graph(gm, decompose_aten_to_aten_ops)
max_try = 10
cnt = 0
while True:
cnt += 1
gm, modified = rewrite_graph(gm, decompose_aten_to_aten_ops)
if not modified:
break
if cnt == max_try:
raise RuntimeError("Failed to decompose aten ops to simpler aten ops")

# Replace more patterns with torch.fx.Transformer
gm = ReplaceMoreTt(gm, self.device, self.use_less_ttnn_op_types).transform()

# Replace patterns manually
max_try = 10
cnt = 0
while cnt < max_try:
while True:
cnt += 1
gm, modified = ReplaceMoreTtManually(gm, self.use_less_ttnn_op_types)
if not modified:
break
if cnt == max_try:
raise RuntimeError("Failed to decompose aten ops to simpler aten ops")

return PassResult(gm, True)
Loading