Skip to content

Commit

Permalink
brgemm -> gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
aobolensk committed Jan 29, 2025
1 parent a8eaf9b commit ff038c9
Show file tree
Hide file tree
Showing 30 changed files with 152 additions and 152 deletions.
4 changes: 2 additions & 2 deletions src/common/snippets/docs/mha_optimization_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ For enhancing the execution efficiency, blocking across the M, K, and N matmul d

### Blocking Parameters

The heuristics for determining the optimal block sizes can be found in [BrgemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).
The heuristics for determining the optimal block sizes can be found in [GemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).

**Please note: Blocking by M dimension is shared between both Brgemms. Please see [SplitLoops](../include/snippets/lowered/pass/split_loops.hpp) lowered pass for the details.**

Expand All @@ -141,7 +141,7 @@ Based on previously discussed information, we provide the following recommendati
In local experiments, some transformations might be worth to change:
- Disable [ExtractUnsupportedTransposes](#extractunsupportedtransposes) transformation in order to benchmark Snippets Transpose implementation.
- Adjust [SplitDimensionM](#splitdimensionm) heuristics in order to benchmark another splitting, or disable the pass at all.
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `BrgemmCPUBlocking`.
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `GemmCPUBlocking`.
- Please note that there are 2 Matmul nodes inside a single MHA, and each Matmul can have his own optimal K, N blocking params.
M block is better to keep the same since the corresponding blocking loop is shared between both Matmuls.
- For the BF16/INT8 blocking loops, 2 options are possible: blocking can be done only for Brgemm node, or for BrgemmCopyB repacking too.
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
}

void Brgemm::custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c) {
INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types);
INTERNAL_OP_SCOPE(GemmCPU_constructor_validate_and_infer_types);

