Skip to content

Commit

Permalink
Remove OOM guard for aten.max_pool2d_with_indices.default (#677)
Browse files Browse the repository at this point in the history
* Remove OOM guard for max_pool2d

* Update max_pool2d tests
  • Loading branch information
jerrysky3 authored Dec 26, 2024
1 parent 51e2076 commit 97683cb
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 13 deletions.
10 changes: 1 addition & 9 deletions tests/lowering/pool/test_max_pool_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def forward(self, *args, **kwargs):
((1, 128, 112, 112), (2, 2), 2, 0, 1, False),
((1, 512, 19, 19), (3, 3), 1, 1, 1, False),
((1, 192, 28, 28), (3, 3), 1, 1, 1, False),
((1, 64, 360, 640), (3, 3), 2, 1, 1, False),
pytest.param(
(1, 320, 28, 28),
(3, 3),
Expand All @@ -38,15 +39,6 @@ def forward(self, *args, **kwargs):
True,
marks=pytest.mark.xfail(reason="ceil_mode=True is not supported yet (tt-metal#14976)"),
),
pytest.param(
(1, 64, 360, 640),
(3, 3),
2,
1,
1,
False,
marks=pytest.mark.xfail(reason="OOM (#385)"),
),
pytest.param(
(1, 4, 14, 14),
(2, 2),
Expand Down
4 changes: 0 additions & 4 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,13 +994,9 @@ def reshape_1d(code, args=args, kwargs=kwargs):
padding = params.get("padding", (0, 0))
dilation = params.get("dilation", (1, 1))
ceil_mode = params.get("ceil_mode", False)
# Assume the element size is bfloat16
volume = (batch_size * in_c * in_h * in_w) * 2
if (
# TODO(tt-metal#14976): ceil mode isn't supported yet
ceil_mode
# TODO(#385): OOM
or volume > 16 * 1024 * 1024
# TODO(tt-metal#13901): Wide input channels can only be multiple of 8 tiles
or (in_c > (ttnn.TILE_SIZE * 8) and in_c % (ttnn.TILE_SIZE * 8) != 0)
# TODO(#419): Currently fails with in_c < 16
Expand Down

0 comments on commit 97683cb

Please sign in to comment.