diff --git a/src/common/snippets/include/snippets/lowered/loop_manager.hpp b/src/common/snippets/include/snippets/lowered/loop_manager.hpp index 387f3634cdbda1..d2532c7668c30d 100644 --- a/src/common/snippets/include/snippets/lowered/loop_manager.hpp +++ b/src/common/snippets/include/snippets/lowered/loop_manager.hpp @@ -180,7 +180,9 @@ class LoopManager { template::value || std::is_same::value), bool>::type> void replace_loop_port(size_t loop_id, const T& actual_port, const std::vector& target_ports) { const auto& loop_info = get_loop_info(loop_id); + fprintf(stderr, "replace_loop_port 1: %zu\n", loop_info->get_input_ports().size()); loop_info->replace_with_new_ports(actual_port, target_ports); + fprintf(stderr, "replace_loop_port 2: %zu\n", loop_info->get_input_ports().size()); } /** * @brief Replace Loop ports for several Unified Loops with new ports. diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index cff3bcbe927d04..7041fbe0ef22de 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -367,6 +367,7 @@ LinearIR::exprIt LinearIR::replace_with_node(const std::vector& o "Failed to replace node: node output port count is not equal to output count of last old expression"); std::vector new_inputs(new_node->get_input_size()); + fprintf(stderr, "new_node->get_input_size(): %zu\n", new_node->get_input_size()); for (size_t i = 0; i < new_node->get_input_size(); ++i) { const auto& source = new_node->get_input_source_output(i); new_inputs[i] = get_expr_by_node(source.get_node_shared_ptr())->get_output_port_connector(source.get_index()); @@ -395,6 +396,7 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector& o "Failed to replace expressions: new expr output port count is not equal to output count of last old expression"); const auto& new_inputs = new_expr->get_input_port_connectors(); + fprintf(stderr, "new_expr->get_inputs_count() 2s: %zu\n", new_expr->get_input_count()); auto is_old_expr = [&old_exprs](const ExpressionPtr& expr) { return std::find(old_exprs.cbegin(), old_exprs.cend(), expr) != old_exprs.cend(); @@ -424,20 +426,36 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector& o } } } + fprintf(stderr, "new_expr->get_inputs_count() 3s: %zu\n", new_expr->get_input_count()); update_consumers_and_regs(new_expr, consumers); + fprintf(stderr, "new_expr->get_inputs_count() 4s: %zu\n", new_expr->get_input_count()); const auto new_expr_it = insert(place, new_expr); const auto& loop_ids = new_expr_it->get()->get_loop_ids(); const auto input_ports = new_expr_it->get()->get_input_ports(); + fprintf(stderr, "b input_ports: %zu\n", input_ports.size()); const auto output_ports = new_expr_it->get()->get_output_ports(); + fprintf(stderr, "b output_ports: %zu\n", output_ports.size()); + const auto& inner_loop_info = m_loop_manager->get_loop_info(loop_ids.front()); + fprintf(stderr, "=== loop id %zu (inputs count: %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size()); for (const auto& old_expr : old_exprs) { - for (size_t i = 0; i < old_expr->get_input_count(); ++i) - m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_input_port(i), input_ports); - for (size_t i = 0; i < old_expr->get_input_count(); ++i) - m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_output_port(i), output_ports); + fprintf(stderr, "=1a loop id %zu (inputs count: %zu %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size(), inner_loop_info->get_input_count()); + for (size_t i = 0; i < old_expr->get_input_count(); ++i) { + printf(" ty: %zu\n", old_expr->get_input_port(i).get_type() == ExpressionPort::Input); + m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_input_port(i), {new_expr_it->get()->get_input_port(i)}); + } + fprintf(stderr, "=1b loop id %zu (inputs count: %zu %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size(), inner_loop_info->get_input_count()); + for (size_t i = 0; i < old_expr->get_output_count(); ++i) + m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_output_port(i), {new_expr_it->get()->get_output_port(i)}); + fprintf(stderr, "=1c loop id %zu (inputs count: %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size()); erase(find(old_expr)); + fprintf(stderr, "=1d loop id %zu (inputs count: %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size()); } + fprintf(stderr, "==2 loop id %zu (inputs count: %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size()); + fprintf(stderr, "a input_ports: %zu\n", input_ports.size()); + fprintf(stderr, "a output_ports: %zu\n", output_ports.size()); + fprintf(stderr, "new_expr->get_inputs_count() 5s: %zu\n", new_expr->get_input_count()); return new_expr_it; } diff --git a/src/common/snippets/src/lowered/loop_info.cpp b/src/common/snippets/src/lowered/loop_info.cpp index 185bf025baefd0..60a1f2ce50f724 100644 --- a/src/common/snippets/src/lowered/loop_info.cpp +++ b/src/common/snippets/src/lowered/loop_info.cpp @@ -136,20 +136,29 @@ void validate_new_target_ports(const std::vector& target_ports, } // namespace void LoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector& target_ports) { + fprintf(stderr, " -> LoopInfo::replace_with_new_ports 0\n"); const auto& actual_port_type = actual_port.get_expr_port()->get_type(); validate_new_target_ports(target_ports, actual_port_type); auto& ports = actual_port_type == ExpressionPort::Input ? m_input_ports : m_output_ports; auto port_it = find_loop_port(actual_port); + fprintf(stderr, "portss 1: %zu\n", ports.size()); port_it = ports.erase(port_it); + fprintf(stderr, "portss 2: %zu\n", ports.size()); ports.insert(port_it, target_ports.cbegin(), target_ports.cend()); + fprintf(stderr, "portss 1: %zu\n", ports.size()); } void LoopInfo::replace_with_new_ports(const ExpressionPort& actual_port, const std::vector& target_ports) { + fprintf(stderr, " -> LoopInfo::replace_with_new_ports 1\n"); const auto& actual_port_type = actual_port.get_type(); + fprintf(stderr, " =: %d\n", get_input_ports().size()); validate_new_target_ports(target_ports, actual_port_type); + fprintf(stderr, " =2: %d\n", get_input_ports().size()); auto& ports = actual_port_type == ExpressionPort::Input ? m_input_ports : m_output_ports; + fprintf(stderr, "actual_port expr_port type: %d\n", actual_port.get_type()); + fprintf(stderr, "actual_port expr_port index: %zu\n", actual_port.get_index()); auto port_it = find_loop_port(actual_port); // In some cases actual ExpressionPort may not be LoopPort. We shouldn't throw exception here since ExpressionPort is not strong condition as LoopPort // For example, not all inner loop ports are ports of outer loops @@ -166,6 +175,7 @@ void LoopInfo::replace_with_new_ports(const ExpressionPort& actual_port, const s }); port_it = ports.erase(port_it); ports.insert(port_it, target_loop_ports.cbegin(), target_loop_ports.cend()); + fprintf(stderr, " =5: %d\n", get_input_ports().size()); } std::vector LoopInfo::clone_loop_ports(const ExpressionMap& expr_map, const std::vector& loop_ports) { @@ -362,6 +372,7 @@ void UnifiedLoopInfo::replace_with_cloned_descs(size_t actual_port_idx, size_t n } void UnifiedLoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector& target_ports) { + fprintf(stderr, " -> UnifiedLoopInfo::replace_with_new_ports 2\n"); const auto& actual_port_type = actual_port.get_expr_port()->get_type(); validate_new_target_ports(target_ports, actual_port_type); @@ -380,19 +391,25 @@ void UnifiedLoopInfo::replace_with_new_ports(const LoopPort& actual_port, const } void UnifiedLoopInfo::replace_with_new_ports(const ExpressionPort& actual_port, const std::vector& target_ports) { + fprintf(stderr, " -> UnifiedLoopInfo::replace_with_new_ports 3\n"); + fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 0: %zu\n", get_input_ports_info().size()); const auto& actual_port_type = actual_port.get_type(); validate_new_target_ports(target_ports, actual_port_type); const auto is_input = actual_port.get_type() == ExpressionPort::Input; auto& ports = is_input ? m_input_ports : m_output_ports; auto port_it = find_loop_port(actual_port); + fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 1: %zu\n", get_input_ports_info().size()); // In some cases actual ExpressionPort may not be LoopPort. We shouldn't throw exception here since ExpressionPort is not strong condition as LoopPort // For example, not all inner loop ports are ports of outer loops if (port_it == ports.end()) return; + fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 2: %zu\n", get_input_ports_info().size()); replace_with_cloned_descs(std::distance(ports.begin(), port_it), target_ports.size(), is_input); + // fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 3: %zu\n", get_input_ports_info().size()); LoopInfo::replace_with_new_ports(actual_port, target_ports); + fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 4: %zu\n", get_input_ports_info().size()); // Sort ports sort_ports(); @@ -536,12 +553,14 @@ void ExpandedLoopInfo::update_finalization_offsets(const std::vector& n } void ExpandedLoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector& target_ports) { + fprintf(stderr, " -> ExpandedLoopInfo::replace_with_new_ports 4\n"); OPENVINO_ASSERT(target_ports.size() == 1, "ExpandedLoopInfo supports replace one port with only one port!"); LoopInfo::replace_with_new_ports(actual_port, target_ports); sort_ports(); } void ExpandedLoopInfo::replace_with_new_ports(const ExpressionPort& actual_port, const std::vector& target_ports) { + fprintf(stderr, " -> ExpandedLoopInfo::replace_with_new_ports 5\n"); OPENVINO_ASSERT(target_ports.size() == 1, "ExpandedLoopInfo supports replace one port with only one port!"); LoopInfo::replace_with_new_ports(actual_port, target_ports); sort_ports(); diff --git a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp index 663b4d1fe05b84..ab224f23260335 100644 --- a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp +++ b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp @@ -95,7 +95,7 @@ size_t BrgemmBlockingBase::get_default_m_blk(size_t m) const { return 32; } size_t BrgemmBlockingBase::get_default_n_blk(size_t n) const { - return 64; + return 2048; } size_t BrgemmBlockingBase::get_default_k_blk(size_t k) const { return !utils::is_dynamic_value(k) && k > 1024 ? 1024 : 512; diff --git a/src/common/snippets/src/lowered/pass/validate.cpp b/src/common/snippets/src/lowered/pass/validate.cpp index 5e6f31ae3f80ea..446ea1087802f4 100644 --- a/src/common/snippets/src/lowered/pass/validate.cpp +++ b/src/common/snippets/src/lowered/pass/validate.cpp @@ -113,6 +113,23 @@ void validate_loop_end(const ExpressionPtr& expr, const LinearIR& linear_ir) { const auto input_port_infos = loop_info->get_input_ports_info(); const auto output_port_infos = loop_info->get_output_ports_info(); + for (const auto& input_port_info : input_port_infos) { + int tmp; + try { + tmp = input_port_info.port.get_dim_idx(); + } catch (...) { + tmp = -1; + } + fprintf(stderr, "Input Port Info - is_incremented: %d, ptr_increment: %zu, finalization_offset: %zu, idx: %zu\n", + input_port_info.port.is_incremented(), + tmp, + input_port_info.desc.ptr_increment, + input_port_info.desc.finalization_offset); + } + fprintf(stderr, "Input Port Infos size: %zu\n", input_port_infos.size()); + fprintf(stderr, "Output Port Infos size: %zu\n", output_port_infos.size()); + fprintf(stderr, "LoopEnd input num: %zu\n", loop_end->get_input_num()); + fprintf(stderr, "LoopEnd output num: %zu\n", loop_end->get_output_num()); OPENVINO_ASSERT(input_port_infos.size() == loop_end->get_input_num() && output_port_infos.size() == loop_end->get_output_num(), "Incompatible LoopEnd and the corresponding LoopInfo"); diff --git a/src/common/snippets/src/lowered/pass/validate_unified_loops.cpp b/src/common/snippets/src/lowered/pass/validate_unified_loops.cpp index e127aaea0c11d3..c0cfdc2b679b54 100644 --- a/src/common/snippets/src/lowered/pass/validate_unified_loops.cpp +++ b/src/common/snippets/src/lowered/pass/validate_unified_loops.cpp @@ -67,17 +67,22 @@ void ValidateUnifiedLoops::validate_loop_infos(const LoopManagerPtr& loop_manage // Validate that iteration dimension is broadcastable std::set unique_dimensions; + fprintf(stderr, "Loop ID: %zu\n", pair.first); loop_info->iterate_through_ports([&unique_dimensions](const LoopPort& loop_port) { + fprintf(stderr, "loop_port.is_processed(): %d\n", loop_port.is_processed()); if (loop_port.is_processed()) { const auto is_input = loop_port.get_expr_port()->get_type() == ExpressionPort::Input; const auto planar_shape = is_input ? ov::snippets::utils::get_planar_vdims(*loop_port.get_expr_port()) : ov::snippets::utils::get_preordered_vdims(*loop_port.get_expr_port()); const auto& dim = *(planar_shape.rbegin() + loop_port.get_dim_idx()); // Since dim == 1 can be broadcasted to any value, it's not necessary to add it to unique dims - if (!utils::is_dynamic_value(dim) && dim != 1) + if (!utils::is_dynamic_value(dim) && dim != 1) { unique_dimensions.insert(dim); + fprintf(stderr, "Dimension: %zu\n", dim); + } } }); + OPENVINO_ASSERT(unique_dimensions.size() <= 1, "Loop ports have incompatible dimensions, by which the loop iterates"); } diff --git a/src/plugins/intel_cpu/src/.clang-tidy b/src/plugins/intel_cpu/src/.clang-tidy index c2c40baacdb90f..0d6ced7b2fcebb 100644 --- a/src/plugins/intel_cpu/src/.clang-tidy +++ b/src/plugins/intel_cpu/src/.clang-tidy @@ -38,6 +38,7 @@ Checks: > performance-*, google-*, modernize-pass-by-value, + misc-include-cleaner, cppcoreguidelines-prefer-member-initializer, -bugprone-easily-swappable-parameters, -bugprone-fold-init-type, diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index 9dd382fed8a86b..1262cac011f17d 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -31,6 +31,7 @@ #include "transformations/cpu_opset/x64/op/mha.hpp" #include "transformations/cpu_opset/x64/op/qkv_proj.hpp" #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" +#include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/gemm_cpu.hpp" #include "transformations/snippets/x64/op/load_convert.hpp" #include "transformations/snippets/x64/op/perf_count_rdtsc.hpp" @@ -101,7 +102,8 @@ class TypeRelaxedExtension : public ov::OpExtension> { OP_EXTENSION_X64(ov::intel_cpu::LoadConvertTruncation) \ OP_EXTENSION_X64(ov::intel_cpu::StoreConvertSaturation) \ OP_EXTENSION_X64(ov::intel_cpu::StoreConvertTruncation) \ - OP_EXTENSION_X64(ov::intel_cpu::GemmCPU) \ + OP_EXTENSION_X64(ov::intel_cpu::GemmCPU) \ + OP_EXTENSION_X64(ov::intel_cpu::BrgemmCPU) \ OP_EXTENSION_X64(ov::intel_cpu::BrgemmCopyB) #define TYPE_RELAXED_EXTENSIONS \ diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 58a0175234b71b..f06f04df0a4ee3 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -564,7 +564,7 @@ Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const { #endif SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, - ov::snippets::lowered::pass::AllocateBuffers, + ov::intel_cpu::pass::GemmCPUBlocking, ov::intel_cpu::pass::BuildBrgemm); #undef SNIPPETS_REGISTER_PASS_RELATIVE diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/build_brgemm.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/build_brgemm.cpp index 5f3d76052ac8ee..ddb459ac08c55d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/build_brgemm.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/build_brgemm.cpp @@ -26,12 +26,21 @@ namespace ov { namespace intel_cpu { -bool pass::BuildBrgemm::run(const snippets::lowered::LinearIR& linear_ir) { - OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::AdjustBrgemmCopyBLoopPorts") +bool pass::BuildBrgemm::run(snippets::lowered::LinearIR& linear_ir, + snippets::lowered::LinearIR::constExprIt begin, + snippets::lowered::LinearIR::constExprIt end) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BuildBrgemm") bool modified = false; - for (const auto& expr : linear_ir) { + + fprintf(stderr, "Dumping Linear IR :\n"); + for (auto it = begin; it != end; ++it) { + const auto& expr = *it; + fprintf(stderr, "%s\n", expr->get_node()->get_friendly_name().c_str()); + } + for (auto expr_it = begin; expr_it != end; expr_it++) { + const auto& expr = *expr_it; const auto gemm_node = ov::as_type_ptr(expr->get_node()); - if (!gemm_node || gemm_node->is_dynamic()) { + if (!gemm_node || gemm_node->is_dynamic() || with_repacking(gemm_node->get_type())) { continue; } const auto& loop_manager = linear_ir.get_loop_manager(); @@ -42,37 +51,54 @@ bool pass::BuildBrgemm::run(const snippets::lowered::LinearIR& linear_ir) { continue; } - const auto& gemm_in0_desc = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(gemm_node->input(0)); - const auto& gemm_in1_desc = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(gemm_node->input(1)); - const auto& gemm_out_desc = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(gemm_node->output(0)); + // TODO: get input port descriptor + const auto& gemm_in0_desc = expr->get_input_port_descriptor(0); + const auto& gemm_in1_desc = expr->get_input_port_descriptor(1); + const auto& gemm_out_desc = expr->get_output_port_descriptor(0); + + // const auto& interm_connector = expr->get_input_port_connector(0); + // const auto gemm_expr = interm_connector->get_source().get_expr(); // Get innermost loop info - auto loop_expr = loop_manager->get_loop_bounds(linear_ir, loop_ids.back()).first; + // TODO: check K-loop const auto& inner_loop_info = loop_manager->get_loop_info(loop_ids.front()); + fprintf(stderr, "inner_loop_info for loop id %zu (inputs count: %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size()); + for (size_t i = 0; i < inner_loop_info->get_input_ports_info().size(); ++i) { + fprintf(stderr, "Input port %zu is_processed: %d\n", i, inner_loop_info->get_input_ports_info()[i].port.is_processed()); + } + // fprintf(stderr, "Output port 0 is_processed: %d\n", inner_loop_info->get_output_ports_info()[1].port.is_processed()); + if (inner_loop_info->get_work_amount() % inner_loop_info->get_increment() != 0) { + continue; + } auto iter_count = inner_loop_info->get_work_amount() / inner_loop_info->get_increment(); - auto brgemm_node = std::make_shared(gemm_node->input_value(0), - gemm_node->input_value(1), - iter_count, - gemm_node->get_type(), - gemm_node->get_offset_a(), - gemm_node->get_offset_b(), - gemm_node->get_offset_c(), - gemm_in0_desc->get_layout(), - gemm_in1_desc->get_layout(), - gemm_out_desc->get_layout()); - // TODO: replace node - - // Transfer ports - snippets::lowered::PortDescriptorUtils::set_port_descriptor(gemm_node->input(0), gemm_in0_desc->get_subtensor(), gemm_in0_desc->get_layout()); - snippets::lowered::PortDescriptorUtils::set_port_descriptor(gemm_node->input(1), gemm_in1_desc->get_subtensor(), gemm_in1_desc->get_layout()); - snippets::lowered::PortDescriptorUtils::set_port_descriptor(gemm_node->output(0), gemm_out_desc->get_subtensor(), gemm_out_desc->get_layout()); - - // need to run validate_and_infer_types manually: either input shapes were updated or - // output Layout was updated (out shape will be updated in validate_and_infer_types()) - gemm_node->validate_and_infer_types(); - brgemm_node->validate_and_infer_types(); + auto brgemm_node = + std::make_shared(expr->get_input_port_connector(0)->get_source().get_expr()->get_node(), + expr->get_input_port_connector(1)->get_source().get_expr()->get_node(), + iter_count, + gemm_node->get_type(), + gemm_node->get_offset_a(), + gemm_node->get_offset_b(), + gemm_node->get_offset_c(), + gemm_in0_desc->get_layout(), + gemm_in1_desc->get_layout(), + gemm_out_desc->get_layout()); + // Replace GemmCPU node with BrgemmCPU + auto live_regs = expr->get_live_regs(); + expr_it = linear_ir.replace_with_node({expr}, brgemm_node, expr->get_loop_ids(), linear_ir.find(expr)); + expr_it->get()->set_live_regs(std::move(live_regs)); + const auto loop_ids2 = (*expr_it)->get_loop_ids(); + const auto& inner_loop_info2 = loop_manager->get_loop_info(loop_ids2.front()); + fprintf(stderr, "inner_loop_info2 for loop id %zu (inputs count: %zu):\n", loop_ids2.front(), inner_loop_info2->get_input_ports_info().size()); + for (size_t i = 0; i < inner_loop_info2->get_input_ports_info().size(); ++i) { + fprintf(stderr, "Input port %zu is_processed: %d\n", i, inner_loop_info2->get_input_ports_info()[i].port.is_processed()); + } - const auto& inputs = expr->get_node()->inputs(); + modified |= true; + } + fprintf(stderr, "Dumping Linear IR :\n"); + for (auto it = begin; it != end; ++it) { + const auto& expr = *it; + fprintf(stderr, "%s\n", expr->get_node()->get_friendly_name().c_str()); } return modified; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/build_brgemm.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/build_brgemm.hpp index 1791012cc221d3..314c5163c33fea 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/build_brgemm.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/build_brgemm.hpp @@ -16,14 +16,13 @@ namespace pass { * @brief The pass explicitly insert LoadBegin and LoadEnd in Linear IR using UnifiedLoopInfo from Loop markup algorithm * @ingroup snippets */ -class BuildBrgemm : public snippets::lowered::pass::ConstPass { +class BuildBrgemm : public snippets::lowered::pass::RangedPass { public: - OPENVINO_RTTI("BuildBrgemm", "", ConstPass) + OPENVINO_RTTI("BuildBrgemm", "", snippets::lowered::pass::RangedPass) BuildBrgemm() = default; - bool run(const snippets::lowered::LinearIR& linear_ir) override; - const std::unordered_set& get_affected_loops() { - return m_affected_loops; - } + bool run(snippets::lowered::LinearIR& linear_ir, + snippets::lowered::LinearIR::constExprIt begin, + snippets::lowered::LinearIR::constExprIt end) override; private: std::unordered_set m_affected_loops; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp index e079bbb253cf93..63328c81dd472c 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp @@ -7,6 +7,7 @@ #include #include "op/brgemm_copy_b.hpp" +#include "op/brgemm_cpu.hpp" #include "op/gemm_cpu.hpp" #include "op/load_convert.hpp" #include "op/perf_count_rdtsc.hpp" @@ -71,6 +72,7 @@ const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::spec SHAPE_INFER_OP_SPECIFIC_EXTERNAL(ov::intel_cpu::tpp::op::ReduceSum, ReduceShapeInfer), #endif SHAPE_INFER_OP_SPECIFIC_EXTERNAL(ov::intel_cpu::GemmCPU, BrgemmShapeInfer), + SHAPE_INFER_OP_SPECIFIC_EXTERNAL(ov::intel_cpu::BrgemmCPU, BrgemmShapeInfer), SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCopyB), }; #undef SHAPE_INFER_OP_SPECIFIC