Skip to content

Commit

Permalink
#12955: cleanup device close
Browse files Browse the repository at this point in the history
  • Loading branch information
aliuTT committed Oct 23, 2024
1 parent f3b2069 commit 5205204
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 100 deletions.
2 changes: 0 additions & 2 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2946,8 +2946,6 @@ bool Device::close() {
}
}

tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false);

tt::Cluster::instance().l1_barrier(id_);
allocator::clear(*this->allocator_);
// After device close, no buffers on this device should be used
Expand Down
57 changes: 55 additions & 2 deletions tt_metal/impl/device/device_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ std::vector<v1::DeviceHandle> DevicePool::get_all_active_devices() const {
}

bool DevicePool::close_device(chip_id_t device_id) {
// Sync and close one device
// Currently can only call this on mmio chips, once we split dispatch kernel shutdown
// from device close, we can call this on remote devices too
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false);
bool pass = true;
const auto& mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id);
Expand All @@ -439,12 +442,62 @@ bool DevicePool::close_device(chip_id_t device_id) {
return pass;
}

DevicePool::~DevicePool() {
log_debug(tt::LogMetal, "DevicePool destructor");
void DevicePool::close_devices(const std::vector<Device*>& devices) {
// Ordered, because we need to shutdown tunnels from the farthest to the closest.
std::vector<chip_id_t> devices_to_close;

// Loop over all devices and add remote devices to devices_to_close
// For Galaxy if an mmio device's tunnels are being closed, close the mmio device as well
std::unordered_set<chip_id_t> mmio_devices_to_close;
for (const auto &dev : devices) {
const auto &mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(dev->id());
if (mmio_devices_to_close.find(mmio_device_id) != mmio_devices_to_close.end()) {
continue;
}
auto mmio_dev_handle = tt::DevicePool::instance().get_active_device(mmio_device_id);
auto tunnels_from_mmio = mmio_dev_handle->tunnels_from_mmio_;
//iterate over all tunnels origination from this mmio device
for (auto t : tunnels_from_mmio) {
//iterate over all tunneled devices (tunnel stops) in this tunnel
for (uint32_t ts = t.size() - 1; ts > 0; ts--) {
if (this->is_device_active(t[ts])) {
devices_to_close.push_back(t[ts]);
}
}
}
devices_to_close.push_back(mmio_device_id);
mmio_devices_to_close.insert(mmio_device_id);
}


// Global Sync across all devices that are being closed
// We need to ensure that commands sent to each device have been completed
// before closing any device + modifying routing info.
// If this is not done, non-blocking CCLs followed by a close will hang, since
// the main thread will modify device state while the CCL is running on device.
for (const auto &dev_id : devices_to_close) {
auto dev = tt::DevicePool::instance().get_active_device(dev_id);
dev->synchronize(); // Synchronize worker queue
Synchronize(dev); // Synchronize device
}

tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false);
for (const auto &dev_id : devices_to_close) {
auto dev = tt::DevicePool::instance().get_active_device(dev_id);
dev->close();
// When a device is closed, its worker thread is joined. Stop tracking this
// worker thread.
this->unregister_worker_thread_for_device(this->get_handle(dev));
}
}

