Skip to content

Commit

Permalink
Disable async mode for single device use cases (#18114)
Browse files Browse the repository at this point in the history
### Ticket

### Problem description
We're moving towards always using a single mesh and remove the
concurrency code, including calls to push_work from ttnn. This PR
disables async for a single device which should be the only path once we
handle multi-device use cases with a single mesh.

### What's changed
Ignore calls to `enable_async()` for single device, logging a warning
that its being ignored.
Add a mutex in `push_work` for sync mode, which provides the same call
serialization guarantee as worker queue for async mode.
Use a direct function call if the number of workers is 1 in
`run_operation.cpp`.

### Checklist
- [x] [All post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/13518698157)
- [x] [Model perf CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/13514345756)
- [x] [T3K model perf CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/13514349012)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Feb 25, 2025
1 parent a32c401 commit 5db78f8
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 52 deletions.
7 changes: 0 additions & 7 deletions tests/tt_eager/tensors/test_async_tensor_apis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,7 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) {
/*layout=*/std::nullopt,
*device);
uint32_t tensor2_device_buf_addr = get_device_buffer_address(tensor2);
// Assign tensor1 to tensor2 and ensure that ref counts are appropriately updated with the buffer for tensor2
// deallocated
tensor2 = tensor1;
EXPECT_EQ(tensor2.tensor_attributes->main_thread_ref_count, 2);
EXPECT_EQ(tensor1.tensor_attributes->main_thread_ref_count, 2);
// To check if tensor2 is deallocated, create a third tensor on device and ensure that its address matches the
// prev addr for tensor2
Tensor tensor3 = ttnn::full(
Expand All @@ -215,7 +211,6 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) {
// This step will copy the tensor to a temp rval and std::move it back to the caller's instance of device_tensor
// Ensure ref count and address remain unchanged
device_tensor = tensor_identity_copy_function(device_tensor);
EXPECT_EQ(device_tensor.tensor_attributes->main_thread_ref_count, 1);
EXPECT_EQ(get_device_buffer_address(device_tensor), device_tensor_address);
}

Expand All @@ -228,7 +223,6 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) {
/*layout=*/std::nullopt,
*device);
Tensor tensor2 = std::move(tensor1);
EXPECT_EQ(tensor2.tensor_attributes->main_thread_ref_count, 1);
}

