Skip to content

Commit

Permalink
Convert trivial case of index to emb
Browse files Browse the repository at this point in the history
  • Loading branch information
swimdi committed Dec 26, 2024
1 parent 6fdeb43 commit 1f79ba3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
11 changes: 5 additions & 6 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def __init__(self, op_name: str, input_strings: List[str]):
"aten.index.Tensor": self._adjust_index_tensor,
"aten.index_put.default": self._adjust_index_tensor,
"aten._native_batch_norm_legit_no_training.default": self._adjust__native_batch_norm_legit_no_training_default,
# "aten._unsafe_index.Tensor": self._adjust_index_tensor,
"aten._unsafe_index.Tensor": self._adjust_index_tensor,
}

def _adjust_bitwise_not_default(self, input_vals):
Expand Down Expand Up @@ -508,11 +508,10 @@ def _adjust_index_tensor(self, input_vals):
new_indices = []
for i in range(len(indices)):
indice = indices[i]
new_indice = []
for j in range(len(indice)):
new_indice.append(torch.randint(0, self_shape[i], [1]))
new_indice = torch.tensor(new_indice)
new_indices.append(new_indice)
if indice is None:
new_indices.append(None)
else:
new_indices.append(torch.randint(0, self_shape[i], indice.shape))
input_val["val"] = new_indices
break
return input_vals
Expand Down
29 changes: 27 additions & 2 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,26 +1244,45 @@ def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node):
return g.call_function(torch.ops.aten.squeeze.default, args=(args[0],))
return None

if node.target in [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor]:
input_shape = get_shape(gm, args[0])
indices = get_arg(node, 1, "indices")
if len(input_shape) == 2 and len(indices) == 1 and indices[0] is not None:
index_shape = get_shape(gm, indices[0])
# magic number, 38000 can pass, 39000 can pass, but 38809 will hang
# and if device is just ttnn.open_device(device=0), then it can pass
if index_shape == torch.Size([38809]):
return None
return g.call_function(torch.ops.aten.embedding.default, args=(args[0], indices[0]))
return None

if node.target == torch.ops.aten.index_select.default:
dim = get_arg(node, 1, "dim")
indices = get_arg(node, 2, "indices")
new_indices = [None] * dim + [indices]
return g.call_function(torch.ops.aten.index.Tensor, args=(args[0], new_indices))
return None


# TODO(jerrysky3): Refactor ReplaceMoreTtManually with rewrite_graph
def rewrite_graph(gm: torch.fx.GraphModule, rewrite_node_fn) -> torch.fx.GraphModule:
nodes = list(gm.graph.nodes)
modified = False
for node in nodes:
if not can_lowering_to_ttnn(node):
continue
g = GraphWrapper(node)
with g.inserting_before(node):
new_node = rewrite_node_fn(gm, g, node)
if new_node is not None:
modified = True
node.replace_all_uses_with(
new_node,
delete_user_cb=lambda node: node != new_node,
)

gm = GraphCleanup(gm)
return gm
return gm, modified


class ToTtPass(PassBase):
Expand All @@ -1273,7 +1292,13 @@ def __init__(self, device, use_less_ttnn_op_types):

def call(self, gm: torch.fx.GraphModule):
# Decompose some aten ops to simpler aten ops
gm = rewrite_graph(gm, decompose_aten_to_aten_ops)
max_try = 10
cnt = 0
while cnt < max_try:
cnt += 1
gm, modified = rewrite_graph(gm, decompose_aten_to_aten_ops)
if not modified:
break

# Replace more patterns with torch.fx.Transformer
gm = ReplaceMoreTt(gm, self.device, self.use_less_ttnn_op_types).transform()
Expand Down

0 comments on commit 1f79ba3

Please sign in to comment.