Skip to content

Commit

Permalink
[GPU] Paged attention cache rotation (openvinotoolkit#28232)
Browse files Browse the repository at this point in the history
### Details:
- This PR adds cache rotation support for PagedAttention and related
tests
- Should be merged after
openvinotoolkit#27088

### Tickets:
 - *ticket-id*

---------

Co-authored-by: Vasily Shamporov <vasily.shamporov@intel.com>
Co-authored-by: Pavel Durandin <pavel.durandin@intel.com>
Co-authored-by: cecilia peng <cecilia.peng@intel.com>
  • Loading branch information
4 people authored Jan 14, 2025
1 parent d5b2966 commit c8c1438
Show file tree
Hide file tree
Showing 13 changed files with 625 additions and 50 deletions.
3 changes: 2 additions & 1 deletion src/core/src/op/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ void PagedAttentionExtension::validate_and_infer_types() {
get_input_partial_shape(15).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(15).is_dynamic() || get_input_element_type(15) == element::f32,
get_input_element_type(15).is_dynamic() || get_input_element_type(15) == element::f32 ||
get_input_element_type(15) == element::f16,
"Element type of `rotation_trig_lut` input should be f32, but it is ",
get_input_element_type(15),
".");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct paged_attention : public primitive_base<paged_attention> {
paged_attention(const primitive_id& id,
const std::vector<input_info>& inputs)
: primitive_base(id, inputs) {
OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 15),
OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 16),
"[GPU] Unexpected inputs number for PagedAttention primitive: ",
inputs.size());
}
Expand Down
113 changes: 105 additions & 8 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "sdpa/sdpa_kernel_base.h"
#include "sdpa/sdpa_kernel_selector.h"
#include "sdpa/pa_kv_cache_rotate_kernel_ref.h"
#include "sdpa/pa_kv_cache_update_kernel_ref.h"
#include "sdpa/pa_sdpa_kernel_opt.h"

Expand All @@ -28,6 +29,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
using pa_sdpa_kernel_selector_t = kernel_selector::pa_sdpa_kernel_selector;
using pa_sdpa_kernel_params_t = kernel_selector::pa_sdpa_params;

using kv_cache_rotate_kernel_selector_t = kernel_selector::kv_cache_rotate_kernel_selector;
using kv_cache_rotate_kernel_params_t = kernel_selector::kv_cache_rotate_params;

using kv_cache_update_kernel_selector_t = kernel_selector::kv_cache_update_kernel_selector;
using kv_cache_update_kernel_params_t = kernel_selector::kv_cache_update_params;

Expand All @@ -50,6 +54,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
KV_CACHE_UPDATE,
SDPA,
PA_SDPA,
KV_CACHE_ROTATE,
};

bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const override {
Expand All @@ -64,6 +69,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
void load(BinaryInputBuffer& ib) override {
parent::load(ib);
ib >> make_data(&has_scores_output, sizeof(bool));
ib >> make_data(&has_rotated_blocks, sizeof(bool));
if (is_dynamic()) {
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
auto kv_cache_update_kernel_impl = kv_cache_update_kernel_selector.GetImplementation(_kernels_data[Stage::KV_CACHE_UPDATE].kernelName);
Expand All @@ -76,12 +82,19 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance();
auto pa_sdpa_kernel_impl = pa_sdpa_kernel_selector.GetImplementation(_kernels_data[Stage::PA_SDPA].kernelName);
pa_sdpa_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[Stage::PA_SDPA]);

if (has_rotated_blocks) {
auto& kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance();
auto kv_cache_rotate_kernel_impl = kv_cache_rotate_kernel_selector.GetImplementation(_kernels_data[Stage::KV_CACHE_ROTATE].kernelName);
kv_cache_rotate_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[Stage::KV_CACHE_ROTATE]);
}
}
}

void save(BinaryOutputBuffer& ob) const override {
parent::save(ob);
ob << make_data(&has_scores_output, sizeof(bool));
ob << make_data(&has_rotated_blocks, sizeof(bool));
}

std::vector<layout> get_internal_buffer_layouts_impl() const override {
Expand Down Expand Up @@ -142,10 +155,16 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
const auto desc = instance.get_node().as<paged_attention>().get_primitive();

kernel_arguments_data args;
if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA)
if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA || stage == Stage::KV_CACHE_ROTATE)
args.shape_info = instance.shape_info_memory_ptr();

