Skip to content

Commit

Permalink
feat: enable ragged fa3 by default on hopper 12.4+ (#3442)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Feb 9, 2025
1 parent d872727 commit 36f6fc5
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(
):
super().__init__()

self.is_multimodal = model_runner.model_config.is_multimodal

# Parse constants
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
Expand Down Expand Up @@ -130,12 +132,8 @@ def __init__(
for _ in range(self.num_wrappers)
]

# Create wrappers
# NOTE: we do not use ragged attention when there are multiple wrappers
self.prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
if self.num_wrappers == 1
else None
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
)

# Two wrappers: one for sliding window attention and one for full attention.
Expand Down Expand Up @@ -217,13 +215,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
else:
prefix_lens = forward_batch.extend_prefix_lens

# Some heuristics to check whether to use ragged forward
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
use_ragged = True
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
else:
if self.is_multimodal:
use_ragged = False
extend_no_prefix = False
else:
use_ragged = True
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)

self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
Expand Down Expand Up @@ -640,7 +637,6 @@ def call_begin_forward(
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1

wrapper.end_forward()
wrapper.begin_forward(
kv_indptr,
kv_indices,
Expand All @@ -651,6 +647,7 @@ def call_begin_forward(
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
)


Expand Down Expand Up @@ -860,7 +857,6 @@ def call_begin_forward(

# extend part
if use_ragged:
wrapper_ragged.end_forward()
wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
Expand All @@ -871,7 +867,6 @@ def call_begin_forward(
)

# cached part
wrapper_paged.end_forward()
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
Expand All @@ -883,6 +878,7 @@ def call_begin_forward(
1,
q_data_type=self.q_data_type,
custom_mask=custom_mask,
non_blocking=True,
)


Expand Down Expand Up @@ -1125,6 +1121,7 @@ def fast_decode_plan(
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
**kwargs,
) -> None:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
batch_size = len(last_page_len)
Expand Down

0 comments on commit 36f6fc5

Please sign in to comment.