Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets][CPU] Batch brgemm execution for K blocking loops #28724

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/common/snippets/docs/mha_optimization_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ For enhancing the execution efficiency, blocking across the M, K, and N matmul d

### Blocking Parameters

The heuristics for determining the optimal block sizes can be found in [BrgemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).
The heuristics for determining the optimal block sizes can be found in [GemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).

**Please note: Blocking by M dimension is shared between both Brgemms. Please see [SplitLoops](../include/snippets/lowered/pass/split_loops.hpp) lowered pass for the details.**

Expand All @@ -141,7 +141,7 @@ Based on previously discussed information, we provide the following recommendati
In local experiments, some transformations might be worth to change:
- Disable [ExtractUnsupportedTransposes](#extractunsupportedtransposes) transformation in order to benchmark Snippets Transpose implementation.
- Adjust [SplitDimensionM](#splitdimensionm) heuristics in order to benchmark another splitting, or disable the pass at all.
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `BrgemmCPUBlocking`.
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `GemmCPUBlocking`.
- Please note that there are 2 Matmul nodes inside a single MHA, and each Matmul can have his own optimal K, N blocking params.
M block is better to keep the same since the corresponding blocking loop is shared between both Matmuls.
- For the BF16/INT8 blocking loops, 2 options are possible: blocking can be done only for Brgemm node, or for BrgemmCopyB repacking too.
Expand Down
14 changes: 10 additions & 4 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,16 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
const auto input_ports = new_expr_it->get()->get_input_ports();
const auto output_ports = new_expr_it->get()->get_output_ports();
for (const auto& old_expr : old_exprs) {
for (size_t i = 0; i < old_expr->get_input_count(); ++i)
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_input_port(i), input_ports);
for (size_t i = 0; i < old_expr->get_input_count(); ++i)
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_output_port(i), output_ports);
for (size_t i = 0; i < old_expr->get_input_count(); ++i) {
m_loop_manager->replace_loop_ports(loop_ids,
old_expr->get_input_port(i),
{new_expr_it->get()->get_input_port(i)});
}
for (size_t i = 0; i < old_expr->get_output_count(); ++i) {
m_loop_manager->replace_loop_ports(loop_ids,
old_expr->get_output_port(i),
{new_expr_it->get()->get_output_port(i)});
}
erase(find(old_expr));
}
return new_expr_it;
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
}

void Brgemm::custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c) {
INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types);
INTERNAL_OP_SCOPE(GemmCPU_constructor_validate_and_infer_types);

// During ctor call, Brgemm doesn't know his port descriptors.
// So we use explicit layouts from parameters
Expand Down Expand Up @@ -100,7 +100,7 @@ ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, con
ov::element::Type Brgemm::get_output_type() const {
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
if (output_type == element::undefined) {
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
OPENVINO_THROW("GemmCPU node has incompatible input element types: " +
get_input_element_type(0).get_type_name() +
" and " +
get_input_element_type(1).get_type_name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
#include "transformations/snippets/x64/op/store_convert.hpp"
Expand Down Expand Up @@ -260,6 +261,10 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho

// Note: jit_brgemm_emitter and jit_brgemm_copy_b_emitter support runtime recompilation, so their constructor takes
// additional arguments
jitters[intel_cpu::GemmCPU::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCPU::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
configurator->get_kernel_executor_table(),
Expand Down Expand Up @@ -431,7 +436,7 @@ std::shared_ptr<snippets::Generator> intel_cpu::CPUGenerator::clone() const {

ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(const ov::Output<ov::Node>& out) const {
const auto op = out.get_node_shared_ptr();
if (is_type<intel_cpu::BrgemmCPU>(op) ||
if (is_type<intel_cpu::GemmCPU>(op) ||
#ifdef SNIPPETS_LIBXSMM_TPP
std::dynamic_pointer_cast<intel_cpu::tpp::modifier::TensorProcessingPrimitive>(op) ||
is_type<intel_cpu::tpp::op::Scalar>(op) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/utils.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_batched.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "utils.hpp"

Expand All @@ -26,44 +28,68 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: jit_binary_call_emitter(h, isa, expr->get_live_regs()) {
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
const auto brgemm_type = brgemm_node->get_type();
m_is_with_amx = brgemm_utils::with_amx(brgemm_type);
if (m_is_with_amx) {
BrgemmAMXKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_utils::get_primitive_isa(brg0Prc, true));
if (is_type<ov::intel_cpu::BrgemmCPU>(expr->get_node())) {
const auto& gemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
const auto& brg0Prc = gemm_node->get_input_element_type(0);
const auto& brg1Prc = gemm_node->get_input_element_type(1);
const auto brgemm_type = gemm_node->get_type();
m_is_with_amx = false;

BrgemmBatchedKernelConfig kernel_config(brg0Prc,
brg1Prc,
gemm_node->get_iter_count(),
with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, false));
m_kernel_executor =
kernel_table->register_kernel<BrgemmAMXKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
kernel_table->register_kernel<BrgemmBatchedKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {gemm_node->get_offset_a(), gemm_node->get_offset_b(), gemm_node->get_offset_c()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
utils::get_buffer_cluster_id(expr->get_input_port(1)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
} else if (is_type<ov::intel_cpu::GemmCPU>(expr->get_node())) {
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::GemmCPU>(expr->get_node());
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
const auto brgemm_type = brgemm_node->get_type();
m_is_with_amx = brgemm_utils::with_amx(brgemm_type);
if (m_is_with_amx) {
BrgemmAMXKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_utils::get_primitive_isa(brg0Prc, true));
m_kernel_executor =
kernel_table->register_kernel<BrgemmAMXKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
} else {
BrgemmKernelConfig kernel_config(brg0Prc,
brg1Prc,
with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, false));
m_kernel_executor =
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}
// Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
OV_CPU_JIT_EMITTER_ASSERT(
!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()) &&
!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()),
"Jit emitter is called when the shapes are unknown");

m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
utils::get_buffer_cluster_id(expr->get_input_port(1)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
if (with_scratchpad(brgemm_type)) {
m_memory_offsets.push_back(brgemm_node->get_offset_scratch());
m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2)));
}
} else {
BrgemmKernelConfig kernel_config(brg0Prc,
brg1Prc,
with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, false));
m_kernel_executor =
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}
// Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()) &&
!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()),
"Jit emitter is called when the shapes are unknown");

m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
utils::get_buffer_cluster_id(expr->get_input_port(1)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
if (with_scratchpad(brgemm_type)) {
m_memory_offsets.push_back(brgemm_node->get_offset_scratch());
m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2)));
OV_CPU_JIT_EMITTER_THROW("got unsupported node type");
}
}

