Skip to content

Commit

Permalink
#11208: Make v0::Program opaque(ish)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickroberts committed Oct 23, 2024
1 parent 5205204 commit c96029f
Show file tree
Hide file tree
Showing 7 changed files with 497 additions and 243 deletions.
24 changes: 12 additions & 12 deletions tt_metal/detail/reports/compilation_reporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ std::string kernel_attributes_str(std::shared_ptr<Kernel> kernel) {
return attr_str;
}

void CompilationReporter::add_kernel_compile_stats(const Program &program, std::shared_ptr<Kernel> kernel, bool cache_hit, size_t kernel_hash) {
void CompilationReporter::add_kernel_compile_stats(uint64_t program_id, std::shared_ptr<Kernel> kernel, bool cache_hit, size_t kernel_hash) {
std::unique_lock<std::mutex> lock(mutex_);

if (cache_hit) {
this->program_id_to_cache_hit_counter_[program.get_id()].hits++;
this->program_id_to_cache_hit_counter_[program_id].hits++;
} else {
this->program_id_to_cache_hit_counter_[program.get_id()].misses++;
this->program_id_to_cache_hit_counter_[program_id].misses++;
}
std::string kernel_stats = "," + kernel->name() + ",";
std::string cache_status = cache_hit ? "cache hit" : "cache miss";
Expand All @@ -99,22 +99,22 @@ void CompilationReporter::add_kernel_compile_stats(const Program &program, std::
}
index++;
}
this->program_id_to_kernel_stats_[program.get_id()].push_back(kernel_stats);
this->program_id_to_kernel_stats_[program_id].push_back(kernel_stats);
}

