Skip to content

Commit

Permalink
Fix accidentally disabled async mode for some t3k tests (#18381)
Browse files Browse the repository at this point in the history
### Ticket
#18360

### Problem description
Recently we disabled async mode for single device, by ignoring
enable_async call for it, assuming multi-device customers make a call to
MeshDevice enable_async. However in some places including our test we
actually iterate over each individual device in the mesh and call
enable_async on it, which is being ignored

### What's changed
Make a single call to MeshDevice::enable_async instead of iterating over
individual devices and calling Device::enable_async for each one of them

### Checklist
- [ ] [All post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/13553947437)
- [x] [T3K demo tests CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/13553950838)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Feb 27, 2025
1 parent 45186bc commit 69a36b8
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
8 changes: 4 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,13 @@ def get_devices(request):
elif "pcie_devices" in request.fixturenames:
devices = request.getfixturevalue("pcie_devices")
elif "mesh_device" in request.fixturenames:
devices = request.getfixturevalue("mesh_device").get_devices()
devices = [request.getfixturevalue("mesh_device")]
elif "n300_mesh_device" in request.fixturenames:
devices = request.getfixturevalue("n300_mesh_device").get_devices()
devices = [request.getfixturevalue("n300_mesh_device")]
elif "t3k_mesh_device" in request.fixturenames:
devices = request.getfixturevalue("t3k_mesh_device").get_devices()
devices = [request.getfixturevalue("t3k_mesh_device")]
elif "pcie_mesh_device" in request.fixturenames:
devices = request.getfixturevalue("pcie_mesh_device").get_devices()
devices = [request.getfixturevalue("pcie_mesh_device")]
else:
devices = []
return devices
Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,7 @@ def test_FalconCausalLM_end_to_end_with_program_cache(
model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices)
devices = t3k_mesh_device.get_devices()
# Set async mode
for device in devices:
device.enable_async(async_mode)
t3k_mesh_device.enable_async(async_mode)
compute_grid_size = devices[0].compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")
Expand Down
4 changes: 1 addition & 3 deletions tests/ttnn/distributed/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,9 +1448,7 @@ def test_device_line_all_gather_8x4_data(mesh_device, cluster_axis: int, dim: in
- Every device will have the shape: [4, 1, 32, 32]
"""
if async_mode:
for i in mesh_device.get_device_ids():
device = mesh_device.get_device(i)
device.enable_async(True)
mesh_device.enable_async(True)

(rows, cols), tile_size = mesh_device.shape, 32
full_tensor = torch.zeros((1, 1, tile_size * rows, tile_size * cols), dtype=torch.bfloat16)
Expand Down

0 comments on commit 69a36b8

Please sign in to comment.