Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable XeTLA LSTM for GPU #28817

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
96 changes: 96 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/cm/lstm_seq.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "impls/ocl/primitive_base.hpp"

#include "lstm_seq.hpp"
#include "lstm/cm/lstm_seq_cm_kernel_selector.h"
#include "lstm/lstm_kernel_base.h"
#include "openvino/op/lstm_sequence.hpp"
#include "impls/registry/implementation_manager.hpp"

namespace cldnn {
namespace cm {

struct lstm_seq_impl : ocl::typed_primitive_impl_ocl<lstm_seq> {
using parent = ocl::typed_primitive_impl_ocl<lstm_seq>;
using parent::parent;
using kernel_selector_t = kernel_selector::lstm_seq_cm_kernel_selector;
using kernel_params_t = kernel_selector::lstm_params;

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::cm::lstm_seq_impl)

std::unique_ptr<primitive_impl> clone() const override {
return std::make_unique<lstm_seq_impl>(*this);
}

protected:
kernel_arguments_data get_arguments(const typed_primitive_inst<lstm_seq>& instance) const override {
kernel_arguments_data args;
for (size_t i = 0; i < instance.inputs_memory_count(); i++) {
args.inputs.push_back(instance.input_memory_ptr(i));
}

for (size_t i = 0; i < instance.outputs_memory_count(); i++) {
args.outputs.push_back(instance.output_memory_ptr(i));
}
return args;
}

public:
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param) {
const auto& primitive = impl_param.typed_desc<lstm_seq>();
auto params = get_default_params<kernel_selector::lstm_params>(impl_param);
for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) {
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(i)));
}

if (!primitive->activations.empty()) {
auto a_sz = primitive->activations.size();
auto param_sz = primitive->activation_params.size();
OPENVINO_ASSERT(param_sz == 0|| a_sz == param_sz, "[GPU] Unexpected activation params count in lstm_seq impl: ", param_sz);
for (size_t i = 0; i < a_sz; i++) {
params.activations.emplace_back(get_kernel_selector_activation_param(primitive->activations[i]),
param_sz ? primitive->activation_params[i].a : 0.0f,
param_sz ? primitive->activation_params[i].b : 0.0f);
}
}

if (primitive->clip > 0.0f) {
params.activations.emplace_back(get_kernel_selector_activation_param(activation_func::clamp), -primitive->clip, primitive->clip);
}

params.SetOffsetOrder(static_cast<int32_t>(primitive->offset_order));
params.clip = primitive->clip;
params.direction = primitive->direction;
//Legacy multi-output
params.outputs.push_back(convert_data_tensor(impl_param.input_layouts[1]));
if (!primitive->initial_cell_state.pid.empty()) {
params.outputs.push_back(convert_data_tensor(impl_param.input_layouts[1]));
}
return params;
}

static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) {
if (impl_params.get_input_layout().get_partial_shape().size() != 3) {
return primitive_impl::static_canonicalize_shapes(impl_params);
}
auto updated_impl_params = canonicalize_fused_shapes(impl_params);
return updated_impl_params;
}

kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
return static_canonicalize_shapes(impl_params);
}
};

std::unique_ptr<primitive_impl> LSTMSeqImplementationManager::create_impl(const program_node& node, const kernel_impl_params& params) const {
OPENVINO_ASSERT(node.is_type<lstm_seq>());
return ocl::typed_primitive_impl_ocl<lstm_seq>::create<lstm_seq_impl>(static_cast<const lstm_seq_node&>(node), params);
}

} // namespace cm
} // namespace cldnn

BIND_BINARY_BUFFER_WITH_TYPE(cldnn::cm::lstm_seq_impl)
102 changes: 102 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/cm/lstm_seq.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <memory>

