Skip to content

Commit

Permalink
build brgemm
Browse files Browse the repository at this point in the history
  • Loading branch information
aobolensk committed Jan 29, 2025
1 parent 6465c23 commit d5618a4
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_batched.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
Expand Down Expand Up @@ -36,12 +37,19 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h,
m_kernel_executor =
kernel_table->register_kernel<BrgemmAMXKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
} else {
BrgemmKernelConfig kernel_config(brg0Prc,
brg1Prc,
with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, false));
BrgemmBatchedKernelConfig kernel_config(brg0Prc,
brg1Prc,
with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, false));
m_kernel_executor =
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
kernel_table->register_kernel<BrgemmBatchedKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

// BrgemmKernelConfig kernel_config(brg0Prc,
// brg1Prc,
// with_compensations(brgemm_type),
// brgemm_utils::get_primitive_isa(brg0Prc, false));
// m_kernel_executor =
// kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}
// Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually
Expand Down Expand Up @@ -103,6 +111,8 @@ void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vec
emit_call<BrgemmAMXKernelExecutor>(mem_ptrs_idxs);
} else if (std::dynamic_pointer_cast<BrgemmKernelExecutor>(m_kernel_executor)) {
emit_call<BrgemmKernelExecutor>(mem_ptrs_idxs);
} else if (std::dynamic_pointer_cast<BrgemmBatchedKernelExecutor>(m_kernel_executor)) {
emit_call<BrgemmBatchedKernelExecutor>(mem_ptrs_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("uknown execuor type");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,12 @@ void BrgemmBaseKernelExecutor::execute_brgemm_kernel(
void* scratch,
bool with_comp) {
cpu::x64::brgemm_kernel_params_t brgemm_p;
brgemm_p.batch = nullptr; // default value
brgemm_p.ptr_A = src;
brgemm_p.ptr_B = wei;
brgemm_batch_element_t addr_batch;
addr_batch.ptr.A = src;
addr_batch.ptr.B = wei;
brgemm_p.batch = &addr_batch;
brgemm_p.ptr_A = nullptr;
brgemm_p.ptr_B = nullptr;
brgemm_p.ptr_C = dst;
brgemm_p.ptr_D = dst;
brgemm_p.ptr_buf = scratch;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (C) 2020-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "brgemm_batched.hpp"

#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
using namespace dnnl::impl::cpu::x64;

namespace ov {
namespace intel_cpu {

BrgemmBatchedKernelConfig::BrgemmBatchedKernelConfig(const element::Type& in0_dtype,
const element::Type& in1_dtype,
bool is_with_comp,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa)
: BrgemmBaseKernelConfig(),
m_static_params(std::make_shared<StaticParams>(in0_dtype, in1_dtype, is_with_comp, primitive_isa)) {
m_hash = compute_hash();
}

BrgemmBatchedKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
bool is_with_comp,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa)
: StaticBaseParams(in0_dtype, in1_dtype, primitive_isa, compute_hash(is_with_comp)),
is_with_comp(is_with_comp) {}

bool BrgemmBatchedKernelConfig::StaticParams::operator==(const StaticParams& rhs) const {
return StaticBaseParams::operator==(rhs) && is_with_comp == rhs.is_with_comp;
}

size_t BrgemmBatchedKernelConfig::StaticParams::compute_hash(bool is_with_comp) {
return hash_combine(0, is_with_comp);
}

#ifdef SNIPPETS_DEBUG_CAPS
std::string BrgemmBatchedKernelConfig::StaticParams::to_string() const {
std::stringstream ss;
ss << StaticBaseParams::to_string();
ss << "is_with_comp = " << is_with_comp << "\n";
return ss.str();
}
#endif

BrgemmBatchedKernelExecutor::BrgemmBatchedKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmBatchedKernelConfig config)
: CPUKernelExecutor<BrgemmBatchedKernelConfig, BrgemmBatchedCompiledKernel>(std::move(kernel_cache), std::move(config)) {}

std::shared_ptr<BrgemmBatchedCompiledKernel> BrgemmBatchedKernelExecutor::compile_kernel(const BrgemmBatchedKernelConfig& config) const {
std::shared_ptr<BrgemmBatchedCompiledKernel> compiled_kernel = std::make_shared<BrgemmBatchedCompiledKernel>();

// Brgemm is not executable - nothing to compile
if (config.is_empty()) {
return compiled_kernel;
}

cpu::x64::brgemm_desc_t desc;
OV_CPU_JIT_EMITTER_ASSERT(brgemm_desc_init(&desc,
config.get_isa(),
cpu::x64::brgemm_addr, // TODO: addr
config.get_dt_in0(),
config.get_dt_in1(),
false,
false,
cpu::x64::brgemm_row_major,
1.f,
config.get_beta(),
config.get_LDA(),
config.get_LDB(),
config.get_LDC(),
config.get_M(),
config.get_N(),
config.get_K(),
nullptr) == dnnl_success,
"Cannot initialize brgemm descriptor due to invalid params");

cpu::x64::brgemm_kernel_t* kernel_ = nullptr;
OV_CPU_JIT_EMITTER_ASSERT(brgemm_kernel_create(&kernel_, desc) == dnnl_success,
"Cannot create brgemm kernel due to invalid params");
compiled_kernel->brgemm_kernel = std::unique_ptr<brgemm_kernel_t>(kernel_);

return compiled_kernel;
}

void BrgemmBatchedKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr,
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
BrgemmBatchedKernelConfig& config) const {
return BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config);
}

void BrgemmBatchedKernelExecutor::execute(const BrgemmBatchedKernelExecutor* executor, call_args* args) {
OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor");
auto kernel = executor->get_kernel();
const auto& config = static_cast<const BrgemmBatchedKernelConfig&>(executor->get_config());
OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr compiler kernel or invalid config");

// Note: compensations should be applied only once, so we do it only on the first iteration, when beta == 0
const auto is_with_comp = config.get_beta() == 0 && config.is_with_comp();
execute_brgemm_kernel(kernel->brgemm_kernel, args->A, args->B, args->C, args->scratch, is_with_comp);
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (C) 2020-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "brgemm_base.hpp"

namespace ov {
namespace intel_cpu {

struct BrgemmBatchedKernelConfig : public BrgemmBaseKernelConfig {
public:
BrgemmBatchedKernelConfig(const element::Type& in0_dtype,
const element::Type& in1_dtype,
bool is_with_comp,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa);
BrgemmBatchedKernelConfig() = delete;

std::unique_ptr<snippets::KernelExecutorBase::GenericConfig> get_clone_ptr() const override {
return std::unique_ptr<BrgemmBatchedKernelConfig>(new BrgemmBatchedKernelConfig(*this));
}

bool is_with_comp() const {
return m_static_params->is_with_comp;
}

private:
struct StaticParams : StaticBaseParams {
StaticParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
bool is_with_comp,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa);

const bool is_with_comp{false};

bool operator==(const StaticParams& rhs) const;
bool operator!=(const StaticParams& rhs) const {
return !(*this == rhs);
}
#ifdef SNIPPETS_DEBUG_CAPS
std::string to_string() const;
#endif
private:
static size_t compute_hash(bool is_with_comp);
};

std::shared_ptr<StaticBaseParams> get_static_params() const override {
return m_static_params;
}

std::shared_ptr<StaticParams> m_static_params{nullptr};
};

// The `update_kernel` method verifies that a compiled kernel is not nullptr.
// However, the compiled kernel might be empty in cases if nothing is to be compiled (`Config.is_empty() == true`).
// To cover this case, we wrap the `brgemm_kernel_t` in the separate structure which may contain empty `brgemm_kernel_t`
struct BrgemmBatchedCompiledKernel {
std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgemm_kernel = nullptr;
};

class BrgemmBatchedKernelExecutor : public BrgemmBaseKernelExecutor,
public CPUKernelExecutor<BrgemmBatchedKernelConfig, BrgemmBatchedCompiledKernel> {
public:
struct call_args {
const void* A = nullptr;
const void* B = nullptr;
void* C = nullptr;
void* scratch = nullptr;
};
BrgemmBatchedKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmBatchedKernelConfig config);

/** Function that will be called in runtime to execute the kernel */
static void execute(const BrgemmBatchedKernelExecutor* executor, call_args* args);

protected:
std::shared_ptr<BrgemmBatchedCompiledKernel> compile_kernel(const BrgemmBatchedKernelConfig& c) const override;

void update_config(const ov::snippets::lowered::ExpressionPtr& expr,
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
BrgemmBatchedKernelConfig& config) const override;
};
#define GET_OFF_BRGEMM_ARGS(field) offsetof(BrgemmBatchedKernelExecutor::call_args, field)

} // namespace intel_cpu
} // namespace ov
6 changes: 6 additions & 0 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "onednn/dnnl.h"
#include "openvino/core/parallel.hpp"
#include "shape_inference/custom/subgraph.hpp"
#include "snippets/lowered/pass/allocate_buffers.hpp"
#include "snippets/lowered/pass/init_loops.hpp"
#include "snippets/lowered/pass/insert_buffers.hpp"
#include "snippets/lowered/pass/insert_loops.hpp"
Expand Down Expand Up @@ -35,6 +36,7 @@
# include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp"
# include "transformations/snippets/x64/pass/enforce_precision.hpp"
# include "transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.hpp"
# include "transformations/snippets/x64/pass/lowered/build_brgemm.hpp"
# include "transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp"
# include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp"
# include "transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.hpp"
Expand Down Expand Up @@ -561,6 +563,10 @@ Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const {
ov::intel_cpu::tpp::pass::SetTPPLeadingDim);
#endif

