Skip to content

Commit

Permalink
[GPU] Update PagedAttention creation logic (#28802)
Browse files Browse the repository at this point in the history
### Details:
- Currently GenAI library sets fully dynamic shapes for Key/Value cache
buffers, which causes GPU Plugin to fail during compile_model() call.
Therefore, update PagedAttention creation logic to use head_size and
heads_num parameters from the rt_info if available
- GenAI related PR:
openvinotoolkit/openvino.genai#1666
  • Loading branch information
sshlyapn authored Feb 4, 2025
1 parent e50d722 commit b603d00
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
auto inputs = p.GetInputInfo(op);
auto prim = cldnn::paged_attention(layer_type_name_ID(op), inputs);

auto key_cache_ps = op->get_input_partial_shape(3);
const auto& rt_info = op->get_rt_info();
const auto k_head_size_id = "k_head_size";
const auto num_k_heads_id = "num_k_heads";
const auto has_rt_params = rt_info.find(k_head_size_id) != rt_info.end() &&
rt_info.find(num_k_heads_id) != rt_info.end();

auto query_ps = op->get_input_partial_shape(0);
auto head_size = key_cache_ps[2].get_length();
auto kv_heads_num = key_cache_ps[1].get_length();
auto key_cache_ps = op->get_input_partial_shape(3);
auto head_size = has_rt_params ? rt_info.at(k_head_size_id).as<int64_t>() : key_cache_ps[2].get_length();
auto kv_heads_num = has_rt_params ? rt_info.at(num_k_heads_id).as<int64_t>() : key_cache_ps[1].get_length();

// WA: in some cases, the query input may have a bounded dimension
// Use input shape of the input node in such cases
Expand Down

0 comments on commit b603d00

Please sign in to comment.