From 23d3d0aad31725c8ea58531f9565c6048efe07c0 Mon Sep 17 00:00:00 2001 From: pchandrasekaran Date: Thu, 8 Feb 2024 06:17:34 +0000 Subject: [PATCH] Decomposed reverse op into adv_index in buda passes --- .../tvm/relay/op/contrib/buda/buda_passes.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/python/tvm/relay/op/contrib/buda/buda_passes.py b/python/tvm/relay/op/contrib/buda/buda_passes.py index ddf87c58f9..cbc8597173 100644 --- a/python/tvm/relay/op/contrib/buda/buda_passes.py +++ b/python/tvm/relay/op/contrib/buda/buda_passes.py @@ -191,6 +191,37 @@ def callback(self, pre, post, node_map): return slice_concatenate return post +class DecomposeReverse(DFPatternCallback): + def __init__(self): + super().__init__(rewrite_once=False, require_type=True) + self.input = wildcard() + self.pattern = is_op("reverse")(self.input) + def callback(self, pre, post, node_map): + input_shape = node_map[self.input][0].checked_type.shape + axis = int(post.attrs.axis) + if int(axis) < 0: + axis = abs(int(axis) + int(len(input_shape))) + act = post.args[0] + start = int(input_shape[axis]) - 1 + stop = -1 + step = -1 + indices = tvm.relay.Constant(tvm.nd.array(np.arange(start,stop,step).astype(int))) + if int(axis) == 0: + adv_index_out = tvm.relay.adv_index([act,indices]) + return adv_index_out + else: + transpose_1_axes = [int(axis)] + intermediate = list(set(np.arange(int(len(input_shape))).tolist()).difference(set(transpose_1_axes))) + intermediate.sort() + transpose_1_axes.extend(intermediate) + transpose_2_axes = [] + transpose_2_axes.extend(np.arange(1,int(len(input_shape))).tolist()) + transpose_2_axes.insert(int(axis),0) + transpose_1 = tvm.relay.transpose(act,axes=transpose_1_axes) + adv_index_1 = tvm.relay.adv_index([transpose_1,indices]) + transpose_2 = tvm.relay.transpose(adv_index_1,axes=transpose_2_axes) + return transpose_2 + class DecomposeDynamicResize2d(DFPatternCallback): def __init__(self): super().__init__(require_type=True) @@ -3710,6 +3741,7 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None, return run_pattern_callbacks( relay_module, [ + DecomposeReverse(), ConvertLayout(), ResolveConvChannels(), DecomposeDynamicResize2d(),