Skip to content

Commit

Permalink
[Kernel] Flashinfer correctness fix for v0.1.3 (vllm-project#7319)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU authored Aug 12, 2024
1 parent 86ab567 commit ec2affa
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
5 changes: 0 additions & 5 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ steps:
- vllm/
- tests/basic_correctness
commands:
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true
- pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
Expand Down Expand Up @@ -157,7 +155,6 @@ steps:
- vllm/
- tests/models
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- pytest -v -s models -m \"not vlm\"

- label: Vision Language Models Test # 42min
Expand Down Expand Up @@ -212,7 +209,6 @@ steps:
- vllm/attention
- tests/kernels
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4

Expand Down Expand Up @@ -331,7 +327,6 @@ steps:
# NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s -x lora/test_mixtral.py

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir

RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE ####################


Expand Down
37 changes: 19 additions & 18 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
is_profile_run: bool = False

def __post_init__(self):
# Refer to
Expand All @@ -127,7 +128,6 @@ def __post_init__(self):
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
self.is_profile_run = is_block_tables_empty(self.block_tables)

def begin_forward(self):
if self.num_prefill_tokens > 0:
Expand All @@ -141,23 +141,20 @@ def begin_forward(self):
assert self.paged_kv_last_page_len is not None
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# The profile run does not read kv cache.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
if self.is_profile_run:
self.paged_kv_indptr = torch.zeros(batch_size + 1,
device=self.device)
else:
# We will use flash attention for profiling to
# determine the number of blocks. Therefore,
# we don't need to prepare the input for flashinfer for profile run.
if not self.is_profile_run:
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
Expand Down Expand Up @@ -249,6 +246,8 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
# paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = []

self.is_profile_run: bool = False

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
Expand Down Expand Up @@ -305,6 +304,7 @@ def _add_seq_group(
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if is_profile_run:
self.is_profile_run = is_profile_run
return

block_table = block_tables[seq_id]
Expand Down Expand Up @@ -435,7 +435,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
query_start_loc=query_start_loc,
device=device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph)
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)


class FlashInferImpl(AttentionImpl):
Expand Down

0 comments on commit ec2affa

Please sign in to comment.