Skip to content

Commit

Permalink
Add support to handle multi dimension expansion in expand_dims (#39)
Browse files Browse the repository at this point in the history
## Description

- expand_dims was mapped to unsqueeze
[here](https://github.com/tenstorrent/tt-forge-fe/blob/375c069d297f770a3b837571b5ff6f9c43600751/forge/forge/tvm_to_python.py#L1681).
In BatchNorm2d , 2 dimensions should be added by expand dims
**expand_dims(%3, axis=1, num_newaxis=2)** .single unsqueeze operation
can't do that.

- Logic (Add number of unsqueeze operations based on num_newaxis) in
buda passes will convert `%4 = expand_dims(%3, axis=1, num_newaxis=2) /*
ty=Tensor[(64, 1, 1), float32] */;` to
```
%5 = expand_dims(%4, axis=1) /* ty=Tensor[(64, 1), float32] */;
%6 = expand_dims(%5, axis=1) /* ty=Tensor[(64, 1, 1), float32] */;
```

**BatchNorm2d's Complete relay expression:**

- Before Fix

```
def @main(%input: Tensor[(1, 64, 112, 112), float32] /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::batch_norm_0.input:0:0 */, %weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.weight:0:0 */, %bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.bias:0:0 */, %running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_mean:0:0 */, %running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_var:0:0 */) -> Tensor[(1, 64, 112, 112), float32] {
  %0 = add(%running_var, 1e-05f /* ty=float32 */) /* ty=Tensor[(64), float32] */;
  %1 = sqrt(%0) /* ty=Tensor[(64), float32] */;
  %2 = divide(1f /* ty=float32 */, %1) /* ty=Tensor[(64), float32] */;
  %3 = multiply(%2, %weight) /* ty=Tensor[(64), float32] */;
  %4 = expand_dims(%3, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */;
  %5 = negative(%running_mean) /* ty=Tensor[(64), float32] */;
  %6 = multiply(%5, %3) /* ty=Tensor[(64), float32] */;
  %7 = add(%6, %bias) /* ty=Tensor[(64), float32] */;
  %8 = multiply(%input, %4) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %9 = expand_dims(%7, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */;
  add(%8, %9) /* ty=Tensor[(1, 64, 112, 112), float32] */
}
```

- After Fix

```
def @main(%input: Tensor[(1, 64, 112, 112), float32] /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::batch_norm_0.input:0:0 */, %weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.weight:0:0 */, %bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.bias:0:0 */, %running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_mean:0:0 */, %running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_var:0:0 */) -> Tensor[(1, 64, 112, 112), float32] {
  %0 = add(%running_var, 1e-05f /* ty=float32 */) /* ty=Tensor[(64), float32] */;
  %1 = sqrt(%0) /* ty=Tensor[(64), float32] */;
  %2 = reciprocal(%1) /* ty=Tensor[(64), float32] */;
  %3 = multiply(1f /* ty=float32 */, %2) /* ty=Tensor[(64), float32] */;
  %4 = multiply(%3, %weight) /* ty=Tensor[(64), float32] */;
  %5 = expand_dims(%4, axis=1) /* ty=Tensor[(64, 1), float32] */;
  %6 = expand_dims(%5, axis=1) /* ty=Tensor[(64, 1, 1), float32] */;
  %7 = multiply(%running_mean, -1f /* ty=float32 */) /* ty=Tensor[(64), float32] */;
  %8 = multiply(%7, %4) /* ty=Tensor[(64), float32] */;
  %9 = add(%8, %bias) /* ty=Tensor[(64), float32] */;
  %10 = expand_dims(%9, axis=1) /* ty=Tensor[(64, 1), float32] */;
  %11 = multiply(%input, %6) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %12 = expand_dims(%10, axis=1) /* ty=Tensor[(64, 1, 1), float32] */;
  add(%11, %12) /* ty=Tensor[(1, 64, 112, 112), float32] */
}
```
  • Loading branch information
kamalrajkannan78 authored Oct 24, 2024
2 parents 1303a00 + d9d1b8f commit 69089f2
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/tvm/relay/op/contrib/forge/forge_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3987,7 +3987,23 @@ def callback(self, pre, post, node_map):

return tvm.relay.nn.dense(act, dense.args[1])

class ExpandMultipleDims(DFPatternCallback):
def __init__(self):
super().__init__()
self.act = wildcard()
self.pattern = is_op('expand_dims')(self.act)

def callback(self, pre, post, node_map):
act = node_map[self.act][0]
num_newaxis = int(post.attrs.num_newaxis)
axis = int(post.attrs.axis)
if num_newaxis > 1:
assert axis >= 0, f"Error: Axis is negative. axis: {axis}"
while num_newaxis:
act = tvm.relay.expand_dims(act, axis=axis)
num_newaxis-=1
return act
return post


def _get_callback_name(callback):
Expand Down Expand Up @@ -4037,6 +4053,7 @@ def run_forge_compile_passes(relay_module, params=None, inputs=None, target=None
return run_pattern_callbacks(
relay_module,
[
ExpandMultipleDims(),
DecomposeReverse(),
ConvertLayout(),
ResolveConvChannels(),
Expand Down

0 comments on commit 69089f2

Please sign in to comment.