SNIPPETS_REGISTER_PASS_RELATIVE(Place::After,
ov::snippets::lowered::pass::AllocateBuffers,
ov::intel_cpu::pass::BuildBrgemm);

#undef SNIPPETS_REGISTER_PASS_RELATIVE
return backend_passes;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "build_brgemm.hpp"

#include "cpu/x64/cpu_isa_traits.hpp"
#include "cpu_shape.h"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "snippets/itt.hpp"
#include "snippets/op/brgemm.hpp"
#include "snippets/op/buffer.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/tpp/x64/op/modifiers.hpp"
#include "utils/general_utils.h"

namespace ov {
namespace intel_cpu {

bool pass::BuildBrgemm::run(const snippets::lowered::LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::AdjustBrgemmCopyBLoopPorts")
fprintf(stderr, "BuildBrgemm::run\n");
bool modified = false;
for (const auto& expr : linear_ir) {
fprintf(stderr, "expr: %s\n", expr->get_node()->get_friendly_name().c_str());
const auto brgemm_node = ov::as_type_ptr<BrgemmCPU>(expr->get_node());
if (!brgemm_node || brgemm_node->is_dynamic()) {
continue;
}
const auto& loop_manager = linear_ir.get_loop_manager();
OPENVINO_ASSERT(loop_manager, "BrgemmCPU node should have a loop manager.");

const auto loop_ids = expr->get_loop_ids();
if (!loop_ids.empty()) {
// Get innermost loop info
auto loop_expr = loop_manager->get_loop_bounds(linear_ir, loop_ids.back()).first;
fprintf(stderr, "Loop bounds: %s\n", loop_expr->get()->get_node()->get_friendly_name().c_str());
const auto& inner_loop_info = loop_manager->get_loop_info<snippets::lowered::UnifiedLoopInfo>(loop_ids.front());
fprintf(stderr, "work_amount: %ld\n", inner_loop_info->get_work_amount());
fprintf(stderr, "increment: %ld\n", inner_loop_info->get_increment());
auto iter_count = inner_loop_info->get_work_amount() / inner_loop_info->get_increment();
fprintf(stderr, "iter_count: %ld\n", iter_count);
}

}

return modified;
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/pass/pass.hpp"

namespace ov {
namespace intel_cpu {
namespace pass {

/**
* @interface BuildBrgemm
* @brief The pass explicitly insert LoadBegin and LoadEnd in Linear IR using UnifiedLoopInfo from Loop markup algorithm
* @ingroup snippets
*/
class BuildBrgemm : public snippets::lowered::pass::ConstPass {
public:
OPENVINO_RTTI("BuildBrgemm", "", ConstPass)
BuildBrgemm() = default;
bool run(const snippets::lowered::LinearIR& linear_ir) override;
const std::unordered_set<snippets::lowered::UnifiedLoopInfoPtr>& get_affected_loops() {
return m_affected_loops;
}

private:
std::unordered_set<snippets::lowered::UnifiedLoopInfoPtr> m_affected_loops;
};

} // namespace pass
} // namespace intel_cpu
} // namespace ov

0 comments on commit d5618a4

Please sign in to comment.