Skip to content

Commit

Permalink
#0: Simplify MeshDevice construction by eradicating mesh_type
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Jan 31, 2025
1 parent f708212 commit ad216ae
Show file tree
Hide file tree
Showing 24 changed files with 48 additions and 100 deletions.
6 changes: 2 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,9 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic

updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(2, 2),
mesh_shape=ttnn.MeshShape(1, num_pcie_devices_requested),
**updated_device_params,
offset=ttnn.MeshOffset(0, 1),
mesh_type=ttnn.MeshType.Ring,
physical_device_ids=device_ids[:num_pcie_devices_requested],
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down Expand Up @@ -305,7 +304,6 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(2, 4),
**updated_device_params,
mesh_type=ttnn.MeshType.Ring,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down
2 changes: 1 addition & 1 deletion models/MODEL_HYBRID_TP_DP.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The main changes involve:

```python
# Work with submesh device as you would with a regular ttnn.MeshDevice
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4))
```

### 2. Compile & Run the Model on Each Submesh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel(
profiler.clear()

submesh_to_metadata = defaultdict(dict)
submeshes = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
submeshes = mesh_device.create_submeshes((2, 4))
for submesh in submeshes:
# Set up model -----------------------------------------------------------------------
logger.info("Moving weights to devices; might take some time...")
Expand Down
2 changes: 1 addition & 1 deletion tech_reports/LLMs/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ Below is a summary and example code of the most important concepts for mapping a
import ttnn

# 2x4 mesh_device, Topology Ring: devices are connected in a ring
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))

# Construct initial torch tensor
torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ Let's see an example of how to use the Ring All-Gather operation:
```py
import ttnn

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))

# Construct test tensor of data; 8 chunks of 32x32
torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16)
Expand Down Expand Up @@ -328,7 +328,7 @@ The result tensor for each device in the column is the concatenation in `dim=3`
```py
import ttnn

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))

# Construct test tensor of data; 8 chunks of 32x32
torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16)
Expand Down Expand Up @@ -534,7 +534,7 @@ torch_hidden_states = (torch.rand(batch_size, 1, sequence_length, config.hidden_
torch_output = model.forward(torch_hidden_states)

# Device Initialization
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4))

# Initialize input activations on all devices in the mesh
# Alternatively, we can shard the input activations on the height dimension and
Expand Down Expand Up @@ -602,7 +602,7 @@ See `models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py::test_Llama_p
1. Submesh Creation

```py
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4))
```

2. Compile & Run the Model on Each Submesh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,6 @@ struct MeshConfig {

// Offset into Logical Device Coordinate Space
MeshOffset offset;

// TODO: consider whether this should be automatically inferred.
// Interpret as e.g. {Ring, Line}
MeshType type;
};

// Class exposing host and device dispatch state
Expand Down Expand Up @@ -986,8 +982,8 @@ Below, we include snippets from both the TT-Mesh and TT-Metal examples to illust
*Specify MeshConfig when creating two Virtual Meshes on a Physical Mesh.*

```cpp
MeshConfig mesh_config_0 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 0}, .type=mesh_type};
MeshConfig mesh_config_1 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 4}, .type=mesh_type};
MeshConfig mesh_config_0 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 0}};
MeshConfig mesh_config_1 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 4}};

DeviceHandle virtual_mesh_0 = CreateMeshDevice(mesh_config_0, 2 /* num_command_queues */, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE);
DeviceHandle virtual_mesh_0 = CreateMeshDevice(mesh_config_1, 2 /* num_command_queues */, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE);
Expand Down
2 changes: 1 addition & 1 deletion tests/sweep_framework/sweeps/ccl/line_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def mesh_device_fixture():
assert ttnn.get_num_devices() >= 8, "Not T3000!"
device_ids = ttnn.get_t3k_physical_device_ids_ring()
num_devices_requested = len(device_ids)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested), mesh_type=ttnn.MeshType.Line)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested))
print("ALL GATHER: Opened device mesh")

