Skip to content

Commit

Permalink
[GPU] Add input reorder for onednn gemm
Browse files Browse the repository at this point in the history
+ Add reorder to change input format from blocked to plain for dGPU to select onednn kernel for gemm.
  • Loading branch information
jade-cho committed Feb 20, 2025
1 parent d3cdfe8 commit 3f3fe88
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/plugins/intel_gpu/src/graph/graph_optimizer/reorder_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,40 @@ void reorder_inputs::run(program& p, reorder_factory& rf) {
}
};

#ifdef ENABLE_ONEDNN_FOR_GPU
const auto reorder_input_gemm = [&p, &rf](typed_program_node<gemm>& gemm_node) {
if (gemm_node.get_preferred_impl_type() != impl_types::onednn || gemm_node.is_dynamic()
|| gemm_node.get_preferred_input_fmts().size() < 2) {
return;
}

for (size_t idx = 0; idx < 2; ++idx) {
auto fmt = gemm_node.get_preferred_input_fmts()[idx];
if (fmt != format::type::any && !format::is_simple_data_format(fmt)) {
return;
}
}

for (size_t idx = 0; idx < 2; idx++) {
auto dep = gemm_node.get_dependency_with_port(idx);
const auto& input = dep.first;
auto input_layout = input->get_output_layout();

if (input_layout.is_dynamic())
continue;

if (!input->is_constant() && !format::is_simple_data_format(input_layout.format)) {
auto new_layout = input_layout;
new_layout.format = format::get_default_format(input_layout.get_rank());
auto new_input = rf.get_reorder(input->id(), dep.second, input_layout, new_layout);
if (new_input.first) {
p.add_intermediate(new_input.first, gemm_node, idx, !new_input.second);
}
}
}
};
#endif // ENABLE_ONEDNN_FOR_GPU

for (auto& prim : p.get_processing_order()) {
program_helpers::do_for_types<detection_output, deconvolution, convolution, fully_connected, pooling>(
*prim,
Expand All @@ -933,6 +967,12 @@ void reorder_inputs::run(program& p, reorder_factory& rf) {
reorder_convolution,
reorder_input_fully_connected,
reorder_input_pooling);

#ifdef ENABLE_ONEDNN_FOR_GPU
program_helpers::do_for_types<gemm>(
*prim,
reorder_input_gemm);
#endif // ENABLE_ONEDNN_FOR_GPU
}

for (auto n : p.get_processing_order()) {
Expand Down

0 comments on commit 3f3fe88

Please sign in to comment.