Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support to handle multi dimension expansion in expand_dims (#39)
## Description - expand_dims was mapped to unsqueeze [here](https://github.com/tenstorrent/tt-forge-fe/blob/375c069d297f770a3b837571b5ff6f9c43600751/forge/forge/tvm_to_python.py#L1681). In BatchNorm2d , 2 dimensions should be added by expand dims **expand_dims(%3, axis=1, num_newaxis=2)** .single unsqueeze operation can't do that. - Logic (Add number of unsqueeze operations based on num_newaxis) in buda passes will convert `%4 = expand_dims(%3, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */;` to ``` %5 = expand_dims(%4, axis=1) /* ty=Tensor[(64, 1), float32] */; %6 = expand_dims(%5, axis=1) /* ty=Tensor[(64, 1, 1), float32] */; ``` **BatchNorm2d's Complete relay expression:** - Before Fix ``` def @main(%input: Tensor[(1, 64, 112, 112), float32] /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::batch_norm_0.input:0:0 */, %weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.weight:0:0 */, %bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.bias:0:0 */, %running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_mean:0:0 */, %running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_var:0:0 */) -> Tensor[(1, 64, 112, 112), float32] { %0 = add(%running_var, 1e-05f /* ty=float32 */) /* ty=Tensor[(64), float32] */; %1 = sqrt(%0) /* ty=Tensor[(64), float32] */; %2 = divide(1f /* ty=float32 */, %1) /* ty=Tensor[(64), float32] */; %3 = multiply(%2, %weight) /* ty=Tensor[(64), float32] */; %4 = expand_dims(%3, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */; %5 = negative(%running_mean) /* ty=Tensor[(64), float32] */; %6 = multiply(%5, %3) /* ty=Tensor[(64), float32] */; %7 = add(%6, %bias) /* ty=Tensor[(64), float32] */; %8 = multiply(%input, %4) /* ty=Tensor[(1, 64, 112, 112), float32] */; %9 = expand_dims(%7, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */; add(%8, %9) /* ty=Tensor[(1, 64, 112, 112), float32] */ } ``` - After Fix ``` def @main(%input: Tensor[(1, 64, 112, 112), float32] /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::batch_norm_0.input:0:0 */, %weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.weight:0:0 */, %bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.bias:0:0 */, %running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_mean:0:0 */, %running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_var:0:0 */) -> Tensor[(1, 64, 112, 112), float32] { %0 = add(%running_var, 1e-05f /* ty=float32 */) /* ty=Tensor[(64), float32] */; %1 = sqrt(%0) /* ty=Tensor[(64), float32] */; %2 = reciprocal(%1) /* ty=Tensor[(64), float32] */; %3 = multiply(1f /* ty=float32 */, %2) /* ty=Tensor[(64), float32] */; %4 = multiply(%3, %weight) /* ty=Tensor[(64), float32] */; %5 = expand_dims(%4, axis=1) /* ty=Tensor[(64, 1), float32] */; %6 = expand_dims(%5, axis=1) /* ty=Tensor[(64, 1, 1), float32] */; %7 = multiply(%running_mean, -1f /* ty=float32 */) /* ty=Tensor[(64), float32] */; %8 = multiply(%7, %4) /* ty=Tensor[(64), float32] */; %9 = add(%8, %bias) /* ty=Tensor[(64), float32] */; %10 = expand_dims(%9, axis=1) /* ty=Tensor[(64, 1), float32] */; %11 = multiply(%input, %6) /* ty=Tensor[(1, 64, 112, 112), float32] */; %12 = expand_dims(%10, axis=1) /* ty=Tensor[(64, 1, 1), float32] */; add(%11, %12) /* ty=Tensor[(1, 64, 112, 112), float32] */ } ```
- Loading branch information