if (stage == Stage::KV_CACHE_UPDATE) {
if (stage == Stage::KV_CACHE_ROTATE) {
args.inputs = { instance.rotated_block_indices_memory_ptr(),
instance.rotation_deltas_memory_ptr(),
instance.rotation_trig_lut_memory_ptr() };

args.outputs = { instance.key_cache_memory_ptr() };
} else if (stage == Stage::KV_CACHE_UPDATE) {
args.inputs = { instance.key_memory_ptr(),
instance.value_memory_ptr(),
instance.past_lens_memory_ptr(),
Expand Down Expand Up @@ -195,6 +214,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
if (desc->has_alibi) {
args.inputs.push_back(instance.alibi_memory_ptr());
}

if (desc->has_rotated_blocks) {
args.inputs.push_back(instance.rotated_block_indices_memory_ptr());
args.inputs.push_back(instance.rotation_deltas_memory_ptr());
args.inputs.push_back(instance.rotation_trig_lut_memory_ptr());
}
} else if (kernel_idx == 2 || kernel_idx == 3) {
// Finalization kernel or mixed stage finalization kernel
args.inputs = { instance.past_lens_memory_ptr() };
Expand Down Expand Up @@ -322,19 +347,25 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
}

event::ptr execute_impl(const std::vector<event::ptr>& events, paged_attention_inst& instance) override {
std::vector<event::ptr> res_events;
const auto stage = get_paged_attention_stage(*instance.get_impl_params());
const auto is_mixed_mode = stage == PagedAttentionStage::MIXED;

execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE, is_mixed_mode);
std::vector<event::ptr> res_events;
std::vector<event::ptr> dep_events = events;
if (has_rotated_blocks) {
execute_stage(dep_events, instance, res_events, Stage::KV_CACHE_ROTATE, is_mixed_mode);
dep_events = res_events;
}

execute_stage(dep_events, instance, res_events, Stage::KV_CACHE_UPDATE, is_mixed_mode);

if (stage == PagedAttentionStage::PREFILL) {
std::vector<event::ptr> dep_events(res_events.begin(), res_events.end());
dep_events = res_events;
execute_stage(dep_events, instance, res_events, Stage::SDPA, is_mixed_mode);
}

if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED || has_scores_output) {
std::vector<event::ptr> dep_events(res_events.begin(), res_events.end());
dep_events = res_events;
execute_stage(dep_events, instance, res_events, Stage::PA_SDPA, is_mixed_mode);
}

Expand Down Expand Up @@ -446,6 +477,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
config.has_const_scale_val = false;
}

config.has_rotated_blocks = desc->has_rotated_blocks;

if (desc->heads_num != desc->kv_heads_num) {
config.broadcast_axis = 1;
config.group_size = desc->heads_num / desc->kv_heads_num;
Expand All @@ -461,6 +494,42 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
return config;
}

static kv_cache_rotate_kernel_params_t get_kv_cache_rotate_kernel_params(const kernel_impl_params& impl_param,
const kernel_selector::MultiDataTensor& input_tensors,
bool is_dynamic = false) {
auto params = get_default_params<kv_cache_rotate_kernel_params_t>(impl_param, is_dynamic);

const auto& key_cache_tensor = input_tensors[3];
const auto& rotated_block_indices_tensor = input_tensors[13];
const auto& rotation_deltas_tensor = input_tensors[14];
const auto& rotation_trig_lut_tensor = input_tensors[15];

const auto inputs_number = 3;
const auto outputs_number = 1;
params.inputs.resize(inputs_number);
params.outputs.resize(outputs_number);
params.inputs[0] = rotated_block_indices_tensor;
params.inputs[1] = rotation_deltas_tensor;
params.inputs[2] = rotation_trig_lut_tensor;
params.outputs[0] = key_cache_tensor;

params.conf = get_sdpa_configuration(impl_param, is_dynamic);

const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
std::map<size_t, size_t> in_tensor_to_offset_map = {
{0, in_offsets_map.at(13)},
{1, in_offsets_map.at(14)},
{2, in_offsets_map.at(15)},
};
std::map<size_t, size_t> out_tensor_to_offset_map = {
{0, in_offsets_map.at(3)},
};

params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);

return params;
}

static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param,
const PagedAttentionStage& stage,
const kernel_selector::MultiDataTensor& input_tensors,
Expand Down Expand Up @@ -618,6 +687,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
if (has_alibi)
inputs_number++;

const auto has_rotation = impl_param.input_layouts.size() == 16;
if (has_rotation)
inputs_number += 3;