log_info(LogTest, "Testing Device tensor self-assignment");
Expand All @@ -240,7 +234,6 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) {
*device);
uint32_t tensor_to_self_assign_address = get_device_buffer_address(tensor_to_self_assign);
tensor_to_self_assign = tensor_to_self_assign;
EXPECT_EQ(tensor_to_self_assign.tensor_attributes->main_thread_ref_count, 1);
tensor_to_self_assign = std::move(tensor_to_self_assign);
EXPECT_EQ(get_device_buffer_address(tensor_to_self_assign), tensor_to_self_assign_address);
auto barrier_tensor = tensor_to_self_assign.cpu();
Expand Down
4 changes: 0 additions & 4 deletions tests/ttnn/unit_tests/gtests/test_async_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) {
EXPECT_EQ(ttnn::event_query(workload_event), true);
// Read output back, once workload is complete
ttnn::read_buffer(io_cq, output_tensor, {readback_data});
// Ensure that reference count book keeping is done correctly
// Tensors only have one reference in the main thread. Ensure this is true.
EXPECT_EQ(input_tensor.tensor_attributes->main_thread_ref_count, 1);
EXPECT_EQ(output_tensor.tensor_attributes->main_thread_ref_count, 1);
// Buffers are currently jointly owned by the original buffer object, the storage object and the tensor (3).
EXPECT_EQ(input_buffer.use_count(), 3);
EXPECT_EQ(output_buffer.use_count(), 3);
Expand Down
3 changes: 3 additions & 0 deletions tt_metal/api/tt-metalium/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ class Device : public IDevice {
// Puts device into reset
bool close() override;

// Calls to enable_async are ignored in effort to forcefully disable async for single device use-cases
// MeshDevice calls force_enable_async directly avoiding enable_async call for multi-device use-case
void enable_async(bool enable) override;
void force_enable_async(bool enable);
void synchronize() override;
WorkExecutorMode get_worker_mode() override { return work_executor_.get_worker_mode(); }
bool is_worker_queue_empty() const override { return work_executor_.worker_queue.empty(); }
Expand Down
1 change: 1 addition & 0 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
std::unique_ptr<SubDeviceManagerTracker> sub_device_manager_tracker_;
std::unordered_map<MeshTraceId, std::shared_ptr<MeshTraceBuffer>> trace_buffer_pool_;
uint32_t trace_buffers_size_ = 0;
std::recursive_mutex push_work_mutex_;
// This is a reference device used to query properties that are the same for all devices in the mesh.
IDevice* reference_device() const;

Expand Down
3 changes: 3 additions & 0 deletions tt_metal/api/tt-metalium/work_executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class WorkExecutor {
if (use_passthrough()) {
// Worker is pushing to itself (nested work) or worker thread is not running. Execute work in current
// thread.
// Using a lock to provide the same call serialization guarantee as with worker queue.
std::lock_guard guard(passthrough_mutex);
work_executor();
} else {
// Push to worker queue.
Expand Down Expand Up @@ -200,6 +202,7 @@ class WorkExecutor {
int managed_device_id;
std::condition_variable cv;
std::mutex cv_mutex;
std::recursive_mutex passthrough_mutex;

inline void start_worker() {
this->worker_queue.parent_thread_id = std::this_thread::get_id();
Expand Down
11 changes: 9 additions & 2 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,13 @@ std::vector<std::shared_ptr<MeshDevice>> MeshDevice::get_submeshes() const { ret
std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); }

void MeshDevice::enable_async(bool enable) {
for (auto device : this->get_devices()) {
device->enable_async(enable);
auto devices = this->get_devices();
if (enable && devices.size() == 1) {
tt::log_warning("Async mode is always disabled for a single device, ignoring enable_async call");
return;
}
for (auto device : devices) {
dynamic_cast<Device*>(device)->force_enable_async(enable);
}
}

Expand Down Expand Up @@ -675,6 +680,8 @@ WorkExecutorMode MeshDevice::get_worker_mode() { return WorkExecutorMode::SYNCHR
bool MeshDevice::is_worker_queue_empty() const { return true; }
void MeshDevice::push_work(std::function<void()> work, bool blocking) {
// Execute inline synchronously.
// Using a lock to provide the same call serialization guarantee as an async single device scheduling.
std::lock_guard lock(push_work_mutex_);
work();
}
program_cache::detail::ProgramCache& MeshDevice::get_program_cache() { return reference_device()->get_program_cache(); }
Expand Down
8 changes: 8 additions & 0 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,14 @@ void Device::set_worker_mode(const WorkExecutorMode& mode) {
}

void Device::enable_async(bool enable) {
if (enable) {
tt::log_warning("Async mode is always disabled for a single device, ignoring enable_async call");
} else {
force_enable_async(false);
}
}

void Device::force_enable_async(bool enable) {
auto mode = enable ? WorkExecutorMode::ASYNCHRONOUS : WorkExecutorMode::SYNCHRONOUS;
this->set_worker_mode(mode);
// If a worker thread is spawned for a device, register/track it in a runtime structure.
Expand Down
4 changes: 1 addition & 3 deletions ttnn/cpp/ttnn/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ struct registered_operation_t {
const OptionalTensors optional_output_tensors =
detail::extract_args_to_vector<std::optional<ttnn::Tensor>>(args...);

bool enable_autoformat = false;
tt::tt_metal::operation::launch_op(
[args...](
const Tensors& input_tensors,
Expand All @@ -350,8 +349,7 @@ struct registered_operation_t {
input_tensors,
output_tensors,
optional_input_tensors,
optional_output_tensors,
enable_autoformat);
optional_output_tensors);

if constexpr (std::is_same_v<std::decay_t<execute_on_worker_thread_return_t>, Tensor>) {
return output_tensors.at(0);
Expand Down
33 changes: 5 additions & 28 deletions ttnn/cpp/ttnn/run_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,18 +607,18 @@ void launch_op_func(
const Tensors input_tensors,
OutputType& output_tensors,
const OptionalConstTensors optional_input_tensors,
const OptionalTensors optional_output_tensors,
bool enable_autoformat_device) {
const OptionalTensors optional_output_tensors) {
// Send host side op compile and run to the worker queue
// Assert to ensure that worker threads are specified.
ZoneScopedN("LaunchOp");
auto& workers = detail::get_workers(output_tensors);
std::size_t workers_size = workers.size();
if (not enable_autoformat_device and workers.empty() or tt::tt_metal::detail::InWorkerThread()) {
if (workers.size() <= 1 || tt::tt_metal::detail::InWorkerThread()) {
// Run in main thread or immediately in worker thread
output_tensors = op_func(input_tensors, optional_input_tensors, optional_output_tensors);
return;
}

detail::check_output(output_tensors, workers);
validate_worker_modes(workers);
// Record ref counts for all tensors before pushing to worker queue.
Expand Down Expand Up @@ -667,27 +667,6 @@ void launch_op_func(
// If so, mark them in use by current worker. Tensors shared across workers
// are only supported when each tensor is tied to a single device/worker
// (example all-gather).
if (workers_size == 1) {
// Single worker per tensor and.
for (int i = 0; i < async_safe_input_tensors.size(); i++) {
if (async_safe_input_tensors[i].get_workers().size() and
async_safe_input_tensors[i].get_workers()[0] != workers[0]) {
// This input has a worker assigned that doesn't match the worker of the output being created (its
// shared).
async_safe_input_tensors[i].tensor_attributes->num_sibling_workers_sharing_tensor++;
cross_worker_input_tensor_idx.insert(i);
}
}
for (int i = 0; i < async_safe_optional_input_tensors.size(); i++) {
if (async_safe_optional_input_tensors[i].has_value() and
async_safe_optional_input_tensors[i].value().get_workers().size() and
async_safe_optional_input_tensors[i].value().get_workers()[0] != workers[0]) {
async_safe_optional_input_tensors[i].value().tensor_attributes->num_sibling_workers_sharing_tensor++;
cross_worker_optional_input_tensor_idx.insert(i);
}
}
}

{
ZoneScopedN("PushOpToWorkers");
auto work_lambda = std::make_shared<std::function<void(IDevice*)>>(
Expand Down Expand Up @@ -810,14 +789,12 @@ template void launch_op_func<Tensors>(
const Tensors input_tensors,
Tensors& output_tensors,
const OptionalConstTensors optional_input_tensors,
const OptionalTensors optional_output_tensors,
bool enable_autoformat_device);
const OptionalTensors optional_output_tensors);
template void launch_op_func<OptionalTensors>(
const std::function<OptionalTensors(const Tensors&, const OptionalConstTensors&, const OptionalTensors&)>& op_func,
const Tensors input_tensors,
OptionalTensors& output_tensors,
const OptionalConstTensors optional_input_tensors,
const OptionalTensors optional_output_tensors,
bool enable_autoformat_device);
const OptionalTensors optional_output_tensors);

} // namespace tt::tt_metal::operation
9 changes: 3 additions & 6 deletions ttnn/cpp/ttnn/run_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ __attribute__((noinline)) void launch_op_func(
const Tensors input_tensors,
OutputType& output_tensors,
const OptionalConstTensors optional_input_tensors = {},
const OptionalTensors optional_output_tensors = {},
bool enable_autoformat_device = true);
const OptionalTensors optional_output_tensors = {});

/*
*/
Expand All @@ -137,16 +136,14 @@ void launch_op(
const Tensors input_tensors,
OutputType& output_tensors,
const OptionalConstTensors optional_input_tensors = {},
const OptionalTensors optional_output_tensors = {},
bool enable_autoformat_device = true) {
const OptionalTensors optional_output_tensors = {}) {
using FuncType = std::function<OutputType(const Tensors&, const OptionalConstTensors&, const OptionalTensors&)>;
launch_op_func(
FuncType(std::forward<F>(op_func)),
input_tensors,
output_tensors,
optional_input_tensors,
optional_output_tensors,
enable_autoformat_device);
optional_output_tensors);
}

void launch_with_autoformat(
Expand Down
10 changes: 8 additions & 2 deletions ttnn/cpp/ttnn/tensor/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ Tensor tensor_to_layout(const Tensor& input_tensor, Layout target_layout, IDevic
ZoneScoped;
GraphTracker::instance().track_function_start("Tensor::to_layout", input_tensor, target_layout, worker);
// Only push layout conversion to worker if running in async mode
if (worker and worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS) {
if (worker && worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS) {
// Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage.
Tensor async_safe_tensor = copy_borrowed_tensor_in_async_mode(worker, input_tensor);
Tensor tensor_modified_layout = Tensor(1);
Expand All @@ -163,12 +163,18 @@ Tensor tensor_to_layout(const Tensor& input_tensor, Layout target_layout, IDevic
GraphTracker::instance().track_function_end(tensor_modified_layout);
return tensor_modified_layout;
}

// Running without worker threads (non-async)
TT_ASSERT(
input_tensor.storage_type() != StorageType::DEVICE or
input_tensor.storage_type() != StorageType::MULTI_DEVICE &&
"Bring tensor to host before converting to target layout");
auto output = tensor_impl::to_layout_wrapper(input_tensor, target_layout);
Tensor output;
if (worker) {
worker->push_work([&] { output = tensor_impl::to_layout_wrapper(input_tensor, target_layout); });
} else {
output = tensor_impl::to_layout_wrapper(input_tensor, target_layout);
}
output = tt::tt_metal::set_tensor_id(output);
GraphTracker::instance().track_function_end(output);
return output;
Expand Down

0 comments on commit 5db78f8

Please sign in to comment.