You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/fx/passes/infra/pass_manager.py", line 270, in __call__
res = fn(module)
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/fx/passes/infra/pass_base.py", line 40, in __call__
res = self.call(graph_module)
File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 1337, in call
gm, modified = ReplaceMoreTtManually(gm, self.use_less_ttnn_op_types)
File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 1214, in ReplaceMoreTtManually
new_node = rewrite_node(node)
File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 1123, in rewrite_node
return insert_sharded_nxc_to_ncx(g, output_tensor, node.meta["val"].size())
File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 102, in insert_sharded_nxc_to_ncx
output_tensor = g.call_function(ttnn.reshape, (output_tensor, target_shape))
File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 382, in call_function
new_node.meta["val"] = self._get_output_val(new_node)
File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 408, in _get_output_val
return torch.reshape(self._get_val(args[0]), args[1])
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1712, in dispatch
r = func(*args, **kwargs)
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_ops.py", line 513, in __call__
return self._op(*args, **(kwargs or {}))
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_refs/__init__.py", line 4427, in view
return _reshape_view_helper(a, *shape, allow_copy=False)
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_refs/__init__.py", line 3557, in _reshape_view_helper
shape = utils.infer_size(shape, a.numel())
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_prims_common/__init__.py", line 855, in infer_size
torch._check(
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/__init__.py", line 1087, in _check
_check_with(RuntimeError, cond, message)
File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/__init__.py", line 1070, in _check_with
raise error_type(message_evaluated)
RuntimeError: shape '[16, 112, 112, 64]' is invalid for input of size 802816
The direct cause is this line:
File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 408, in _get_output_val
return torch.reshape(self._get_val(args[0]), args[1])
This is trying to reshape a tensor of 1, 1, 7168, 112 to 16, 112, 112, 64 which is not compatible. The sequence of ops that lead to this is:
ttnn_sharded_to_interleaved inherits the shape of conv directly, so there is something wrong with the metadata for conv.
Because there are a lot of op decompositions and other indirect transformations, the shape metadata is now computed separately for some op lowerings. conv is one of those.
If I change the line below to
batch_t = torch.stack([img_t] * 16)
for example, I get a compilation failure:pytorch2.0_ttnn/tests/models/resnet50/test_resnet50.py
Line 27 in 95d112e
The direct cause is this line:
This is trying to reshape a tensor of
1, 1, 7168, 112
to16, 112, 112, 64
which is not compatible. The sequence of ops that lead to this is:ttnn_sharded_to_interleaved
inherits the shape ofconv
directly, so there is something wrong with the metadata forconv
.Because there are a lot of op decompositions and other indirect transformations, the shape metadata is now computed separately for some op lowerings. conv is one of those.
pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py
Lines 403 to 406 in 95d112e
In this case, the N and C dimensions in the shape
(N, C, H, W)
are always 1.The text was updated successfully, but these errors were encountered: