Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resnet50 compilation failure when changing batch size #716

Open
kevinwuTT opened this issue Jan 22, 2025 · 1 comment
Open

resnet50 compilation failure when changing batch size #716

kevinwuTT opened this issue Jan 22, 2025 · 1 comment
Assignees
Labels
bug Something isn't working conversion

Comments

@kevinwuTT
Copy link
Contributor

If I change the line below to batch_t = torch.stack([img_t] * 16) for example, I get a compilation failure:

batch_t = torch.unsqueeze(img_t, 0)

Traceback (most recent call last):
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/fx/passes/infra/pass_manager.py", line 270, in __call__
    res = fn(module)
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/fx/passes/infra/pass_base.py", line 40, in __call__
    res = self.call(graph_module)
  File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 1337, in call
    gm, modified = ReplaceMoreTtManually(gm, self.use_less_ttnn_op_types)
  File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 1214, in ReplaceMoreTtManually
    new_node = rewrite_node(node)
  File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 1123, in rewrite_node
    return insert_sharded_nxc_to_ncx(g, output_tensor, node.meta["val"].size())
  File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 102, in insert_sharded_nxc_to_ncx
    output_tensor = g.call_function(ttnn.reshape, (output_tensor, target_shape))
  File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 382, in call_function
    new_node.meta["val"] = self._get_output_val(new_node)
  File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 408, in _get_output_val
    return torch.reshape(self._get_val(args[0]), args[1])
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1712, in dispatch
    r = func(*args, **kwargs)
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_ops.py", line 513, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_refs/__init__.py", line 4427, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_refs/__init__.py", line 3557, in _reshape_view_helper
    shape = utils.infer_size(shape, a.numel())
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/_prims_common/__init__.py", line 855, in infer_size
    torch._check(
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/__init__.py", line 1087, in _check
    _check_with(RuntimeError, cond, message)
  File "/home/ubuntu/virtualenv/torch_ttnn/lib/python3.8/site-packages/torch/__init__.py", line 1070, in _check_with
    raise error_type(message_evaluated)
RuntimeError: shape '[16, 112, 112, 64]' is invalid for input of size 802816

The direct cause is this line:

File "/home/ubuntu/repo/pytorch2.0_ttnn/torch_ttnn/passes/lowering/to_tt_pass.py", line 408, in _get_output_val
    return torch.reshape(self._get_val(args[0]), args[1])

This is trying to reshape a tensor of 1, 1, 7168, 112 to 16, 112, 112, 64 which is not compatible. The sequence of ops that lead to this is:

conv = ttnn.conv(...)
ttnn_sharded_to_interleaved = ttnn.sharded_to_interleaved(conv, ...)
ttnn_reshape_1 = ttnn.reshape(ttnn_sharded_to_interleaved, [16, 112, 112, 64])

ttnn_sharded_to_interleaved inherits the shape of conv directly, so there is something wrong with the metadata for conv.

Because there are a lot of op decompositions and other indirect transformations, the shape metadata is now computed separately for some op lowerings. conv is one of those.

if node.target in [ttnn.max_pool2d, target_wrappers.conv]:
output_tensor = self._get_val(self.node)[0]
output_shape = list(output_tensor.size())
return output_tensor.new_empty((1, 1, output_shape[0] * math.prod(output_shape[2:]), output_shape[1]))

In this case, the N and C dimensions in the shape (N, C, H, W) are always 1.

@ayerofieiev-tt
Copy link
Member

Can be closed now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working conversion
Projects
Status: No status
Development

No branches or pull requests

2 participants