#include "impls/registry/implementation_manager.hpp"
#include "intel_gpu/runtime/layout.hpp"
#include "lstm_seq_inst.h"
namespace cldnn {
namespace cm {

struct LSTMSeqImplementationManager : public ImplementationManager {
OV_GPU_PRIMITIVE_IMPL("cm::lstm_seq")
LSTMSeqImplementationManager(shape_types shape_type, ValidateFunc vf = nullptr)
: ImplementationManager(impl_types::cm, shape_type, vf) {}

in_out_fmts_t query_formats(const program_node& node) const override {
assert(node.is_type<lstm_seq>());
std::vector<format::type> in_fmts(node.get_dependencies().size(), format::any);
std::vector<format::type> out_fmts(node.get_outputs_count(), format::any);

for (size_t idx = 0; idx < node.get_dependencies().size(); idx++) {
in_fmts[idx] = format::bfyx;
}
out_fmts[0] = format::byfx;
for (size_t idx = 1; idx < node.get_outputs_count(); idx++) {
out_fmts[idx] = format::bfyx;
}

return {in_fmts, out_fmts};
}

std::unique_ptr<primitive_impl> create_impl(const program_node& node,
const kernel_impl_params& params) const override;

bool validate_impl(const program_node& node) const override {
assert(node.is_type<lstm_seq>());

auto &engine = node.get_program().get_engine();
auto &config = node.get_program().get_config();
const auto& info = engine.get_device_info();

// XeTLA LSTM optimized for Xe2 architectures
if (!check_cm_jit_support(engine, config) || info.arch != gpu_arch::xe2) {
return false;
}

const auto& lstm_node = node.as<lstm_seq>();
const auto& lstm_prim = lstm_node.get_primitive();
if (lstm_prim->clip > 0.0f) {
return false;
}

if (lstm_prim->activations.size() != 3 ||
lstm_prim->activations[0] != activation_func::logistic ||
lstm_prim->activations[1] != activation_func::hyperbolic_tan ||
lstm_prim->activations[2] != activation_func::hyperbolic_tan) {
return false;
}

auto in_layouts = node.get_input_layouts();
unsigned int expected_inputs = 7;
if (in_layouts.size() != expected_inputs) {
return false;
}
{
auto &seq_lengths = in_layouts[expected_inputs-1];
if (seq_lengths.format != format::bfyx || seq_lengths.data_type != data_types::i32) {
return false;
}
in_layouts.pop_back();
}


auto out_layouts = node.get_output_layouts();
for (auto &layout : in_layouts) {
if (layout.format != format::bfyx || layout.data_type != data_types::f16) {
return false;
}
}
for (auto &layout : out_layouts) {
if (layout.data_type != data_types::f16) {
return false;
}
}

auto num_gates = 4;
auto batch_size = in_layouts[0].get_dim(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check that none of the input/output tensors have dynamic shapes. Otherwise this call will trigger the exception

auto input_size = in_layouts[0].get_dim(2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This get_dim() method is deprecated, please use get_shape() and then work with ov::Shape instead

auto hidden_size = in_layouts[3].get_dim(1) / num_gates;
auto num_dir = in_layouts[3].get_dim(0);
if (hidden_size != 128 || batch_size != 1 || num_dir != 2 || (input_size != 64 && input_size != 256)) {
return false;
}

return true;
}
};

} // namespace cm
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class kernels_cache {
}
}
} else {
source.push_back("#include <cm/cm.h>\n#include <cm/cmtl.h>\n");
for (const auto& kv : batch_headers)
source.push_back(kv.second);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include "impls/ocl/rnn_seq.hpp"
#endif

#if OV_GPU_WITH_CM
#include "impls/cm/lstm_seq.hpp"
#endif

#if OV_GPU_WITH_ONEDNN
#include "impls/onednn/lstm_seq_onednn.hpp"
#endif
Expand All @@ -20,6 +24,7 @@ using namespace cldnn;

const std::vector<std::shared_ptr<cldnn::ImplementationManager>>& Registry<lstm_seq>::get_implementations() {
static const std::vector<std::shared_ptr<ImplementationManager>> impls = {
OV_GPU_CREATE_INSTANCE_CM(cm::LSTMSeqImplementationManager, shape_types::static_shape)
OV_GPU_CREATE_INSTANCE_ONEDNN(onednn::LSTMSeqImplementationManager, shape_types::static_shape)
OV_GPU_CREATE_INSTANCE_OCL(ocl::RNNSeqImplementationManager, shape_types::static_shape)
};
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_gpu/src/graph/lstm_seq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ layout lstm_seq_inst::calc_output_layout(lstm_seq_node const& node, kernel_impl_
const auto& lstm_hidden_size = input_pshape_hidden[2];

auto first_out_fmt = cldnn::format::bfyx;
if (node.get_preferred_impl_type() == impl_types::onednn && node.get_preferred_output_fmt() != format::any) {
if (node.get_preferred_output_fmt() != format::any) {
first_out_fmt = node.get_preferred_output_fmt();
}

Expand All @@ -41,7 +41,7 @@ std::vector<layout> lstm_seq_inst::calc_output_layouts(lstm_seq_node const& node
auto first_out_fmt = cldnn::format::bfyx;
auto second_out_fmt = input_layout.format;
auto third_out_fmt = input_layout.format;
if (node.get_preferred_impl_type() == impl_types::onednn && node.get_preferred_output_fmt() != format::any) {
if (node.get_preferred_output_fmt() != format::any) {
first_out_fmt = node.get_preferred_output_fmt();
second_out_fmt = node.get_preferred_output_fmt(1);
third_out_fmt = node.get_preferred_output_fmt(2);
Expand Down
27 changes: 24 additions & 3 deletions src/plugins/intel_gpu/src/kernel_selector/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ file(GLOB_RECURSE KERNELS
)

file(GLOB_RECURSE CM_KERNELS
"${CMAKE_CURRENT_SOURCE_DIR}/cm_kernels/*"
"${CMAKE_CURRENT_SOURCE_DIR}/cm_kernels/*.h"
"${CMAKE_CURRENT_SOURCE_DIR}/cm_kernels/*.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/cm_kernels/*.cpp"
)

# Path which points to root directory where code generated elements are created
Expand Down Expand Up @@ -61,12 +63,31 @@ add_custom_command(OUTPUT "${CODEGEN_INCDIR}/${PRIM_DB}"
COMMENT "Updating file if the file changed (${PRIM_DB}) ..."
)

set(XETLA_HEADER "cm_xetla.h")
if(WIN32)
add_custom_command(OUTPUT "${CODEGEN_CACHE_DIR}/cm_kernels"
COMMAND "${CMAKE_COMMAND}" -E copy_directory "${CMAKE_CURRENT_SOURCE_DIR}/cm_kernels/" "${CODEGEN_CACHE_DIR}/cm_kernels"
COMMAND "${CMAKE_CXX_COMPILER}" "${CODEGEN_CACHE_DIR}/cm_kernels/include/batch_headers/${XETLA_HEADER}" -I ${XETLA_INCLUDE_DIR} -D _WIN32 -EP > "${CODEGEN_CACHE_DIR}/${XETLA_HEADER}"
COMMAND "${CMAKE_COMMAND}" -E rename "${CODEGEN_CACHE_DIR}/${XETLA_HEADER}" "${CODEGEN_CACHE_DIR}/cm_kernels/include/batch_headers/${XETLA_HEADER}"
DEPENDS "${CM_KERNELS}" "${CODEGEN_INCDIR}/${PRIM_DB}"
COMMENT "Copying CM sources and preprocessing XeTLA headers ..."
)
else()
add_custom_command(OUTPUT "${CODEGEN_CACHE_DIR}/cm_kernels"
COMMAND "${CMAKE_COMMAND}" -E copy_directory "${CMAKE_CURRENT_SOURCE_DIR}/cm_kernels/" "${CODEGEN_CACHE_DIR}/cm_kernels"
COMMAND "${CMAKE_CXX_COMPILER}" "${CODEGEN_CACHE_DIR}/cm_kernels/include/batch_headers/${XETLA_HEADER}" -I ${XETLA_INCLUDE_DIR} -E -P > "${CODEGEN_CACHE_DIR}/${XETLA_HEADER}"
COMMAND "${CMAKE_COMMAND}" -E rename "${CODEGEN_CACHE_DIR}/${XETLA_HEADER}" "${CODEGEN_CACHE_DIR}/cm_kernels/include/batch_headers/${XETLA_HEADER}"
DEPENDS "${CM_KERNELS}" "${CODEGEN_INCDIR}/${PRIM_DB}"
COMMENT "Copying CM sources and preprocessing XeTLA headers ..."
)
endif()

add_custom_command(OUTPUT "${CODEGEN_CACHE_DIR}/${CM_PRIM_DB}"
COMMAND "${Python3_EXECUTABLE}" "${CODEGEN_SCRIPT}" -out_path "${CODEGEN_CACHE_DIR}"
-out_file_name_prim_db "${CM_PRIM_DB}"
-out_file_name_batch_headers "${CM_PRIM_DB_BATCH_HEADERS}"
-kernels "${CMAKE_CURRENT_SOURCE_DIR}/cm_kernels" -cm
DEPENDS ${CM_KERNELS} "${CODEGEN_SCRIPT}" "${CODEGEN_INCDIR}/${PRIM_DB}"
-kernels "${CODEGEN_CACHE_DIR}/cm_kernels" -cm
DEPENDS ${CM_KERNELS} "${CODEGEN_SCRIPT}" "${CODEGEN_CACHE_DIR}/cm_kernels" "${CODEGEN_INCDIR}/${PRIM_DB}"
COMMENT "Generating ${CODEGEN_CACHE_DIR}/${CM_PRIM_DB} ..."
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// This file is to be preprocessed by the cpp compiler before adding to batch headers db for cm

#define XETLA_CODE_BASE __CM__
#define XETLA_NO_CM_INCLUDE
#include <xetla.hpp>
Loading
Loading