// During ctor call, Brgemm doesn't know his port descriptors.
// So we use explicit layouts from parameters
Expand Down Expand Up @@ -100,7 +100,7 @@ ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, con
ov::element::Type Brgemm::get_output_type() const {
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
if (output_type == element::undefined) {
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
OPENVINO_THROW("GemmCPU node has incompatible input element types: " +
get_input_element_type(0).get_type_name() +
" and " +
get_input_element_type(1).get_type_name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
#include "transformations/snippets/x64/op/store_convert.hpp"
Expand Down Expand Up @@ -260,7 +260,7 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho

// Note: jit_brgemm_emitter and jit_brgemm_copy_b_emitter support runtime recompilation, so their constructor takes
// additional arguments
jitters[intel_cpu::BrgemmCPU::get_type_info_static()] =
jitters[intel_cpu::GemmCPU::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
Expand Down Expand Up @@ -431,7 +431,7 @@ std::shared_ptr<snippets::Generator> intel_cpu::CPUGenerator::clone() const {

ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(const ov::Output<ov::Node>& out) const {
const auto op = out.get_node_shared_ptr();
if (is_type<intel_cpu::BrgemmCPU>(op) ||
if (is_type<intel_cpu::GemmCPU>(op) ||
#ifdef SNIPPETS_LIBXSMM_TPP
std::dynamic_pointer_cast<intel_cpu::tpp::modifier::TensorProcessingPrimitive>(op) ||
is_type<intel_cpu::tpp::op::Scalar>(op) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/utils.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_batched.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "utils.hpp"

Expand All @@ -27,7 +27,7 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: jit_binary_call_emitter(h, isa, expr->get_live_regs()) {
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::GemmCPU>(expr->get_node());
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
const auto brgemm_type = brgemm_node->get_type();
Expand Down Expand Up @@ -70,8 +70,8 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h,

std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
const auto brgemm = as_type_ptr<ov::intel_cpu::BrgemmCPU>(node);
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects BrgemmCPU node");
const auto brgemm = as_type_ptr<ov::intel_cpu::GemmCPU>(node);
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects GemmCPU node");
using brgemm_utils::BRGEMM_TYPE;
if (brgemm->get_type() == BRGEMM_TYPE::STAND_ALONE) {
return {{element::f32, element::f32}};
Expand All @@ -91,7 +91,7 @@ std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precision
{element::bf16, element::bf16, element::u8},
{element::f16, element::f16, element::u8}};
}
OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type");
OV_CPU_JIT_EMITTER_THROW("got GemmCPU node with unsupported type");
}

void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

using namespace Xbyak;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <cpu/x64/amx_tile_configure.hpp>

#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define INNER_K_BLK(dtype) static_cast<dnnl_dim_t>((brgemm_utils::repacking::compute_inner_k_block(in0_dtype)))
Expand Down Expand Up @@ -293,7 +293,7 @@ void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, c

if (K_tail != 0) {
if (config.need_copy_a(K_tail)) {
auto* tr_src = scratch + BrgemmCPU::SCRATCH_BYTE_SIZE;
auto* tr_src = scratch + GemmCPU::SCRATCH_BYTE_SIZE;

execute_brgemm_copy_a_kernel(kernel->brgemm_copy_a_kernel, src_ptr, tr_src, config.get_M(), K_tail);
src_ptr = tr_src;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define DIM_CAST(X) static_cast<dnnl_dim_t>(X)
Expand Down Expand Up @@ -254,7 +254,7 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));

const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::GemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

using namespace Xbyak;
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include "transformations/cpu_opset/x64/op/mha.hpp"
#include "transformations/cpu_opset/x64/op/qkv_proj.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
#include "transformations/snippets/x64/op/store_convert.hpp"
Expand Down Expand Up @@ -101,7 +101,7 @@ class TypeRelaxedExtension : public ov::OpExtension<ov::op::TypeRelaxed<Op>> {
OP_EXTENSION_X64(ov::intel_cpu::LoadConvertTruncation) \
OP_EXTENSION_X64(ov::intel_cpu::StoreConvertSaturation) \
OP_EXTENSION_X64(ov::intel_cpu::StoreConvertTruncation) \
OP_EXTENSION_X64(ov::intel_cpu::BrgemmCPU) \
OP_EXTENSION_X64(ov::intel_cpu::GemmCPU) \
OP_EXTENSION_X64(ov::intel_cpu::BrgemmCopyB)

#define TYPE_RELAXED_EXTENSIONS \
Expand Down
12 changes: 6 additions & 6 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#else
# include "emitters/snippets/x64/cpu_generator.hpp"
# include "executors/x64/subgraph.hpp"
# include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
# include "transformations/snippets/x64/pass/brgemm_to_gemm_cpu.hpp"
# include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp"
# include "transformations/snippets/x64/pass/enforce_precision.hpp"
# include "transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.hpp"
Expand Down Expand Up @@ -497,16 +497,16 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
}
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,
ov::snippets::pass::PropagatePrecision,
ov::intel_cpu::pass::BrgemmToBrgemmCPU);
ov::intel_cpu::pass::BrgemmToGemmCPU);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::intel_cpu::pass::BrgemmToBrgemmCPU,
ov::intel_cpu::pass::BrgemmToGemmCPU,
ov::intel_cpu::pass::EliminateBrgemmCopyB);
SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64(Place::PipelineEnd, ov::intel_cpu::pass::RemoveConverts);
SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineEnd, ov::intel_cpu::pass::MulAddToFMA);

#ifdef SNIPPETS_LIBXSMM_TPP
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,
ov::intel_cpu::pass::BrgemmToBrgemmCPU,
ov::intel_cpu::pass::BrgemmToGemmCPU,
ov::intel_cpu::tpp::pass::BrgemmToBrgemmTPP);
// Note: There could be several ConvertConstantsToScalars instances in the pipeline
SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64(Place::PipelineEnd, ov::intel_cpu::tpp::pass::ScalarToScalarTPP);
Expand Down Expand Up @@ -541,7 +541,7 @@ Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const {

SNIPPETS_REGISTER_PASS_RELATIVE(Place::After,
ov::snippets::lowered::pass::MarkLoops,
ov::intel_cpu::pass::BrgemmCPUBlocking);
ov::intel_cpu::pass::GemmCPUBlocking);

SNIPPETS_REGISTER_PASS_RELATIVE(Place::After,
ov::snippets::lowered::pass::InitLoops,
Expand All @@ -556,7 +556,7 @@ Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const {

#ifdef SNIPPETS_LIBXSMM_TPP
SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before,
ov::intel_cpu::pass::BrgemmCPUBlocking,
ov::intel_cpu::pass::GemmCPUBlocking,
ov::intel_cpu::tpp::pass::BrgemmTPPBlocking);
SNIPPETS_REGISTER_PASS_RELATIVE(Place::After,
ov::intel_cpu::pass::FuseLoadStoreConvert,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void BrgemmCopyB::custom_constructor_validate_and_infer_types(const std::vector<
const auto planar_pshape = snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_input);
// data repacking output
set_output_type(0, element_type, planar_pshape);
// If compensations are needed, they are provided in 2nd output (which is used in BrgemmCPU)
// If compensations are needed, they are provided in 2nd output (which is used in GemmCPU)
if (with_compensations(m_type)) {
set_output_type(1, ov::element::f32, planar_pshape);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "snippets/lowered/expressions/buffer_expression.hpp"
#include "snippets/op/buffer.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "utils/general_utils.h"

using namespace Xbyak;
Expand All @@ -33,7 +33,7 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) {
#define SUPPORT_TWO(X, Y, MESSAGE) SUPPORT(X, SUPPORT_ONE(Y, MESSAGE))
#define SUPPORT_THREE(X, Y, Z, MESSAGE) SUPPORT(X, SUPPORT_TWO(Y, Z, MESSAGE))

// Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToBrgemmCPU pass for details
// Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToGemmCPU pass for details
if (is_with_amx) {
if (dt_in0 == ov::element::f16) {
SUPPORT_ONE(avx512_core_amx_fp16,
Expand Down Expand Up @@ -65,9 +65,9 @@ BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, bool transp
}

OPENVINO_ASSERT(element_type_a != element::bf16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16),
"BrgemmCPU BF16 precision is not supported on non avx512_core_bf16 system");
"GemmCPU BF16 precision is not supported on non avx512_core_bf16 system");
OPENVINO_ASSERT(element_type_a != element::f16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16),
"BrgemmCPU FP16 precision is not supported on non avx512_core_amx_fp16 system");
"GemmCPU FP16 precision is not supported on non avx512_core_amx_fp16 system");

if (one_of(element_type_a, element::u8, element::i8, element::bf16) &&
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) {
Expand Down Expand Up @@ -124,8 +124,8 @@ size_t compute_inner_k_block(const ov::element::Type& precision) {
}

ov::snippets::lowered::ExpressionPtr get_copy_b_expr(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) {
OPENVINO_ASSERT(ov::is_type<BrgemmCPU>(brgemm_expr->get_node()),
"get_copy_b_expr must be called only for BrgemmCPU node");
OPENVINO_ASSERT(ov::is_type<GemmCPU>(brgemm_expr->get_node()),
"get_copy_b_expr must be called only for GemmCPU node");
auto b_input_expr = brgemm_expr->get_input_port_connector(1)->get_source().get_expr();
if (ov::is_type<BrgemmCopyB>(b_input_expr->get_node())) {
return b_input_expr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ size_t compute_inner_n_block(const ov::element::Type& precision);
size_t compute_inner_k_block(const ov::element::Type& precision);
/**
* @brief Computes leading dimension (LDB) which must be used in brgemm and brgemm_copy_b emitters
* @param n_block N block size shared between BrgemmCPU and BrgemmCopyB node
* @param n_block N block size shared between GemmCPU and BrgemmCopyB node
* @param precision tensor precision
*/
template <
Expand All @@ -74,9 +74,9 @@ T compute_LDB(T n_block, const ov::element::Type& precision) {
: std::max(n_block, static_cast<T>(compute_inner_n_block(precision)));
}
/**
* @brief Retrieves the expression pointer for the brgemm_copy_b expression corresponding to the given BrgemmCPU
* @brief Retrieves the expression pointer for the brgemm_copy_b expression corresponding to the given GemmCPU
* expression.
* @param brgemm_expr The expression pointer for the BrgemmCPU operation.
* @param brgemm_expr The expression pointer for the GemmCPU operation.
* @return The expression pointer for the BrgemmCopyB operation.
*/
snippets::lowered::ExpressionPtr get_copy_b_expr(const snippets::lowered::ExpressionPtr& brgemm_expr);
Expand Down
Loading

0 comments on commit ff038c9

Please sign in to comment.