auto input_idx = 0;
params.inputs.resize(inputs_number);
params.inputs[input_idx++] = query_tensor;
Expand All @@ -636,6 +709,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
if (has_alibi)
params.inputs[input_idx++] = alibi_tensor;

if (has_rotation) {
params.inputs[input_idx++] = input_tensors[13];
params.inputs[input_idx++] = input_tensors[14];
params.inputs[input_idx++] = input_tensors[15];
}

if (has_scores_output) {
params.outputs.resize(2);
params.outputs[1] = convert_data_tensor(impl_param.get_output_layout(1));
Expand Down Expand Up @@ -673,6 +752,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
if (has_alibi)
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)});

if (has_rotation) {
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(13)});
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(14)});
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(15)});
}

if (has_scores_output)
out_tensor_to_offset_map.insert({1, out_offsets_map.at(1)});

Expand All @@ -688,6 +773,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
for (const auto& input_layout : impl_param.input_layouts)
input_tensors.emplace_back(convert_data_tensor(input_layout));

if (has_rotated_blocks) {
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic());
(_kernels_data[Stage::KV_CACHE_ROTATE].update_dispatch_data_func)(kv_cache_rotate_kernel_params, _kernels_data[Stage::KV_CACHE_ROTATE]);
}

auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);

Expand All @@ -710,6 +800,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
for (const auto& input_layout : impl_param.input_layouts)
input_tensors.emplace_back(convert_data_tensor(input_layout));

const auto& desc = impl_param.typed_desc<paged_attention>();
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
Expand All @@ -722,16 +813,22 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance();
kernels_data.push_back(pa_sdpa_kernel_selector.get_best_kernel(pa_sdpa_kernel_params));

auto impl = cldnn::make_unique<paged_attention_impl>(kernels_data);
if (desc->has_rotated_blocks) {
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic());
auto& kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance();
kernels_data.push_back(kv_cache_rotate_kernel_selector.get_best_kernel(kv_cache_rotate_kernel_params));
}

const auto& desc = impl_param.typed_desc<paged_attention>();
auto impl = cldnn::make_unique<paged_attention_impl>(kernels_data);
impl->has_scores_output = desc->has_scores_output();
impl->has_rotated_blocks = desc->has_rotated_blocks;

return impl;
}

private:
bool has_scores_output = false;
bool has_rotated_blocks = false;
};

