Skip to content

Commit

Permalink
Merge branch 'pchandrasekaran/decompose_reverse' into 'main'
Browse files Browse the repository at this point in the history
Decomposed reverse op into adv_index

See merge request tenstorrent/tvm!49
  • Loading branch information
chandrasekaranpradeep committed Feb 20, 2024
2 parents 7e919d1 + 23d3d0a commit 48b2db4
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions python/tvm/relay/op/contrib/buda/buda_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,37 @@ def callback(self, pre, post, node_map):
return slice_concatenate
return post

class DecomposeReverse(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=False, require_type=True)
self.input = wildcard()
self.pattern = is_op("reverse")(self.input)
def callback(self, pre, post, node_map):
input_shape = node_map[self.input][0].checked_type.shape
axis = int(post.attrs.axis)
if int(axis) < 0:
axis = abs(int(axis) + int(len(input_shape)))
act = post.args[0]
start = int(input_shape[axis]) - 1
stop = -1
step = -1
indices = tvm.relay.Constant(tvm.nd.array(np.arange(start,stop,step).astype(int)))
if int(axis) == 0:
adv_index_out = tvm.relay.adv_index([act,indices])
return adv_index_out
else:
transpose_1_axes = [int(axis)]
intermediate = list(set(np.arange(int(len(input_shape))).tolist()).difference(set(transpose_1_axes)))
intermediate.sort()
transpose_1_axes.extend(intermediate)
transpose_2_axes = []
transpose_2_axes.extend(np.arange(1,int(len(input_shape))).tolist())
transpose_2_axes.insert(int(axis),0)
transpose_1 = tvm.relay.transpose(act,axes=transpose_1_axes)
adv_index_1 = tvm.relay.adv_index([transpose_1,indices])
transpose_2 = tvm.relay.transpose(adv_index_1,axes=transpose_2_axes)
return transpose_2

class DecomposeDynamicResize2d(DFPatternCallback):
def __init__(self):
super().__init__(require_type=True)
Expand Down Expand Up @@ -3710,6 +3741,7 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None,
return run_pattern_callbacks(
relay_module,
[
DecomposeReverse(),
ConvertLayout(),
ResolveConvChannels(),
DecomposeDynamicResize2d(),
Expand Down

0 comments on commit 48b2db4

Please sign in to comment.