From 69a36b8d74d3f59b6309a3a224ab9a6c24d4249e Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Thu, 27 Feb 2025 07:22:18 +0000 Subject: [PATCH] Fix accidentally disabled async mode for some t3k tests (#18381) ### Ticket https://github.com/tenstorrent/tt-metal/issues/18360 ### Problem description Recently we disabled async mode for single device, by ignoring enable_async call for it, assuming multi-device customers make a call to MeshDevice enable_async. However in some places including our test we actually iterate over each individual device in the mesh and call enable_async on it, which is being ignored ### What's changed Make a single call to MeshDevice::enable_async instead of iterating over individual devices and calling Device::enable_async for each one of them ### Checklist - [ ] [All post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/13553947437) - [x] [T3K demo tests CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/13553950838) - [x] New/Existing tests provide coverage for changes --- conftest.py | 8 ++++---- .../demos/t3000/falcon40b/tests/test_falcon_end_to_end.py | 3 +-- tests/ttnn/distributed/test_multidevice_TG.py | 4 +--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/conftest.py b/conftest.py index 9e94913a18f..05cd3c67536 100644 --- a/conftest.py +++ b/conftest.py @@ -342,13 +342,13 @@ def get_devices(request): elif "pcie_devices" in request.fixturenames: devices = request.getfixturevalue("pcie_devices") elif "mesh_device" in request.fixturenames: - devices = request.getfixturevalue("mesh_device").get_devices() + devices = [request.getfixturevalue("mesh_device")] elif "n300_mesh_device" in request.fixturenames: - devices = request.getfixturevalue("n300_mesh_device").get_devices() + devices = [request.getfixturevalue("n300_mesh_device")] elif "t3k_mesh_device" in request.fixturenames: - devices = request.getfixturevalue("t3k_mesh_device").get_devices() + devices = [request.getfixturevalue("t3k_mesh_device")] elif "pcie_mesh_device" in request.fixturenames: - devices = request.getfixturevalue("pcie_mesh_device").get_devices() + devices = [request.getfixturevalue("pcie_mesh_device")] else: devices = [] return devices diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py b/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py index c686d2dda7e..2ea829068a1 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py @@ -526,8 +526,7 @@ def test_FalconCausalLM_end_to_end_with_program_cache( model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices) devices = t3k_mesh_device.get_devices() # Set async mode - for device in devices: - device.enable_async(async_mode) + t3k_mesh_device.enable_async(async_mode) compute_grid_size = devices[0].compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/tests/ttnn/distributed/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py index 6c1c84c5dd9..82b4381c4aa 100644 --- a/tests/ttnn/distributed/test_multidevice_TG.py +++ b/tests/ttnn/distributed/test_multidevice_TG.py @@ -1448,9 +1448,7 @@ def test_device_line_all_gather_8x4_data(mesh_device, cluster_axis: int, dim: in - Every device will have the shape: [4, 1, 32, 32] """ if async_mode: - for i in mesh_device.get_device_ids(): - device = mesh_device.get_device(i) - device.enable_async(True) + mesh_device.enable_async(True) (rows, cols), tile_size = mesh_device.shape, 32 full_tensor = torch.zeros((1, 1, tile_size * rows, tile_size * cols), dtype=torch.bfloat16)