yield (mesh_device, "T3000 Mesh")
Expand Down
11 changes: 2 additions & 9 deletions tests/ttnn/distributed/test_distributed_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void check_test_environment() {

std::vector<chip_id_t> get_physical_device_ids(const MeshDevice& mesh) {
std::vector<chip_id_t> device_ids;
for (auto* device : mesh.get_devices(ttnn::distributed::MeshType::RowMajor)) {
for (auto* device : mesh.get_devices()) {
device_ids.push_back(device->id());
}
return device_ids;
Expand Down Expand Up @@ -138,12 +138,7 @@ TEST_F(T3000ReshapeTest, From1x8To2x4) {

TEST_F(T3000ReshapeTest, OnRingTopology) {
auto mesh = ttnn::distributed::open_mesh_device(
{1, 8},
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
1,
tt::tt_metal::DispatchCoreType::WORKER,
ttnn::distributed::MeshType::Ring);
{1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

EXPECT_EQ(mesh->num_rows(), 1);
EXPECT_EQ(mesh->num_cols(), 8);
Expand Down Expand Up @@ -228,7 +223,6 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) {
// Fetch the device ids for a physically connected 2x2 mesh.
auto physical_device_ids = system_mesh.get_mapped_physical_device_ids(MeshDeviceConfig{
.mesh_shape = MeshShape{2, 2},
.mesh_type = ttnn::distributed::MeshType::Line,
});

// Supply the physical device ids to the mesh constructor that we know we know is 2x2 physically connected.
Expand All @@ -239,7 +233,6 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) {
DEFAULT_TRACE_REGION_SIZE,
1,
tt::tt_metal::DispatchCoreType::WORKER,
ttnn::distributed::MeshType::Line,
MeshOffset{0, 0},
physical_device_ids);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_tensor_parallel_falcon_mlp():

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(2, 4),
mesh_type=ttnn.MeshType.Ring,
)

# Set PyTorch seed for reproducibility
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class T3kMultiDeviceFixture : public ::testing::Test {
if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine.";
}
mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}, .mesh_type = MeshType::Ring});
mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}});
}

void TearDown() override {
Expand Down
12 changes: 9 additions & 3 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,22 +672,28 @@ def test_visualize_mesh_device(t3k_mesh_device):
ttnn.visualize_mesh_device(t3k_mesh_device)


def test_all_gather_multiple_submeshes(t3k_mesh_device):
def test_all_gather_multiple_submeshes():
"""Test all_gather with multiple submeshes"""

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))
ttnn.visualize_mesh_device(mesh_device)

def model(submesh):
full_tensor = torch.ones((1, 1, 32, 32 * submesh.get_num_devices()), dtype=torch.bfloat16)
for i in range(submesh.get_num_devices()):
full_tensor[..., i * 32 : (i + 1) * 32] = i

for device in submesh.get_devices():
print(device.id())

ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(submesh, dim=3))
ttnn_tensor = ttnn.to_device(ttnn_tensor, submesh)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1, topology=ttnn.Topology.Ring)

for device_tensor in ttnn.get_device_tensors(ttnn_tensor):
device_tensor_torch = ttnn.to_torch(device_tensor)
assert torch.all(device_tensor_torch == full_tensor)

submesh_devices = t3k_mesh_device.create_submeshes(ttnn.MeshShape(2, 2), ttnn.MeshType.Ring)
submesh_devices = mesh_device.create_submeshes(ttnn.MeshShape(2, 2))
for submesh in submesh_devices:
model(submesh)
3 changes: 1 addition & 2 deletions tt-train/sources/ttml/core/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ MeshDevice::MeshDevice(tt::tt_metal::distributed::MeshShape shape) :
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
/* num_command_queues*/ 1,
DispatchCoreConfig{},
ttnn::distributed::MeshType::RowMajor)) {
DispatchCoreConfig{})) {
assert(m_mesh_device);
}

Expand Down
2 changes: 0 additions & 2 deletions tt_metal/api/tt-metalium/mesh_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ struct MeshShape {
*
* - Line: Devices are arranged linearly in a single dimension.
*/
enum class MeshType { RowMajor, Ring, Line };

struct MeshDeviceConfig {
MeshShape mesh_shape{0, 0};
MeshOffset offset{0, 0};
std::vector<chip_id_t> physical_device_ids{};
MeshType mesh_type{MeshType::RowMajor};
};

} // namespace tt::tt_metal::distributed
13 changes: 4 additions & 9 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
std::shared_ptr<ScopedDevices> scoped_devices_;
MeshDeviceID mesh_id_;
MeshShape mesh_shape_;
MeshType type_;
std::unique_ptr<MeshDeviceView> view_;
std::vector<std::shared_ptr<MeshDevice>>
submeshes_; // Parent owns submeshes and is responsible for their destruction
Expand All @@ -71,7 +70,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
MeshDevice(
std::shared_ptr<ScopedDevices> mesh_handle,
const MeshShape& mesh_shape,
MeshType type,
std::weak_ptr<MeshDevice> parent_mesh = {});
~MeshDevice() override;

