Skip to content

Commit

Permalink
build brgemm
Browse files Browse the repository at this point in the history
  • Loading branch information
aobolensk committed Feb 5, 2025
1 parent c27f796 commit abc20bd
Show file tree
Hide file tree
Showing 35 changed files with 862 additions and 238 deletions.
4 changes: 2 additions & 2 deletions src/common/snippets/docs/mha_optimization_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ For enhancing the execution efficiency, blocking across the M, K, and N matmul d

### Blocking Parameters

The heuristics for determining the optimal block sizes can be found in [BrgemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).
The heuristics for determining the optimal block sizes can be found in [GemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).

**Please note: Blocking by M dimension is shared between both Brgemms. Please see [SplitLoops](../include/snippets/lowered/pass/split_loops.hpp) lowered pass for the details.**

Expand All @@ -141,7 +141,7 @@ Based on previously discussed information, we provide the following recommendati
In local experiments, some transformations might be worth to change:
- Disable [ExtractUnsupportedTransposes](#extractunsupportedtransposes) transformation in order to benchmark Snippets Transpose implementation.
- Adjust [SplitDimensionM](#splitdimensionm) heuristics in order to benchmark another splitting, or disable the pass at all.
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `BrgemmCPUBlocking`.
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `GemmCPUBlocking`.
- Please note that there are 2 Matmul nodes inside a single MHA, and each Matmul can have his own optimal K, N blocking params.
M block is better to keep the same since the corresponding blocking loop is shared between both Matmuls.
- For the BF16/INT8 blocking loops, 2 options are possible: blocking can be done only for Brgemm node, or for BrgemmCopyB repacking too.
Expand Down
14 changes: 10 additions & 4 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,16 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
const auto input_ports = new_expr_it->get()->get_input_ports();
const auto output_ports = new_expr_it->get()->get_output_ports();
for (const auto& old_expr : old_exprs) {
for (size_t i = 0; i < old_expr->get_input_count(); ++i)
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_input_port(i), input_ports);
for (size_t i = 0; i < old_expr->get_input_count(); ++i)
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_output_port(i), output_ports);
for (size_t i = 0; i < old_expr->get_input_count(); ++i) {
m_loop_manager->replace_loop_ports(loop_ids,
old_expr->get_input_port(i),
{new_expr_it->get()->get_input_port(i)});
}
for (size_t i = 0; i < old_expr->get_output_count(); ++i) {
m_loop_manager->replace_loop_ports(loop_ids,
old_expr->get_output_port(i),
{new_expr_it->get()->get_output_port(i)});
}
erase(find(old_expr));
}
return new_expr_it;
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
}

void Brgemm::custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c) {
INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types);
INTERNAL_OP_SCOPE(GemmCPU_constructor_validate_and_infer_types);

// During ctor call, Brgemm doesn't know his port descriptors.
// So we use explicit layouts from parameters
Expand Down Expand Up @@ -100,7 +100,7 @@ ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, con
ov::element::Type Brgemm::get_output_type() const {
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
if (output_type == element::undefined) {
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
OPENVINO_THROW("GemmCPU node has incompatible input element types: " +
get_input_element_type(0).get_type_name() +
" and " +
get_input_element_type(1).get_type_name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
#include "transformations/snippets/x64/op/store_convert.hpp"
Expand Down Expand Up @@ -260,6 +261,10 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho

// Note: jit_brgemm_emitter and jit_brgemm_copy_b_emitter support runtime recompilation, so their constructor takes
// additional arguments
jitters[intel_cpu::GemmCPU::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCPU::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
configurator->get_kernel_executor_table(),
Expand Down Expand Up @@ -431,7 +436,7 @@ std::shared_ptr<snippets::Generator> intel_cpu::CPUGenerator::clone() const {

ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(const ov::Output<ov::Node>& out) const {
const auto op = out.get_node_shared_ptr();
if (is_type<intel_cpu::BrgemmCPU>(op) ||
if (is_type<intel_cpu::GemmCPU>(op) ||
#ifdef SNIPPETS_LIBXSMM_TPP
std::dynamic_pointer_cast<intel_cpu::tpp::modifier::TensorProcessingPrimitive>(op) ||
is_type<intel_cpu::tpp::op::Scalar>(op) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/utils.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_batched.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "utils.hpp"

Expand All @@ -26,44 +28,67 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: jit_binary_call_emitter(h, isa, expr->get_live_regs()) {
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
const auto brgemm_type = brgemm_node->get_type();
m_is_with_amx = brgemm_utils::with_amx(brgemm_type);
if (m_is_with_amx) {
BrgemmAMXKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_utils::get_primitive_isa(brg0Prc, true));
if (is_type<ov::intel_cpu::BrgemmCPU>(expr->get_node())) {
const auto& gemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
const auto& brg0Prc = gemm_node->get_input_element_type(0);
const auto& brg1Prc = gemm_node->get_input_element_type(1);
const auto brgemm_type = gemm_node->get_type();
m_is_with_amx = false;

BrgemmBatchedKernelConfig kernel_config(brg0Prc,
brg1Prc,
with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, false));
m_kernel_executor =
kernel_table->register_kernel<BrgemmAMXKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
kernel_table->register_kernel<BrgemmBatchedKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {gemm_node->get_offset_a(), gemm_node->get_offset_b(), gemm_node->get_offset_c()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
utils::get_buffer_cluster_id(expr->get_input_port(1)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
} else if (is_type<ov::intel_cpu::GemmCPU>(expr->get_node())) {
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::GemmCPU>(expr->get_node());
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
const auto brgemm_type = brgemm_node->get_type();
m_is_with_amx = brgemm_utils::with_amx(brgemm_type);
if (m_is_with_amx) {
BrgemmAMXKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_utils::get_primitive_isa(brg0Prc, true));
m_kernel_executor =
kernel_table->register_kernel<BrgemmAMXKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
} else {
BrgemmKernelConfig kernel_config(brg0Prc,
brg1Prc,
with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, false));
m_kernel_executor =
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}
// Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
OV_CPU_JIT_EMITTER_ASSERT(
!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()) &&
!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()),
"Jit emitter is called when the shapes are unknown");

m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
utils::get_buffer_cluster_id(expr->get_input_port(1)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
if (with_scratchpad(brgemm_type)) {
m_memory_offsets.push_back(brgemm_node->get_offset_scratch());
m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2)));
}
} else {
BrgemmKernelConfig kernel_config(brg0Prc,
brg1Prc,
with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, false));
m_kernel_executor =
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}
// Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()) &&
!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()),
"Jit emitter is called when the shapes are unknown");

m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
utils::get_buffer_cluster_id(expr->get_input_port(1)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
if (with_scratchpad(brgemm_type)) {
m_memory_offsets.push_back(brgemm_node->get_offset_scratch());
m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2)));
OV_CPU_JIT_EMITTER_THROW("got unsupported node type");
}
}