namespace detail {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "include/batch_headers/common.cl"

#define SUBGROUPS_PER_WG KV_HEADS_NUM

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
__attribute__((reqd_work_group_size(SUBGROUP_SIZE, KV_HEADS_NUM, 1)))
KERNEL(pa_kv_cache_rotate)(
OPTIONAL_SHAPE_INFO_ARG
__global const INPUT0_TYPE* rotated_block_indices,
__global const INPUT1_TYPE* rotation_deltas,
__global const INPUT2_TYPE* rotation_trig_lut,
__global OUTPUT_TYPE* key_cache
) {
// Input shapes:
// rotated_block_indices: [num_blocks_to_rotate]
// rotation_deltas: [num_blocks_to_rotate, PAGED_ATTENTION_BLOCK_SIZE] || [num_blocks_to_rotate, 1]
// rotation_trig_lut: [max_num_batched_tokens / PAGED_ATTENTION_BLOCK_SIZE, HEAD_SIZE] || [max_num_batched_tokens, HEAD_SIZE]
// key_cache: [num_blocks, HEADS_NUM, HEAD_SIZE, PAGED_ATTENTION_BLOCK_SIZE]

// Output shapes:
// key_cache (updated): [num_blocks, HEADS_NUM, HEAD_SIZE, PAGED_ATTENTION_BLOCK_SIZE]

const uint head_idx = get_global_id(1);
const uint block_idx = get_global_id(2);
const uint sglid = get_sub_group_local_id();
const uint sgid = get_sub_group_id();

__local INPUT2_TYPE rotation_coefficients[HEAD_SIZE][PAGED_ATTENTION_BLOCK_SIZE];

const bool per_token_rotation = INPUT1_FEATURE_NUM == PAGED_ATTENTION_BLOCK_SIZE;

if (per_token_rotation) {
// Need to load HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE coefficients in total, each subgroup loads SUBGROUP_SIZE values
for (uint i = sgid; i < HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE / SUBGROUP_SIZE; i += SUBGROUPS_PER_WG) {
const uint token_idx = (i / (HEAD_SIZE / SUBGROUP_SIZE));
const uint rotation_trig_lut_start_offset = rotation_deltas[block_idx * INPUT1_FEATURE_NUM + token_idx] * HEAD_SIZE;
const uint inner_offset = (i % (HEAD_SIZE / SUBGROUP_SIZE)) * SUBGROUP_SIZE;
const uint rotation_trig_lut_offset = rotation_trig_lut_start_offset + inner_offset;

INPUT2_TYPE coefficient = rotation_trig_lut[rotation_trig_lut_offset + sglid];

rotation_coefficients[inner_offset + sglid][token_idx] = coefficient;
}
} else {
// Need to load HEAD_SIZE coefficients in total, each subgroup loads SUBGROUP_SIZE values
for (uint i = sgid; i < HEAD_SIZE / SUBGROUP_SIZE; i += SUBGROUPS_PER_WG) {
const uint token_idx = 0;
const uint rotation_trig_lut_start_offset = rotation_deltas[block_idx * INPUT1_FEATURE_NUM + token_idx] * HEAD_SIZE;
const uint inner_offset = i * SUBGROUP_SIZE;
const uint rotation_trig_lut_offset = rotation_trig_lut_start_offset + inner_offset;

INPUT2_TYPE coefficient = rotation_trig_lut[rotation_trig_lut_offset + sglid];

rotation_coefficients[inner_offset + sglid][token_idx] = coefficient;
}
}

barrier(CLK_LOCAL_MEM_FENCE);

const uint token_coefficient_idx = per_token_rotation ? sglid : 0;
const uint block_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + sglid;
for (uint i = 0; i < HEAD_SIZE / 2; i++) {
const uint cache_offset = block_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
OUTPUT_TYPE cache_value_first = key_cache[cache_offset];
OUTPUT_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE];

INPUT2_TYPE rotation_value_cos = rotation_coefficients[i][token_coefficient_idx];
INPUT2_TYPE rotation_value_sin = rotation_coefficients[i + (HEAD_SIZE / 2)][token_coefficient_idx];

OUTPUT_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin;
OUTPUT_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos;

key_cache[cache_offset] = new_cache_value_first;
key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] = new_cache_value_second;
}
}

#undef SUBGROUPS_PER_WG
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ KERNEL(pa_sdpa_opt)(
#if HAS_ALIBI
const __global ALIBI_INPUT_TYPE* alibi_slopes,
#endif

#if HAS_ROTATED_BLOCKS
const __global INPUT8_TYPE* rotated_block_indices,
const __global INPUT9_TYPE* rotation_deltas,
const __global INPUT10_TYPE* rotated_block_indices,
const __global INPUT7_TYPE* rotated_block_indices,
const __global INPUT8_TYPE* rotation_deltas,
const __global INPUT9_TYPE* rotation_trig_lut,
#endif
__global OUTPUT_TYPE* output,
#if PAGED_ATTENTION_SCORES_OUTPUT
Expand Down Expand Up @@ -156,10 +157,6 @@ KERNEL(pa_sdpa_opt)(
}
#endif

#ifdef HAS_ROTATED_BLOCKS
// TODO (vshampor): add cache block rotation at this spot
#endif

const uint blocks_num_per_partition = min(total_blocks_num - partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION, (uint)PAGED_ATTENTION_BLOCKS_PER_PARTITION);

uint blocks_num = blocks_num_per_partition / SUBGROUPS_PER_WG;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,7 @@ KERNEL(sdpa_opt)(
const uint partition_seq_len = min((uint)SOURCE_SEQ_LEN - start_partition_idx, (uint)SEQ_LEN_PARTITION_SIZE);
#endif

MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZERO;
#if IS_CAUSAL
if (seq_len <= target_seq_idx) { // keep tril i.e. m >= n
#endif
Expand Down Expand Up @@ -1037,11 +1038,7 @@ KERNEL(sdpa_opt)(
#endif

int seq_len_calc_size = min((int)(SOURCE_SEQ_LEN) - (int)seq_len, (int)SUBGROUP_SIZE);
#if IS_CAUSAL
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZERO;
#else // !IS_CAUSAL
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc;

#if !IS_CAUSAL
qk_acc = FUNC_CALL(load_attn_mask)(OPTIONAL_SHAPE_INFO_TENSOR
b0_idx,
b1_idx,
Expand Down
Loading

0 comments on commit c8c1438

Please sign in to comment.