From ad216ae16af8823f8c03e25b22b159bcae74c8dd Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Fri, 31 Jan 2025 09:51:15 +0000 Subject: [PATCH] #0: Simplify MeshDevice construction by eradicating mesh_type --- conftest.py | 6 ++-- models/MODEL_HYBRID_TP_DP.md | 2 +- .../tests/test_llama_perf_decode.py | 2 +- tech_reports/LLMs/llms.md | 2 +- .../Programming_Mesh_of_Devices_with_TT-NN.md | 8 +++--- .../TT-Distributed-Architecture-1219.md | 8 ++---- .../sweeps/ccl/line_all_gather.py | 2 +- .../distributed/test_distributed_reshape.cpp | 11 ++------ .../test_tensor_parallel_example_T3000.py | 1 - .../unit_tests/gtests/ttnn_test_fixtures.hpp | 2 +- tests/ttnn/unit_tests/test_multi_device.py | 12 ++++++-- tt-train/sources/ttml/core/mesh_device.cpp | 3 +- tt_metal/api/tt-metalium/mesh_config.hpp | 2 -- tt_metal/api/tt-metalium/mesh_device.hpp | 13 +++------ tt_metal/api/tt-metalium/mesh_device_view.hpp | 2 +- tt_metal/distributed/mesh_device.cpp | 28 +++++++------------ tt_metal/distributed/mesh_device_view.cpp | 11 ++------ ttnn/cpp/ttnn/distributed/api.cpp | 6 ++-- ttnn/cpp/ttnn/distributed/api.hpp | 1 - .../ttnn/distributed/distributed_pybind.cpp | 17 ++--------- ttnn/cpp/ttnn/distributed/types.hpp | 2 -- .../operations/ccl/barrier/barrier_pybind.cpp | 2 +- ttnn/ttnn/distributed/__init__.py | 1 - ttnn/ttnn/distributed/distributed.py | 4 --- 24 files changed, 48 insertions(+), 100 deletions(-) diff --git a/conftest.py b/conftest.py index f965b25e257..7e60cfe6fed 100644 --- a/conftest.py +++ b/conftest.py @@ -254,10 +254,9 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic updated_device_params = get_updated_device_params(device_params) mesh_device = ttnn.open_mesh_device( - mesh_shape=ttnn.MeshShape(2, 2), + mesh_shape=ttnn.MeshShape(1, num_pcie_devices_requested), **updated_device_params, - offset=ttnn.MeshOffset(0, 1), - mesh_type=ttnn.MeshType.Ring, + physical_device_ids=device_ids[:num_pcie_devices_requested], ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") @@ -305,7 +304,6 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device mesh_device = ttnn.open_mesh_device( mesh_shape=ttnn.MeshShape(2, 4), **updated_device_params, - mesh_type=ttnn.MeshType.Ring, ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") diff --git a/models/MODEL_HYBRID_TP_DP.md b/models/MODEL_HYBRID_TP_DP.md index 299cfdc369c..148b8e05b13 100644 --- a/models/MODEL_HYBRID_TP_DP.md +++ b/models/MODEL_HYBRID_TP_DP.md @@ -16,7 +16,7 @@ The main changes involve: ```python # Work with submesh device as you would with a regular ttnn.MeshDevice - submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring) + submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4)) ``` ### 2. Compile & Run the Model on Each Submesh diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py index 1a8d8e2a3eb..e4fec61ceef 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py @@ -296,7 +296,7 @@ def run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel( profiler.clear() submesh_to_metadata = defaultdict(dict) - submeshes = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring) + submeshes = mesh_device.create_submeshes((2, 4)) for submesh in submeshes: # Set up model ----------------------------------------------------------------------- logger.info("Moving weights to devices; might take some time...") diff --git a/tech_reports/LLMs/llms.md b/tech_reports/LLMs/llms.md index db07c296c59..12abeb815b7 100644 --- a/tech_reports/LLMs/llms.md +++ b/tech_reports/LLMs/llms.md @@ -1195,7 +1195,7 @@ Below is a summary and example code of the most important concepts for mapping a import ttnn # 2x4 mesh_device, Topology Ring: devices are connected in a ring -mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring) +mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4)) # Construct initial torch tensor torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16) diff --git a/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md b/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md index fc38a435373..862921f5d33 100644 --- a/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md +++ b/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md @@ -296,7 +296,7 @@ Let's see an example of how to use the Ring All-Gather operation: ```py import ttnn -mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring) +mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4)) # Construct test tensor of data; 8 chunks of 32x32 torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16) @@ -328,7 +328,7 @@ The result tensor for each device in the column is the concatenation in `dim=3` ```py import ttnn -mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring) +mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4)) # Construct test tensor of data; 8 chunks of 32x32 torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16) @@ -534,7 +534,7 @@ torch_hidden_states = (torch.rand(batch_size, 1, sequence_length, config.hidden_ torch_output = model.forward(torch_hidden_states) # Device Initialization -mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4), mesh_type=ttnn.MeshType.Ring) +mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4)) # Initialize input activations on all devices in the mesh # Alternatively, we can shard the input activations on the height dimension and @@ -602,7 +602,7 @@ See `models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py::test_Llama_p 1. Submesh Creation ```py - submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring) + submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4)) ``` 2. Compile & Run the Model on Each Submesh diff --git a/tech_reports/TT-Distributed/TT-Distributed-Architecture-1219.md b/tech_reports/TT-Distributed/TT-Distributed-Architecture-1219.md index a15f6dee9da..869d52930df 100644 --- a/tech_reports/TT-Distributed/TT-Distributed-Architecture-1219.md +++ b/tech_reports/TT-Distributed/TT-Distributed-Architecture-1219.md @@ -223,10 +223,6 @@ struct MeshConfig { // Offset into Logical Device Coordinate Space MeshOffset offset; - - // TODO: consider whether this should be automatically inferred. - // Interpret as e.g. {Ring, Line} - MeshType type; }; // Class exposing host and device dispatch state @@ -986,8 +982,8 @@ Below, we include snippets from both the TT-Mesh and TT-Metal examples to illust *Specify MeshConfig when creating two Virtual Meshes on a Physical Mesh.* ```cpp -MeshConfig mesh_config_0 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 0}, .type=mesh_type}; -MeshConfig mesh_config_1 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 4}, .type=mesh_type}; +MeshConfig mesh_config_0 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 0}}; +MeshConfig mesh_config_1 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 4}}; DeviceHandle virtual_mesh_0 = CreateMeshDevice(mesh_config_0, 2 /* num_command_queues */, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE); DeviceHandle virtual_mesh_0 = CreateMeshDevice(mesh_config_1, 2 /* num_command_queues */, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE); diff --git a/tests/sweep_framework/sweeps/ccl/line_all_gather.py b/tests/sweep_framework/sweeps/ccl/line_all_gather.py index bdd30d3393c..b30cd0f9f1e 100644 --- a/tests/sweep_framework/sweeps/ccl/line_all_gather.py +++ b/tests/sweep_framework/sweeps/ccl/line_all_gather.py @@ -66,7 +66,7 @@ def mesh_device_fixture(): assert ttnn.get_num_devices() >= 8, "Not T3000!" device_ids = ttnn.get_t3k_physical_device_ids_ring() num_devices_requested = len(device_ids) - mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested), mesh_type=ttnn.MeshType.Line) + mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested)) print("ALL GATHER: Opened device mesh") yield (mesh_device, "T3000 Mesh") diff --git a/tests/ttnn/distributed/test_distributed_reshape.cpp b/tests/ttnn/distributed/test_distributed_reshape.cpp index fd218327348..a537c45eb8a 100644 --- a/tests/ttnn/distributed/test_distributed_reshape.cpp +++ b/tests/ttnn/distributed/test_distributed_reshape.cpp @@ -27,7 +27,7 @@ void check_test_environment() { std::vector get_physical_device_ids(const MeshDevice& mesh) { std::vector device_ids; - for (auto* device : mesh.get_devices(ttnn::distributed::MeshType::RowMajor)) { + for (auto* device : mesh.get_devices()) { device_ids.push_back(device->id()); } return device_ids; @@ -138,12 +138,7 @@ TEST_F(T3000ReshapeTest, From1x8To2x4) { TEST_F(T3000ReshapeTest, OnRingTopology) { auto mesh = ttnn::distributed::open_mesh_device( - {1, 8}, - DEFAULT_L1_SMALL_SIZE, - DEFAULT_TRACE_REGION_SIZE, - 1, - tt::tt_metal::DispatchCoreType::WORKER, - ttnn::distributed::MeshType::Ring); + {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); EXPECT_EQ(mesh->num_rows(), 1); EXPECT_EQ(mesh->num_cols(), 8); @@ -228,7 +223,6 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) { // Fetch the device ids for a physically connected 2x2 mesh. auto physical_device_ids = system_mesh.get_mapped_physical_device_ids(MeshDeviceConfig{ .mesh_shape = MeshShape{2, 2}, - .mesh_type = ttnn::distributed::MeshType::Line, }); // Supply the physical device ids to the mesh constructor that we know we know is 2x2 physically connected. @@ -239,7 +233,6 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) { DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER, - ttnn::distributed::MeshType::Line, MeshOffset{0, 0}, physical_device_ids); diff --git a/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py b/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py index 673646a950e..65e1c18dc7a 100644 --- a/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py +++ b/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py @@ -36,7 +36,6 @@ def test_tensor_parallel_falcon_mlp(): mesh_device = ttnn.open_mesh_device( ttnn.MeshShape(2, 4), - mesh_type=ttnn.MeshType.Ring, ) # Set PyTorch seed for reproducibility diff --git a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp index ac0951e1975..c4ad28babc8 100644 --- a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp +++ b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp @@ -67,7 +67,7 @@ class T3kMultiDeviceFixture : public ::testing::Test { if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) { GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine."; } - mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}, .mesh_type = MeshType::Ring}); + mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}}); } void TearDown() override { diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index a3a1e491824..f1c82cf8409 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -672,22 +672,28 @@ def test_visualize_mesh_device(t3k_mesh_device): ttnn.visualize_mesh_device(t3k_mesh_device) -def test_all_gather_multiple_submeshes(t3k_mesh_device): +def test_all_gather_multiple_submeshes(): """Test all_gather with multiple submeshes""" + mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4)) + ttnn.visualize_mesh_device(mesh_device) + def model(submesh): full_tensor = torch.ones((1, 1, 32, 32 * submesh.get_num_devices()), dtype=torch.bfloat16) for i in range(submesh.get_num_devices()): full_tensor[..., i * 32 : (i + 1) * 32] = i + for device in submesh.get_devices(): + print(device.id()) + ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(submesh, dim=3)) ttnn_tensor = ttnn.to_device(ttnn_tensor, submesh) - ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) + ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1, topology=ttnn.Topology.Ring) for device_tensor in ttnn.get_device_tensors(ttnn_tensor): device_tensor_torch = ttnn.to_torch(device_tensor) assert torch.all(device_tensor_torch == full_tensor) - submesh_devices = t3k_mesh_device.create_submeshes(ttnn.MeshShape(2, 2), ttnn.MeshType.Ring) + submesh_devices = mesh_device.create_submeshes(ttnn.MeshShape(2, 2)) for submesh in submesh_devices: model(submesh) diff --git a/tt-train/sources/ttml/core/mesh_device.cpp b/tt-train/sources/ttml/core/mesh_device.cpp index b6fff694bf0..079604b0d9b 100644 --- a/tt-train/sources/ttml/core/mesh_device.cpp +++ b/tt-train/sources/ttml/core/mesh_device.cpp @@ -12,8 +12,7 @@ MeshDevice::MeshDevice(tt::tt_metal::distributed::MeshShape shape) : DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, /* num_command_queues*/ 1, - DispatchCoreConfig{}, - ttnn::distributed::MeshType::RowMajor)) { + DispatchCoreConfig{})) { assert(m_mesh_device); } diff --git a/tt_metal/api/tt-metalium/mesh_config.hpp b/tt_metal/api/tt-metalium/mesh_config.hpp index b04e16e1649..a37111f076e 100644 --- a/tt_metal/api/tt-metalium/mesh_config.hpp +++ b/tt_metal/api/tt-metalium/mesh_config.hpp @@ -36,13 +36,11 @@ struct MeshShape { * * - Line: Devices are arranged linearly in a single dimension. */ -enum class MeshType { RowMajor, Ring, Line }; struct MeshDeviceConfig { MeshShape mesh_shape{0, 0}; MeshOffset offset{0, 0}; std::vector physical_device_ids{}; - MeshType mesh_type{MeshType::RowMajor}; }; } // namespace tt::tt_metal::distributed diff --git a/tt_metal/api/tt-metalium/mesh_device.hpp b/tt_metal/api/tt-metalium/mesh_device.hpp index 11d72846a7f..0f4a6aa5e4c 100644 --- a/tt_metal/api/tt-metalium/mesh_device.hpp +++ b/tt_metal/api/tt-metalium/mesh_device.hpp @@ -55,7 +55,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this scoped_devices_; MeshDeviceID mesh_id_; MeshShape mesh_shape_; - MeshType type_; std::unique_ptr view_; std::vector> submeshes_; // Parent owns submeshes and is responsible for their destruction @@ -71,7 +70,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this mesh_handle, const MeshShape& mesh_shape, - MeshType type, std::weak_ptr parent_mesh = {}); ~MeshDevice() override; @@ -200,9 +198,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this get_devices(const std::optional& type = std::nullopt) const; + // Returns the devices in the mesh in row-major order. + std::vector get_devices() const; IDevice* get_device_index(size_t logical_device_id) const; IDevice* get_device(chip_id_t physical_device_id) const; IDevice* get_device(size_t row_idx, size_t col_idx) const; @@ -238,12 +236,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this> get_submeshes() const; std::shared_ptr create_submesh( - const MeshShape& submesh_shape, - const MeshOffset& offset = MeshOffset{0, 0}, - MeshType type = MeshType::RowMajor); + const MeshShape& submesh_shape, const MeshOffset& offset = MeshOffset{0, 0}); - std::vector> create_submeshes( - const MeshShape& submesh_shape, MeshType type = MeshType::RowMajor); + std::vector> create_submeshes(const MeshShape& submesh_shape); // These methods will get removed once in favour of the ones in IDevice* and TT-Mesh bringup // These are prefixed with "mesh_" to avoid conflicts with the IDevice* methods diff --git a/tt_metal/api/tt-metalium/mesh_device_view.hpp b/tt_metal/api/tt-metalium/mesh_device_view.hpp index 4db29921404..224a37598bb 100644 --- a/tt_metal/api/tt-metalium/mesh_device_view.hpp +++ b/tt_metal/api/tt-metalium/mesh_device_view.hpp @@ -74,7 +74,7 @@ class MeshDeviceView { // devices are returned in row-major order with start/end coordinates inclusive [[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end) const; [[nodiscard]] DeviceView get_devices(const MeshShape& submesh_shape) const; - [[nodiscard]] DeviceView get_devices(MeshType type = MeshType::RowMajor) const; + [[nodiscard]] DeviceView get_devices() const; [[nodiscard]] DeviceView get_devices_on_row(size_t row) const; [[nodiscard]] DeviceView get_devices_on_column(size_t col) const; diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index 518a8c4455f..4f866c3a0a9 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -119,16 +119,11 @@ uint32_t MeshDevice::dram_size_per_channel() const { IDevice* MeshDevice::reference_device() const { return this->get_devices().at(0); } MeshDevice::MeshDevice( - std::shared_ptr mesh_handle, - const MeshShape& mesh_shape, - MeshType type, - std::weak_ptr parent_mesh) : + std::shared_ptr mesh_handle, const MeshShape& mesh_shape, std::weak_ptr parent_mesh) : scoped_devices_(std::move(mesh_handle)), mesh_shape_(mesh_shape), - type_(type), mesh_id_(generate_unique_mesh_id()), - parent_mesh_(std::move(parent_mesh)) -{ + parent_mesh_(std::move(parent_mesh)) { work_executor_ = std::make_unique(0 /* worker_core */, mesh_id_); work_executor_->initialize(); work_executor_->set_worker_mode(WorkExecutorMode::SYNCHRONOUS); @@ -142,16 +137,15 @@ std::shared_ptr MeshDevice::create( const DispatchCoreConfig& dispatch_core_config, tt::stl::Span l1_bank_remap) { auto mesh_device = std::make_shared( - std::make_shared(l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config), - config.mesh_shape, - config.mesh_type); + std::make_shared( + l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config), + config.mesh_shape); mesh_device->initialize(num_command_queues, l1_small_size, trace_region_size, l1_bank_remap); return mesh_device; } -std::shared_ptr MeshDevice::create_submesh( - const MeshShape& submesh_shape, const MeshOffset& offset, MeshType type) { +std::shared_ptr MeshDevice::create_submesh(const MeshShape& submesh_shape, const MeshOffset& offset) { if (submesh_shape.num_rows <= 0 || submesh_shape.num_cols <= 0) { TT_THROW( "Invalid submesh shape: ({}, {}). Both dimensions must be positive.", @@ -175,7 +169,7 @@ std::shared_ptr MeshDevice::create_submesh( mesh_shape_.num_cols); } - auto submesh = std::make_shared(scoped_devices_, submesh_shape, type, shared_from_this()); + auto submesh = std::make_shared(scoped_devices_, submesh_shape, shared_from_this()); auto start_coordinate = Coordinate{offset.row, offset.col}; auto end_coordinate = Coordinate{offset.row + submesh_shape.num_rows - 1, offset.col + submesh_shape.num_cols - 1}; @@ -196,11 +190,11 @@ std::shared_ptr MeshDevice::create_submesh( return submesh; } -std::vector> MeshDevice::create_submeshes(const MeshShape& submesh_shape, MeshType type) { +std::vector> MeshDevice::create_submeshes(const MeshShape& submesh_shape) { std::vector> submeshes; for (int row = 0; row < this->num_rows(); row += submesh_shape.num_rows) { for (int col = 0; col < this->num_cols(); col += submesh_shape.num_cols) { - auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col}, type); + auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col}); submeshes.push_back(submesh); } } @@ -224,9 +218,7 @@ IDevice* MeshDevice::get_device(chip_id_t physical_device_id) const { TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); } -std::vector MeshDevice::get_devices(const std::optional& requested_type) const { - return view_->get_devices(requested_type.value_or(type_)); -} +std::vector MeshDevice::get_devices() const { return view_->get_devices(); } // TODO: Remove this function once we have a proper view interface IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const { diff --git a/tt_metal/distributed/mesh_device_view.cpp b/tt_metal/distributed/mesh_device_view.cpp index 3888fbf1e4f..883b9a38ebb 100644 --- a/tt_metal/distributed/mesh_device_view.cpp +++ b/tt_metal/distributed/mesh_device_view.cpp @@ -39,7 +39,7 @@ MeshDeviceView::MeshDeviceView(const std::vector& devices, Coordinate } MeshDeviceView::MeshDeviceView(const MeshDevice& mesh_device) : - MeshDeviceView(mesh_device.get_devices(MeshType::RowMajor), mesh_device.shape()) {} + MeshDeviceView(mesh_device.get_devices(), mesh_device.shape()) {} MeshDeviceView::MeshDeviceView(const std::vector& devices, const MeshShape& shape) : MeshDeviceView(devices, Coordinate{0, 0}, Coordinate{shape.num_rows - 1, shape.num_cols - 1}) {} @@ -261,13 +261,6 @@ std::vector MeshDeviceView::get_ring_devices() const { return get_devices_from_coordinates(*this, boundary_coords); } -MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType type) const { - switch (type) { - case MeshType::RowMajor: return this->devices_; - case MeshType::Ring: return this->get_ring_devices(); - case MeshType::Line: return this->get_line_devices(); - default: TT_THROW("Unsupported Mesh type: {}", type); - } -} +MeshDeviceView::DeviceView MeshDeviceView::get_devices() const { return this->devices_; } } // namespace tt::tt_metal::distributed diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 8c9e9a6f971..bf7f0d17db6 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -24,11 +24,10 @@ std::shared_ptr open_mesh_device( size_t trace_region_size, size_t num_command_queues, const DispatchCoreConfig& dispatch_core_config, - MeshType mesh_type, const MeshOffset& offset, const std::vector& physical_device_ids) { - auto config = MeshDeviceConfig{ - .mesh_shape = mesh_shape, .offset = offset, .physical_device_ids = physical_device_ids, .mesh_type = mesh_type}; + auto config = + MeshDeviceConfig{.mesh_shape = mesh_shape, .offset = offset, .physical_device_ids = physical_device_ids}; return MeshDevice::create(config, l1_small_size, trace_region_size, num_command_queues, dispatch_core_config); } @@ -152,6 +151,7 @@ std::vector get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_ [&](const ShardTensor2D& s) { return mesh_device.get_view().get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x}); }, + [&](const ShardTensor& s) { return mesh_device.get_view().get_line_devices(); }, [&](const auto&) { return get_workers_for_tensor(); }}, host_storage.strategy); } else if (std::holds_alternative(tensor.get_storage())) { diff --git a/ttnn/cpp/ttnn/distributed/api.hpp b/ttnn/cpp/ttnn/distributed/api.hpp index e25afd3ee38..868aa553d73 100644 --- a/ttnn/cpp/ttnn/distributed/api.hpp +++ b/ttnn/cpp/ttnn/distributed/api.hpp @@ -18,7 +18,6 @@ std::shared_ptr open_mesh_device( size_t trace_region_size, size_t num_command_queues, const tt::tt_metal::DispatchCoreConfig& dispatch_core_config, - MeshType mesh_type = MeshType::RowMajor, const MeshOffset& offset = MeshOffset(0, 0), const std::vector& physical_device_ids = {}); diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 7f11a33ceed..fc49c0cdf09 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -26,12 +26,6 @@ void py_module_types(py::module& module) { } void py_module(py::module& module) { - py::enum_(module, "MeshType") - .value("RowMajor", MeshType::RowMajor) - .value("Ring", MeshType::Ring) - .value("Line", MeshType::Line) - .export_values(); - static_cast>(module.attr("MeshShape")) .def( py::init([](size_t num_rows, size_t num_cols) { return MeshShape(num_rows, num_cols); }), @@ -71,14 +65,12 @@ void py_module(py::module& module) { size_t num_command_queues, const DispatchCoreConfig& dispatch_core_config, const MeshOffset& offset, - const std::vector& physical_device_ids, - MeshType mesh_type) { + const std::vector& physical_device_ids) { return MeshDevice::create( MeshDeviceConfig{ .mesh_shape = mesh_device_shape, .offset = offset, .physical_device_ids = physical_device_ids, - .mesh_type = mesh_type, }, l1_small_size, trace_region_size, @@ -92,8 +84,7 @@ void py_module(py::module& module) { py::arg("num_command_queues"), py::arg("dispatch_core_config"), py::arg("offset"), - py::arg("physical_device_ids"), - py::arg("mesh_type")) + py::arg("physical_device_ids")) .def("get_num_devices", &MeshDevice::num_devices) .def("id", &MeshDevice::id) .def("get_device_ids", &MeshDevice::get_device_ids) @@ -109,7 +100,6 @@ void py_module(py::module& module) { "get_devices", &MeshDevice::get_devices, py::return_value_policy::reference, - py::arg("type") = py::none(), R"doc( Get the devices in the device mesh. @@ -121,13 +111,11 @@ void py_module(py::module& module) { &MeshDevice::create_submesh, py::arg("submesh_shape"), py::arg("offset"), - py::arg("mesh_type"), py::keep_alive<1, 0>()) // Keep MeshDevice alive as long as SubmeshDevice is alive .def( "create_submeshes", &MeshDevice::create_submeshes, py::arg("submesh_shape"), - py::arg("mesh_type"), py::keep_alive<1, 0>()) // Keep MeshDevice alive as long as SubmeshDevices are alive .def( "compute_with_storage_grid_size", @@ -310,7 +298,6 @@ void py_module(py::module& module) { py::arg("num_command_queues"), py::arg("offset"), py::arg("physical_device_ids"), - py::arg("mesh_type"), py::arg("dispatch_core_config")); module.def("close_mesh_device", &close_mesh_device, py::arg("mesh_device"), py::kw_only()); diff --git a/ttnn/cpp/ttnn/distributed/types.hpp b/ttnn/cpp/ttnn/distributed/types.hpp index be033b58fef..c31993a3d01 100644 --- a/ttnn/cpp/ttnn/distributed/types.hpp +++ b/ttnn/cpp/ttnn/distributed/types.hpp @@ -18,7 +18,6 @@ using DeviceIds = tt::tt_metal::distributed::DeviceIds; using MeshDevice = tt::tt_metal::distributed::MeshDevice; using SystemMesh = tt::tt_metal::distributed::SystemMesh; using MeshDeviceView = tt::tt_metal::distributed::MeshDeviceView; -using MeshType = tt::tt_metal::distributed::MeshType; using MeshDeviceConfig = tt::tt_metal::distributed::MeshDeviceConfig; using MeshSubDeviceManagerId = tt::tt_metal::distributed::MeshSubDeviceManagerId; @@ -34,7 +33,6 @@ using ttnn::distributed::MeshDeviceView; using ttnn::distributed::MeshOffset; using ttnn::distributed::MeshShape; using ttnn::distributed::MeshSubDeviceManagerId; -using ttnn::distributed::MeshType; using ttnn::distributed::SystemMesh; } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/barrier/barrier_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/barrier/barrier_pybind.cpp index 84fcf2b37e7..33a22506091 100644 --- a/ttnn/cpp/ttnn/operations/ccl/barrier/barrier_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/barrier/barrier_pybind.cpp @@ -57,7 +57,7 @@ void py_bind_barrier(pybind11::module& module) { Example: >>> full_tensor = torch.randn([1, 1, 256, 256], dtype=torch.bfloat16) >>> input_tensors = torch.chunk(full_tensor, num_devices, dim) - >>> physical_device_ids = ttnn.open_mesh_device(ttnn.MeshShape(1, 8), mesh_type=ttnn.MeshType.Ring) + >>> physical_device_ids = ttnn.open_mesh_device(ttnn.MeshShape(1, 8)) >>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 8), physical_device_ids=physical_device_ids[:8]) >>> tt_input_tensors = [] >>> for i, t in enumerate(input_tensors): diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index 635a60b04fa..02b0c03e677 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -23,5 +23,4 @@ visualize_mesh_device, ConcatMesh2dToTensor, distribute, - MeshType, ) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index 63ec5e9e2da..cf3221e8158 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -18,7 +18,6 @@ def get_mesh_device_core_grid(mesh_device): MeshDevice = ttnn._ttnn.multi_device.MeshDevice MeshDevice.core_grid = property(get_mesh_device_core_grid) DispatchCoreType = ttnn._ttnn.device.DispatchCoreType -MeshType = ttnn._ttnn.multi_device.MeshType def _get_rich_table( @@ -141,7 +140,6 @@ def open_mesh_device( dispatch_core_config: ttnn.DispatchCoreConfig = ttnn.DispatchCoreConfig(), offset: ttnn.MeshOffset = ttnn.MeshOffset(row=0, col=0), physical_device_ids: List[int] = [], - mesh_type: "MeshType" = MeshType.RowMajor, ): """ Open a mesh device with the specified configuration. @@ -154,7 +152,6 @@ def open_mesh_device( dispatch_core_type (int, optional): Type of dispatch core. Defaults to DispatchCoreType.WORKER. offset (ttnn.MeshOffset, optional): Offset in logical mesh coordinates for the mesh device. Defaults to (0, 0). physical_device_ids (List[int], optional): List of physical device IDs to use. Defaults to []. - mesh_type (MeshType, optional): Defines type of mesh requested. Type imposes connectivity constraints and defines device iteration order. Returns: ttnn._ttnn.multi_device.MeshDevice: The opened mesh device. @@ -168,7 +165,6 @@ def open_mesh_device( dispatch_core_config=dispatch_core_config, offset=offset, physical_device_ids=physical_device_ids, - mesh_type=mesh_type, )