std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
const auto brgemm = as_type_ptr<ov::intel_cpu::BrgemmCPU>(node);
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects BrgemmCPU node");
const auto brgemm = as_type_ptr<ov::intel_cpu::GemmCPU>(node);
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects GemmCPU node");
using brgemm_utils::BRGEMM_TYPE;
if (brgemm->get_type() == BRGEMM_TYPE::STAND_ALONE) {
return {{element::f32, element::f32}};
Expand All @@ -83,7 +109,7 @@ std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precision
{element::bf16, element::bf16, element::u8},
{element::f16, element::f16, element::u8}};
}
OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type");
OV_CPU_JIT_EMITTER_THROW("got GemmCPU node with unsupported type");
}

void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
Expand All @@ -103,6 +129,8 @@ void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vec
emit_call<BrgemmAMXKernelExecutor>(mem_ptrs_idxs);
} else if (std::dynamic_pointer_cast<BrgemmKernelExecutor>(m_kernel_executor)) {
emit_call<BrgemmKernelExecutor>(mem_ptrs_idxs);
} else if (std::dynamic_pointer_cast<BrgemmBatchedKernelExecutor>(m_kernel_executor)) {
emit_call<BrgemmBatchedKernelExecutor>(mem_ptrs_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("uknown execuor type");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

using namespace Xbyak;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <cpu/x64/amx_tile_configure.hpp>

#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define INNER_K_BLK(dtype) static_cast<dnnl_dim_t>((brgemm_utils::repacking::compute_inner_k_block(in0_dtype)))
Expand Down Expand Up @@ -293,7 +293,7 @@ void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, c

if (K_tail != 0) {
if (config.need_copy_a(K_tail)) {
auto* tr_src = scratch + BrgemmCPU::SCRATCH_BYTE_SIZE;
auto* tr_src = scratch + GemmCPU::SCRATCH_BYTE_SIZE;

execute_brgemm_copy_a_kernel(kernel->brgemm_copy_a_kernel, src_ptr, tr_src, config.get_M(), K_tail);
src_ptr = tr_src;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define DIM_CAST(X) static_cast<dnnl_dim_t>(X)
Expand Down Expand Up @@ -163,7 +164,9 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
const auto in0_shape = snippets::utils::get_planar_vdims(input_pds[0]->get_shape(), input_pds[0]->get_layout());
const auto in1_shape = snippets::utils::get_planar_vdims(input_pds[1]->get_shape(), input_pds[1]->get_layout());
auto in0_subtensor = input_pds[0]->get_subtensor();
OPENVINO_ASSERT(!in0_subtensor.empty(), "Incorrect in0 subtensor size");
auto in1_subtensor = input_pds[1]->get_subtensor();
OPENVINO_ASSERT(!in1_subtensor.empty(), "Incorrect in1 subtensor size");

// Need to update M, K, N
// 1. If the original value in subtensor is `FULL_DIM`, it means that
Expand Down Expand Up @@ -254,11 +257,15 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));

const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type())) {
LDB = DIM_CAST(brgemm_utils::repacking::compute_repacked_n_dim(LDB, brgemm_node->get_input_element_type(1)));
if (is_type<ov::intel_cpu::BrgemmCPU>(expr->get_node())) {
} else if (is_type<ov::intel_cpu::GemmCPU>(expr->get_node())) {
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::GemmCPU>(expr->get_node());
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type())) {
LDB = DIM_CAST(brgemm_utils::repacking::compute_repacked_n_dim(LDB, brgemm_node->get_input_element_type(1)));
}
} else {
OV_CPU_JIT_EMITTER_ASSERT(false, "Got invalid node type in update_config");
}

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
Expand Down
Loading
Loading