Skip to content

Commit

Permalink
pass impl
Browse files Browse the repository at this point in the history
  • Loading branch information
aobolensk committed Jan 29, 2025
1 parent 0b72893 commit 20f93d4
Showing 1 changed file with 37 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "cpu_shape.h"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/port_descriptor.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
Expand All @@ -16,6 +17,7 @@
#include "snippets/op/buffer.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/tpp/x64/op/modifiers.hpp"
Expand All @@ -26,31 +28,51 @@ namespace intel_cpu {

bool pass::BuildBrgemm::run(const snippets::lowered::LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::AdjustBrgemmCopyBLoopPorts")
fprintf(stderr, "BuildBrgemm::run\n");
bool modified = false;
for (const auto& expr : linear_ir) {
fprintf(stderr, "expr: %s\n", expr->get_node()->get_friendly_name().c_str());
const auto brgemm_node = ov::as_type_ptr<GemmCPU>(expr->get_node());
if (!brgemm_node || brgemm_node->is_dynamic()) {
const auto gemm_node = ov::as_type_ptr<GemmCPU>(expr->get_node());
if (!gemm_node || gemm_node->is_dynamic()) {
continue;
}
const auto& loop_manager = linear_ir.get_loop_manager();
OPENVINO_ASSERT(loop_manager, "GemmCPU node should have a loop manager.");

const auto loop_ids = expr->get_loop_ids();
if (!loop_ids.empty()) {
// Get innermost loop info
// auto loop_expr = loop_manager->get_loop_bounds(linear_ir, loop_ids.back()).first;
// fprintf(stderr, "Loop bounds: %s\n", loop_expr->get()->get_node()->get_friendly_name().c_str());
// const auto& inner_loop_info = loop_manager->get_loop_info<snippets::lowered::UnifiedLoopInfo>(loop_ids.front());
// fprintf(stderr, "work_amount: %ld\n", inner_loop_info->get_work_amount());
// fprintf(stderr, "increment: %ld\n", inner_loop_info->get_increment());
// auto iter_count = inner_loop_info->get_work_amount() / inner_loop_info->get_increment();
// fprintf(stderr, "iter_count: %ld\n", iter_count);
// const auto& inputs = expr->get_node()->inputs();
// fprintf(stderr, "Number of inputs: %lu\n", inputs.size());
if (loop_ids.empty()) {
continue;
}

const auto& gemm_in0_desc = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(gemm_node->input(0));
const auto& gemm_in1_desc = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(gemm_node->input(1));
const auto& gemm_out_desc = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(gemm_node->output(0));

// Get innermost loop info
auto loop_expr = loop_manager->get_loop_bounds(linear_ir, loop_ids.back()).first;
const auto& inner_loop_info = loop_manager->get_loop_info<snippets::lowered::UnifiedLoopInfo>(loop_ids.front());
auto iter_count = inner_loop_info->get_work_amount() / inner_loop_info->get_increment();
auto brgemm_node = std::make_shared<BrgemmCPU>(gemm_node->input_value(0),
gemm_node->input_value(1),
iter_count,
gemm_node->get_type(),
gemm_node->get_offset_a(),
gemm_node->get_offset_b(),
gemm_node->get_offset_c(),
gemm_in0_desc->get_layout(),
gemm_in1_desc->get_layout(),
gemm_out_desc->get_layout());
// TODO: replace node

// Transfer ports
snippets::lowered::PortDescriptorUtils::set_port_descriptor(gemm_node->input(0), gemm_in0_desc->get_subtensor(), gemm_in0_desc->get_layout());
snippets::lowered::PortDescriptorUtils::set_port_descriptor(gemm_node->input(1), gemm_in1_desc->get_subtensor(), gemm_in1_desc->get_layout());
snippets::lowered::PortDescriptorUtils::set_port_descriptor(gemm_node->output(0), gemm_out_desc->get_subtensor(), gemm_out_desc->get_layout());

// need to run validate_and_infer_types manually: either input shapes were updated or
// output Layout was updated (out shape will be updated in validate_and_infer_types())
gemm_node->validate_and_infer_types();
brgemm_node->validate_and_infer_types();

const auto& inputs = expr->get_node()->inputs();
}

return modified;
Expand Down

0 comments on commit 20f93d4

Please sign in to comment.