From d9d1b8f6c61e76b4067537cd087fd446b9a3d878 Mon Sep 17 00:00:00 2001 From: kkannan Date: Tue, 22 Oct 2024 13:40:10 +0000 Subject: [PATCH] Add support to handle multi dimension expansion in expand_dims --- .../tvm/relay/op/contrib/forge/forge_passes.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/tvm/relay/op/contrib/forge/forge_passes.py b/python/tvm/relay/op/contrib/forge/forge_passes.py index 7dba61a72d..c91042f034 100644 --- a/python/tvm/relay/op/contrib/forge/forge_passes.py +++ b/python/tvm/relay/op/contrib/forge/forge_passes.py @@ -3987,7 +3987,23 @@ def callback(self, pre, post, node_map): return tvm.relay.nn.dense(act, dense.args[1]) +class ExpandMultipleDims(DFPatternCallback): + def __init__(self): + super().__init__() + self.act = wildcard() + self.pattern = is_op('expand_dims')(self.act) + def callback(self, pre, post, node_map): + act = node_map[self.act][0] + num_newaxis = int(post.attrs.num_newaxis) + axis = int(post.attrs.axis) + if num_newaxis > 1: + assert axis >= 0, f"Error: Axis is negative. axis: {axis}" + while num_newaxis: + act = tvm.relay.expand_dims(act, axis=axis) + num_newaxis-=1 + return act + return post def _get_callback_name(callback): @@ -4037,6 +4053,7 @@ def run_forge_compile_passes(relay_module, params=None, inputs=None, target=None return run_pattern_callbacks( relay_module, [ + ExpandMultipleDims(), DecomposeReverse(), ConvertLayout(), ResolveConvChannels(),