void CompilationReporter::flush_program_entry(const Program &program, bool persistent_compilation_cache_enabled) {
void CompilationReporter::flush_program_entry(uint64_t program_id, size_t num_kernels, std::function<std::shared_ptr<Kernel>(size_t)> get_kernel, bool persistent_compilation_cache_enabled) {
std::unique_lock<std::mutex> lock(mutex_);
auto num_cache_misses = this->program_id_to_cache_hit_counter_.at(program.get_id()).misses;
auto num_cache_hits = this->program_id_to_cache_hit_counter_.at(program.get_id()).hits;
auto num_cache_misses = this->program_id_to_cache_hit_counter_.at(program_id).misses;
auto num_cache_hits = this->program_id_to_cache_hit_counter_.at(program_id).hits;
if (this->total_num_compile_programs_ == 0) {
this->init_reports();
}

auto get_num_compute_and_data_movement_kernels = [&]() {
uint32_t num_compute = 0;
uint32_t num_data_movement = 0;
for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) {
const auto kernel = detail::GetKernel(program, kernel_id);
for (size_t kernel_id = 0; kernel_id < num_kernels; kernel_id++) {
const auto kernel = get_kernel(kernel_id);
if (kernel->processor() == tt::RISCV::BRISC or kernel->processor() == tt::RISCV::NCRISC) {
num_data_movement++;
} else {
Expand All @@ -126,14 +126,14 @@ void CompilationReporter::flush_program_entry(const Program &program, bool persi

auto [num_compute_kernels, num_data_movement_kernels] = get_num_compute_and_data_movement_kernels();

this->summary_report_ << program.get_id() << ", "
this->summary_report_ << program_id << ", "
<< num_compute_kernels << ", "
<< num_data_movement_kernels << ", "
<< (persistent_compilation_cache_enabled ? "Y" : "N") << ", "
<< num_cache_misses << ", "
<< num_cache_hits << "\n";

this->detailed_report_ << "Compiling Program: " << program.get_id() << "\n";
this->detailed_report_ << "Compiling Program: " << program_id << "\n";
this->detailed_report_ << "\n,Kernel Creation Report:\n";
this->detailed_report_ << ",,Number of Compute CreateKernel API calls: " << num_compute_kernels << "\n";
this->detailed_report_ << ",,Number of Datamovement CreateKernel API calls: " << num_data_movement_kernels << "\n";
Expand All @@ -144,7 +144,7 @@ void CompilationReporter::flush_program_entry(const Program &program, bool persi
this->detailed_report_ << ",,Total number of kernel compile cache hits: " << num_cache_hits << "\n";

this->detailed_report_ << "\n,Kernel File Name, Core Range, Cache Hit, Kernel Attributes, Hash\n";
auto kernel_stats_vec = this->program_id_to_kernel_stats_.at(program.get_id());
auto kernel_stats_vec = this->program_id_to_kernel_stats_.at(program_id);
for (const auto &kernel_stats : kernel_stats_vec) {
this->detailed_report_ << kernel_stats;
}
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/detail/reports/compilation_reporter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ class CompilationReporter {
CompilationReporter(const CompilationReporter&) = delete;
CompilationReporter(CompilationReporter&& other) noexcept = delete;

void add_kernel_compile_stats(const Program &program, std::shared_ptr<Kernel> kernel, bool cache_hit, size_t kernel_hash);
void add_kernel_compile_stats(uint64_t program_id, std::shared_ptr<Kernel> kernel, bool cache_hit, size_t kernel_hash);

void flush_program_entry(const Program &program, bool persistent_compilation_cache_enabled);
void flush_program_entry(uint64_t program_id, size_t num_kernels, std::function<std::shared_ptr<Kernel>(size_t)> get_kernel, bool persistent_compilation_cache_enabled);
static CompilationReporter& inst();
static void toggle (bool state);
static bool enabled ();
Expand Down
8 changes: 4 additions & 4 deletions tt_metal/detail/reports/memory_reporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ void populate_reports(const Device *device, std::ofstream &memory_usage_summary_
write_memory_usage(device, BufferType::L1, memory_usage_summary_report, detailed_memory_usage_report, l1_usage_summary_report);
}

void MemoryReporter::flush_program_memory_usage(const Program &program, const Device *device) {
void MemoryReporter::flush_program_memory_usage(uint64_t program_id, const Device *device) {
if (not this->program_memory_usage_summary_report_.is_open()) {
this->init_reports();
}

this->program_memory_usage_summary_report_ << program.get_id();
this->program_l1_usage_summary_report_ << program.get_id();
this->program_detailed_memory_usage_report_ << program.get_id();
this->program_memory_usage_summary_report_ << program_id;
this->program_l1_usage_summary_report_ << program_id;
this->program_detailed_memory_usage_report_ << program_id;

populate_reports(device, this->program_memory_usage_summary_report_, this->program_detailed_memory_usage_report_, this->program_l1_usage_summary_report_);
}
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/detail/reports/memory_reporter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class MemoryReporter {
MemoryReporter(const MemoryReporter&) = delete;
MemoryReporter(MemoryReporter&& other) noexcept = delete;

void flush_program_memory_usage(const Program &program, const Device *device);
void flush_program_memory_usage(uint64_t program_id, const Device *device);

void dump_memory_usage_state(const Device *device, std::string prefix="") const;

Expand Down
60 changes: 30 additions & 30 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,16 +730,17 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro
const uint32_t max_prefetch_command_size =
dispatch_constants::get(dispatch_core_type).max_prefetch_command_size();

auto &program_transfer_info = program.get_program_transfer_info();
// Multicast Semaphore Cmd
uint32_t num_multicast_semaphores = program.program_transfer_info.multicast_semaphores.size();
uint32_t num_multicast_semaphores = program_transfer_info.multicast_semaphores.size();
std::vector<std::vector<CQDispatchWritePackedMulticastSubCmd>> multicast_sem_sub_cmds(num_multicast_semaphores);
std::vector<std::vector<std::pair<const void*, uint32_t>>> multicast_sem_data(num_multicast_semaphores);
std::vector<std::vector<std::pair<uint32_t, uint32_t>>> multicast_sem_payload(num_multicast_semaphores);
std::vector<std::pair<uint32_t, uint32_t>> multicast_sem_dst_size;
multicast_sem_dst_size.reserve(num_multicast_semaphores);
if (num_multicast_semaphores > 0) {
uint32_t i = 0;
for (const auto& [dst, transfer_info_vec] : program.program_transfer_info.multicast_semaphores) {
for (const auto& [dst, transfer_info_vec] : program_transfer_info.multicast_semaphores) {
// TODO: loop over things inside transfer_info[i]
uint32_t write_packed_len = transfer_info_vec[0].data.size();
multicast_sem_dst_size.emplace_back(std::make_pair(dst, write_packed_len * sizeof(uint32_t)));
Expand Down Expand Up @@ -768,15 +769,15 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro
}

// Unicast Semaphore Cmd
uint32_t num_unicast_semaphores = program.program_transfer_info.unicast_semaphores.size();
uint32_t num_unicast_semaphores = program_transfer_info.unicast_semaphores.size();
std::vector<std::vector<CQDispatchWritePackedUnicastSubCmd>> unicast_sem_sub_cmds(num_unicast_semaphores);
std::vector<std::vector<std::pair<const void*, uint32_t>>> unicast_sem_data(num_unicast_semaphores);
std::vector<std::vector<std::pair<uint32_t, uint32_t>>> unicast_sem_payload(num_unicast_semaphores);
std::vector<std::pair<uint32_t, uint32_t>> unicast_sem_dst_size;
unicast_sem_dst_size.reserve(num_unicast_semaphores);
if (num_unicast_semaphores > 0) {
uint32_t i = 0;
for (const auto& [dst, transfer_info_vec] : program.program_transfer_info.unicast_semaphores) {
for (const auto& [dst, transfer_info_vec] : program_transfer_info.unicast_semaphores) {
// TODO: loop over things inside transfer_info[i]
uint32_t write_packed_len = transfer_info_vec[0].data.size();
unicast_sem_dst_size.emplace_back(std::make_pair(dst, write_packed_len * sizeof(uint32_t)));
Expand Down Expand Up @@ -876,7 +877,8 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro
const uint32_t max_length_per_sub_cmd = dispatch_constants::get(this->dispatch_core_type).scratch_db_size() / 2;
const uint32_t max_paged_length_per_sub_cmd =
max_length_per_sub_cmd / HostMemDeviceCommand::PROGRAM_PAGE_SIZE * HostMemDeviceCommand::PROGRAM_PAGE_SIZE;
for (const auto& [cores, num_mcast_dests, kg_transfer_info] : program.program_transfer_info.kernel_bins) {
const auto &kernels_buffer = program.get_kernels_buffer();
for (const auto& [cores, num_mcast_dests, kg_transfer_info] : program_transfer_info.kernel_bins) {
bool write_linear;
uint32_t noc_encoding;
std::visit(
Expand Down Expand Up @@ -913,26 +915,26 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro

uint32_t base_address, page_offset;
if (kg_transfer_info.page_offsets[kernel_idx] > CQ_PREFETCH_RELAY_PAGED_START_PAGE_MASK) {
const uint32_t num_banks = this->device->num_banks(this->program.kernels_buffer->buffer_type());
const uint32_t num_banks = this->device->num_banks(kernels_buffer->buffer_type());
page_offset = kg_transfer_info.page_offsets[kernel_idx] % num_banks;
uint32_t num_full_pages_written_per_bank =
kg_transfer_info.page_offsets[kernel_idx] / num_banks;
base_address = this->program.kernels_buffer->address() +
num_full_pages_written_per_bank * this->program.kernels_buffer->page_size();
base_address = kernels_buffer->address() +
num_full_pages_written_per_bank * kernels_buffer->page_size();
} else {
base_address = this->program.kernels_buffer->address();
base_address = kernels_buffer->address();
page_offset = kg_transfer_info.page_offsets[kernel_idx];
}

kernel_bins_unicast_cmds.back().add_prefetch_relay_paged(
true, // is_dram
page_offset,
base_address,
this->program.kernels_buffer->page_size(),
relayed_bytes / this->program.kernels_buffer->page_size(),
kernels_buffer->page_size(),
relayed_bytes / kernels_buffer->page_size(),
length_adjust);
} else {
uint32_t base_address = this->program.kernels_buffer->address();
uint32_t base_address = kernels_buffer->address();
uint32_t page_offset = kg_transfer_info.page_offsets[kernel_idx];

// TODO: pack all these writes into 1 linear write
Expand Down Expand Up @@ -1070,7 +1072,7 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro
}
// if dispatch_s is enabled have dispatch_d send a semaphore update to dispatch_s (this will include a write barrier on dispatch_d if program is active)
// if not, check if the program is active on workers. If active, have dispatch_d issue a write barrier
cmd_sequence_sizeB += (this->device->dispatch_s_enabled() || program.program_transfer_info.num_active_cores > 0) * CQ_PREFETCH_CMD_BARE_MIN_SIZE;
cmd_sequence_sizeB += (this->device->dispatch_s_enabled() || program_transfer_info.num_active_cores > 0) * CQ_PREFETCH_CMD_BARE_MIN_SIZE;

// either dispatch_s or dispatch_d will send the go signal (go_signal_mcast command)
cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE;
Expand Down Expand Up @@ -1257,11 +1259,11 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro
DispatcherSelect dispatcher_for_go_signal = DispatcherSelect::DISPATCH_MASTER;
if (this->device->dispatch_s_enabled()) {
// dispatch_d signals dispatch_s to send the go signal, use a barrier if there are cores active
device_command_sequence.add_notify_dispatch_s_go_signal_cmd(program.program_transfer_info.num_active_cores > 0);
device_command_sequence.add_notify_dispatch_s_go_signal_cmd(program_transfer_info.num_active_cores > 0);
dispatcher_for_go_signal = DispatcherSelect::DISPATCH_SLAVE;
} else {
// Wait Noc Write Barrier, wait for binaries/configs and launch_msg to be written to worker cores
if (program.program_transfer_info.num_active_cores > 0) {
if (program_transfer_info.num_active_cores > 0) {
device_command_sequence.add_dispatch_wait(true, this->dispatch_message_addr, 0, 0, false, false);
}
}
Expand Down Expand Up @@ -1464,7 +1466,7 @@ void EnqueueProgramCommand::write_program_command_sequence(const ProgramCommandS
void EnqueueProgramCommand::process() {

const std::pair<ConfigBufferSync, std::vector<ConfigBufferEntry>&> reservation =
this->manager.get_config_buffer_mgr().reserve(program.program_config_sizes_);
this->manager.get_config_buffer_mgr().reserve(program.get_program_config_sizes());
bool stall_first = reservation.first.need_sync;
// Note: since present implementation always stalls, we always free up to "now"
this->manager.get_config_buffer_mgr().free(reservation.first.sync_count);
Expand All @@ -1486,8 +1488,8 @@ void EnqueueProgramCommand::process() {
// If cache has a program entry but the program is not finalized, then the cache is stale
// Currently this is mapped by device, but will be mapped by multiple values in the future
uint64_t command_hash = this->device->id();
auto cached_cmd_iter = this->program.cached_program_command_sequences_.find(command_hash);
bool is_cached = program.is_cached() && cached_cmd_iter != this->program.cached_program_command_sequences_.end();
auto cached_cmd_iter = this->program.get_cached_program_command_sequences().find(command_hash);
bool is_cached = program.is_cached() && cached_cmd_iter != this->program.get_cached_program_command_sequences().end();

// Calculate all commands size and determine how many fetch q entries to use
// Preamble, some waits and stalls
Expand All @@ -1507,7 +1509,7 @@ void EnqueueProgramCommand::process() {
this->assemble_device_commands(program_command_sequence, kernel_config_addrs);
this->write_program_command_sequence(program_command_sequence, stall_first);
this->assemble_stall_commands(program_command_sequence, false);
this->program.cached_program_command_sequences_.insert({command_hash, std::move(program_command_sequence)});
this->program.get_cached_program_command_sequences().insert({command_hash, std::move(program_command_sequence)});
program.set_cached();
} else {
static constexpr uint32_t wait_count_offset = (sizeof(CQPrefetchCmd) + offsetof(CQDispatchCmd, wait.count));
Expand Down Expand Up @@ -2233,21 +2235,20 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) {
if (not program.is_finalized()) {
program.finalize(device);
TT_FATAL(!this->manager.get_bypass_mode(), "Tracing should only be used when programs have been cached");
if (program.kernels_buffer != nullptr) {
if (const auto &kernels_buffer = program.get_kernels_buffer()) {
this->enqueue_write_buffer(
*program.kernels_buffer, program.program_transfer_info.binary_data.data(), false);
*kernels_buffer, program.get_program_transfer_info().binary_data.data(), false);
}
}

#ifdef DEBUG
if (tt::llrt::OptionsG.get_validate_kernel_binaries()) {
TT_FATAL(!this->manager.get_bypass_mode(), "Tracing cannot be used while validating program binaries");
if (program.kernels_buffer != nullptr) {
const auto& buffer = program.kernels_buffer;
if (const auto &buffer = program.get_kernels_buffer()) {
std::vector<uint32_t> read_data(buffer->page_size() * buffer->num_pages() / sizeof(uint32_t));
this->enqueue_read_buffer(*program.kernels_buffer, read_data.data(), true);
this->enqueue_read_buffer(*buffer, read_data.data(), true);
TT_FATAL(
program.program_transfer_info.binary_data == read_data,
program.get_program_transfer_info().binary_data == read_data,
"Binary for program to be executed is corrupted. Another program likely corrupted this binary");
}
}
Expand Down Expand Up @@ -2297,12 +2298,11 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) {
#ifdef DEBUG
if (tt::llrt::OptionsG.get_validate_kernel_binaries()) {
TT_FATAL(!this->manager.get_bypass_mode(), "Tracing cannot be used while validating program binaries");
if (program.kernels_buffer != nullptr) {
const auto& buffer = program.kernels_buffer;
if (const auto& buffer = program.get_kernels_buffer()) {
std::vector<uint32_t> read_data(buffer->page_size() * buffer->num_pages() / sizeof(uint32_t));
this->enqueue_read_buffer(*program.kernels_buffer, read_data.data(), true);
this->enqueue_read_buffer(*buffer, read_data.data(), true);
TT_FATAL(
program.program_transfer_info.binary_data == read_data,
program.get_program_transfer_info().binary_data == read_data,
"Binary for program that executed is corrupted. This program likely corrupted its own binary.");
}
}
Expand All @@ -2311,7 +2311,7 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) {
log_trace(
tt::LogMetal,
"Created EnqueueProgramCommand (active_cores: {} bypass_mode: {} expected_workers_completed: {})",
program.program_transfer_info.num_active_cores,
program.get_program_transfer_info().num_active_cores,
this->manager.get_bypass_mode(),
expected_workers_completed);
}
Expand Down
Loading

0 comments on commit c96029f

Please sign in to comment.