Skip to content

Commit

Permalink
unittest: add MLA test cases where kv_len is evenly divided by page_s…
Browse files Browse the repository at this point in the history
…ize. (#861)

@yzh119 PLZ take a look for this test. It will fail when `kv_len` is not
multiple of `page_size`. I checked the kernel but had no clue. If I was
not wrong, some bound check in mla kernel was wrong?

---------

Co-authored-by: Zihao Ye <expye@outlook.com>
  • Loading branch information
foreverlms and yzh119 authored Feb 17, 2025
1 parent 672c211 commit 7cd000b
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions tests/test_deepseek_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,24 @@ def test_batch_prefill_with_ragged_kv_cache(
torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3)


def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads):
bs_page_num, page_size, ckv_dim = ckv.shape
page_num = bs_page_num // batch_size
_, _, kpe_dim = kpe.shape
ckv = ckv.view(batch_size, page_num * page_size, ckv_dim)
kpe = kpe.view(batch_size, page_num * page_size, kpe_dim)
ckv = ckv[:, :kv_len, :]
kpe = kpe[:, :kv_len, :]
k = (
torch.cat([ckv, kpe], dim=-1)
.view(-1, 1, ckv_dim + kpe_dim)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.repeat_interleave(num_heads, dim=1)

return k, v


@pytest.mark.parametrize("batch_size", [1, 17, 37])
@pytest.mark.parametrize("kv_len", [17, 33, 96, 97, 114, 514, 1024])
@pytest.mark.parametrize("qo_len", [1, 17, 37, 77])
Expand All @@ -171,8 +189,6 @@ def test_batch_mla_page_attention(
if causal and qo_len > kv_len:
pytest.skip("qo_len > kv_len not supported for causal attention")
torch.manual_seed(42)
if kv_len % page_size != 0:
pytest.skip("kv_len not divisible by page_size")
head_dim_ckv = 512
head_dim_kpe = 64
q_nope = torch.randn(
Expand All @@ -181,15 +197,16 @@ def test_batch_mla_page_attention(
q_pe = torch.randn(
batch_size * qo_len, num_heads, head_dim_kpe, dtype=torch.half, device="cuda"
)
pages_num = math.ceil(kv_len / page_size)
ckv = torch.randn(
batch_size * kv_len // page_size,
batch_size * pages_num,
page_size,
head_dim_ckv,
dtype=torch.half,
device="cuda",
)
kpe = torch.randn(
batch_size * kv_len // page_size,
batch_size * pages_num,
page_size,
head_dim_kpe,
dtype=torch.half,
Expand All @@ -201,8 +218,8 @@ def test_batch_mla_page_attention(
workspace_buffer, backend=backend
)
q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len // page_size
kv_indices = torch.arange(0, batch_size * kv_len // page_size).to(0).int()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * pages_num
kv_indices = torch.arange(0, batch_size * pages_num).to(0).int()
kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
wrapper.plan(
q_indptr,
Expand All @@ -220,12 +237,7 @@ def test_batch_mla_page_attention(
)
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)

k = (
torch.cat([ckv, kpe], dim=-1)
.view(-1, 1, head_dim_ckv + head_dim_kpe)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.repeat_interleave(num_heads, dim=1)
k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads)

q = torch.cat([q_nope, q_pe], dim=-1)
o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)
Expand Down

0 comments on commit 7cd000b

Please sign in to comment.