Expand Down Expand Up @@ -200,9 +198,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic

// A MeshDevice is a collection of devices arranged in a 2D grid.
// The type parameter allows the caller to specify how to linearize the devices in the mesh.
// If type is not provided, the default behavior is to return the devices based on the MeshType of the MeshDevice.

std::vector<IDevice*> get_devices(const std::optional<MeshType>& type = std::nullopt) const;
// Returns the devices in the mesh in row-major order.
std::vector<IDevice*> get_devices() const;
IDevice* get_device_index(size_t logical_device_id) const;
IDevice* get_device(chip_id_t physical_device_id) const;
IDevice* get_device(size_t row_idx, size_t col_idx) const;
Expand Down Expand Up @@ -238,12 +236,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
std::vector<std::shared_ptr<MeshDevice>> get_submeshes() const;

std::shared_ptr<MeshDevice> create_submesh(
const MeshShape& submesh_shape,
const MeshOffset& offset = MeshOffset{0, 0},
MeshType type = MeshType::RowMajor);
const MeshShape& submesh_shape, const MeshOffset& offset = MeshOffset{0, 0});

std::vector<std::shared_ptr<MeshDevice>> create_submeshes(
const MeshShape& submesh_shape, MeshType type = MeshType::RowMajor);
std::vector<std::shared_ptr<MeshDevice>> create_submeshes(const MeshShape& submesh_shape);

// These methods will get removed once in favour of the ones in IDevice* and TT-Mesh bringup
// These are prefixed with "mesh_" to avoid conflicts with the IDevice* methods
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/api/tt-metalium/mesh_device_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class MeshDeviceView {
// devices are returned in row-major order with start/end coordinates inclusive
[[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end) const;
[[nodiscard]] DeviceView get_devices(const MeshShape& submesh_shape) const;
[[nodiscard]] DeviceView get_devices(MeshType type = MeshType::RowMajor) const;
[[nodiscard]] DeviceView get_devices() const;

[[nodiscard]] DeviceView get_devices_on_row(size_t row) const;
[[nodiscard]] DeviceView get_devices_on_column(size_t col) const;
Expand Down
28 changes: 10 additions & 18 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,11 @@ uint32_t MeshDevice::dram_size_per_channel() const {
IDevice* MeshDevice::reference_device() const { return this->get_devices().at(0); }

MeshDevice::MeshDevice(
std::shared_ptr<ScopedDevices> mesh_handle,
const MeshShape& mesh_shape,
MeshType type,
std::weak_ptr<MeshDevice> parent_mesh) :
std::shared_ptr<ScopedDevices> mesh_handle, const MeshShape& mesh_shape, std::weak_ptr<MeshDevice> parent_mesh) :
scoped_devices_(std::move(mesh_handle)),
mesh_shape_(mesh_shape),
type_(type),
mesh_id_(generate_unique_mesh_id()),
parent_mesh_(std::move(parent_mesh))
{
parent_mesh_(std::move(parent_mesh)) {
work_executor_ = std::make_unique<WorkExecutor>(0 /* worker_core */, mesh_id_);
work_executor_->initialize();
work_executor_->set_worker_mode(WorkExecutorMode::SYNCHRONOUS);
Expand All @@ -142,16 +137,15 @@ std::shared_ptr<MeshDevice> MeshDevice::create(
const DispatchCoreConfig& dispatch_core_config,
tt::stl::Span<const std::uint32_t> l1_bank_remap) {
auto mesh_device = std::make_shared<MeshDevice>(
std::make_shared<ScopedDevices>(l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config),
config.mesh_shape,
config.mesh_type);
std::make_shared<ScopedDevices>(
l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config),
config.mesh_shape);

mesh_device->initialize(num_command_queues, l1_small_size, trace_region_size, l1_bank_remap);
return mesh_device;
}

std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
const MeshShape& submesh_shape, const MeshOffset& offset, MeshType type) {
std::shared_ptr<MeshDevice> MeshDevice::create_submesh(const MeshShape& submesh_shape, const MeshOffset& offset) {
if (submesh_shape.num_rows <= 0 || submesh_shape.num_cols <= 0) {
TT_THROW(
"Invalid submesh shape: ({}, {}). Both dimensions must be positive.",
Expand All @@ -175,7 +169,7 @@ std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
mesh_shape_.num_cols);
}

auto submesh = std::make_shared<MeshDevice>(scoped_devices_, submesh_shape, type, shared_from_this());
auto submesh = std::make_shared<MeshDevice>(scoped_devices_, submesh_shape, shared_from_this());
auto start_coordinate = Coordinate{offset.row, offset.col};
auto end_coordinate = Coordinate{offset.row + submesh_shape.num_rows - 1, offset.col + submesh_shape.num_cols - 1};

Expand All @@ -196,11 +190,11 @@ std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
return submesh;
}

std::vector<std::shared_ptr<MeshDevice>> MeshDevice::create_submeshes(const MeshShape& submesh_shape, MeshType type) {
std::vector<std::shared_ptr<MeshDevice>> MeshDevice::create_submeshes(const MeshShape& submesh_shape) {
std::vector<std::shared_ptr<MeshDevice>> submeshes;
for (int row = 0; row < this->num_rows(); row += submesh_shape.num_rows) {
for (int col = 0; col < this->num_cols(); col += submesh_shape.num_cols) {
auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col}, type);
auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col});
submeshes.push_back(submesh);
}
}
Expand All @@ -224,9 +218,7 @@ IDevice* MeshDevice::get_device(chip_id_t physical_device_id) const {
TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id);
}