DevicePool::~DevicePool() {
log_debug(tt::LogMetal, "DevicePool destructor");
for (const auto& dev : this->devices) {
if (dev != nullptr and dev->is_initialized()) {
// TODO: #13876, Was encountering issues with the dispatch_constants being destroyed before the DevicePool destructor,
// which leads to device->close() hitting asserts. We need to move the ownership of dispatch_constants
// to the device, so it doesn't go out of scope before the device is closed.
dev->close();
}
}
Expand Down
2 changes: 2 additions & 0 deletions tt_metal/impl/device/device_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#pragma once

#include "tt_metal/host_api.hpp"
#include "impl/debug/dprint_server.hpp"
#include "impl/debug/noc_logging.hpp"
#include "impl/debug/watcher_server.hpp"
Expand Down Expand Up @@ -45,6 +46,7 @@ class DevicePool {
v1::DeviceHandle get_active_device(chip_id_t device_id) const;
std::vector<v1::DeviceHandle> get_all_active_devices() const;
bool close_device(chip_id_t device_id);
void close_devices(const std::vector<Device *> &devices);
bool is_device_active(chip_id_t id) const;
void register_worker_thread_for_device(v1::DeviceHandle device, std::thread::id worker_thread_id);
void unregister_worker_thread_for_device(v1::DeviceHandle device);
Expand Down
86 changes: 44 additions & 42 deletions tt_metal/llrt/tt_cluster.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,56 +909,58 @@ void Cluster::set_internal_routing_info_for_ethernet_cores(bool enable_internal_
// TODO: initialize devices if user does not
// Must initialize remote chips first, then mmio chips since once mmio chips are doing fd routing
// we do not always context switch to base FW
const routing_info_t routing_info_disabled = {
.routing_enabled = 0,
.src_sent_valid_cmd = 0,
.dst_acked_valid_cmd = 0,
};
const routing_info_t routing_info_enabled = {
.routing_enabled = 1,
.src_sent_valid_cmd = 0,
.dst_acked_valid_cmd = 0,
};
std::vector<chip_id_t> mmio_devices;
mmio_devices.reserve(this->devices_grouped_by_assoc_mmio_device_.size());
std::vector<chip_id_t> non_mmio_devices;
for (const auto &[assoc_mmio_device, devices] : this->devices_grouped_by_assoc_mmio_device_) {
mmio_devices.emplace_back(assoc_mmio_device);
for (const auto &chip_id : devices) {
non_mmio_devices.emplace_back(chip_id);
}
}

if (enable_internal_routing) {
const routing_info_t routing_info_enabled = {
.routing_enabled = 1,
.src_sent_valid_cmd = 0,
.dst_acked_valid_cmd = 0,
};
for (const auto &chip_id : non_mmio_devices) {
for (const auto &[eth_core, routing_info] : this->device_eth_routing_info_.at(chip_id)) {
tt_cxy_pair eth_phys_core(chip_id, ethernet_core_from_logical_core(chip_id, eth_core));
if (chip_id == assoc_mmio_device and not enable_internal_routing) {
// Disable internal ethernet routing for mmio devices
write_core(
(void *)&routing_info_disabled,
sizeof(routing_info_t),
eth_phys_core,
routing_info_addr,
false);
} else if (chip_id != assoc_mmio_device and enable_internal_routing) {
// Enable internal ethernet routing for non-mmio devices
write_core(
(void *)&routing_info_enabled, sizeof(routing_info_t), eth_phys_core, routing_info_addr, false);

} else {
continue;
}
// Enable internal ethernet routing for non-mmio devices
write_core(
(void *)&routing_info_enabled, sizeof(routing_info_t), eth_phys_core, routing_info_addr, false);
}
}
for (const auto &chip_id : devices) {
for (const auto &chip_id : mmio_devices) {
for (const auto &[eth_core, routing_info] : this->device_eth_routing_info_.at(chip_id)) {
tt_cxy_pair eth_phys_core(chip_id, ethernet_core_from_logical_core(chip_id, eth_core));
if (chip_id != assoc_mmio_device and not enable_internal_routing) {
// Disable internal ethernet routing for non-mmio devices
write_core(
(void *)&routing_info_disabled,
sizeof(routing_info_t),
eth_phys_core,
routing_info_addr,
false);
} else if (chip_id == assoc_mmio_device and enable_internal_routing) {
// Enable internal ethernet routing for mmio devices
write_core(
(void *)&routing_info_enabled, sizeof(routing_info_t), eth_phys_core, routing_info_addr, false);
} else {
continue;
}
// Enable internal ethernet routing for mmio devices
write_core(
(void *)&routing_info_enabled, sizeof(routing_info_t), eth_phys_core, routing_info_addr, false);
}
}
} else {
const routing_info_t routing_info_disabled = {
.routing_enabled = 0,
.src_sent_valid_cmd = 0,
.dst_acked_valid_cmd = 0,
};
for (const auto &chip_id : mmio_devices) {
for (const auto &[eth_core, routing_info] : this->device_eth_routing_info_.at(chip_id)) {
tt_cxy_pair eth_phys_core(chip_id, ethernet_core_from_logical_core(chip_id, eth_core));
// Disable internal ethernet routing for mmio devices
write_core(
(void *)&routing_info_disabled, sizeof(routing_info_t), eth_phys_core, routing_info_addr, false);
}
}
for (const auto &chip_id : non_mmio_devices) {
for (const auto &[eth_core, routing_info] : this->device_eth_routing_info_.at(chip_id)) {
tt_cxy_pair eth_phys_core(chip_id, ethernet_core_from_logical_core(chip_id, eth_core));
// Disable internal ethernet routing for non-mmio devices
write_core(
(void *)&routing_info_disabled, sizeof(routing_info_t), eth_phys_core, routing_info_addr, false);
}
}
}
Expand Down
58 changes: 4 additions & 54 deletions tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,61 +326,11 @@ std::map<chip_id_t, Device *> CreateDevices(
}

void CloseDevices(std::map<chip_id_t, Device *> devices) {
// Global Sync across all devices in the pool.
// We need to ensure that commands sent to each device have been completed
// before closing any device + modifying routing info.
// If this is not done, non-blocking CCLs followed by a close will hang, since
// the main thread will modify device state while the CCL is running on device.
for (const auto &[device_id, dev] : devices) {
dev->synchronize(); // Synchronize worker queue
Synchronize(dev); // Synchronize device
}
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false);
std::map<chip_id_t, v1::DeviceHandle> mmio_devices = {};
bool is_galaxy = tt::Cluster::instance().is_galaxy_cluster();

if (is_galaxy) {
//On Galaxy, gateway wormhole devices (mmio devices) are not included in the set of devices
//created by CreateDevices(). So when closing devices, we need to find the corresponding
//gateway chips for all the tunneled devcies.
for (const auto &[device_id, dev] : devices) {
const auto &mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id);
if (mmio_devices.find(mmio_device_id) == mmio_devices.end()) {
auto dev_handle = tt::DevicePool::instance().get_active_device(mmio_device_id);
mmio_devices.insert({mmio_device_id, dev_handle});
}
}
} else {
for (const auto &[device_id, dev] : devices) {
if(dev->is_mmio_capable()) {
mmio_devices.insert({device_id, tt::DevicePool::instance().get_handle(dev)});
}
}
for (const auto &[device_id, dev] : mmio_devices) {
devices.erase(device_id);
}
}

for (const auto &[device_id, dev] : mmio_devices) {
//For each mmio device, first close all the remote tunneled devices.
//Close the farthest tunneled device first.
auto tunnels_from_mmio = dev->tunnels_from_mmio_;
//iterate over all tunnels origination from this mmio device
for (auto t : tunnels_from_mmio) {
//iterate over all tunneled devices (tunnel stops) in this tunnel and close them.
for (uint32_t ts = t.size() - 1; ts > 0; ts--) {
if (devices.find(t[ts]) != devices.end()) {
devices[t[ts]]->close();
// When a device is closed, its worker thread is joined. Stop tracking this
// worker thread.
tt::DevicePool::instance().unregister_worker_thread_for_device(tt::DevicePool::instance().get_handle(devices[t[ts]]));
}
}
}
//finally close the mmio device
dev->close();
tt::DevicePool::instance().unregister_worker_thread_for_device(dev);
std::vector<Device *> devices_to_close;
for (auto& [id, device] : devices) {
devices_to_close.push_back(device);
}
tt::DevicePool::instance().close_devices(devices_to_close);
}

bool InWorkerThread() {
Expand Down

0 comments on commit 5205204

Please sign in to comment.