Skip to content

Commit

Permalink
#0: Refactor enqueue_write_buffer
Browse files Browse the repository at this point in the history
  - Move logic for generating dispatch commands
    out of EnqueueWriteBufferCommand to allow reuse
    with MeshBuffer
  • Loading branch information
tt-asaigal committed Jan 19, 2025
1 parent 348e4f6 commit 0763212
Show file tree
Hide file tree
Showing 5 changed files with 520 additions and 505 deletions.
131 changes: 0 additions & 131 deletions tt_metal/api/tt-metalium/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,137 +168,6 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand {
core(core) {}
};

class EnqueueWriteShardedBufferCommand;
class EnqueueWriteInterleavedBufferCommand;
class EnqueueWriteBufferCommand : public Command {
private:
SystemMemoryManager& manager;
CoreType dispatch_core_type;

virtual void add_dispatch_write(HugepageDeviceCommand& command) = 0;
virtual void add_buffer_data(HugepageDeviceCommand& command) = 0;

protected:
IDevice* device;
uint32_t command_queue_id;
NOC noc_index;
const void* src;
const Buffer& buffer;
tt::stl::Span<const uint32_t> expected_num_workers_completed;
tt::stl::Span<const SubDeviceId> sub_device_ids;
uint32_t bank_base_address;
uint32_t padded_page_size;
uint32_t dst_page_index;
uint32_t pages_to_write;
bool issue_wait;

public:
EnqueueWriteBufferCommand(
uint32_t command_queue_id,
IDevice* device,
NOC noc_index,
const Buffer& buffer,
const void* src,
SystemMemoryManager& manager,
bool issue_wait,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
uint32_t bank_base_address,
uint32_t padded_page_size,
uint32_t dst_page_index = 0,
std::optional<uint32_t> pages_to_write = std::nullopt);

void process();

EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_WRITE_BUFFER; }

constexpr bool has_side_effects() { return true; }
};

class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand {
private:
void add_dispatch_write(HugepageDeviceCommand& command) override;
void add_buffer_data(HugepageDeviceCommand& command) override;

uint32_t initial_src_addr_offset;

public:
EnqueueWriteInterleavedBufferCommand(
uint32_t command_queue_id,
IDevice* device,
NOC noc_index,
const Buffer& buffer,
const void* src,
const uint32_t initial_src_addr_offset,
SystemMemoryManager& manager,
bool issue_wait,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
uint32_t bank_base_address,
uint32_t padded_page_size,
uint32_t dst_page_index = 0,
std::optional<uint32_t> pages_to_write = std::nullopt) :
EnqueueWriteBufferCommand(
command_queue_id,
device,
noc_index,
buffer,
src,
manager,
issue_wait,
expected_num_workers_completed,
sub_device_ids,
bank_base_address,
padded_page_size,
dst_page_index,
pages_to_write) {
this->initial_src_addr_offset = initial_src_addr_offset;
}
};

class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand {
private:
void add_dispatch_write(HugepageDeviceCommand& command) override;
void add_buffer_data(HugepageDeviceCommand& command) override;

const std::shared_ptr<const BufferPageMapping>& buffer_page_mapping;
const CoreCoord core;

public:
EnqueueWriteShardedBufferCommand(
uint32_t command_queue_id,
IDevice* device,
NOC noc_index,
const Buffer& buffer,
const void* src,
SystemMemoryManager& manager,
bool issue_wait,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
uint32_t bank_base_address,
const std::shared_ptr<const BufferPageMapping>& buffer_page_mapping,
const CoreCoord& core,
uint32_t padded_page_size,
uint32_t dst_page_index = 0,
std::optional<uint32_t> pages_to_write = std::nullopt) :
EnqueueWriteBufferCommand(
command_queue_id,
device,
noc_index,
buffer,
src,
manager,
issue_wait,
expected_num_workers_completed,
sub_device_ids,
bank_base_address,
padded_page_size,
dst_page_index,
pages_to_write),
buffer_page_mapping(buffer_page_mapping),
core(core) {}
};

class EnqueueProgramCommand : public Command {
private:
uint32_t command_queue_id;
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(IMPL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/device/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device/device_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/buffers/buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/buffers/dispatch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/buffers/circular_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/buffers/circular_buffer_types.cpp
${CMAKE_CURRENT_SOURCE_DIR}/buffers/global_circular_buffer.cpp
Expand Down
Loading

0 comments on commit 0763212

Please sign in to comment.