std::vector<IDevice*> MeshDevice::get_devices(const std::optional<MeshType>& requested_type) const {
return view_->get_devices(requested_type.value_or(type_));
}
std::vector<IDevice*> MeshDevice::get_devices() const { return view_->get_devices(); }

// TODO: Remove this function once we have a proper view interface
IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const {
Expand Down
11 changes: 2 additions & 9 deletions tt_metal/distributed/mesh_device_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ MeshDeviceView::MeshDeviceView(const std::vector<IDevice*>& devices, Coordinate
}

MeshDeviceView::MeshDeviceView(const MeshDevice& mesh_device) :
MeshDeviceView(mesh_device.get_devices(MeshType::RowMajor), mesh_device.shape()) {}
MeshDeviceView(mesh_device.get_devices(), mesh_device.shape()) {}

MeshDeviceView::MeshDeviceView(const std::vector<IDevice*>& devices, const MeshShape& shape) :
MeshDeviceView(devices, Coordinate{0, 0}, Coordinate{shape.num_rows - 1, shape.num_cols - 1}) {}
Expand Down Expand Up @@ -261,13 +261,6 @@ std::vector<IDevice*> MeshDeviceView::get_ring_devices() const {
return get_devices_from_coordinates(*this, boundary_coords);
}

MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType type) const {
switch (type) {
case MeshType::RowMajor: return this->devices_;
case MeshType::Ring: return this->get_ring_devices();
case MeshType::Line: return this->get_line_devices();
default: TT_THROW("Unsupported Mesh type: {}", type);
}
}
MeshDeviceView::DeviceView MeshDeviceView::get_devices() const { return this->devices_; }

} // namespace tt::tt_metal::distributed
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ std::shared_ptr<MeshDevice> open_mesh_device(
size_t trace_region_size,
size_t num_command_queues,
const DispatchCoreConfig& dispatch_core_config,
MeshType mesh_type,
const MeshOffset& offset,
const std::vector<int>& physical_device_ids) {
auto config = MeshDeviceConfig{
.mesh_shape = mesh_shape, .offset = offset, .physical_device_ids = physical_device_ids, .mesh_type = mesh_type};
auto config =
MeshDeviceConfig{.mesh_shape = mesh_shape, .offset = offset, .physical_device_ids = physical_device_ids};
return MeshDevice::create(config, l1_small_size, trace_region_size, num_command_queues, dispatch_core_config);
}

Expand Down Expand Up @@ -152,6 +151,7 @@ std::vector<IDevice*> get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_
[&](const ShardTensor2D& s) {
return mesh_device.get_view().get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x});
},
[&](const ShardTensor& s) { return mesh_device.get_view().get_line_devices(); },
[&](const auto&) { return get_workers_for_tensor(); }},
host_storage.strategy);
} else if (std::holds_alternative<MultiDeviceStorage>(tensor.get_storage())) {
Expand Down
Loading

0 comments on commit ad216ae

Please sign in to comment.