std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
const auto brgemm = as_type_ptr<ov::intel_cpu::BrgemmCPU>(node);
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects BrgemmCPU node");
const auto brgemm = as_type_ptr<ov::intel_cpu::GemmCPU>(node);
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects GemmCPU node");
using brgemm_utils::BRGEMM_TYPE;
if (brgemm->get_type() == BRGEMM_TYPE::STAND_ALONE) {
return {{element::f32, element::f32}};
Expand All @@ -83,7 +108,7 @@ std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precision
{element::bf16, element::bf16, element::u8},
{element::f16, element::f16, element::u8}};
}
OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type");
OV_CPU_JIT_EMITTER_THROW("got GemmCPU node with unsupported type");
}

void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
Expand All @@ -103,6 +128,8 @@ void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vec
emit_call<BrgemmAMXKernelExecutor>(mem_ptrs_idxs);
} else if (std::dynamic_pointer_cast<BrgemmKernelExecutor>(m_kernel_executor)) {
emit_call<BrgemmKernelExecutor>(mem_ptrs_idxs);
} else if (std::dynamic_pointer_cast<BrgemmBatchedKernelExecutor>(m_kernel_executor)) {
emit_call<BrgemmBatchedKernelExecutor>(mem_ptrs_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("uknown execuor type");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

using namespace Xbyak;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <cpu/x64/amx_tile_configure.hpp>

#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define INNER_K_BLK(dtype) static_cast<dnnl_dim_t>((brgemm_utils::repacking::compute_inner_k_block(in0_dtype)))
Expand Down Expand Up @@ -293,7 +293,7 @@ void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, c

if (K_tail != 0) {
if (config.need_copy_a(K_tail)) {
auto* tr_src = scratch + BrgemmCPU::SCRATCH_BYTE_SIZE;
auto* tr_src = scratch + GemmCPU::SCRATCH_BYTE_SIZE;

execute_brgemm_copy_a_kernel(kernel->brgemm_copy_a_kernel, src_ptr, tr_src, config.get_M(), K_tail);
src_ptr = tr_src;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define DIM_CAST(X) static_cast<dnnl_dim_t>(X)
Expand Down Expand Up @@ -163,7 +164,9 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
const auto in0_shape = snippets::utils::get_planar_vdims(input_pds[0]->get_shape(), input_pds[0]->get_layout());
const auto in1_shape = snippets::utils::get_planar_vdims(input_pds[1]->get_shape(), input_pds[1]->get_layout());
auto in0_subtensor = input_pds[0]->get_subtensor();
OPENVINO_ASSERT(!in0_subtensor.empty(), "Incorrect in0 subtensor size");
auto in1_subtensor = input_pds[1]->get_subtensor();
OPENVINO_ASSERT(!in1_subtensor.empty(), "Incorrect in1 subtensor size");

// Need to update M, K, N
// 1. If the original value in subtensor is `FULL_DIM`, it means that
Expand Down Expand Up @@ -254,11 +257,15 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));

const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type())) {
LDB = DIM_CAST(brgemm_utils::repacking::compute_repacked_n_dim(LDB, brgemm_node->get_input_element_type(1)));
if (is_type<ov::intel_cpu::BrgemmCPU>(expr->get_node())) {
} else if (is_type<ov::intel_cpu::GemmCPU>(expr->get_node())) {
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::GemmCPU>(expr->get_node());
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type())) {
LDB = DIM_CAST(brgemm_utils::repacking::compute_repacked_n_dim(LDB, brgemm_node->get_input_element_type(1)));
}
} else {
OV_CPU_JIT_EMITTER_ASSERT(false, "Got invalid node type in update_config");
}

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
Expand Down
Loading

0 comments on commit abc20bd

Please sign in to comment.