Skip to content

Commit

Permalink
state
Browse files Browse the repository at this point in the history
  • Loading branch information
aobolensk committed Jan 30, 2025
1 parent 20f93d4 commit 5a7b16b
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 44 deletions.
2 changes: 2 additions & 0 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ class LoopManager {
template<typename T, typename = typename std::enable_if<(std::is_same<T, ExpressionPort>::value || std::is_same<T, LoopPort>::value), bool>::type>
void replace_loop_port(size_t loop_id, const T& actual_port, const std::vector<T>& target_ports) {
const auto& loop_info = get_loop_info(loop_id);
fprintf(stderr, "replace_loop_port 1: %zu\n", loop_info->get_input_ports().size());
loop_info->replace_with_new_ports(actual_port, target_ports);
fprintf(stderr, "replace_loop_port 2: %zu\n", loop_info->get_input_ports().size());
}
/**
* @brief Replace Loop ports for several Unified Loops with new ports.
Expand Down
26 changes: 22 additions & 4 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ LinearIR::exprIt LinearIR::replace_with_node(const std::vector<ExpressionPtr>& o
"Failed to replace node: node output port count is not equal to output count of last old expression");

std::vector<PortConnectorPtr> new_inputs(new_node->get_input_size());
fprintf(stderr, "new_node->get_input_size(): %zu\n", new_node->get_input_size());
for (size_t i = 0; i < new_node->get_input_size(); ++i) {
const auto& source = new_node->get_input_source_output(i);
new_inputs[i] = get_expr_by_node(source.get_node_shared_ptr())->get_output_port_connector(source.get_index());
Expand Down Expand Up @@ -395,6 +396,7 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
"Failed to replace expressions: new expr output port count is not equal to output count of last old expression");

const auto& new_inputs = new_expr->get_input_port_connectors();
fprintf(stderr, "new_expr->get_inputs_count() 2s: %zu\n", new_expr->get_input_count());

auto is_old_expr = [&old_exprs](const ExpressionPtr& expr) {
return std::find(old_exprs.cbegin(), old_exprs.cend(), expr) != old_exprs.cend();
Expand Down Expand Up @@ -424,20 +426,36 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
}
}
}
fprintf(stderr, "new_expr->get_inputs_count() 3s: %zu\n", new_expr->get_input_count());

update_consumers_and_regs(new_expr, consumers);

fprintf(stderr, "new_expr->get_inputs_count() 4s: %zu\n", new_expr->get_input_count());
const auto new_expr_it = insert(place, new_expr);
const auto& loop_ids = new_expr_it->get()->get_loop_ids();
const auto input_ports = new_expr_it->get()->get_input_ports();
fprintf(stderr, "b input_ports: %zu\n", input_ports.size());
const auto output_ports = new_expr_it->get()->get_output_ports();
fprintf(stderr, "b output_ports: %zu\n", output_ports.size());
const auto& inner_loop_info = m_loop_manager->get_loop_info<snippets::lowered::UnifiedLoopInfo>(loop_ids.front());
fprintf(stderr, "=== loop id %zu (inputs count: %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size());
for (const auto& old_expr : old_exprs) {
for (size_t i = 0; i < old_expr->get_input_count(); ++i)
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_input_port(i), input_ports);
for (size_t i = 0; i < old_expr->get_input_count(); ++i)
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_output_port(i), output_ports);
fprintf(stderr, "=1a loop id %zu (inputs count: %zu %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size(), inner_loop_info->get_input_count());
for (size_t i = 0; i < old_expr->get_input_count(); ++i) {
printf(" ty: %zu\n", old_expr->get_input_port(i).get_type() == ExpressionPort::Input);
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_input_port(i), {new_expr_it->get()->get_input_port(i)});
}
fprintf(stderr, "=1b loop id %zu (inputs count: %zu %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size(), inner_loop_info->get_input_count());
for (size_t i = 0; i < old_expr->get_output_count(); ++i)
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_output_port(i), {new_expr_it->get()->get_output_port(i)});
fprintf(stderr, "=1c loop id %zu (inputs count: %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size());
erase(find(old_expr));
fprintf(stderr, "=1d loop id %zu (inputs count: %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size());
}
fprintf(stderr, "==2 loop id %zu (inputs count: %zu):\n", loop_ids.front(), inner_loop_info->get_input_ports_info().size());
fprintf(stderr, "a input_ports: %zu\n", input_ports.size());
fprintf(stderr, "a output_ports: %zu\n", output_ports.size());
fprintf(stderr, "new_expr->get_inputs_count() 5s: %zu\n", new_expr->get_input_count());
return new_expr_it;
}

Expand Down
19 changes: 19 additions & 0 deletions src/common/snippets/src/lowered/loop_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,29 @@ void validate_new_target_ports(const std::vector<ExpressionPort>& target_ports,
} // namespace

void LoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector<LoopPort>& target_ports) {
fprintf(stderr, " -> LoopInfo::replace_with_new_ports 0\n");
const auto& actual_port_type = actual_port.get_expr_port()->get_type();
validate_new_target_ports(target_ports, actual_port_type);

auto& ports = actual_port_type == ExpressionPort::Input ? m_input_ports : m_output_ports;
auto port_it = find_loop_port(actual_port);
fprintf(stderr, "portss 1: %zu\n", ports.size());
port_it = ports.erase(port_it);
fprintf(stderr, "portss 2: %zu\n", ports.size());
ports.insert(port_it, target_ports.cbegin(), target_ports.cend());
fprintf(stderr, "portss 1: %zu\n", ports.size());
}

void LoopInfo::replace_with_new_ports(const ExpressionPort& actual_port, const std::vector<ExpressionPort>& target_ports) {
fprintf(stderr, " -> LoopInfo::replace_with_new_ports 1\n");
const auto& actual_port_type = actual_port.get_type();
fprintf(stderr, " =: %d\n", get_input_ports().size());
validate_new_target_ports(target_ports, actual_port_type);
fprintf(stderr, " =2: %d\n", get_input_ports().size());

auto& ports = actual_port_type == ExpressionPort::Input ? m_input_ports : m_output_ports;
fprintf(stderr, "actual_port expr_port type: %d\n", actual_port.get_type());
fprintf(stderr, "actual_port expr_port index: %zu\n", actual_port.get_index());
auto port_it = find_loop_port(actual_port);
// In some cases actual ExpressionPort may not be LoopPort. We shouldn't throw exception here since ExpressionPort is not strong condition as LoopPort
// For example, not all inner loop ports are ports of outer loops
Expand All @@ -166,6 +175,7 @@ void LoopInfo::replace_with_new_ports(const ExpressionPort& actual_port, const s
});
port_it = ports.erase(port_it);
ports.insert(port_it, target_loop_ports.cbegin(), target_loop_ports.cend());
fprintf(stderr, " =5: %d\n", get_input_ports().size());
}

std::vector<LoopPort> LoopInfo::clone_loop_ports(const ExpressionMap& expr_map, const std::vector<LoopPort>& loop_ports) {
Expand Down Expand Up @@ -362,6 +372,7 @@ void UnifiedLoopInfo::replace_with_cloned_descs(size_t actual_port_idx, size_t n
}

void UnifiedLoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector<LoopPort>& target_ports) {
fprintf(stderr, " -> UnifiedLoopInfo::replace_with_new_ports 2\n");
const auto& actual_port_type = actual_port.get_expr_port()->get_type();
validate_new_target_ports(target_ports, actual_port_type);

Expand All @@ -380,19 +391,25 @@ void UnifiedLoopInfo::replace_with_new_ports(const LoopPort& actual_port, const
}

void UnifiedLoopInfo::replace_with_new_ports(const ExpressionPort& actual_port, const std::vector<ExpressionPort>& target_ports) {
fprintf(stderr, " -> UnifiedLoopInfo::replace_with_new_ports 3\n");
fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 0: %zu\n", get_input_ports_info().size());
const auto& actual_port_type = actual_port.get_type();
validate_new_target_ports(target_ports, actual_port_type);

const auto is_input = actual_port.get_type() == ExpressionPort::Input;
auto& ports = is_input ? m_input_ports : m_output_ports;
auto port_it = find_loop_port(actual_port);
fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 1: %zu\n", get_input_ports_info().size());
// In some cases actual ExpressionPort may not be LoopPort. We shouldn't throw exception here since ExpressionPort is not strong condition as LoopPort
// For example, not all inner loop ports are ports of outer loops
if (port_it == ports.end())
return;
fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 2: %zu\n", get_input_ports_info().size());

replace_with_cloned_descs(std::distance(ports.begin(), port_it), target_ports.size(), is_input);
// fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 3: %zu\n", get_input_ports_info().size());
LoopInfo::replace_with_new_ports(actual_port, target_ports);
fprintf(stderr, "UnifiedLoopInfo::replace_with_new_ports 4: %zu\n", get_input_ports_info().size());

// Sort ports
sort_ports();
Expand Down Expand Up @@ -536,12 +553,14 @@ void ExpandedLoopInfo::update_finalization_offsets(const std::vector<int64_t>& n
}

void ExpandedLoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector<LoopPort>& target_ports) {
fprintf(stderr, " -> ExpandedLoopInfo::replace_with_new_ports 4\n");
OPENVINO_ASSERT(target_ports.size() == 1, "ExpandedLoopInfo supports replace one port with only one port!");
LoopInfo::replace_with_new_ports(actual_port, target_ports);
sort_ports();
}

void ExpandedLoopInfo::replace_with_new_ports(const ExpressionPort& actual_port, const std::vector<ExpressionPort>& target_ports) {
fprintf(stderr, " -> ExpandedLoopInfo::replace_with_new_ports 5\n");
OPENVINO_ASSERT(target_ports.size() == 1, "ExpandedLoopInfo supports replace one port with only one port!");
LoopInfo::replace_with_new_ports(actual_port, target_ports);
sort_ports();
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/brgemm_blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ size_t BrgemmBlockingBase::get_default_m_blk(size_t m) const {
return 32;
}
size_t BrgemmBlockingBase::get_default_n_blk(size_t n) const {
return 64;
return 2048;
}
size_t BrgemmBlockingBase::get_default_k_blk(size_t k) const {
return !utils::is_dynamic_value(k) && k > 1024 ? 1024 : 512;
Expand Down
17 changes: 17 additions & 0 deletions src/common/snippets/src/lowered/pass/validate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,23 @@ void validate_loop_end(const ExpressionPtr& expr, const LinearIR& linear_ir) {

const auto input_port_infos = loop_info->get_input_ports_info();
const auto output_port_infos = loop_info->get_output_ports_info();
for (const auto& input_port_info : input_port_infos) {
int tmp;
try {
tmp = input_port_info.port.get_dim_idx();
} catch (...) {
tmp = -1;
}
fprintf(stderr, "Input Port Info - is_incremented: %d, ptr_increment: %zu, finalization_offset: %zu, idx: %zu\n",
input_port_info.port.is_incremented(),
tmp,
input_port_info.desc.ptr_increment,
input_port_info.desc.finalization_offset);
}
fprintf(stderr, "Input Port Infos size: %zu\n", input_port_infos.size());
fprintf(stderr, "Output Port Infos size: %zu\n", output_port_infos.size());
fprintf(stderr, "LoopEnd input num: %zu\n", loop_end->get_input_num());
fprintf(stderr, "LoopEnd output num: %zu\n", loop_end->get_output_num());
OPENVINO_ASSERT(input_port_infos.size() == loop_end->get_input_num() &&
output_port_infos.size() == loop_end->get_output_num(),
"Incompatible LoopEnd and the corresponding LoopInfo");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,22 @@ void ValidateUnifiedLoops::validate_loop_infos(const LoopManagerPtr& loop_manage

// Validate that iteration dimension is broadcastable
std::set<size_t> unique_dimensions;
fprintf(stderr, "Loop ID: %zu\n", pair.first);
loop_info->iterate_through_ports([&unique_dimensions](const LoopPort& loop_port) {
fprintf(stderr, "loop_port.is_processed(): %d\n", loop_port.is_processed());
if (loop_port.is_processed()) {
const auto is_input = loop_port.get_expr_port()->get_type() == ExpressionPort::Input;
const auto planar_shape = is_input ? ov::snippets::utils::get_planar_vdims(*loop_port.get_expr_port())
: ov::snippets::utils::get_preordered_vdims(*loop_port.get_expr_port());
const auto& dim = *(planar_shape.rbegin() + loop_port.get_dim_idx());
// Since dim == 1 can be broadcasted to any value, it's not necessary to add it to unique dims
if (!utils::is_dynamic_value(dim) && dim != 1)
if (!utils::is_dynamic_value(dim) && dim != 1) {
unique_dimensions.insert(dim);
fprintf(stderr, "Dimension: %zu\n", dim);
}
}
});

OPENVINO_ASSERT(unique_dimensions.size() <= 1,
"Loop ports have incompatible dimensions, by which the loop iterates");
}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/.clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Checks: >
performance-*,
google-*,
modernize-pass-by-value,
misc-include-cleaner,
cppcoreguidelines-prefer-member-initializer,
-bugprone-easily-swappable-parameters,
-bugprone-fold-init-type,
Expand Down
4 changes: 3 additions & 1 deletion src/plugins/intel_cpu/src/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "transformations/cpu_opset/x64/op/mha.hpp"
#include "transformations/cpu_opset/x64/op/qkv_proj.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
Expand Down Expand Up @@ -101,7 +102,8 @@ class TypeRelaxedExtension : public ov::OpExtension<ov::op::TypeRelaxed<Op>> {
OP_EXTENSION_X64(ov::intel_cpu::LoadConvertTruncation) \
OP_EXTENSION_X64(ov::intel_cpu::StoreConvertSaturation) \
OP_EXTENSION_X64(ov::intel_cpu::StoreConvertTruncation) \
OP_EXTENSION_X64(ov::intel_cpu::GemmCPU) \
OP_EXTENSION_X64(ov::intel_cpu::GemmCPU) \
OP_EXTENSION_X64(ov::intel_cpu::BrgemmCPU) \
OP_EXTENSION_X64(ov::intel_cpu::BrgemmCopyB)

#define TYPE_RELAXED_EXTENSIONS \
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const {
#endif

SNIPPETS_REGISTER_PASS_RELATIVE(Place::After,
ov::snippets::lowered::pass::AllocateBuffers,
ov::intel_cpu::pass::GemmCPUBlocking,
ov::intel_cpu::pass::BuildBrgemm);

#undef SNIPPETS_REGISTER_PASS_RELATIVE
Expand Down
Loading

0 comments on commit 5a7b16b

Please sign in to comment.