From b129e9a91b79a22e8a48484d69cdebf613a90571 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Jan 2025 00:22:38 +0000 Subject: [PATCH 1/9] add the first version. All tests pass except for test_paged_attention_extreme_one_tokens_per_sequence. --- ...t_multi_queries_paged_attention_kernel.py} | 0 test/test_ragged_paged_attention_kernel.py | 406 ++++++++ test/tpu/run_tests.sh | 3 +- .../ragged_paged_attention_kernel.py | 976 ++++++++++++++++++ 4 files changed, 1384 insertions(+), 1 deletion(-) rename test/{test_tpu_paged_attention_kernel.py => test_multi_queries_paged_attention_kernel.py} (100%) create mode 100644 test/test_ragged_paged_attention_kernel.py create mode 100644 torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_multi_queries_paged_attention_kernel.py similarity index 100% rename from test/test_tpu_paged_attention_kernel.py rename to test/test_multi_queries_paged_attention_kernel.py diff --git a/test/test_ragged_paged_attention_kernel.py b/test/test_ragged_paged_attention_kernel.py new file mode 100644 index 000000000000..73ce5ed1b395 --- /dev/null +++ b/test/test_ragged_paged_attention_kernel.py @@ -0,0 +1,406 @@ +from typing import List, Optional, Tuple + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_group_metadata +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +# https://github.com/vllm-project/flash-attention/blob/98a4f8df6f5f50413e03f102dc319690300d4aaf/tests/test_vllm_flash_attn.py#L22 +def _ref_ragged_paged_attention( + queries: jax.Array, # [num_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] + v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] + kv_lens: jax.Array, # i32[num_tokens] + page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] + cu_q_lens: jax.Array, # i32[num_tokens + 1] + num_seqs: int, +): + num_kv_heads, _, page_size, head_dim = k_pages.shape + num_q_heads = queries.shape[1] + assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." + num_query_per_kv = num_q_heads // num_kv_heads + start_idx = 0 + outputs: List[jax.Array] = [] + for i in range(num_seqs): + cur_q_len = cu_q_lens[i+1] - cu_q_lens[i] + q = queries[start_idx:start_idx+cur_q_len] # [cur_q_len, num_q_heads, head_dim] + + cur_kv_len = kv_lens[i] + num_pages = (cur_kv_len + page_size - 1) // page_size + page_indices_to_use = page_indices[i, :num_pages] + k = k_pages[:, page_indices_to_use, :, :] + k = jnp.permute_dims(k, (1, 2, 0, 3)) + k = jnp.reshape(k, (-1, num_kv_heads, head_dim)) + k = k[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim] + v = v_pages[:, page_indices_to_use, :, :] + v = jnp.permute_dims(v, (1, 2, 0, 3)) + v = jnp.reshape(v, (-1, num_kv_heads, head_dim)) + v = v[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim] + + if num_query_per_kv != 1: + k = jnp.repeat(k, num_query_per_kv, axis=1) + v = jnp.repeat(v, num_query_per_kv, axis=1) + + attn = jnp.einsum("qhd,khd->hqk", q, k) + attn = attn.astype('float32') + q_span = (cur_kv_len - cur_q_len) + jax.lax.broadcasted_iota( + jnp.int32, (cur_q_len, cur_kv_len), 0 + ) + kv_span = jax.lax.broadcasted_iota(jnp.int32, (cur_q_len, cur_kv_len), 1) + mask = jnp.where(q_span < kv_span, float("-inf"), 0.) + with jax.numpy_rank_promotion("allow"): + attn = attn + mask + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) + out = jnp.einsum("hqk,khd->qhd", attn, v) # [cur_q_len, num_q_heads, head_dim] + + outputs.append(out) + start_idx += cur_q_len + + return jnp.concatenate(outputs, axis=0) + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class RaggedPagedAttentionKernelTest(jtu.JaxTestCase): + + def _verify_ragged_paged_attention_debug( + self, + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ): + num_seqs = len(seq_lens) + query_lens = [seq_len[0] for seq_len in seq_lens] + num_q_tokens = sum(query_lens) + kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) + num_q_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." + + prng_key = jax.random.key(0) + k1, k2, k3, k4 = jax.random.split(prng_key, 4) + queries = jax.random.normal(k1, + (num_q_tokens, num_q_heads, head_dim), + dtype=dtype) + k_pages = jax.random.normal(k2, + (num_kv_heads, num_pages, page_size, head_dim), + dtype=dtype) + v_pages = jax.random.normal(k3, + (num_kv_heads, num_pages, page_size, head_dim), + dtype=dtype) + # Create a kv_lens: i32[num_tokens] + kv_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + kv_lens_with_paddings[i] = kv_lens[i] + kv_lens_np = jnp.array(kv_lens_with_paddings) + # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] + max_kv_len = max([seq_len[1] for seq_len in seq_lens]) + max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size + # The reason why we need to pad max_num_pages_per_seq is that + # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 + max_num_pages_per_seq = self._get_closest_power_of_two(max_num_pages_per_seq) + print(f"xw32 max_kv_len: {max_kv_len}, {max_num_pages_per_seq=}") + # The assert below mimics the reality that each page get a unique index. + # But for testing, the assert could be omitted. + assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" + page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) + # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] + q_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + q_lens_with_paddings[i] = query_lens[i] + cu_q_lens = jnp.cumsum(jnp.array([0]+q_lens_with_paddings)) + + actual_output = ragged_paged_attention( + queries, + k_pages, + v_pages, + kv_lens_np, + page_indices, + cu_q_lens, + num_seqs, + ) + actual_output = jax.block_until_ready(actual_output) + print("ragged paged attention finished.") + + expected_output = _ref_ragged_paged_attention( + queries, + k_pages, + v_pages, + kv_lens_np, + page_indices, + cu_q_lens, + num_seqs, + ) + + self.assertEqual(actual_output.shape, expected_output.shape) + self.assertEqual(actual_output.dtype, expected_output.dtype) + + print(f'xw32 {expected_output[:192]=}') + print(f'xw32 {actual_output[:192]=}') + + print(f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}') + print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}') + if dtype == jnp.float32: + atol = 2e-2 + rtol = 1e-2 + elif dtype == jnp.bfloat16: + atol = 6e-1 + rtol = 1e-1 + else: + self.fail(f'Unsupported dtype: {dtype}') + self.assertTrue(jnp.allclose(actual_output[:128], expected_output[:128], atol=atol, rtol=rtol)) + self.assertTrue(jnp.allclose(actual_output[128:192], expected_output[128:192], atol=atol, rtol=rtol)) + self.assertTrue(jnp.allclose(actual_output[192:256], expected_output[192:256], atol=atol, rtol=rtol)) + self.assertTrue(jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol)) + + def _verify_ragged_paged_attention( + self, + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ): + num_seqs = len(seq_lens) + query_lens = [seq_len[0] for seq_len in seq_lens] + num_q_tokens = sum(query_lens) + kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) + num_q_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." + + prng_key = jax.random.key(0) + k1, k2, k3, k4 = jax.random.split(prng_key, 4) + queries = jax.random.normal(k1, + (num_q_tokens, num_q_heads, head_dim), + dtype=dtype) + k_pages = jax.random.normal(k2, + (num_kv_heads, num_pages, page_size, head_dim), + dtype=dtype) + v_pages = jax.random.normal(k3, + (num_kv_heads, num_pages, page_size, head_dim), + dtype=dtype) + # Create a kv_lens: i32[num_tokens] + kv_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + kv_lens_with_paddings[i] = kv_lens[i] + kv_lens_np = jnp.array(kv_lens_with_paddings) + # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] + max_kv_len = max([seq_len[1] for seq_len in seq_lens]) + max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size + # The reason why we need to pad max_num_pages_per_seq is that + # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 + max_num_pages_per_seq = self._get_closest_power_of_two(max_num_pages_per_seq) + print(f"xw32 max_kv_len: {max_kv_len}, {max_num_pages_per_seq=}") + # The assert below mimics the reality that each page get a unique index. + # But for testing, the assert could be omitted. + assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" + page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) + # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] + q_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + q_lens_with_paddings[i] = query_lens[i] + cu_q_lens = jnp.cumsum(jnp.array([0]+q_lens_with_paddings)) + + actual_output = ragged_paged_attention( + queries, + k_pages, + v_pages, + kv_lens_np, + page_indices, + cu_q_lens, + num_seqs, + ) + actual_output = jax.block_until_ready(actual_output) + print("ragged paged attention finished.") + + expected_output = _ref_ragged_paged_attention( + queries, + k_pages, + v_pages, + kv_lens_np, + page_indices, + cu_q_lens, + num_seqs, + ) + + self.assertEqual(actual_output.shape, expected_output.shape) + self.assertEqual(actual_output.dtype, expected_output.dtype) + + print(f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}') + print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}') + if dtype == jnp.float32: + atol = 2e-2 + rtol = 1e-2 + elif dtype == jnp.bfloat16: + atol = 6e-1 + rtol = 1e-1 + else: + self.fail(f'Unsupported dtype: {dtype}') + self.assertTrue(jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol)) + + def _get_closest_power_of_two(self, x): + if x <= 0: + raise ValueError(f"x must be positive. Got {x}") + return 2 ** int(np.ceil(np.log2(x))) + + def test_paged_attention_min_two_kv_block_per_sequence( + self, + ): + # assuming q_blk_size=128, page_size=16, num_kv_pages_per_compute_block=16 + # One of the constraints of the kernel is that q.shape[0]%q_blk_size==0 as in _calculate_num_tiles. + # If we cannot get the assumption, we can pad the matrix q in the kernel. + seq_lens = [(192, 328), (128, 180), (64, 255)] # [(q_len, kv_len),...] + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention_debug( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_paged_attention_basic( + self, + ): + # assuming q_blk_size=128 + seq_lens = [(192, 1328), (128, 180), (64, 463)] # [(q_len, kv_len),...] + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + + def test_paged_attention_basic_with_one_token_per_sequence( + self, + ): + # assuming q_blk_size=128 + seq_lens = [(1, 127), (120, 1328), (1, 64), (1, 64), (1, 64), (1, 64), (256, 256), (131, 463)] # [(q_len, kv_len),...] + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_paged_attention_extreme_all_tokens_belong_to_one_sequence( + self, + ): + # assuming q_blk_size=128 + seq_lens = [(512, 1328)] # [(q_len, kv_len),...] + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_paged_attention_extreme_one_tokens_per_sequence( + self, + ): + # assuming q_blk_size=128 + seq_lens = [] # [(q_len, kv_len),...] + num_seqs = 512 + for i in range(num_seqs): + seq_lens.append((1, i)) + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_make_sequence_metadata( + self, + ): + cu_q_lens = jnp.array([0, 192, 448, 512] + [512]*(512-4)) + num_q_tokens = 512 + num_queries_per_compute_block = 128 + start_group = jnp.array([0]) + num_seqs = 3 + metadata, num_logical_q_tiles = make_group_metadata( + cu_q_lens=cu_q_lens, + m=num_q_tokens, + tm=num_queries_per_compute_block, + start_group=start_group, + num_seqs=num_seqs + ) + seq_ids, physical_q_tile_ids = metadata + # print(f"xw32 metadata.physical_q_tile_ids: {metadata.physical_q_tile_ids}") + # print(f"xw32 metadata.seq_ids: {metadata.seq_ids}") + self.assertEqual(num_logical_q_tiles, 6) + self.assertTrue(jnp.array_equal(seq_ids, [0, 0, 1, 1, 1, 2])) + self.assertTrue(jnp.array_equal(physical_q_tile_ids, [0, 1, 1, 2, 3, 3])) + # print('xw32======================') + # q_lens = jnp.array([192, 256, 64] + [0]*(512-3)) + # metadata = ragged_paged_attention_kernel.original_make_group_metadata( + # group_sizes=q_lens, + # m=num_q_tokens, + # tm=num_queries_per_compute_block, + # start_group=start_group, + # num_nonzero_groups=num_seqs, + # visit_empty_groups=False, + # ) + # print(f"xw32 {metadata=}") + # self.assertEqual(metadata.num_logical_q_tiles, 6) + # print(f"xw32 metadata.seq_ids: {metadata.seq_ids}") + # print(f"xw32 metadata.physical_q_tile_ids: {metadata.physical_q_tile_ids}") + # print(f"xw32 metadata.seq_ids[:metadata.num_logical_q_tiles]: {metadata.seq_ids[:metadata.num_logical_q_tiles]}") + # self.assertTrue(jnp.array_equal(metadata.seq_ids[:metadata.num_logical_q_tiles], [0, 0, 1, 1, 1, 2])) + # self.assertTrue(jnp.array_equal(metadata.physical_q_tile_ids[:metadata.num_logical_q_tiles], [0, 1, 1, 2, 3, 3])) + + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index e429a782f6b2..d11c8eecc2b0 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -37,7 +37,8 @@ run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py" python3 "$TEST_CDIR/test_pallas.py" -v python3 "$TEST_CDIR/test_pallas_spmd.py" XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py" -python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py" +python3 "$TEST_CDIR/test_multi_queries_paged_attention_kernel.py" +python3 "$TEST_CDIR/test_ragged_paged_attention_kernel.py" python3 "$TEST_CDIR/test_input_output_aliases.py" python3 "$TEST_CDIR/test_gmm.py" python3 "$TEST_CDIR/eager/test_eager_spmd.py" diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py new file mode 100644 index 000000000000..4b1bd5100997 --- /dev/null +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -0,0 +1,976 @@ +from collections.abc import Sequence +from collections import namedtuple +import functools +from typing import Any, Literal, Optional, cast + +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +import jax.numpy as jnp +import numpy as np + + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + +class MultiPageAsyncCopyDescriptor: + """Descriptor for async copy of multiple K/V pages from HBM.""" + + def __init__( + self, + pages_hbm_ref, # [num_kv_heads, total_num_pages, page_size, head_dim] + scales_pages_hbm_ref, + vmem_buffer, # [pages_per_compute_block, page_size, head_dim] + scales_vmem_buffer, + sem, + page_indices, + page_indices_start_offset, + num_pages_to_load, + kv_head_index, + ): + # Original k_pages has shape [num_kv_heads, total_num_pages, page_size, head_dim] + self._vmem_buffer = vmem_buffer + self._scales_vmem_buffer = scales_vmem_buffer + self._num_pages_to_load = num_pages_to_load + if kv_head_index is not None: + self._pages_hbm_ref = pages_hbm_ref.at[kv_head_index] + if scales_pages_hbm_ref is not None: + self._scales_pages_hbm_ref = scales_pages_hbm_ref.at[kv_head_index] + else: + self._scales_pages_hbm_ref = None + else: + self._pages_hbm_ref = pages_hbm_ref + self._scales_pages_hbm_ref = scales_pages_hbm_ref + self._sem = sem + self._page_indices = page_indices + self._page_indices_start_offset = page_indices_start_offset + self._async_copies = [ + self._make_async_copy(i) for i in range(self._num_pages_to_load) + ] + if (self._scales_pages_hbm_ref is not None and + self._scales_vmem_buffer is not None): + self._async_copies += [ + self._make_scales_async_copy(i) + for i in range(self._num_pages_to_load) + ] + + def _make_async_copy(self, i): + page_index = self._page_indices[self._page_indices_start_offset + i] + return pltpu.make_async_copy(self._pages_hbm_ref.at[page_index], + self._vmem_buffer.at[i], self._sem) + + def _make_scales_async_copy(self, i): + page_index = self._page_indices[self._page_indices_start_offset + i] + return pltpu.make_async_copy( + self._scales_pages_hbm_ref.at[page_index], # pytype: disable=attribute-error + self._scales_vmem_buffer.at[i], # pytype: disable=attribute-error + self._sem, + ) + + def start(self): + """Starts the async copies.""" + for async_copy in self._async_copies: + async_copy.start() + + def _maybe_dequantize(self, x, x_scale, dtype=jnp.bfloat16): + if x_scale is None: + return x.astype(dtype) + return quantization_utils.from_int8(x, x_scale, dtype=dtype) + + def wait_and_get_loaded(self) -> jax.Array: + """Wait async copies and gets the loaded buffer as a jax.Array.""" + # Return value shape is [pages_per_compute_block*page_size, head_dim] + for async_copy in self._async_copies: + async_copy.wait() + head_dim = self._vmem_buffer.shape[-1] + jax_array = self._vmem_buffer[...].astype(jnp.float32) + if self._scales_vmem_buffer is not None: + scales_jax_array = self._scales_vmem_buffer[...].astype(jnp.float32) + else: + scales_jax_array = None + jax_array = self._maybe_dequantize(jax_array, scales_jax_array) + return jax_array.reshape(-1, head_dim) + + +def _calculate_num_tiles(x: int, tx: int) -> int: + tiles, rem = divmod(x, tx) + if rem: + raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") + return tiles + +SequenceMetadata = namedtuple( + "SequenceMetadata", + [ + "num_logical_q_tiles", + "seq_ids", + "physical_q_tile_ids", + ], +) + +GroupMetadata = Any + +# https://github.com/jax-ml/jax/blob/9fb29766a2130e74a85cba30420cf777d185ea5a/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L79 +# TODO(xw32): need to do some renaming to adapt to our case. +# Currently, group maps to sequence. +def make_group_metadata( + *, + cu_q_lens: jnp.ndarray, + m: int, + tm: int, + start_group: jnp.ndarray, + num_seqs: int, +): + """Create the metadata needed for grouped matmul computation. + + Args: + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + m: The number of rows in lhs. + tm: The m-dimension tile size being used. + start_group: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + num_nonzero_groups: Number of groups in group sizes to compute on. Useful in + combination with group_offset. + visit_empty_groups: If True, do not squeeze tiles for empty groups out of + the metadata. This is necessary for tgmm, where we at least need to zero + the output for each group. + + Returns: + tuple of: + group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32 + dtype. group_offsets[i] indicates the row at which group [i] starts in + the lhs matrix and group_offsets[i-1] = m. + group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will + work on. + m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' + will work on. + num_tiles: The number of m-dimension tiles to execute. + """ + num_groups = num_seqs + end_group = start_group + num_seqs - 1 + + # Calculate the offset of each group, starting at zero. This metadata is + # similar to row offsets in a CSR matrix. The following properties hold: + # + # group_offsets.shape = [num_groups + 1] + # group_offsets[0] = 0 + # group_offsets[num_groups] = m + # + # The row at which group 'i' starts is group_offsets[i]. + group_ends = cu_q_lens[1:] + group_offsets = cu_q_lens + + # Assign a group id to each grid index. + # + # If a group starts somewhere other than the start of a tile or ends somewhere + # other than the end of a tile we need to compute that full tile. Calculate + # the number of tiles for each group by rounding their end up to the nearest + # 'tm' and their start down to the nearest 'tm'. + + # (1) Round the group_ends up to the nearest multiple of 'tm'. + # + # NOTE: This does not change group_offsets[num_groups], which is m + # (because we enforce m is divisible by tm). + rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32) + print('xw32 {rounded_group_ends=}') + + # (2) Round the group_starts down to the nearest multiple of 'tm'. + group_starts = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]] + ) + rounded_group_starts = group_starts // tm * tm + + # (3) Calculate the number of rows in each group. + # + # NOTE: Handle zero-sized groups as a special case. If the start for a + # zero-sized group is not divisible by 'tm' its start will be rounded down and + # its end will be rounded up such that its size will become 1 tile here. + rounded_group_sizes = rounded_group_ends - rounded_group_starts + + # (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles. + # + # An m-dimension tile is 'owned' by group 'i' if the first row of the tile + # belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1 + # initial partial tiles if it's first row does not occur in the first row of a + # tile. The '0-th' group never has a partial tile because it always starts at + # the 0-th row. + # + # If no group has a partial tile, the total number of tiles is equal to + # 'm // tm'. If every group has a partial except the 0-th group, the total + # number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that + # + # tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1 + # + # Where tiles_m = m // tm. + # + # NOTE: All group sizes are divisible by 'tm' because of the rounding in steps + # (1) and (2) so this division is exact. + group_tiles = rounded_group_sizes // tm + + # Create the group ids for each grid index based on the tile counts for each + # group. + # + # NOTE: This repeat(...) will pad group_ids with the final group id if + # group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized + # such that we only execute the necessary number of tiles. + tiles_m = _calculate_num_tiles(m, tm) + group_ids = jnp.repeat( + jnp.arange(num_groups, dtype=jnp.int32), + group_tiles[:num_groups], # would it introduce dynamic shape to impact JIT? + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Assign an m-dimension tile id to each grid index. + # + # NOTE: Output tiles can only be re-visited consecutively. The following + # procedure guarantees that m-dimension tile indices respect this. + + # (1) Calculate how many times each m-dimension tile will be visited. + # + # Each tile is guaranteed to be visited once by the group that owns the tile. + # The remaining possible visits occur when a group starts inside of a tile at + # a position other than the first row. We can calculate which m-dimension tile + # each group starts in by floor-dividing its offset with `tm` and then count + # tile visits with a histogram. + # + # To avoid double counting tile visits from the group that owns the tile, + # filter these out by assigning their tile id to `tile_m` (one beyond the max) + # such that they're ignored by the subsequent histogram. Also filter out any + # group which is empty. + # + # TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear. + partial_tile_mask = ((group_offsets[:-1] % tm) == 0) + + partial_tile_ids = jnp.where( + partial_tile_mask, tiles_m, group_offsets[:-1] // tm + ) + + tile_visits = ( + jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + + 1 + ) + + # Create the m-dimension tile ids for each grid index based on the visit + # counts for each tile. + m_tile_ids = jnp.repeat( + jnp.arange(tiles_m, dtype=jnp.int32), + tile_visits.astype(jnp.int32), + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Account for sharding. + # + # Find the start of the groups owned by our shard and shift the group_ids and + # m_tile_ids s.t. the metadata for our tiles are at the front of the arrays. + # + # TODO(tgale): Move this offset into the kernel to avoid these rolls. + first_tile_in_shard = (group_ids < start_group).sum() + group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0) + m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0) + + # Calculate the number of tiles we need to compute for our shard. + # + # Remove tile visits that belong to a group not in our shard. + iota = jnp.arange(num_groups, dtype=jnp.int32) + active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group) + group_tiles = jnp.where(active_group_mask, group_tiles[:num_groups], 0) + num_tiles = group_tiles.sum() + return (group_ids, m_tile_ids), num_tiles # num_logical_q_tiles, seq_ids, physical_q_tile_ids + +def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, + num_seqs, num_kv_pages_per_block): + num_q_heads, num_tokens, head_dim = q.shape + num_kv_heads, _, _, head_dim_k = k_pages.shape + _, pages_per_sequence = page_indices.shape + if k_pages.shape != v_pages.shape: + raise ValueError( + f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" + f" {v_pages.shape}" # pytype: disable=attribute-error + ) + if head_dim_k != head_dim: + raise ValueError("head_dim of Q must be the same as that of K/V. Got" + f" {head_dim} and {head_dim_k}.") + if kv_lens.shape[0] != num_tokens: + raise ValueError("kv_lens.shape[0] must be thet same as num_tokens. Got" + f" {kv_lens.shape[0]} and {num_tokens}") + if page_indices.shape[0] != num_tokens: + raise ValueError("page_indices.shape[0] must be thet same as num_tokens. Got" + f" {page_indices.shape[0]} and {num_tokens}") + if cu_q_lens.shape[0] != num_tokens + 1: + raise ValueError("cu_q_lens.shape[0] must be thet same as num_tokens + 1. Got" + f" {cu_q_lens.shape[0]} and {num_tokens + 1}") + if num_seqs > num_tokens: + raise ValueError(f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}") + # int16: will pack. need to explicit cast to int32. int64 is not supported in Pallas. for smem 1d case. + # 2d smem: int16 will be packed with an empty. So we didn't save any memory. + # scalar: use i32 (1, N). int16 for (1, N) will be padding. Need to use (2, N). + if kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or cu_q_lens.dtype != jnp.int32: + raise ValueError( + f"The dtype of `lengths` must be int32. Got {kv_lens.dtype=}, " + f"{page_indices.dtype=}, {cu_q_lens.dtype=}") + if num_kv_pages_per_block > pages_per_sequence: + raise ValueError( + f"{num_kv_pages_per_block=} should be smaller or equal to {pages_per_sequence=}" + ) + if pages_per_sequence % num_kv_pages_per_block != 0: + raise ValueError( + "pages_per_sequence must be divisible by num_kv_pages_per_block. Got" + f" {pages_per_sequence=} and {num_kv_pages_per_block=}.") + if num_q_heads % num_kv_heads != 0: + raise ValueError( + "Number of Q heads must be divisible by number of KV heads. Got" + f" {num_q_heads} and {num_kv_heads}.") + +# https://github.com/jax-ml/jax/blob/e3b3b913f7bcec3767e1442ace08999413f8703d/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L269C1-L283C64 +def _get_store_mask( + *, + grid_id: jnp.ndarray, + group_offsets: jnp.ndarray, + group_ids: jnp.ndarray, + m_tile_ids: jnp.ndarray, + tm: int, + tn: int, +) -> jnp.ndarray: + """Mask for rows that belong to the current group in the current tile.""" + group_id = group_ids[grid_id] + group_start = group_offsets[group_id] + group_end = group_offsets[group_id + 1] + m_id = m_tile_ids[grid_id] * tm + iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id + return jnp.logical_and(iota >= group_start, iota < group_end) + +def _flash_attention( + q_head_idx_per_kv, # scalar, ranges from 0 to num_query_heads_per_kv_head + group_metadata_ref, # (seq_ids, physical_q_tile_ids) + effective_kv_lens_ref, # [num_tokens] + effective_cu_q_lens_ref, # [num_tokens + 1] + # kernel inputs + q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + k, # [kv_blk_size, head_dim] + v, # [kv_blk_size, head_dim] + # outputs + o_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + l_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + m_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + # scratch space + # TODO: double check if the scratch ref shape is correct. + l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + m_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + acc_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + *, + num_tokens: int, + num_seqs: int, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + mask_value: float, + page_size: int, + head_dim: int, + num_q_heads_per_kv_head: int, +): + assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim) + kv_blk_size = page_size * num_kv_pages_per_compute_block + assert k.shape == (kv_blk_size, head_dim) + assert v.shape == (kv_blk_size, head_dim) + + kv_head_idx, logical_q_blk_idx, kv_blk_idx = ( + pl.program_id(0), + pl.program_id(1), + pl.program_id(2), + ) + seq_ids, physical_q_tile_ids = group_metadata_ref + + # If the q-dim physical tile is changed (meaning it is a new physical q-dim tile that has not visited before), initialize the acc_scratch_ref, m_scratch_ref, and l_scratch_ref to run the flash attention v2 algorithm. + prev_logical_q_blk_idx = jnp.where(logical_q_blk_idx > 0, logical_q_blk_idx - 1, 0) + is_first_processed_logical_q_blk = logical_q_blk_idx == 0 + physical_q_blk_changed = (physical_q_tile_ids[logical_q_blk_idx] != physical_q_tile_ids[prev_logical_q_blk_idx]) + first_time_seeing_physical_q_blk = jnp.logical_or(is_first_processed_logical_q_blk, physical_q_blk_changed) + is_first_kv_blk = (kv_blk_idx == 0) + should_init_scratch_ref = jnp.logical_and(is_first_kv_blk, + first_time_seeing_physical_q_blk) + @pl.when(should_init_scratch_ref) + def init_scratch_ref(): # pylint: disable=unused-variable + pl.debug_print("xw32 should_init_scratch_ref begins: kv_head_idx={}, logical_q_blk_idx={}, kv_blk_idx={}", kv_head_idx, logical_q_blk_idx, kv_blk_idx) + l_scratch_ref[q_head_idx_per_kv] = jnp.zeros( + l_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) + m_scratch_ref[q_head_idx_per_kv] = jnp.full( + m_scratch_ref[q_head_idx_per_kv].shape, -jnp.inf, jnp.float32) + acc_scratch_ref[q_head_idx_per_kv] = jnp.zeros( + acc_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) + + m_prev = m_scratch_ref[q_head_idx_per_kv] # [num_queries_per_compute_block, MIN_BLOCK_SIZE] + l_prev = l_scratch_ref[q_head_idx_per_kv] # [num_queries_per_compute_block, MIN_BLOCK_SIZE] + + # Load the whole q_block that belongs to the current physical q_blk and compute the attention. When we write, we only write the part that belongs to the current sequence. + # I cannot just load only the part of q_block that belongs to the current sequence, because it results in dynamic shapes and then fails the JIT compilation. + # Note, q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + q = q_ref[q_head_idx_per_kv, :, :].astype(jnp.float32) # [block_q, head_dim] + assert q.shape == (num_queries_per_compute_block, head_dim) + s = jnp.einsum( + 'qd,td->qt', q, k, + preferred_element_type=jnp.float32) # [block_q, block_k] + assert s.shape == (num_queries_per_compute_block, kv_blk_size) + + # Modify the mask accordingly: first form the mask. Then move the mask down to the right place. + cur_seq_idx = seq_ids[logical_q_blk_idx] + cur_seq_start = effective_cu_q_lens_ref[cur_seq_idx] + cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx+1] + physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] + seq_start_in_cur_physical_q_blk = cur_seq_start >= physical_q_blk_idx*num_queries_per_compute_block + # seq_start_idx_in_cur_physical_q_blk = jnp.where(seq_start_in_cur_physical_q_blk, + # cur_seq_start - physical_q_blk_idx*num_queries_per_compute_block, + # 0) + # q_index = physical_q_blk_idx*num_queries_per_compute_block - seq_start_idx_in_cur_physical_q_blk # start_q_idx_for_cur_seq_in_cur_physical_q_blk. TODO: let's rename num_queries_per_compute_block to q_blk_size later. + q_index = physical_q_blk_idx*num_queries_per_compute_block-cur_seq_start + pl.debug_print("xw32 line423, kv_head_idx={}, logical_q_blk_idx={}, kv_blk_idx={}, q_index={}", kv_head_idx, logical_q_blk_idx, kv_blk_idx, q_index) + kv_index = kv_blk_idx * kv_blk_size + effective_kv_len = effective_kv_lens_ref[cur_seq_idx] + effective_q_len = cur_seq_end - cur_seq_start + row_ids = ( + effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( + jnp.int32, + (num_queries_per_compute_block, kv_blk_size), 0) + col_ids = kv_index + jax.lax.broadcasted_iota( + jnp.int32, + (num_queries_per_compute_block, kv_blk_size), 1) + causal_mask = jnp.where(row_ids < col_ids, mask_value, 0.) + assert causal_mask.shape == (num_queries_per_compute_block, + kv_blk_size) + s = s + causal_mask # [block_q, block_k] + assert s.shape == (num_queries_per_compute_block, + kv_blk_size) + + m_curr = jnp.max(s, axis=1)[:, None] # Row max, shape [block_q, 1]. + # why the second dim of m_prev, m_curr, or m_next is 128? + m_next = jnp.maximum(m_prev, m_curr) # Shape [block_q, 128]. + + block_k_repeats, rem = divmod(kv_blk_size, MIN_BLOCK_SIZE) + if rem: + raise NotImplementedError( + f"{kv_blk_size=} should be a multiple of {MIN_BLOCK_SIZE}" + ) + p = jnp.exp( + s - pltpu.repeat(m_next, block_k_repeats, 1)) # Shape [block_q, block_k] + + alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128] + + l_corr = alpha * l_prev # Shape [block_q, 128] + + l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128] + + head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE) + l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1) + if rem: + if head_dim_repeats == 0: + l_broadcast = lambda l: l[:, :head_dim] + else: + raise NotImplementedError( + f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger") + + # Need to store these l_next and m_next which will relay to the output. + # But only update the part that belongs to the current sequence we are working on. + lm_mask = _get_store_mask(grid_id=logical_q_blk_idx, + group_offsets=effective_cu_q_lens_ref, + group_ids=seq_ids, + m_tile_ids=physical_q_tile_ids, + tm=num_queries_per_compute_block, + tn=MIN_BLOCK_SIZE, + ) + # Should I use jax.lax.select or jnp.where? What's the difference? eg: jnp.where(lm_mask, l_next, 0), jnp.where(lm_mask, m_next, 0) + # Can `lm_mask[...]` be `lm_mask`? + l_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], l_next, l_scratch_ref[q_head_idx_per_kv]) + m_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], m_next, m_scratch_ref[q_head_idx_per_kv]) + + l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, + 1.0 / l_next) # [block_q, 128] + temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast(l_corr * l_next_inv_safe) + acc_mask = _get_store_mask(grid_id=logical_q_blk_idx, + group_offsets=effective_cu_q_lens_ref, + group_ids=seq_ids, + m_tile_ids=physical_q_tile_ids, + tm=num_queries_per_compute_block, + tn=head_dim, + ) + print(f"xw32 line486 {acc_mask.shape=}, {temp.shape=}, {acc_scratch_ref[q_head_idx_per_kv]=}") + acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) + # Note Matmul operandlhs must have a shape divisible by (16, 1) + o_curr = jax.lax.dot( + p.astype(v.dtype), v, + preferred_element_type=jnp.float32) # [block_q, 128] + temp = (acc_scratch_ref[q_head_idx_per_kv] + o_curr * l_broadcast(l_next_inv_safe)) + acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) + + # Store the result from VMEM to HBM only when it is the last kv_block and the next q-dim logical tile belongs to a different q-dim physical tile. + is_last_kv_blk_idx = (kv_blk_idx == (pl.cdiv(effective_kv_len, kv_blk_size) - 1)) + num_logical_q_blks = pl.num_programs(1) # grid=(num_kv_heads, num_logical_q_tiles, num_kv_blks) + next_logical_q_blk_idx = jnp.where(logical_q_blk_idx == num_logical_q_blks - 1, + logical_q_blk_idx, + logical_q_blk_idx+1) + is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks-1) + physical_q_blk_will_change = (physical_q_tile_ids[logical_q_blk_idx] != physical_q_tile_ids[next_logical_q_blk_idx]) + last_time_seeing_cur_physical_q_blk = jnp.logical_or(is_last_logical_q_blk, physical_q_blk_will_change) + should_store_to_hbm = jnp.logical_and(is_last_kv_blk_idx, last_time_seeing_cur_physical_q_blk) + @pl.when(should_store_to_hbm) + def store_to_hbm(): # pylint: disable=unused-variable + pl.debug_print("xw32 store_to_hbm begins: kv_head_idx={}, logical_q_blk_idx={}, kv_blk_idx={}", kv_head_idx, logical_q_blk_idx, kv_blk_idx) + o_ref[q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( + o_ref.dtype) + l_ref[q_head_idx_per_kv] = l_scratch_ref[q_head_idx_per_kv].astype( + l_ref.dtype) + m_ref[q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( + m_ref.dtype) + + +def paged_flash_attention_kernel( + # prefetch refs, in smem + group_metadata_ref, # (seq_ids, physical_q_tile_ids) + effective_kv_lens_ref, # [num_tokens] + # 1d vector, results from page_indices.reshape(-1) where originally page_indices.shape=[num_tokens, pages_per_sequence] + page_indices_1d_ref, + effective_cu_q_lens_ref, # [num_tokens + 1] + buffer_index_ref, + step_ref, + # kernel inputs + # At caller, q.shape= [num_q_heads, num_tokens, head_dim] + q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + k_pages_hbm_ref, # shape=[num_kv_heads, total_num_pages, page_size, head_dim] + k_scales_pages_hbm_ref, + v_pages_hbm_ref, # shape=[num_kv_heads, total_num_pages, page_size, head_dim] + v_scales_pages_hbm_ref, + # same shape as q_ref: [1, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim], output + # outputs + o_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + l_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + m_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + # scratch space + k_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim) + k_scales_vmem_buffer, + v_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim) + v_scales_vmem_buffer, + sem, + l_scratch_ref, + m_scratch_ref, + acc_scratch_ref, + *, + # Where do the following parameter live? SMEM? Not in smem. Not to pass in mosaic. Static value. + pages_per_sequence: int, # Note [bs, pages_per_sequence] = page_indices.shape + num_tokens: int, + num_seqs: int, + num_kv_pages_per_compute_block: int, + mask_value: float, +): + # assert the input shapes + print(f"xw32 line283 paged_flash_attention_kernel begins. q_ref.shape={q_ref.shape}") + kv_head_idx, logical_q_blk_idx, kv_blk_idx = ( + pl.program_id(0), + pl.program_id(1), + pl.program_id(2), + ) + num_logical_q_blks = pl.num_programs(1) + num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim = q_ref.shape + num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape + kv_blk_size = page_size * num_kv_pages_per_compute_block + + seq_ids, physical_q_tile_ids = group_metadata_ref + cur_seq_idx = seq_ids[logical_q_blk_idx] + effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] + should_run = (kv_blk_idx * kv_blk_size < effective_kv_len_cur_seq) + pl.debug_print("xw32 paged_flash_attention_kernel begins kv_head_idx={}, logical_q_blk_idx={}, kv_blk_idx={}, cur_seq_idx={}, effective_kv_len_cur_seq={}", kv_head_idx, logical_q_blk_idx, kv_blk_idx, cur_seq_idx, effective_kv_len_cur_seq) # pl.debug_print can only print JAX type. So cannot print tuple such as q.shape. + + @pl.when(should_run) + def get_kv_and_run_flash_attention(): + # grid = (num_kv_heads, num_logical_q_tiles, num_kv_blks) + def compute_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx): + """Return next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx + + Note, k_pages has shape [num_kv_heads, total_num_pages, page_size, head_dim]. + To get the KV, it needs the kv_head_idx, then we need the sequence_idx + and the kv_blk_idx to get the offset. + """ + + def advance_kv_head_idx(): + next_kv_head_idx = kv_head_idx + 1 + return next_kv_head_idx, 0, 0 + + def advance_logical_q_blk_idx(): + next_logical_q_blk_idx = logical_q_blk_idx + 1 + return lax.cond( + next_logical_q_blk_idx < num_logical_q_blks, + lambda: (kv_head_idx, next_logical_q_blk_idx, 0), + advance_kv_head_idx, + ) + + cur_seq_idx = seq_ids[logical_q_blk_idx] + effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] + return lax.cond( + kv_blk_idx*kv_blk_size < effective_kv_len_cur_seq, + lambda: (kv_head_idx, logical_q_blk_idx, kv_blk_idx), + advance_logical_q_blk_idx, + ) + + def create_kv_async_copy_descriptors(seq_idx, kv_head_idx, kv_blk_idx, + buffer_index): + page_offset = seq_idx * pages_per_sequence + kv_blk_idx * num_kv_pages_per_compute_block + pages_to_load = num_kv_pages_per_compute_block + async_copy_k = MultiPageAsyncCopyDescriptor( + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + k_vmem_buffer.at[buffer_index], + k_scales_vmem_buffer.at[buffer_index] + if k_scales_vmem_buffer is not None else None, + sem, + page_indices_1d_ref, # [batch_size*pages_per_sequence] + page_offset, + pages_to_load, + kv_head_idx, + ) + async_copy_v = MultiPageAsyncCopyDescriptor( + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + v_vmem_buffer.at[buffer_index], + v_scales_vmem_buffer.at[buffer_index] + if v_scales_vmem_buffer is not None else None, + sem, + page_indices_1d_ref, + page_offset, + pages_to_load, + kv_head_idx, + ) + return async_copy_k, async_copy_v + + step = step_ref[0] + buffer_index = buffer_index_ref[0] + + @pl.when(step == 0) + def prefetch_first_block(): # pylint: disable=unused-variable + pl.debug_print("xw32 prefetch_first_block kv_head_idx={}, cur_seq_idx={}, kv_blk_idx={}, buffer_index={}", kv_head_idx, cur_seq_idx, kv_blk_idx, buffer_index) + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index) + async_copy_k.start() + async_copy_v.start() + + # kv_head_idx, logical_q_blk_idx, kv_blk_idx + next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = compute_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx+1) + + @pl.when(next_kv_head_idx < num_kv_heads) + def prefetch_next_block(): # pylint: disable=unused-variable + next_buffer_index = jnp.where(buffer_index == 0, 1, 0) + next_seq_idx = seq_ids[next_logical_q_blk_idx] + pl.debug_print("xw32 prefetch_next_block next_kv_head_idx={}, next_seq_idx={}, next_kv_blk_idx={}, buffer_index={}", next_kv_head_idx, next_seq_idx, next_kv_blk_idx, next_buffer_index) + async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors( + next_seq_idx, next_kv_head_idx, next_kv_blk_idx, next_buffer_index) + async_copy_next_k.start() + async_copy_next_v.start() + buffer_index_ref[0] = next_buffer_index + + # xw32: is the async_copy_k and async_copy_v the same as the ones created in prefetch_first_block? + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index) + k = async_copy_k.wait_and_get_loaded( + ) # [pages_per_compute_block*page_size,head_dim] + v = async_copy_v.wait_and_get_loaded() + assert k.shape == (num_kv_pages_per_compute_block*page_size, head_dim) + assert v.shape == (num_kv_pages_per_compute_block*page_size, head_dim) + + for q_head_idx in range(num_q_heads_per_kv_head): + _flash_attention( + q_head_idx, + group_metadata_ref, + effective_kv_lens_ref, + effective_cu_q_lens_ref, + # kernel inputs + q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + k, + v, + # outputs + o_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + l_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + m_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + # scratch space + l_scratch_ref, + m_scratch_ref, + acc_scratch_ref, + num_tokens=num_tokens, + num_seqs=num_seqs, + num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, + num_queries_per_compute_block=num_queries_per_compute_block, + mask_value=mask_value, + page_size=page_size, + head_dim=head_dim, + num_q_heads_per_kv_head=num_q_heads_per_kv_head, + ) + step_ref[0] = step + 1 + # end of get_kv_and_run_flash_attention + + +MIN_BLOCK_SIZE = 128 + +# TODO(xw32): uncomment this once the kernel output is correct. +@functools.partial( + jax.jit, + static_argnames=[ + "num_kv_pages_per_compute_block", + "num_queries_per_compute_block", + "mask_value", + "num_seqs", + ], +) +def ragged_paged_attention( + q: jax.Array, # [num_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] + v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] + kv_lens: jax.Array, # i32[num_tokens] + page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] + cu_q_lens: jax.Array, # i32[num_tokens + 1] + num_seqs, # i32[] + *, + mask_value: float = DEFAULT_MASK_VALUE, + num_kv_pages_per_compute_block: int = 16, + num_queries_per_compute_block: int = 128, +) -> jax.Array: + """Paged grouped query attention. + + Args: + q: A [num_tokens, num_q_heads, head_dim] jax.Array. + k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. + v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. + kv_lens: A i32[num_tokens] jax.Array the effective kv length of each sequence. + page_indices: A i32[num_tokens, pages_per_sequence] jax.Array. Each entry + should be in the range of [0, total_num_pages), indicating where to locate + the page in `k_pages` or `v_pages`. + cu_q_lens: A i32[num_tokens+1] jax.Array the cumulative sum of the effective + query lengths. + num_seqs: A i32[] jax.Array the number of sequences. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + num_kv_pages_per_compute_block: how many kv pages to be processed in one flash + attention block in the pallas kernel. + num_queries_per_compute_block: how many queries to be processes in one flash + attention block in the pallas kernel. + + Returns: + The output of attention([num_tokens, query_len, num_q_heads, head_dim]). + """ + # TODO: consider remove the k_scales_pages and v_scales_pages during cleaning up. + if isinstance(k_pages, quantization_utils.QuantizedTensor): + k_pages, k_scales_pages = k_pages.weight, k_pages.scales + assert isinstance(k_scales_pages, jax.Array) # For typing. + k_scales_pages = jnp.broadcast_to( + k_scales_pages, (*k_scales_pages.shape[:-1], k_pages.shape[-1])) + else: + k_scales_pages = None + if isinstance(v_pages, quantization_utils.QuantizedTensor): + v_pages, v_scales_pages = v_pages.weight, v_pages.scales + assert isinstance(v_scales_pages, jax.Array) # For typing. + v_scales_pages = jnp.broadcast_to( + v_scales_pages, (*v_scales_pages.shape[:-1], v_pages.shape[-1])) + else: + v_scales_pages = None + + num_tokens, num_q_heads, head_dim = q.shape + # If permute_dims turns out to be expensive, try jnp.swapaxes. The compiler + # may optimize the copies away. + # Or consider unsqueeze a dimension at the 2nd last dimension and squeeze it + # out later. + # jevin: can we not do the permute_dims? + # Why the permute_dims is needed? Before permute, q.shape=[num_tokens, num_q_heads, head_dim]; then when we apply the GridSpec, the 2nd last dimension is num_q_heads which is hard to be a multiple of 8. + q = jnp.permute_dims(q, (1, 0, 2)) # [num_q_heads, num_tokens, head_dim] + num_kv_heads, total_num_pages, page_size, head_dim = k_pages.shape + check_kernel_input(q, k_pages, v_pages,kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_compute_block) + num_q_heads_per_kv_head = num_q_heads // num_kv_heads + + group_metadata, num_logical_q_tiles = make_group_metadata( + cu_q_lens=cu_q_lens, + m=num_tokens, + tm=num_queries_per_compute_block, + start_group=jnp.array([0]), + num_seqs=num_seqs, + ) + seq_ids, physical_q_tile_ids = group_metadata + + pages_per_sequence = page_indices.shape[1] + num_kv_blks = pages_per_sequence // num_kv_pages_per_compute_block + # num_logical_q_tiles has type jnp.ndarray. So we need the .item() below. + grid = (num_kv_heads, num_logical_q_tiles, num_kv_blks) + print(f"xw32 line367 grid={grid}") + + # out_shape + o_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) + # xw32: need to double check that the out_shape of l and m are correct. + l = jax.ShapeDtypeStruct((num_q_heads, num_tokens, MIN_BLOCK_SIZE), + dtype=jnp.float32) + m = jax.ShapeDtypeStruct((num_q_heads, num_tokens, MIN_BLOCK_SIZE), + dtype=jnp.float32) + out_shape = (o_shape, l, m) + + # in-spec. Note currently q.shape=[num_q_heads, num_tokens, head_dim] + # Within the kernel, q.shape should be [num_q_heads_per_kv_head, q_block_size, head_dim] + def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_): + seq_ids, physical_q_tile_ids = group_metadata + del seq_ids + physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] + return (kv_head_idx, physical_q_blk_idx, 0) + q_block_spec = pl.BlockSpec( + (num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim), + qo_index_map, + ) + q_dtype_for_kernel_launch = q.dtype + in_specs = [ + q_block_spec, + # Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, q_scales_pages. + # TODO: consider to remove the k_scales_pages and v_scales_pages during cleaning up. + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + None, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + None, + ] + + # out_spec + # jevin: create a qo spec and reuse it. + o_specs = pl.BlockSpec( # Should be the same as q_block_spec + (num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim), + qo_index_map, + ) + + # lm_index_map is same as qo_index_map + # TODO: think about reusing q_indx_map. + def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_): + seq_ids, physical_q_tile_ids = group_metadata + del seq_ids + physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] + return (kv_head_idx, physical_q_blk_idx, 0) + + out_specs = [ + o_specs, + pl.BlockSpec( + (num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE), + lm_index_map), # l + pl.BlockSpec( + (num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE), + lm_index_map), # m + ] + + # scratch space. Note k_pages.shape=[num_kv_heads, total_num_pages, page_size, head_dim] + l_scratch = pltpu.VMEM( + (num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE), + jnp.float32) + m_scratch = pltpu.VMEM( + (num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE), + jnp.float32) + acc_scratch = pltpu.VMEM( + (num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim), + jnp.float32) + scratch_shapes = [ + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + num_kv_pages_per_compute_block, + page_size, + head_dim, + ), + k_pages.dtype, + ), # k_pages buffer, k_pages.shape=[num_kv_heads, total_num_pages, page_size, head_dim] + None, # k_scales_pages=None + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + num_kv_pages_per_compute_block, + page_size, + head_dim, + ), + v_pages.dtype, + ), # v_pages buffer + None, # v_scales_pages=None + pltpu.SemaphoreType.DMA, + l_scratch, + m_scratch, + acc_scratch, + ] + + kernel = pl.pallas_call( + functools.partial( + paged_flash_attention_kernel, + pages_per_sequence=pages_per_sequence, + num_tokens=num_tokens, + num_seqs=num_seqs, # it they changes, need to recompile. + num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, + mask_value=mask_value, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=6, # TODO(xw32): may need to adjust. + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + scratch_shapes=scratch_shapes, + ), + compiler_params=pltpu.TPUCompilerParams( + # due to compute_block_indices, we loop batch, kv_head, q_blk, kv_blk, the order matters. + dimension_semantics=( + "arbitrary", + "arbitrary", + "arbitrary", + )), + out_shape=out_shape, + ) + # TODO: need to slice the page_indices later to avoid the SMEM OOM. + page_indices_1d = page_indices.reshape(-1) + buffer_index = jnp.zeros((1,), jnp.int32) + step = jnp.zeros((1,), jnp.int32) + + # debug compile begins + # To enable debug, uncomment this section, comment out the `kernel()` below and comment out the jax.jit above. + # compiled_kernel = ( + # jax.jit(kernel) + # .lower( + # # prefetch + # group_metadata, + # kv_lens, + # page_indices_1d, + # cu_q_lens, + # buffer_index, + # step, + # # kernel inputs + # q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + # k_pages, + # k_scales_pages, + # v_pages, + # v_scales_pages, + # ) + # .compile({'xla_tpu_enable_log_recorder': 'true'}) + # ) + # outputs = compiled_kernel( + # # prefetch + # group_metadata, + # kv_lens, + # page_indices_1d, + # cu_q_lens, + # buffer_index, + # step, + # # kernel inputs + # q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + # k_pages, + # k_scales_pages, + # v_pages, + # v_scales_pages, + # ) + # debug compile ends + + outputs = kernel( + # prefetch + group_metadata, + kv_lens, + page_indices_1d, + cu_q_lens, + buffer_index, + step, + # kernel inputs + q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + k_pages, + k_scales_pages, + v_pages, + v_scales_pages, + ) + ret = outputs[0] + # print(f"xw32 line495 ret.shape={ret.shape}, {ret=}") + return jnp.permute_dims(ret, (1, 0, 2)).astype(q.dtype) From 6c3bf7304ab5d6db68f4ecc3679109e0856e6ebf Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 30 Jan 2025 03:42:45 +0000 Subject: [PATCH 2/9] all tests passed. --- .../ragged_paged_attention_kernel.py | 248 ++++++++++-------- 1 file changed, 134 insertions(+), 114 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py index 4b1bd5100997..afbf2be6aa09 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -347,30 +347,30 @@ def _flash_attention( effective_kv_lens_ref, # [num_tokens] effective_cu_q_lens_ref, # [num_tokens + 1] # kernel inputs - q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_block, head_dim] k, # [kv_blk_size, head_dim] v, # [kv_blk_size, head_dim] # outputs - o_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] - l_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] - m_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + o_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + l_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + m_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] # scratch space # TODO: double check if the scratch ref shape is correct. - l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] - m_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] - acc_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + m_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + acc_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] *, num_tokens: int, num_seqs: int, - num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, + num_kv_pages_per_block: int, + num_queries_per_block: int, mask_value: float, page_size: int, head_dim: int, num_q_heads_per_kv_head: int, ): - assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim) - kv_blk_size = page_size * num_kv_pages_per_compute_block + assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_block, head_dim) + kv_blk_size = page_size * num_kv_pages_per_block assert k.shape == (kv_blk_size, head_dim) assert v.shape == (kv_blk_size, head_dim) @@ -399,30 +399,30 @@ def init_scratch_ref(): # pylint: disable=unused-variable acc_scratch_ref[q_head_idx_per_kv] = jnp.zeros( acc_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) - m_prev = m_scratch_ref[q_head_idx_per_kv] # [num_queries_per_compute_block, MIN_BLOCK_SIZE] - l_prev = l_scratch_ref[q_head_idx_per_kv] # [num_queries_per_compute_block, MIN_BLOCK_SIZE] + m_prev = m_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] + l_prev = l_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] # Load the whole q_block that belongs to the current physical q_blk and compute the attention. When we write, we only write the part that belongs to the current sequence. # I cannot just load only the part of q_block that belongs to the current sequence, because it results in dynamic shapes and then fails the JIT compilation. - # Note, q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + # Note, q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_block, head_dim] q = q_ref[q_head_idx_per_kv, :, :].astype(jnp.float32) # [block_q, head_dim] - assert q.shape == (num_queries_per_compute_block, head_dim) + assert q.shape == (num_queries_per_block, head_dim) s = jnp.einsum( 'qd,td->qt', q, k, preferred_element_type=jnp.float32) # [block_q, block_k] - assert s.shape == (num_queries_per_compute_block, kv_blk_size) + assert s.shape == (num_queries_per_block, kv_blk_size) # Modify the mask accordingly: first form the mask. Then move the mask down to the right place. cur_seq_idx = seq_ids[logical_q_blk_idx] cur_seq_start = effective_cu_q_lens_ref[cur_seq_idx] cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx+1] physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] - seq_start_in_cur_physical_q_blk = cur_seq_start >= physical_q_blk_idx*num_queries_per_compute_block + seq_start_in_cur_physical_q_blk = cur_seq_start >= physical_q_blk_idx*num_queries_per_block # seq_start_idx_in_cur_physical_q_blk = jnp.where(seq_start_in_cur_physical_q_blk, - # cur_seq_start - physical_q_blk_idx*num_queries_per_compute_block, + # cur_seq_start - physical_q_blk_idx*num_queries_per_block, # 0) - # q_index = physical_q_blk_idx*num_queries_per_compute_block - seq_start_idx_in_cur_physical_q_blk # start_q_idx_for_cur_seq_in_cur_physical_q_blk. TODO: let's rename num_queries_per_compute_block to q_blk_size later. - q_index = physical_q_blk_idx*num_queries_per_compute_block-cur_seq_start + # q_index = physical_q_blk_idx*num_queries_per_block - seq_start_idx_in_cur_physical_q_blk # start_q_idx_for_cur_seq_in_cur_physical_q_blk. TODO: let's rename num_queries_per_block to q_blk_size later. + q_index = physical_q_blk_idx*num_queries_per_block-cur_seq_start pl.debug_print("xw32 line423, kv_head_idx={}, logical_q_blk_idx={}, kv_blk_idx={}, q_index={}", kv_head_idx, logical_q_blk_idx, kv_blk_idx, q_index) kv_index = kv_blk_idx * kv_blk_size effective_kv_len = effective_kv_lens_ref[cur_seq_idx] @@ -430,15 +430,20 @@ def init_scratch_ref(): # pylint: disable=unused-variable row_ids = ( effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( jnp.int32, - (num_queries_per_compute_block, kv_blk_size), 0) + (num_queries_per_block, kv_blk_size), 0) col_ids = kv_index + jax.lax.broadcasted_iota( jnp.int32, - (num_queries_per_compute_block, kv_blk_size), 1) - causal_mask = jnp.where(row_ids < col_ids, mask_value, 0.) - assert causal_mask.shape == (num_queries_per_compute_block, + (num_queries_per_block, kv_blk_size), 1) + causal_mask = jnp.where(row_ids < col_ids, mask_value, 0.) # TODO: use this mask. + # causal_mask_debug = jnp.where(row_ids < col_ids, -1, 0) # TODO: remove this line. + should_print_mask = jnp.logical_and(kv_head_idx==0, logical_q_blk_idx==2) + # @pl.when(should_print_mask) + # def print_mask(): # pylint: disable=unused-variable + # pl.debug_print("xw32 line438, causal_mask={}", causal_mask) + assert causal_mask.shape == (num_queries_per_block, kv_blk_size) s = s + causal_mask # [block_q, block_k] - assert s.shape == (num_queries_per_compute_block, + assert s.shape == (num_queries_per_block, kv_blk_size) m_curr = jnp.max(s, axis=1)[:, None] # Row max, shape [block_q, 1]. @@ -474,7 +479,7 @@ def init_scratch_ref(): # pylint: disable=unused-variable group_offsets=effective_cu_q_lens_ref, group_ids=seq_ids, m_tile_ids=physical_q_tile_ids, - tm=num_queries_per_compute_block, + tm=num_queries_per_block, tn=MIN_BLOCK_SIZE, ) # Should I use jax.lax.select or jnp.where? What's the difference? eg: jnp.where(lm_mask, l_next, 0), jnp.where(lm_mask, m_next, 0) @@ -482,6 +487,10 @@ def init_scratch_ref(): # pylint: disable=unused-variable l_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], l_next, l_scratch_ref[q_head_idx_per_kv]) m_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], m_next, m_scratch_ref[q_head_idx_per_kv]) + # @pl.when(should_print_mask) + # def _(): # pylint: disable=unused-variable + # print("xw32 line492, l_next.shape={}, ", l_next.shape) + # pl.debug_print("xw32 line492, l_next[6]={}", l_next[6]) l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) # [block_q, 128] temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast(l_corr * l_next_inv_safe) @@ -489,7 +498,7 @@ def init_scratch_ref(): # pylint: disable=unused-variable group_offsets=effective_cu_q_lens_ref, group_ids=seq_ids, m_tile_ids=physical_q_tile_ids, - tm=num_queries_per_compute_block, + tm=num_queries_per_block, tn=head_dim, ) print(f"xw32 line486 {acc_mask.shape=}, {temp.shape=}, {acc_scratch_ref[q_head_idx_per_kv]=}") @@ -499,6 +508,10 @@ def init_scratch_ref(): # pylint: disable=unused-variable p.astype(v.dtype), v, preferred_element_type=jnp.float32) # [block_q, 128] temp = (acc_scratch_ref[q_head_idx_per_kv] + o_curr * l_broadcast(l_next_inv_safe)) + # @pl.when(should_print_mask) + # def _(): # pylint: disable=unused-variable + # print("xw32 line512, temp.shape={}", temp.shape) + # pl.debug_print("xw32 line512, temp={}", temp) acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) # Store the result from VMEM to HBM only when it is the last kv_block and the next q-dim logical tile belongs to a different q-dim physical tile. @@ -533,20 +546,20 @@ def paged_flash_attention_kernel( step_ref, # kernel inputs # At caller, q.shape= [num_q_heads, num_tokens, head_dim] - q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_block, head_dim] k_pages_hbm_ref, # shape=[num_kv_heads, total_num_pages, page_size, head_dim] k_scales_pages_hbm_ref, v_pages_hbm_ref, # shape=[num_kv_heads, total_num_pages, page_size, head_dim] v_scales_pages_hbm_ref, - # same shape as q_ref: [1, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim], output + # same shape as q_ref: [1, num_q_heads_per_kv_head, num_queries_per_block, head_dim], output # outputs - o_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] - l_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] - m_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + o_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + l_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + m_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] # scratch space - k_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim) + k_vmem_buffer, # (2, num_kv_pages_per_block, num_kv_heads, head_dim) k_scales_vmem_buffer, - v_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim) + v_vmem_buffer, # (2, num_kv_pages_per_block, num_kv_heads, head_dim) v_scales_vmem_buffer, sem, l_scratch_ref, @@ -557,7 +570,7 @@ def paged_flash_attention_kernel( pages_per_sequence: int, # Note [bs, pages_per_sequence] = page_indices.shape num_tokens: int, num_seqs: int, - num_kv_pages_per_compute_block: int, + num_kv_pages_per_block: int, mask_value: float, ): # assert the input shapes @@ -568,9 +581,9 @@ def paged_flash_attention_kernel( pl.program_id(2), ) num_logical_q_blks = pl.num_programs(1) - num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim = q_ref.shape + num_q_heads_per_kv_head, num_queries_per_block, head_dim = q_ref.shape num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape - kv_blk_size = page_size * num_kv_pages_per_compute_block + kv_blk_size = page_size * num_kv_pages_per_block seq_ids, physical_q_tile_ids = group_metadata_ref cur_seq_idx = seq_ids[logical_q_blk_idx] @@ -611,8 +624,8 @@ def advance_logical_q_blk_idx(): def create_kv_async_copy_descriptors(seq_idx, kv_head_idx, kv_blk_idx, buffer_index): - page_offset = seq_idx * pages_per_sequence + kv_blk_idx * num_kv_pages_per_compute_block - pages_to_load = num_kv_pages_per_compute_block + page_offset = seq_idx * pages_per_sequence + kv_blk_idx * num_kv_pages_per_block + pages_to_load = num_kv_pages_per_block async_copy_k = MultiPageAsyncCopyDescriptor( k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -670,8 +683,8 @@ def prefetch_next_block(): # pylint: disable=unused-variable k = async_copy_k.wait_and_get_loaded( ) # [pages_per_compute_block*page_size,head_dim] v = async_copy_v.wait_and_get_loaded() - assert k.shape == (num_kv_pages_per_compute_block*page_size, head_dim) - assert v.shape == (num_kv_pages_per_compute_block*page_size, head_dim) + assert k.shape == (num_kv_pages_per_block*page_size, head_dim) + assert v.shape == (num_kv_pages_per_block*page_size, head_dim) for q_head_idx in range(num_q_heads_per_kv_head): _flash_attention( @@ -680,21 +693,21 @@ def prefetch_next_block(): # pylint: disable=unused-variable effective_kv_lens_ref, effective_cu_q_lens_ref, # kernel inputs - q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] + q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_block, head_dim] k, v, # outputs - o_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] - l_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] - m_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + o_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + l_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + m_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] # scratch space l_scratch_ref, m_scratch_ref, acc_scratch_ref, num_tokens=num_tokens, num_seqs=num_seqs, - num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, - num_queries_per_compute_block=num_queries_per_compute_block, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, mask_value=mask_value, page_size=page_size, head_dim=head_dim, @@ -707,15 +720,15 @@ def prefetch_next_block(): # pylint: disable=unused-variable MIN_BLOCK_SIZE = 128 # TODO(xw32): uncomment this once the kernel output is correct. -@functools.partial( - jax.jit, - static_argnames=[ - "num_kv_pages_per_compute_block", - "num_queries_per_compute_block", - "mask_value", - "num_seqs", - ], -) +# @functools.partial( +# jax.jit, +# static_argnames=[ +# "num_kv_pages_per_block", +# "num_queries_per_block", +# "mask_value", +# "num_seqs", +# ], +# ) def ragged_paged_attention( q: jax.Array, # [num_tokens, num_q_heads, head_dim] k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] @@ -726,27 +739,33 @@ def ragged_paged_attention( num_seqs, # i32[] *, mask_value: float = DEFAULT_MASK_VALUE, - num_kv_pages_per_compute_block: int = 16, - num_queries_per_compute_block: int = 128, + num_kv_pages_per_block: int = 16, + num_queries_per_block: int = 128, ) -> jax.Array: - """Paged grouped query attention. + """Paged attention kernel with ragged input. Args: q: A [num_tokens, num_q_heads, head_dim] jax.Array. k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. - kv_lens: A i32[num_tokens] jax.Array the effective kv length of each sequence. + kv_lens: A i32[num_tokens] jax.Array the effective kv length of each + sequence. For example, if we have three sequences, lengths could be + [16, 3, 1024, x, x, x, x, ...] where x is any value for padding. While + lengths’s shape is [num_tokens], only the first num_seqs values are valid. + The rest should be ignored. page_indices: A i32[num_tokens, pages_per_sequence] jax.Array. Each entry should be in the range of [0, total_num_pages), indicating where to locate - the page in `k_pages` or `v_pages`. + the page in `k_pages` or `v_pages`. Similar to kv_lens, only the first + num_seqs values are valid. cu_q_lens: A i32[num_tokens+1] jax.Array the cumulative sum of the effective - query lengths. + query lengths. Similar to kv_lens, only the first num_seqs+1 values are + valid. num_seqs: A i32[] jax.Array the number of sequences. mask_value: The value used for padding in attention. By default it is a very negative floating point number. - num_kv_pages_per_compute_block: how many kv pages to be processed in one flash + num_kv_pages_per_block: how many kv pages to be processed in one flash attention block in the pallas kernel. - num_queries_per_compute_block: how many queries to be processes in one flash + num_queries_per_block: how many queries to be processes in one flash attention block in the pallas kernel. Returns: @@ -777,20 +796,21 @@ def ragged_paged_attention( # Why the permute_dims is needed? Before permute, q.shape=[num_tokens, num_q_heads, head_dim]; then when we apply the GridSpec, the 2nd last dimension is num_q_heads which is hard to be a multiple of 8. q = jnp.permute_dims(q, (1, 0, 2)) # [num_q_heads, num_tokens, head_dim] num_kv_heads, total_num_pages, page_size, head_dim = k_pages.shape - check_kernel_input(q, k_pages, v_pages,kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_compute_block) + check_kernel_input(q, k_pages, v_pages,kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_block) num_q_heads_per_kv_head = num_q_heads // num_kv_heads group_metadata, num_logical_q_tiles = make_group_metadata( cu_q_lens=cu_q_lens, m=num_tokens, - tm=num_queries_per_compute_block, + tm=num_queries_per_block, start_group=jnp.array([0]), num_seqs=num_seqs, ) seq_ids, physical_q_tile_ids = group_metadata + pl.debug_print("xw32 line797 seq_ids={}, physical_q_tile_ids={}, num_logical_q_tiles={}", seq_ids, physical_q_tile_ids, num_logical_q_tiles) pages_per_sequence = page_indices.shape[1] - num_kv_blks = pages_per_sequence // num_kv_pages_per_compute_block + num_kv_blks = pages_per_sequence // num_kv_pages_per_block # num_logical_q_tiles has type jnp.ndarray. So we need the .item() below. grid = (num_kv_heads, num_logical_q_tiles, num_kv_blks) print(f"xw32 line367 grid={grid}") @@ -812,7 +832,7 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] return (kv_head_idx, physical_q_blk_idx, 0) q_block_spec = pl.BlockSpec( - (num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim), + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), qo_index_map, ) q_dtype_for_kernel_launch = q.dtype @@ -829,7 +849,7 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) # out_spec # jevin: create a qo spec and reuse it. o_specs = pl.BlockSpec( # Should be the same as q_block_spec - (num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim), + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), qo_index_map, ) @@ -844,28 +864,28 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) out_specs = [ o_specs, pl.BlockSpec( - (num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE), + (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), lm_index_map), # l pl.BlockSpec( - (num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE), + (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), lm_index_map), # m ] # scratch space. Note k_pages.shape=[num_kv_heads, total_num_pages, page_size, head_dim] l_scratch = pltpu.VMEM( - (num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE), + (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), jnp.float32) m_scratch = pltpu.VMEM( - (num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE), + (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), jnp.float32) acc_scratch = pltpu.VMEM( - (num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim), + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), jnp.float32) scratch_shapes = [ pltpu.VMEM( ( 2, # For double buffering during DMA copies. - num_kv_pages_per_compute_block, + num_kv_pages_per_block, page_size, head_dim, ), @@ -875,7 +895,7 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) pltpu.VMEM( ( 2, # For double buffering during DMA copies. - num_kv_pages_per_compute_block, + num_kv_pages_per_block, page_size, head_dim, ), @@ -894,7 +914,7 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) pages_per_sequence=pages_per_sequence, num_tokens=num_tokens, num_seqs=num_seqs, # it they changes, need to recompile. - num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, + num_kv_pages_per_block=num_kv_pages_per_block, mask_value=mask_value, ), grid_spec=pltpu.PrefetchScalarGridSpec( @@ -920,43 +940,26 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) # debug compile begins # To enable debug, uncomment this section, comment out the `kernel()` below and comment out the jax.jit above. - # compiled_kernel = ( - # jax.jit(kernel) - # .lower( - # # prefetch - # group_metadata, - # kv_lens, - # page_indices_1d, - # cu_q_lens, - # buffer_index, - # step, - # # kernel inputs - # q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. - # k_pages, - # k_scales_pages, - # v_pages, - # v_scales_pages, - # ) - # .compile({'xla_tpu_enable_log_recorder': 'true'}) - # ) - # outputs = compiled_kernel( - # # prefetch - # group_metadata, - # kv_lens, - # page_indices_1d, - # cu_q_lens, - # buffer_index, - # step, - # # kernel inputs - # q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. - # k_pages, - # k_scales_pages, - # v_pages, - # v_scales_pages, - # ) - # debug compile ends - - outputs = kernel( + compiled_kernel = ( + jax.jit(kernel) + .lower( + # prefetch + group_metadata, + kv_lens, + page_indices_1d, + cu_q_lens, + buffer_index, + step, + # kernel inputs + q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + k_pages, + k_scales_pages, + v_pages, + v_scales_pages, + ) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) + outputs = compiled_kernel( # prefetch group_metadata, kv_lens, @@ -971,6 +974,23 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) v_pages, v_scales_pages, ) + # debug compile ends + + # outputs = kernel( + # # prefetch + # group_metadata, + # kv_lens, + # page_indices_1d, + # cu_q_lens, + # buffer_index, + # step, + # # kernel inputs + # q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + # k_pages, + # k_scales_pages, + # v_pages, + # v_scales_pages, + # ) ret = outputs[0] # print(f"xw32 line495 ret.shape={ret.shape}, {ret=}") return jnp.permute_dims(ret, (1, 0, 2)).astype(q.dtype) From dbc3fee2825da53584feebd466018f999585b0c7 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 30 Jan 2025 22:33:01 +0000 Subject: [PATCH 3/9] all tests passed except for one test which oom'ed. --- test/test_ragged_paged_attention_kernel.py | 321 ++++++++++++++++-- .../ragged_paged_attention_kernel.py | 99 +++--- 2 files changed, 355 insertions(+), 65 deletions(-) diff --git a/test/test_ragged_paged_attention_kernel.py b/test/test_ragged_paged_attention_kernel.py index 73ce5ed1b395..720924ce5d05 100644 --- a/test/test_ragged_paged_attention_kernel.py +++ b/test/test_ragged_paged_attention_kernel.py @@ -12,6 +12,15 @@ jax.config.parse_flags_with_absl() +# Make sure the q_len is no longer than the kv_len. For example, +# seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because +# the 3rd sequence has q_len(506) > kv_len(463). + +# Just to use the same very negative value in the ref impl as in the kernel. +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + +ATOL_FP32 = 2e-1 + # https://github.com/vllm-project/flash-attention/blob/98a4f8df6f5f50413e03f102dc319690300d4aaf/tests/test_vllm_flash_attn.py#L22 def _ref_ragged_paged_attention( @@ -36,9 +45,9 @@ def _ref_ragged_paged_attention( cur_kv_len = kv_lens[i] num_pages = (cur_kv_len + page_size - 1) // page_size page_indices_to_use = page_indices[i, :num_pages] - k = k_pages[:, page_indices_to_use, :, :] - k = jnp.permute_dims(k, (1, 2, 0, 3)) - k = jnp.reshape(k, (-1, num_kv_heads, head_dim)) + k = k_pages[:, page_indices_to_use, :, :] # [num_kv_heads, page_indices_to_use, page_size, head_dim] + k = jnp.permute_dims(k, (1, 2, 0, 3)) # [page_indices_to_use, page_size, num_kv_heads, head_dim] + k = jnp.reshape(k, (-1, num_kv_heads, head_dim)) # [kv_len, num_kv_heads, head_dim] k = k[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim] v = v_pages[:, page_indices_to_use, :, :] v = jnp.permute_dims(v, (1, 2, 0, 3)) @@ -55,7 +64,10 @@ def _ref_ragged_paged_attention( jnp.int32, (cur_q_len, cur_kv_len), 0 ) kv_span = jax.lax.broadcasted_iota(jnp.int32, (cur_q_len, cur_kv_len), 1) - mask = jnp.where(q_span < kv_span, float("-inf"), 0.) + # mask = jnp.where(q_span < kv_span, float("-inf"), 0.) + mask = jnp.where(q_span < kv_span, DEFAULT_MASK_VALUE, 0.) + if i == 2: + print(f"xw32 ref impl {mask.shape=}, {mask=}") with jax.numpy_rank_promotion("allow"): attn = attn + mask attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) @@ -78,8 +90,13 @@ def _verify_ragged_paged_attention_debug( page_size, dtype, num_pages, + num_queries_per_block=128, ): num_seqs = len(seq_lens) + for i in range(num_seqs): + cur_q_len = seq_lens[i][0] + cur_kv_len = seq_lens[i][1] + assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}" query_lens = [seq_len[0] for seq_len in seq_lens] num_q_tokens = sum(query_lens) kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) @@ -112,7 +129,7 @@ def _verify_ragged_paged_attention_debug( print(f"xw32 max_kv_len: {max_kv_len}, {max_num_pages_per_seq=}") # The assert below mimics the reality that each page get a unique index. # But for testing, the assert could be omitted. - assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" + # assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] q_lens_with_paddings = [0] * num_q_tokens @@ -120,7 +137,7 @@ def _verify_ragged_paged_attention_debug( q_lens_with_paddings[i] = query_lens[i] cu_q_lens = jnp.cumsum(jnp.array([0]+q_lens_with_paddings)) - actual_output = ragged_paged_attention( + err, actual_output = ragged_paged_attention( queries, k_pages, v_pages, @@ -128,7 +145,9 @@ def _verify_ragged_paged_attention_debug( page_indices, cu_q_lens, num_seqs, + num_queries_per_block=num_queries_per_block, ) + err.throw() actual_output = jax.block_until_ready(actual_output) print("ragged paged attention finished.") @@ -145,22 +164,40 @@ def _verify_ragged_paged_attention_debug( self.assertEqual(actual_output.shape, expected_output.shape) self.assertEqual(actual_output.dtype, expected_output.dtype) - print(f'xw32 {expected_output[:192]=}') - print(f'xw32 {actual_output[:192]=}') + print(f'xw32 {expected_output[:1]=}') + print(f'xw32 {actual_output[:1]=}') print(f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}') print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}') if dtype == jnp.float32: - atol = 2e-2 + atol = ATOL_FP32 rtol = 1e-2 elif dtype == jnp.bfloat16: atol = 6e-1 rtol = 1e-1 else: self.fail(f'Unsupported dtype: {dtype}') - self.assertTrue(jnp.allclose(actual_output[:128], expected_output[:128], atol=atol, rtol=rtol)) - self.assertTrue(jnp.allclose(actual_output[128:192], expected_output[128:192], atol=atol, rtol=rtol)) - self.assertTrue(jnp.allclose(actual_output[192:256], expected_output[192:256], atol=atol, rtol=rtol)) + + print(f'Output max diff [:1]: {jnp.max(jnp.abs(expected_output[:1] - actual_output[:1]))}') + print(f'Output mean diff [:1]: {jnp.mean(jnp.abs(expected_output[:1] - actual_output[:1]))}') + # print(f'Output max diff [1:6]: {jnp.max(jnp.abs(expected_output[1:6] - actual_output[1:6]))}') + # print(f'Output mean diff [1:6]: {jnp.mean(jnp.abs(expected_output[1:6] - actual_output[1:6]))}') + # print(f'Output max diff [6:128]: {jnp.max(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') + # print(f'Output mean diff [6:128]: {jnp.mean(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') + # print(f'Output max diff [128:256]: {jnp.max(jnp.abs(expected_output[128:256] - actual_output[128:256]))}') + # print(f'Output mean diff [128:256]: {jnp.mean(jnp.abs(expected_output[128:256] - actual_output[128:256]))}') + # print(f'xw32 {expected_output[6:128]=}') + # print(f'xw32 {actual_output[6:128]=}') + # print(f'Output max diff: {jnp.max(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') + # print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') + # print(f'Output max diff: {jnp.max(jnp.abs(expected_output[128:256] - actual_output[128:256]))}') + # print(f'Output max diff: {jnp.max(jnp.abs(expected_output[256:384] - actual_output[128:256]))}') + # print(f'Output max diff: {jnp.max(jnp.abs(expected_output[384:512] - actual_output[384:512]))}') + # print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') + + self.assertTrue(jnp.allclose(actual_output[:1], expected_output[:1], atol=atol, rtol=rtol)) + self.assertTrue(jnp.allclose(actual_output[1:6], expected_output[1:6], atol=atol, rtol=rtol)) + self.assertTrue(jnp.allclose(actual_output[6:], expected_output[6:], atol=atol, rtol=rtol)) self.assertTrue(jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol)) def _verify_ragged_paged_attention( @@ -173,6 +210,10 @@ def _verify_ragged_paged_attention( num_pages, ): num_seqs = len(seq_lens) + for i in range(num_seqs): + cur_q_len = seq_lens[i][0] + cur_kv_len = seq_lens[i][1] + assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}" query_lens = [seq_len[0] for seq_len in seq_lens] num_q_tokens = sum(query_lens) kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) @@ -205,7 +246,7 @@ def _verify_ragged_paged_attention( print(f"xw32 max_kv_len: {max_kv_len}, {max_num_pages_per_seq=}") # The assert below mimics the reality that each page get a unique index. # But for testing, the assert could be omitted. - assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" + # assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] q_lens_with_paddings = [0] * num_q_tokens @@ -213,7 +254,7 @@ def _verify_ragged_paged_attention( q_lens_with_paddings[i] = query_lens[i] cu_q_lens = jnp.cumsum(jnp.array([0]+q_lens_with_paddings)) - actual_output = ragged_paged_attention( + err, actual_output = ragged_paged_attention( queries, k_pages, v_pages, @@ -222,6 +263,7 @@ def _verify_ragged_paged_attention( cu_q_lens, num_seqs, ) + err.throw() actual_output = jax.block_until_ready(actual_output) print("ragged paged attention finished.") @@ -241,7 +283,7 @@ def _verify_ragged_paged_attention( print(f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}') print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}') if dtype == jnp.float32: - atol = 2e-2 + atol = ATOL_FP32 rtol = 1e-2 elif dtype == jnp.bfloat16: atol = 6e-1 @@ -297,6 +339,109 @@ def test_paged_attention_basic( num_pages, ) + @parameterized.product( + seq_lens=[[(1, 1328), (5, 18), (506, 563)]], + num_heads=[(4, 4), (8, 2), (16, 2)], + head_dim=[128, 256], + dtype=(jnp.float32, jnp.bfloat16), + page_size=[16, 32], + num_pages=[32768, 2048], + ) + def test_paged_attention_varlen_comprehensive( + self, + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_dim: int, + dtype, + page_size: int, + num_pages: int, + ): + # assuming q_blk_size=128 + # seq_lens = [(512, 1328)] # [(q_len, kv_len),...] + # num_heads = (1, 1) + # head_dim = 128 + # page_size = 16 + # dtype = jnp.float32 + # num_pages = 65536 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + # @parameterized.product( + # seq_lens=[[(1, 1328), (5, 18), (129, 463)]], + # num_heads=[(4, 4), (8, 2), (16, 2)], + # head_dim=[128, 256], + # dtype=(jnp.float32, jnp.bfloat16), + # page_size=[16, 32], + # num_pages=[32768, 2048], + # ) + def test_paged_attention_varlen1( + self, + ): + # assuming q_blk_size=128 + seq_lens = [(1, 1328), (5, 18), (1, 129), (120, 229), (1, 122), # first physical q block + (1, 64), (32, 100), (250, 463), (1, 18), (1, 17), (99, 123)] # last 3 physical q blocks [(q_len, kv_len),...] + num_heads = (4, 4) + head_dim = 128 + dtype = jnp.float32 + page_size = 16 + num_pages = 32768 + + self._verify_ragged_paged_attention_debug( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_paged_attention_varlen2( + self, + ): + # assuming q_blk_size=128 + seq_lens = [(1, 1328), (5, 18), (506, 563)] + num_heads = (8, 2) + head_dim = 256 + dtype = jnp.float32 + page_size = 16 + num_pages = 32768 + + self._verify_ragged_paged_attention_debug( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + # this test passed. + def test_paged_attention_varlen3( + self, + ): + # assuming q_blk_size=128 + seq_lens = [(256, 256),(128, 256)] + num_heads = (8, 2) + head_dim = 128 + dtype = jnp.float32 + page_size = 16 + num_pages = 32768 + + self._verify_ragged_paged_attention_debug( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) def test_paged_attention_basic_with_one_token_per_sequence( self, @@ -338,6 +483,146 @@ def test_paged_attention_extreme_all_tokens_belong_to_one_sequence( num_pages, ) + def test_paged_attention_extreme_one_tokens_per_sequence_min( + self, + ): + seq_lens = [] # [(q_len, kv_len),...] + num_seqs = 64 + num_queries_per_block=16 + for i in range(num_seqs): + seq_lens.append((1, 128+i)) + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 1024 + + self._verify_ragged_paged_attention_debug( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_queries_per_block=num_queries_per_block, + ) + + def test_paged_attention_extreme_one_tokens_per_sequence_min2( + self, + ): + seq_lens = [] # [(q_len, kv_len),...] + num_seqs = 64 + num_queries_per_block=16 + for i in range(num_seqs): + seq_lens.append((1, 256+i)) + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 1024 + + self._verify_ragged_paged_attention_debug( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_queries_per_block=num_queries_per_block, + ) + + def test_paged_attention_extreme_one_tokens_per_sequence_min3( + self, + ): + seq_lens = [] # [(q_len, kv_len),...] + num_seqs = 16 + num_queries_per_block=8 + for i in range(num_seqs): + seq_lens.append((1, 128+i)) + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention_debug( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_queries_per_block=num_queries_per_block, + ) + + def test_paged_attention_q_len_should_be_no_longer_than_kv_len( + self, + ): + # assuming q_blk_size=128 + seq_lens = [(1, 0)] # [(q_len, kv_len),...] + num_seqs = 511 + for i in range(num_seqs): + seq_lens.append((1, 128+i)) + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + num_seqs = len(seq_lens) + query_lens = [seq_len[0] for seq_len in seq_lens] + num_q_tokens = sum(query_lens) + kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) + num_q_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." + + prng_key = jax.random.key(0) + k1, k2, k3, k4 = jax.random.split(prng_key, 4) + queries = jax.random.normal(k1, + (num_q_tokens, num_q_heads, head_dim), + dtype=dtype) + k_pages = jax.random.normal(k2, + (num_kv_heads, num_pages, page_size, head_dim), + dtype=dtype) + v_pages = jax.random.normal(k3, + (num_kv_heads, num_pages, page_size, head_dim), + dtype=dtype) + # Create a kv_lens: i32[num_tokens] + kv_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + kv_lens_with_paddings[i] = kv_lens[i] + kv_lens_np = jnp.array(kv_lens_with_paddings) + # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] + max_kv_len = max([seq_len[1] for seq_len in seq_lens]) + max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size + # The reason why we need to pad max_num_pages_per_seq is that + # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 + max_num_pages_per_seq = self._get_closest_power_of_two(max_num_pages_per_seq) + print(f"xw32 max_kv_len: {max_kv_len}, {max_num_pages_per_seq=}") + # The assert below mimics the reality that each page get a unique index. + # But for testing, the assert could be omitted. + assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" + page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) + # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] + q_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + q_lens_with_paddings[i] = query_lens[i] + cu_q_lens = jnp.cumsum(jnp.array([0]+q_lens_with_paddings)) + print(f"xw32 {cu_q_lens=}, {kv_lens_np=}") + + with self.assertRaisesRegex(ValueError, "cur_q_len must be less or equal to cur_kv_len"): + err, _ = ragged_paged_attention(queries, + k_pages, + v_pages, + kv_lens_np, + page_indices, + cu_q_lens, + num_seqs, + ) + err.throw() + + # failing test def test_paged_attention_extreme_one_tokens_per_sequence( self, ): @@ -345,14 +630,14 @@ def test_paged_attention_extreme_one_tokens_per_sequence( seq_lens = [] # [(q_len, kv_len),...] num_seqs = 512 for i in range(num_seqs): - seq_lens.append((1, i)) + seq_lens.append((1, 128+i)) num_heads = (1, 1) head_dim = 128 page_size = 16 dtype = jnp.float32 num_pages = 65536 - self._verify_ragged_paged_attention( + self._verify_ragged_paged_attention_debug( seq_lens, num_heads, head_dim, @@ -400,7 +685,5 @@ def test_make_sequence_metadata( # self.assertTrue(jnp.array_equal(metadata.seq_ids[:metadata.num_logical_q_tiles], [0, 0, 1, 1, 1, 2])) # self.assertTrue(jnp.array_equal(metadata.physical_q_tile_ids[:metadata.num_logical_q_tiles], [0, 1, 1, 2, 3, 3])) - - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py index afbf2be6aa09..f0c7d8309173 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -5,6 +5,7 @@ import jax from jax import lax +from jax.experimental import checkify from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils @@ -301,6 +302,11 @@ def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, if cu_q_lens.shape[0] != num_tokens + 1: raise ValueError("cu_q_lens.shape[0] must be thet same as num_tokens + 1. Got" f" {cu_q_lens.shape[0]} and {num_tokens + 1}") + for i in range(num_seqs): + cur_q_len = cu_q_lens[i+1] - cu_q_lens[i] + cur_kv_len = kv_lens[i] + jax.debug.print("xw32 line308 {i} {cur_q_len}, {cur_kv_len}", i=i, cur_q_len=cur_q_len, cur_kv_len=cur_kv_len) + checkify.check(cur_q_len <= cur_kv_len, "cur_q_len must be less or equal to cur_kv_len. Got {} and {}", cur_q_len, cur_kv_len) if num_seqs > num_tokens: raise ValueError(f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}") # int16: will pack. need to explicit cast to int32. int64 is not supported in Pallas. for smem 1d case. @@ -720,15 +726,16 @@ def prefetch_next_block(): # pylint: disable=unused-variable MIN_BLOCK_SIZE = 128 # TODO(xw32): uncomment this once the kernel output is correct. -# @functools.partial( -# jax.jit, -# static_argnames=[ -# "num_kv_pages_per_block", -# "num_queries_per_block", -# "mask_value", -# "num_seqs", -# ], -# ) +@checkify.checkify +@functools.partial( + jax.jit, + static_argnames=[ + "num_kv_pages_per_block", + "num_queries_per_block", + "mask_value", + "num_seqs", + ], +) def ragged_paged_attention( q: jax.Array, # [num_tokens, num_q_heads, head_dim] k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] @@ -940,43 +947,26 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) # debug compile begins # To enable debug, uncomment this section, comment out the `kernel()` below and comment out the jax.jit above. - compiled_kernel = ( - jax.jit(kernel) - .lower( - # prefetch - group_metadata, - kv_lens, - page_indices_1d, - cu_q_lens, - buffer_index, - step, - # kernel inputs - q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. - k_pages, - k_scales_pages, - v_pages, - v_scales_pages, - ) - .compile({'xla_tpu_enable_log_recorder': 'true'}) - ) - outputs = compiled_kernel( - # prefetch - group_metadata, - kv_lens, - page_indices_1d, - cu_q_lens, - buffer_index, - step, - # kernel inputs - q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. - k_pages, - k_scales_pages, - v_pages, - v_scales_pages, - ) - # debug compile ends - - # outputs = kernel( + # compiled_kernel = ( + # jax.jit(kernel) + # .lower( + # # prefetch + # group_metadata, + # kv_lens, + # page_indices_1d, + # cu_q_lens, + # buffer_index, + # step, + # # kernel inputs + # q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + # k_pages, + # k_scales_pages, + # v_pages, + # v_scales_pages, + # ) + # .compile({'xla_tpu_enable_log_recorder': 'true'}) + # ) + # outputs = compiled_kernel( # # prefetch # group_metadata, # kv_lens, @@ -991,6 +981,23 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) # v_pages, # v_scales_pages, # ) + # debug compile ends + + outputs = kernel( + # prefetch + group_metadata, + kv_lens, + page_indices_1d, + cu_q_lens, + buffer_index, + step, + # kernel inputs + q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + k_pages, + k_scales_pages, + v_pages, + v_scales_pages, + ) ret = outputs[0] # print(f"xw32 line495 ret.shape={ret.shape}, {ret=}") return jnp.permute_dims(ret, (1, 0, 2)).astype(q.dtype) From e589c8608a63ffc493332b12224cbfd796f5701a Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 30 Jan 2025 23:32:18 +0000 Subject: [PATCH 4/9] Improved the tests and all tests passed except for the OOM one. Also added runtime check. --- test/test_ragged_paged_attention_kernel.py | 318 ++------------------- 1 file changed, 27 insertions(+), 291 deletions(-) diff --git a/test/test_ragged_paged_attention_kernel.py b/test/test_ragged_paged_attention_kernel.py index 720924ce5d05..9c150d6b4545 100644 --- a/test/test_ragged_paged_attention_kernel.py +++ b/test/test_ragged_paged_attention_kernel.py @@ -5,20 +5,13 @@ import jax from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils -from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_group_metadata +from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_group_metadata, DEFAULT_MASK_VALUE import jax.numpy as jnp import numpy as np jax.config.parse_flags_with_absl() -# Make sure the q_len is no longer than the kv_len. For example, -# seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because -# the 3rd sequence has q_len(506) > kv_len(463). - -# Just to use the same very negative value in the ref impl as in the kernel. -DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) - ATOL_FP32 = 2e-1 @@ -49,6 +42,7 @@ def _ref_ragged_paged_attention( k = jnp.permute_dims(k, (1, 2, 0, 3)) # [page_indices_to_use, page_size, num_kv_heads, head_dim] k = jnp.reshape(k, (-1, num_kv_heads, head_dim)) # [kv_len, num_kv_heads, head_dim] k = k[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim] + v = v_pages[:, page_indices_to_use, :, :] v = jnp.permute_dims(v, (1, 2, 0, 3)) v = jnp.reshape(v, (-1, num_kv_heads, head_dim)) @@ -64,10 +58,8 @@ def _ref_ragged_paged_attention( jnp.int32, (cur_q_len, cur_kv_len), 0 ) kv_span = jax.lax.broadcasted_iota(jnp.int32, (cur_q_len, cur_kv_len), 1) - # mask = jnp.where(q_span < kv_span, float("-inf"), 0.) + # Use the same DEFAULT_MASK_VALUE as in the kernel instead of float("-inf") so that the kernel can match the ref implement better. mask = jnp.where(q_span < kv_span, DEFAULT_MASK_VALUE, 0.) - if i == 2: - print(f"xw32 ref impl {mask.shape=}, {mask=}") with jax.numpy_rank_promotion("allow"): attn = attn + mask attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) @@ -82,7 +74,7 @@ def _ref_ragged_paged_attention( @jtu.with_config(jax_numpy_dtype_promotion="standard") class RaggedPagedAttentionKernelTest(jtu.JaxTestCase): - def _verify_ragged_paged_attention_debug( + def _verify_ragged_paged_attention( self, seq_lens, num_heads, @@ -93,127 +85,14 @@ def _verify_ragged_paged_attention_debug( num_queries_per_block=128, ): num_seqs = len(seq_lens) + # Make sure the q_len is no longer than the kv_len. For example, + # seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because + # the 3rd sequence has q_len(506) > kv_len(463). for i in range(num_seqs): cur_q_len = seq_lens[i][0] cur_kv_len = seq_lens[i][1] assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}" - query_lens = [seq_len[0] for seq_len in seq_lens] - num_q_tokens = sum(query_lens) - kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) - num_q_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." - - prng_key = jax.random.key(0) - k1, k2, k3, k4 = jax.random.split(prng_key, 4) - queries = jax.random.normal(k1, - (num_q_tokens, num_q_heads, head_dim), - dtype=dtype) - k_pages = jax.random.normal(k2, - (num_kv_heads, num_pages, page_size, head_dim), - dtype=dtype) - v_pages = jax.random.normal(k3, - (num_kv_heads, num_pages, page_size, head_dim), - dtype=dtype) - # Create a kv_lens: i32[num_tokens] - kv_lens_with_paddings = [0] * num_q_tokens - for i in range(num_seqs): - kv_lens_with_paddings[i] = kv_lens[i] - kv_lens_np = jnp.array(kv_lens_with_paddings) - # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] - max_kv_len = max([seq_len[1] for seq_len in seq_lens]) - max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size - # The reason why we need to pad max_num_pages_per_seq is that - # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 - max_num_pages_per_seq = self._get_closest_power_of_two(max_num_pages_per_seq) - print(f"xw32 max_kv_len: {max_kv_len}, {max_num_pages_per_seq=}") - # The assert below mimics the reality that each page get a unique index. - # But for testing, the assert could be omitted. - # assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" - page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) - # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] - q_lens_with_paddings = [0] * num_q_tokens - for i in range(num_seqs): - q_lens_with_paddings[i] = query_lens[i] - cu_q_lens = jnp.cumsum(jnp.array([0]+q_lens_with_paddings)) - - err, actual_output = ragged_paged_attention( - queries, - k_pages, - v_pages, - kv_lens_np, - page_indices, - cu_q_lens, - num_seqs, - num_queries_per_block=num_queries_per_block, - ) - err.throw() - actual_output = jax.block_until_ready(actual_output) - print("ragged paged attention finished.") - - expected_output = _ref_ragged_paged_attention( - queries, - k_pages, - v_pages, - kv_lens_np, - page_indices, - cu_q_lens, - num_seqs, - ) - - self.assertEqual(actual_output.shape, expected_output.shape) - self.assertEqual(actual_output.dtype, expected_output.dtype) - - print(f'xw32 {expected_output[:1]=}') - print(f'xw32 {actual_output[:1]=}') - - print(f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}') - print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}') - if dtype == jnp.float32: - atol = ATOL_FP32 - rtol = 1e-2 - elif dtype == jnp.bfloat16: - atol = 6e-1 - rtol = 1e-1 - else: - self.fail(f'Unsupported dtype: {dtype}') - - print(f'Output max diff [:1]: {jnp.max(jnp.abs(expected_output[:1] - actual_output[:1]))}') - print(f'Output mean diff [:1]: {jnp.mean(jnp.abs(expected_output[:1] - actual_output[:1]))}') - # print(f'Output max diff [1:6]: {jnp.max(jnp.abs(expected_output[1:6] - actual_output[1:6]))}') - # print(f'Output mean diff [1:6]: {jnp.mean(jnp.abs(expected_output[1:6] - actual_output[1:6]))}') - # print(f'Output max diff [6:128]: {jnp.max(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') - # print(f'Output mean diff [6:128]: {jnp.mean(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') - # print(f'Output max diff [128:256]: {jnp.max(jnp.abs(expected_output[128:256] - actual_output[128:256]))}') - # print(f'Output mean diff [128:256]: {jnp.mean(jnp.abs(expected_output[128:256] - actual_output[128:256]))}') - # print(f'xw32 {expected_output[6:128]=}') - # print(f'xw32 {actual_output[6:128]=}') - # print(f'Output max diff: {jnp.max(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') - # print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') - # print(f'Output max diff: {jnp.max(jnp.abs(expected_output[128:256] - actual_output[128:256]))}') - # print(f'Output max diff: {jnp.max(jnp.abs(expected_output[256:384] - actual_output[128:256]))}') - # print(f'Output max diff: {jnp.max(jnp.abs(expected_output[384:512] - actual_output[384:512]))}') - # print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output[6:128] - actual_output[6:128]))}') - - self.assertTrue(jnp.allclose(actual_output[:1], expected_output[:1], atol=atol, rtol=rtol)) - self.assertTrue(jnp.allclose(actual_output[1:6], expected_output[1:6], atol=atol, rtol=rtol)) - self.assertTrue(jnp.allclose(actual_output[6:], expected_output[6:], atol=atol, rtol=rtol)) - self.assertTrue(jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol)) - def _verify_ragged_paged_attention( - self, - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, - ): - num_seqs = len(seq_lens) - for i in range(num_seqs): - cur_q_len = seq_lens[i][0] - cur_kv_len = seq_lens[i][1] - assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}" query_lens = [seq_len[0] for seq_len in seq_lens] num_q_tokens = sum(query_lens) kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) @@ -232,22 +111,24 @@ def _verify_ragged_paged_attention( v_pages = jax.random.normal(k3, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) + # Create a kv_lens: i32[num_tokens] kv_lens_with_paddings = [0] * num_q_tokens for i in range(num_seqs): kv_lens_with_paddings[i] = kv_lens[i] kv_lens_np = jnp.array(kv_lens_with_paddings) + # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] max_kv_len = max([seq_len[1] for seq_len in seq_lens]) max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size # The reason why we need to pad max_num_pages_per_seq is that # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 max_num_pages_per_seq = self._get_closest_power_of_two(max_num_pages_per_seq) - print(f"xw32 max_kv_len: {max_kv_len}, {max_num_pages_per_seq=}") # The assert below mimics the reality that each page get a unique index. # But for testing, the assert could be omitted. # assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) + # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] q_lens_with_paddings = [0] * num_q_tokens for i in range(num_seqs): @@ -262,10 +143,10 @@ def _verify_ragged_paged_attention( page_indices, cu_q_lens, num_seqs, + num_queries_per_block=num_queries_per_block, ) - err.throw() + err.throw() # noop if there is not err. actual_output = jax.block_until_ready(actual_output) - print("ragged paged attention finished.") expected_output = _ref_ragged_paged_attention( queries, @@ -283,7 +164,7 @@ def _verify_ragged_paged_attention( print(f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}') print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}') if dtype == jnp.float32: - atol = ATOL_FP32 + atol = 2e-1 rtol = 1e-2 elif dtype == jnp.bfloat16: atol = 6e-1 @@ -297,12 +178,12 @@ def _get_closest_power_of_two(self, x): raise ValueError(f"x must be positive. Got {x}") return 2 ** int(np.ceil(np.log2(x))) - def test_paged_attention_min_two_kv_block_per_sequence( + def test_paged_attention_basic( self, ): + # Same setup as in the design doc. # assuming q_blk_size=128, page_size=16, num_kv_pages_per_compute_block=16 - # One of the constraints of the kernel is that q.shape[0]%q_blk_size==0 as in _calculate_num_tiles. - # If we cannot get the assumption, we can pad the matrix q in the kernel. + # Note one of the constraints of the kernel is that q.shape[0]%q_blk_size==0 as in _calculate_num_tiles. seq_lens = [(192, 328), (128, 180), (64, 255)] # [(q_len, kv_len),...] num_heads = (1, 1) head_dim = 128 @@ -310,26 +191,6 @@ def test_paged_attention_min_two_kv_block_per_sequence( dtype = jnp.float32 num_pages = 65536 - self._verify_ragged_paged_attention_debug( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, - ) - - def test_paged_attention_basic( - self, - ): - # assuming q_blk_size=128 - seq_lens = [(192, 1328), (128, 180), (64, 463)] # [(q_len, kv_len),...] - num_heads = (1, 1) - head_dim = 128 - page_size = 16 - dtype = jnp.float32 - num_pages = 65536 - self._verify_ragged_paged_attention( seq_lens, num_heads, @@ -357,13 +218,6 @@ def test_paged_attention_varlen_comprehensive( num_pages: int, ): # assuming q_blk_size=128 - # seq_lens = [(512, 1328)] # [(q_len, kv_len),...] - # num_heads = (1, 1) - # head_dim = 128 - # page_size = 16 - # dtype = jnp.float32 - # num_pages = 65536 - self._verify_ragged_paged_attention( seq_lens, num_heads, @@ -373,15 +227,7 @@ def test_paged_attention_varlen_comprehensive( num_pages, ) - # @parameterized.product( - # seq_lens=[[(1, 1328), (5, 18), (129, 463)]], - # num_heads=[(4, 4), (8, 2), (16, 2)], - # head_dim=[128, 256], - # dtype=(jnp.float32, jnp.bfloat16), - # page_size=[16, 32], - # num_pages=[32768, 2048], - # ) - def test_paged_attention_varlen1( + def test_paged_attention_mix_prefill_and_decode1( self, ): # assuming q_blk_size=128 @@ -393,48 +239,7 @@ def test_paged_attention_varlen1( page_size = 16 num_pages = 32768 - self._verify_ragged_paged_attention_debug( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, - ) - - def test_paged_attention_varlen2( - self, - ): - # assuming q_blk_size=128 - seq_lens = [(1, 1328), (5, 18), (506, 563)] - num_heads = (8, 2) - head_dim = 256 - dtype = jnp.float32 - page_size = 16 - num_pages = 32768 - - self._verify_ragged_paged_attention_debug( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, - ) - - # this test passed. - def test_paged_attention_varlen3( - self, - ): - # assuming q_blk_size=128 - seq_lens = [(256, 256),(128, 256)] - num_heads = (8, 2) - head_dim = 128 - dtype = jnp.float32 - page_size = 16 - num_pages = 32768 - - self._verify_ragged_paged_attention_debug( + self._verify_ragged_paged_attention( seq_lens, num_heads, head_dim, @@ -443,7 +248,7 @@ def test_paged_attention_varlen3( num_pages, ) - def test_paged_attention_basic_with_one_token_per_sequence( + def test_paged_attention_mix_prefill_and_decode2( self, ): # assuming q_blk_size=128 @@ -485,30 +290,6 @@ def test_paged_attention_extreme_all_tokens_belong_to_one_sequence( def test_paged_attention_extreme_one_tokens_per_sequence_min( self, - ): - seq_lens = [] # [(q_len, kv_len),...] - num_seqs = 64 - num_queries_per_block=16 - for i in range(num_seqs): - seq_lens.append((1, 128+i)) - num_heads = (1, 1) - head_dim = 128 - page_size = 16 - dtype = jnp.float32 - num_pages = 1024 - - self._verify_ragged_paged_attention_debug( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, - num_queries_per_block=num_queries_per_block, - ) - - def test_paged_attention_extreme_one_tokens_per_sequence_min2( - self, ): seq_lens = [] # [(q_len, kv_len),...] num_seqs = 64 @@ -521,31 +302,7 @@ def test_paged_attention_extreme_one_tokens_per_sequence_min2( dtype = jnp.float32 num_pages = 1024 - self._verify_ragged_paged_attention_debug( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, - num_queries_per_block=num_queries_per_block, - ) - - def test_paged_attention_extreme_one_tokens_per_sequence_min3( - self, - ): - seq_lens = [] # [(q_len, kv_len),...] - num_seqs = 16 - num_queries_per_block=8 - for i in range(num_seqs): - seq_lens.append((1, 128+i)) - num_heads = (1, 1) - head_dim = 128 - page_size = 16 - dtype = jnp.float32 - num_pages = 65536 - - self._verify_ragged_paged_attention_debug( + self._verify_ragged_paged_attention( seq_lens, num_heads, head_dim, @@ -559,10 +316,7 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len( self, ): # assuming q_blk_size=128 - seq_lens = [(1, 0)] # [(q_len, kv_len),...] - num_seqs = 511 - for i in range(num_seqs): - seq_lens.append((1, 128+i)) + seq_lens = [(1, 0), (511, 256)] # [(q_len, kv_len),...] num_heads = (1, 1) head_dim = 128 page_size = 16 @@ -588,28 +342,29 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len( v_pages = jax.random.normal(k3, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) + # Create a kv_lens: i32[num_tokens] kv_lens_with_paddings = [0] * num_q_tokens for i in range(num_seqs): kv_lens_with_paddings[i] = kv_lens[i] kv_lens_np = jnp.array(kv_lens_with_paddings) + # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] max_kv_len = max([seq_len[1] for seq_len in seq_lens]) max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size # The reason why we need to pad max_num_pages_per_seq is that # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 max_num_pages_per_seq = self._get_closest_power_of_two(max_num_pages_per_seq) - print(f"xw32 max_kv_len: {max_kv_len}, {max_num_pages_per_seq=}") # The assert below mimics the reality that each page get a unique index. # But for testing, the assert could be omitted. assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) + # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] q_lens_with_paddings = [0] * num_q_tokens for i in range(num_seqs): q_lens_with_paddings[i] = query_lens[i] cu_q_lens = jnp.cumsum(jnp.array([0]+q_lens_with_paddings)) - print(f"xw32 {cu_q_lens=}, {kv_lens_np=}") with self.assertRaisesRegex(ValueError, "cur_q_len must be less or equal to cur_kv_len"): err, _ = ragged_paged_attention(queries, @@ -622,8 +377,7 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len( ) err.throw() - # failing test - def test_paged_attention_extreme_one_tokens_per_sequence( + def test_paged_attention_extreme_one_tokens_per_sequence_large( self, ): # assuming q_blk_size=128 @@ -637,7 +391,7 @@ def test_paged_attention_extreme_one_tokens_per_sequence( dtype = jnp.float32 num_pages = 65536 - self._verify_ragged_paged_attention_debug( + self._verify_ragged_paged_attention( seq_lens, num_heads, head_dim, @@ -662,28 +416,10 @@ def test_make_sequence_metadata( num_seqs=num_seqs ) seq_ids, physical_q_tile_ids = metadata - # print(f"xw32 metadata.physical_q_tile_ids: {metadata.physical_q_tile_ids}") - # print(f"xw32 metadata.seq_ids: {metadata.seq_ids}") self.assertEqual(num_logical_q_tiles, 6) self.assertTrue(jnp.array_equal(seq_ids, [0, 0, 1, 1, 1, 2])) self.assertTrue(jnp.array_equal(physical_q_tile_ids, [0, 1, 1, 2, 3, 3])) - # print('xw32======================') - # q_lens = jnp.array([192, 256, 64] + [0]*(512-3)) - # metadata = ragged_paged_attention_kernel.original_make_group_metadata( - # group_sizes=q_lens, - # m=num_q_tokens, - # tm=num_queries_per_compute_block, - # start_group=start_group, - # num_nonzero_groups=num_seqs, - # visit_empty_groups=False, - # ) - # print(f"xw32 {metadata=}") - # self.assertEqual(metadata.num_logical_q_tiles, 6) - # print(f"xw32 metadata.seq_ids: {metadata.seq_ids}") - # print(f"xw32 metadata.physical_q_tile_ids: {metadata.physical_q_tile_ids}") - # print(f"xw32 metadata.seq_ids[:metadata.num_logical_q_tiles]: {metadata.seq_ids[:metadata.num_logical_q_tiles]}") - # self.assertTrue(jnp.array_equal(metadata.seq_ids[:metadata.num_logical_q_tiles], [0, 0, 1, 1, 1, 2])) - # self.assertTrue(jnp.array_equal(metadata.physical_q_tile_ids[:metadata.num_logical_q_tiles], [0, 1, 1, 2, 3, 3])) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 9b3cdab6d5a63d8af3d407a9ac9511388bafa523 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 31 Jan 2025 00:39:24 +0000 Subject: [PATCH 5/9] clean up --- test/test_ragged_paged_attention_kernel.py | 8 +- .../ragged_paged_attention_kernel.py | 328 +++++++----------- 2 files changed, 126 insertions(+), 210 deletions(-) diff --git a/test/test_ragged_paged_attention_kernel.py b/test/test_ragged_paged_attention_kernel.py index 9c150d6b4545..bd98e0b5e1c5 100644 --- a/test/test_ragged_paged_attention_kernel.py +++ b/test/test_ragged_paged_attention_kernel.py @@ -5,7 +5,7 @@ import jax from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils -from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_group_metadata, DEFAULT_MASK_VALUE +from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_sequence_metadata, DEFAULT_MASK_VALUE import jax.numpy as jnp import numpy as np @@ -408,12 +408,12 @@ def test_make_sequence_metadata( num_queries_per_compute_block = 128 start_group = jnp.array([0]) num_seqs = 3 - metadata, num_logical_q_tiles = make_group_metadata( + metadata, num_logical_q_tiles = make_sequence_metadata( cu_q_lens=cu_q_lens, m=num_q_tokens, tm=num_queries_per_compute_block, - start_group=start_group, - num_seqs=num_seqs + start_sequence=start_group, + num_sequences=num_seqs ) seq_ids, physical_q_tile_ids = metadata self.assertEqual(num_logical_q_tiles, 6) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py index f0c7d8309173..863ef1bf3c8b 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -100,127 +100,101 @@ def _calculate_num_tiles(x: int, tx: int) -> int: raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") return tiles -SequenceMetadata = namedtuple( - "SequenceMetadata", - [ - "num_logical_q_tiles", - "seq_ids", - "physical_q_tile_ids", - ], -) - -GroupMetadata = Any - # https://github.com/jax-ml/jax/blob/9fb29766a2130e74a85cba30420cf777d185ea5a/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L79 -# TODO(xw32): need to do some renaming to adapt to our case. -# Currently, group maps to sequence. -def make_group_metadata( +def make_sequence_metadata( *, cu_q_lens: jnp.ndarray, m: int, tm: int, - start_group: jnp.ndarray, - num_seqs: int, + start_sequence: jnp.ndarray, + num_sequences: int, ): """Create the metadata needed for grouped matmul computation. Args: - group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. - m: The number of rows in lhs. + cu_q_lens: : A 1d, jnp.ndarray with shape [num_seqs+1] and jnp.int32 dtype. + The cumulative query lengths. + m: The number of query tokens. tm: The m-dimension tile size being used. - start_group: The group in group sizes to start computing from. This is - particularly useful for when rhs num_groups is sharded. - num_nonzero_groups: Number of groups in group sizes to compute on. Useful in - combination with group_offset. - visit_empty_groups: If True, do not squeeze tiles for empty groups out of - the metadata. This is necessary for tgmm, where we at least need to zero - the output for each group. + start_sequence: The sequence in cu_q_lens to start computing from. This is useful for when num_seqs is sharded. + num_sequences: The number of sequences to compute on. Returns: tuple of: - group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32 - dtype. group_offsets[i] indicates the row at which group [i] starts in - the lhs matrix and group_offsets[i-1] = m. - group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and - jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will - work on. - m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and - jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' - will work on. - num_tiles: The number of m-dimension tiles to execute. + seq_ids: A 1d, jnp.ndarray with shape [m_tiles + num_seqs] and + jnp.int32 dtype. seq_ids[i] indicates which sequence the grid index (num_logical_tiles_q) will work on. + physical_q_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_seqs] and + jnp.int32. physical_q_tile_ids[i] indicates which query-dim physical tile the grid index (num_logical_tiles_q) will work on. + + num_logical_q_tiles: The number of query-dim logical tiles to execute. """ - num_groups = num_seqs - end_group = start_group + num_seqs - 1 + end_sequence = start_sequence + num_sequences - 1 - # Calculate the offset of each group, starting at zero. This metadata is + # We need the offset of each sequence from input, starting at zero. This metadata is # similar to row offsets in a CSR matrix. The following properties hold: # - # group_offsets.shape = [num_groups + 1] - # group_offsets[0] = 0 - # group_offsets[num_groups] = m + # sequence_offsets.shape = [num_sequences + 1] + # sequence_offsets[0] = 0 + # sequence_offsets[num_sequences] = m # - # The row at which group 'i' starts is group_offsets[i]. - group_ends = cu_q_lens[1:] - group_offsets = cu_q_lens + # The row at which sequence 'i' starts is sequence_offsets[i]. + sequence_ends = cu_q_lens[1:] + sequence_offsets = cu_q_lens - # Assign a group id to each grid index. + # Assign a sequence id to each grid index. The grid index refers to the logical q tile index. # - # If a group starts somewhere other than the start of a tile or ends somewhere + # If a sequence starts somewhere other than the start of a tile or ends somewhere # other than the end of a tile we need to compute that full tile. Calculate - # the number of tiles for each group by rounding their end up to the nearest + # the number of tiles for each sequence by rounding their end up to the nearest # 'tm' and their start down to the nearest 'tm'. - # (1) Round the group_ends up to the nearest multiple of 'tm'. + # (1) Round the sequence_ends up to the nearest multiple of 'tm'. # - # NOTE: This does not change group_offsets[num_groups], which is m + # NOTE: This does not change sequence_offsets[num_sequences], which is m # (because we enforce m is divisible by tm). - rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32) - print('xw32 {rounded_group_ends=}') + rounded_sequence_ends = ((sequence_ends + tm - 1) // tm * tm).astype(jnp.int32) + - # (2) Round the group_starts down to the nearest multiple of 'tm'. - group_starts = jnp.concatenate( - [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]] + # (2) Round the sequence_starts down to the nearest multiple of 'tm'. + sequence_starts = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), sequence_ends[:-1]] ) - rounded_group_starts = group_starts // tm * tm + rounded_sequence_starts = sequence_starts // tm * tm - # (3) Calculate the number of rows in each group. - # - # NOTE: Handle zero-sized groups as a special case. If the start for a - # zero-sized group is not divisible by 'tm' its start will be rounded down and - # its end will be rounded up such that its size will become 1 tile here. - rounded_group_sizes = rounded_group_ends - rounded_group_starts + # (3) Calculate the number of rows in each sequence. + rounded_sequence_sizes = rounded_sequence_ends - rounded_sequence_starts - # (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles. + # (4) Convert the sequence sizes from units of rows to unit of 'tm' sized tiles. # - # An m-dimension tile is 'owned' by group 'i' if the first row of the tile - # belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1 + # An m-dimension tile is 'owned' by sequence 'i' if the first row of the tile + # belongs to sequence 'i'. In addition to owned tiles, each sequence can have 0 or 1 # initial partial tiles if it's first row does not occur in the first row of a - # tile. The '0-th' group never has a partial tile because it always starts at + # tile. The '0-th' sequence never has a partial tile because it always starts at # the 0-th row. # - # If no group has a partial tile, the total number of tiles is equal to - # 'm // tm'. If every group has a partial except the 0-th group, the total - # number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that + # If no sequence has a partial tile, the total number of tiles is equal to + # 'm // tm'. If every sequence has a partial except the 0-th sequence, the total + # number of tiles is equal to 'm // tm + num_sequences - 1'. Thus we know that # - # tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1 + # tiles_m <= sequence_tiles.sum() <= tiles_m + num_sequences - 1 # # Where tiles_m = m // tm. # - # NOTE: All group sizes are divisible by 'tm' because of the rounding in steps + # NOTE: All sequence sizes are divisible by 'tm' because of the rounding in steps # (1) and (2) so this division is exact. - group_tiles = rounded_group_sizes // tm + sequence_tiles = rounded_sequence_sizes // tm - # Create the group ids for each grid index based on the tile counts for each - # group. + # Create the sequence ids for each grid index based on the tile counts for each + # sequence. # - # NOTE: This repeat(...) will pad group_ids with the final group id if - # group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized + # NOTE: This repeat(...) will pad sequence_ids with the final sequence id if + # sequence_tiles.sum() < tiles_m + num_sequences - 1. The kernel grid will be sized # such that we only execute the necessary number of tiles. tiles_m = _calculate_num_tiles(m, tm) - group_ids = jnp.repeat( - jnp.arange(num_groups, dtype=jnp.int32), - group_tiles[:num_groups], # would it introduce dynamic shape to impact JIT? - total_repeat_length=tiles_m + num_groups - 1, + sequence_ids = jnp.repeat( + jnp.arange(num_sequences, dtype=jnp.int32), + sequence_tiles[:num_sequences], + total_repeat_length=tiles_m + num_sequences - 1, ) # Assign an m-dimension tile id to each grid index. @@ -230,22 +204,20 @@ def make_group_metadata( # (1) Calculate how many times each m-dimension tile will be visited. # - # Each tile is guaranteed to be visited once by the group that owns the tile. - # The remaining possible visits occur when a group starts inside of a tile at + # Each tile is guaranteed to be visited once by the sequence that owns the tile. + # The remaining possible visits occur when a sequence starts inside of a tile at # a position other than the first row. We can calculate which m-dimension tile - # each group starts in by floor-dividing its offset with `tm` and then count + # each sequence starts in by floor-dividing its offset with `tm` and then count # tile visits with a histogram. # - # To avoid double counting tile visits from the group that owns the tile, + # To avoid double counting tile visits from the sequence that owns the tile, # filter these out by assigning their tile id to `tile_m` (one beyond the max) - # such that they're ignored by the subsequent histogram. Also filter out any - # group which is empty. + # such that they're ignored by the subsequent histogram. # - # TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear. - partial_tile_mask = ((group_offsets[:-1] % tm) == 0) + partial_tile_mask = ((sequence_offsets[:-1] % tm) == 0) partial_tile_ids = jnp.where( - partial_tile_mask, tiles_m, group_offsets[:-1] // tm + partial_tile_mask, tiles_m, sequence_offsets[:-1] // tm ) tile_visits = ( @@ -258,27 +230,26 @@ def make_group_metadata( m_tile_ids = jnp.repeat( jnp.arange(tiles_m, dtype=jnp.int32), tile_visits.astype(jnp.int32), - total_repeat_length=tiles_m + num_groups - 1, + total_repeat_length=tiles_m + num_sequences - 1, ) # Account for sharding. # - # Find the start of the groups owned by our shard and shift the group_ids and + # Find the start of the sequences owned by our shard and shift the sequence_ids and # m_tile_ids s.t. the metadata for our tiles are at the front of the arrays. # - # TODO(tgale): Move this offset into the kernel to avoid these rolls. - first_tile_in_shard = (group_ids < start_group).sum() - group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0) + first_tile_in_shard = (sequence_ids < start_sequence).sum() + sequence_ids = jnp.roll(sequence_ids, shift=-first_tile_in_shard, axis=0) m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0) # Calculate the number of tiles we need to compute for our shard. # - # Remove tile visits that belong to a group not in our shard. - iota = jnp.arange(num_groups, dtype=jnp.int32) - active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group) - group_tiles = jnp.where(active_group_mask, group_tiles[:num_groups], 0) - num_tiles = group_tiles.sum() - return (group_ids, m_tile_ids), num_tiles # num_logical_q_tiles, seq_ids, physical_q_tile_ids + # Remove tile visits that belong to a sequence not in our shard. + iota = jnp.arange(num_sequences, dtype=jnp.int32) + active_sequence_mask = jnp.logical_and(iota <= end_sequence, iota >= start_sequence) + sequence_tiles = jnp.where(active_sequence_mask, sequence_tiles[:num_sequences], 0) + num_tiles = sequence_tiles.sum() + return (sequence_ids, m_tile_ids), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_block): @@ -305,7 +276,6 @@ def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, for i in range(num_seqs): cur_q_len = cu_q_lens[i+1] - cu_q_lens[i] cur_kv_len = kv_lens[i] - jax.debug.print("xw32 line308 {i} {cur_q_len}, {cur_kv_len}", i=i, cur_q_len=cur_q_len, cur_kv_len=cur_kv_len) checkify.check(cur_q_len <= cur_kv_len, "cur_q_len must be less or equal to cur_kv_len. Got {} and {}", cur_q_len, cur_kv_len) if num_seqs > num_tokens: raise ValueError(f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}") @@ -353,7 +323,7 @@ def _flash_attention( effective_kv_lens_ref, # [num_tokens] effective_cu_q_lens_ref, # [num_tokens + 1] # kernel inputs - q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_block, head_dim] + q_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] k, # [kv_blk_size, head_dim] v, # [kv_blk_size, head_dim] # outputs @@ -361,7 +331,6 @@ def _flash_attention( l_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] m_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] # scratch space - # TODO: double check if the scratch ref shape is correct. l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] m_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] acc_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] @@ -379,14 +348,14 @@ def _flash_attention( kv_blk_size = page_size * num_kv_pages_per_block assert k.shape == (kv_blk_size, head_dim) assert v.shape == (kv_blk_size, head_dim) - + kv_head_idx, logical_q_blk_idx, kv_blk_idx = ( pl.program_id(0), pl.program_id(1), pl.program_id(2), ) seq_ids, physical_q_tile_ids = group_metadata_ref - + # If the q-dim physical tile is changed (meaning it is a new physical q-dim tile that has not visited before), initialize the acc_scratch_ref, m_scratch_ref, and l_scratch_ref to run the flash attention v2 algorithm. prev_logical_q_blk_idx = jnp.where(logical_q_blk_idx > 0, logical_q_blk_idx - 1, 0) is_first_processed_logical_q_blk = logical_q_blk_idx == 0 @@ -395,21 +364,21 @@ def _flash_attention( is_first_kv_blk = (kv_blk_idx == 0) should_init_scratch_ref = jnp.logical_and(is_first_kv_blk, first_time_seeing_physical_q_blk) + @pl.when(should_init_scratch_ref) def init_scratch_ref(): # pylint: disable=unused-variable - pl.debug_print("xw32 should_init_scratch_ref begins: kv_head_idx={}, logical_q_blk_idx={}, kv_blk_idx={}", kv_head_idx, logical_q_blk_idx, kv_blk_idx) l_scratch_ref[q_head_idx_per_kv] = jnp.zeros( l_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) m_scratch_ref[q_head_idx_per_kv] = jnp.full( m_scratch_ref[q_head_idx_per_kv].shape, -jnp.inf, jnp.float32) acc_scratch_ref[q_head_idx_per_kv] = jnp.zeros( acc_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) - + m_prev = m_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] l_prev = l_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] - + # Load the whole q_block that belongs to the current physical q_blk and compute the attention. When we write, we only write the part that belongs to the current sequence. - # I cannot just load only the part of q_block that belongs to the current sequence, because it results in dynamic shapes and then fails the JIT compilation. + # Cannot just load only the part of q_block that belongs to the current sequence, because it results in dynamic shapes and then fails the JIT compilation. # Note, q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_block, head_dim] q = q_ref[q_head_idx_per_kv, :, :].astype(jnp.float32) # [block_q, head_dim] assert q.shape == (num_queries_per_block, head_dim) @@ -417,19 +386,13 @@ def init_scratch_ref(): # pylint: disable=unused-variable 'qd,td->qt', q, k, preferred_element_type=jnp.float32) # [block_q, block_k] assert s.shape == (num_queries_per_block, kv_blk_size) - - # Modify the mask accordingly: first form the mask. Then move the mask down to the right place. + + # Modify the mask accordingly: first form the mask. Then move the mask up/down to the right place. cur_seq_idx = seq_ids[logical_q_blk_idx] cur_seq_start = effective_cu_q_lens_ref[cur_seq_idx] cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx+1] physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] - seq_start_in_cur_physical_q_blk = cur_seq_start >= physical_q_blk_idx*num_queries_per_block - # seq_start_idx_in_cur_physical_q_blk = jnp.where(seq_start_in_cur_physical_q_blk, - # cur_seq_start - physical_q_blk_idx*num_queries_per_block, - # 0) - # q_index = physical_q_blk_idx*num_queries_per_block - seq_start_idx_in_cur_physical_q_blk # start_q_idx_for_cur_seq_in_cur_physical_q_blk. TODO: let's rename num_queries_per_block to q_blk_size later. q_index = physical_q_blk_idx*num_queries_per_block-cur_seq_start - pl.debug_print("xw32 line423, kv_head_idx={}, logical_q_blk_idx={}, kv_blk_idx={}, q_index={}", kv_head_idx, logical_q_blk_idx, kv_blk_idx, q_index) kv_index = kv_blk_idx * kv_blk_size effective_kv_len = effective_kv_lens_ref[cur_seq_idx] effective_q_len = cur_seq_end - cur_seq_start @@ -440,22 +403,15 @@ def init_scratch_ref(): # pylint: disable=unused-variable col_ids = kv_index + jax.lax.broadcasted_iota( jnp.int32, (num_queries_per_block, kv_blk_size), 1) - causal_mask = jnp.where(row_ids < col_ids, mask_value, 0.) # TODO: use this mask. - # causal_mask_debug = jnp.where(row_ids < col_ids, -1, 0) # TODO: remove this line. - should_print_mask = jnp.logical_and(kv_head_idx==0, logical_q_blk_idx==2) - # @pl.when(should_print_mask) - # def print_mask(): # pylint: disable=unused-variable - # pl.debug_print("xw32 line438, causal_mask={}", causal_mask) + causal_mask = jnp.where(row_ids < col_ids, mask_value, 0.) assert causal_mask.shape == (num_queries_per_block, kv_blk_size) + s = s + causal_mask # [block_q, block_k] - assert s.shape == (num_queries_per_block, - kv_blk_size) - + m_curr = jnp.max(s, axis=1)[:, None] # Row max, shape [block_q, 1]. - # why the second dim of m_prev, m_curr, or m_next is 128? m_next = jnp.maximum(m_prev, m_curr) # Shape [block_q, 128]. - + block_k_repeats, rem = divmod(kv_blk_size, MIN_BLOCK_SIZE) if rem: raise NotImplementedError( @@ -478,7 +434,7 @@ def init_scratch_ref(): # pylint: disable=unused-variable else: raise NotImplementedError( f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger") - + # Need to store these l_next and m_next which will relay to the output. # But only update the part that belongs to the current sequence we are working on. lm_mask = _get_store_mask(grid_id=logical_q_blk_idx, @@ -488,15 +444,10 @@ def init_scratch_ref(): # pylint: disable=unused-variable tm=num_queries_per_block, tn=MIN_BLOCK_SIZE, ) - # Should I use jax.lax.select or jnp.where? What's the difference? eg: jnp.where(lm_mask, l_next, 0), jnp.where(lm_mask, m_next, 0) - # Can `lm_mask[...]` be `lm_mask`? + # Either jax.lax.select or jnp.where works here. l_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], l_next, l_scratch_ref[q_head_idx_per_kv]) m_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], m_next, m_scratch_ref[q_head_idx_per_kv]) - - # @pl.when(should_print_mask) - # def _(): # pylint: disable=unused-variable - # print("xw32 line492, l_next.shape={}, ", l_next.shape) - # pl.debug_print("xw32 line492, l_next[6]={}", l_next[6]) + l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) # [block_q, 128] temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast(l_corr * l_next_inv_safe) @@ -507,19 +458,13 @@ def init_scratch_ref(): # pylint: disable=unused-variable tm=num_queries_per_block, tn=head_dim, ) - print(f"xw32 line486 {acc_mask.shape=}, {temp.shape=}, {acc_scratch_ref[q_head_idx_per_kv]=}") acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) - # Note Matmul operandlhs must have a shape divisible by (16, 1) o_curr = jax.lax.dot( p.astype(v.dtype), v, preferred_element_type=jnp.float32) # [block_q, 128] temp = (acc_scratch_ref[q_head_idx_per_kv] + o_curr * l_broadcast(l_next_inv_safe)) - # @pl.when(should_print_mask) - # def _(): # pylint: disable=unused-variable - # print("xw32 line512, temp.shape={}", temp.shape) - # pl.debug_print("xw32 line512, temp={}", temp) acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) - + # Store the result from VMEM to HBM only when it is the last kv_block and the next q-dim logical tile belongs to a different q-dim physical tile. is_last_kv_blk_idx = (kv_blk_idx == (pl.cdiv(effective_kv_len, kv_blk_size) - 1)) num_logical_q_blks = pl.num_programs(1) # grid=(num_kv_heads, num_logical_q_tiles, num_kv_blks) @@ -532,7 +477,6 @@ def init_scratch_ref(): # pylint: disable=unused-variable should_store_to_hbm = jnp.logical_and(is_last_kv_blk_idx, last_time_seeing_cur_physical_q_blk) @pl.when(should_store_to_hbm) def store_to_hbm(): # pylint: disable=unused-variable - pl.debug_print("xw32 store_to_hbm begins: kv_head_idx={}, logical_q_blk_idx={}, kv_blk_idx={}", kv_head_idx, logical_q_blk_idx, kv_blk_idx) o_ref[q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( o_ref.dtype) l_ref[q_head_idx_per_kv] = l_scratch_ref[q_head_idx_per_kv].astype( @@ -540,9 +484,8 @@ def store_to_hbm(): # pylint: disable=unused-variable m_ref[q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( m_ref.dtype) - def paged_flash_attention_kernel( - # prefetch refs, in smem + # prefetch refs group_metadata_ref, # (seq_ids, physical_q_tile_ids) effective_kv_lens_ref, # [num_tokens] # 1d vector, results from page_indices.reshape(-1) where originally page_indices.shape=[num_tokens, pages_per_sequence] @@ -552,10 +495,10 @@ def paged_flash_attention_kernel( step_ref, # kernel inputs # At caller, q.shape= [num_q_heads, num_tokens, head_dim] - q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_block, head_dim] - k_pages_hbm_ref, # shape=[num_kv_heads, total_num_pages, page_size, head_dim] + q_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + k_pages_hbm_ref, # [num_kv_heads, total_num_pages, page_size, head_dim] k_scales_pages_hbm_ref, - v_pages_hbm_ref, # shape=[num_kv_heads, total_num_pages, page_size, head_dim] + v_pages_hbm_ref, # [num_kv_heads, total_num_pages, page_size, head_dim] v_scales_pages_hbm_ref, # same shape as q_ref: [1, num_q_heads_per_kv_head, num_queries_per_block, head_dim], output # outputs @@ -572,15 +515,13 @@ def paged_flash_attention_kernel( m_scratch_ref, acc_scratch_ref, *, - # Where do the following parameter live? SMEM? Not in smem. Not to pass in mosaic. Static value. + # The following parameters are not passed to Mosaic and not in SMEM. They are static values. pages_per_sequence: int, # Note [bs, pages_per_sequence] = page_indices.shape num_tokens: int, num_seqs: int, num_kv_pages_per_block: int, mask_value: float, ): - # assert the input shapes - print(f"xw32 line283 paged_flash_attention_kernel begins. q_ref.shape={q_ref.shape}") kv_head_idx, logical_q_blk_idx, kv_blk_idx = ( pl.program_id(0), pl.program_id(1), @@ -590,28 +531,27 @@ def paged_flash_attention_kernel( num_q_heads_per_kv_head, num_queries_per_block, head_dim = q_ref.shape num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape kv_blk_size = page_size * num_kv_pages_per_block - + seq_ids, physical_q_tile_ids = group_metadata_ref cur_seq_idx = seq_ids[logical_q_blk_idx] effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] should_run = (kv_blk_idx * kv_blk_size < effective_kv_len_cur_seq) - pl.debug_print("xw32 paged_flash_attention_kernel begins kv_head_idx={}, logical_q_blk_idx={}, kv_blk_idx={}, cur_seq_idx={}, effective_kv_len_cur_seq={}", kv_head_idx, logical_q_blk_idx, kv_blk_idx, cur_seq_idx, effective_kv_len_cur_seq) # pl.debug_print can only print JAX type. So cannot print tuple such as q.shape. - + @pl.when(should_run) def get_kv_and_run_flash_attention(): # grid = (num_kv_heads, num_logical_q_tiles, num_kv_blks) def compute_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx): """Return next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx - + Note, k_pages has shape [num_kv_heads, total_num_pages, page_size, head_dim]. To get the KV, it needs the kv_head_idx, then we need the sequence_idx and the kv_blk_idx to get the offset. """ - + def advance_kv_head_idx(): next_kv_head_idx = kv_head_idx + 1 return next_kv_head_idx, 0, 0 - + def advance_logical_q_blk_idx(): next_logical_q_blk_idx = logical_q_blk_idx + 1 return lax.cond( @@ -619,7 +559,7 @@ def advance_logical_q_blk_idx(): lambda: (kv_head_idx, next_logical_q_blk_idx, 0), advance_kv_head_idx, ) - + cur_seq_idx = seq_ids[logical_q_blk_idx] effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] return lax.cond( @@ -627,7 +567,7 @@ def advance_logical_q_blk_idx(): lambda: (kv_head_idx, logical_q_blk_idx, kv_blk_idx), advance_logical_q_blk_idx, ) - + def create_kv_async_copy_descriptors(seq_idx, kv_head_idx, kv_blk_idx, buffer_index): page_offset = seq_idx * pages_per_sequence + kv_blk_idx * num_kv_pages_per_block @@ -663,27 +603,23 @@ def create_kv_async_copy_descriptors(seq_idx, kv_head_idx, kv_blk_idx, @pl.when(step == 0) def prefetch_first_block(): # pylint: disable=unused-variable - pl.debug_print("xw32 prefetch_first_block kv_head_idx={}, cur_seq_idx={}, kv_blk_idx={}, buffer_index={}", kv_head_idx, cur_seq_idx, kv_blk_idx, buffer_index) async_copy_k, async_copy_v = create_kv_async_copy_descriptors( cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index) async_copy_k.start() async_copy_v.start() - # kv_head_idx, logical_q_blk_idx, kv_blk_idx next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = compute_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx+1) - + @pl.when(next_kv_head_idx < num_kv_heads) def prefetch_next_block(): # pylint: disable=unused-variable next_buffer_index = jnp.where(buffer_index == 0, 1, 0) next_seq_idx = seq_ids[next_logical_q_blk_idx] - pl.debug_print("xw32 prefetch_next_block next_kv_head_idx={}, next_seq_idx={}, next_kv_blk_idx={}, buffer_index={}", next_kv_head_idx, next_seq_idx, next_kv_blk_idx, next_buffer_index) async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors( next_seq_idx, next_kv_head_idx, next_kv_blk_idx, next_buffer_index) async_copy_next_k.start() async_copy_next_v.start() buffer_index_ref[0] = next_buffer_index - - # xw32: is the async_copy_k and async_copy_v the same as the ones created in prefetch_first_block? + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index) k = async_copy_k.wait_and_get_loaded( @@ -691,7 +627,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable v = async_copy_v.wait_and_get_loaded() assert k.shape == (num_kv_pages_per_block*page_size, head_dim) assert v.shape == (num_kv_pages_per_block*page_size, head_dim) - + for q_head_idx in range(num_q_heads_per_kv_head): _flash_attention( q_head_idx, @@ -699,7 +635,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable effective_kv_lens_ref, effective_cu_q_lens_ref, # kernel inputs - q_ref, # q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_block, head_dim] + q_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] k, v, # outputs @@ -725,7 +661,6 @@ def prefetch_next_block(): # pylint: disable=unused-variable MIN_BLOCK_SIZE = 128 -# TODO(xw32): uncomment this once the kernel output is correct. @checkify.checkify @functools.partial( jax.jit, @@ -743,7 +678,7 @@ def ragged_paged_attention( kv_lens: jax.Array, # i32[num_tokens] page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] cu_q_lens: jax.Array, # i32[num_tokens + 1] - num_seqs, # i32[] + num_seqs, # int *, mask_value: float = DEFAULT_MASK_VALUE, num_kv_pages_per_block: int = 16, @@ -758,7 +693,7 @@ def ragged_paged_attention( kv_lens: A i32[num_tokens] jax.Array the effective kv length of each sequence. For example, if we have three sequences, lengths could be [16, 3, 1024, x, x, x, x, ...] where x is any value for padding. While - lengths’s shape is [num_tokens], only the first num_seqs values are valid. + lengths's shape is [num_tokens], only the first num_seqs values are valid. The rest should be ignored. page_indices: A i32[num_tokens, pages_per_sequence] jax.Array. Each entry should be in the range of [0, total_num_pages), indicating where to locate @@ -767,7 +702,7 @@ def ragged_paged_attention( cu_q_lens: A i32[num_tokens+1] jax.Array the cumulative sum of the effective query lengths. Similar to kv_lens, only the first num_seqs+1 values are valid. - num_seqs: A i32[] jax.Array the number of sequences. + num_seqs: the number of sequences. mask_value: The value used for padding in attention. By default it is a very negative floating point number. num_kv_pages_per_block: how many kv pages to be processed in one flash @@ -776,7 +711,7 @@ def ragged_paged_attention( attention block in the pallas kernel. Returns: - The output of attention([num_tokens, query_len, num_q_heads, head_dim]). + The output of attention([num_tokens, num_q_heads, head_dim]). """ # TODO: consider remove the k_scales_pages and v_scales_pages during cleaning up. if isinstance(k_pages, quantization_utils.QuantizedTensor): @@ -795,36 +730,30 @@ def ragged_paged_attention( v_scales_pages = None num_tokens, num_q_heads, head_dim = q.shape + # Why the permute_dims is needed? Before permute, q.shape=[num_tokens, num_q_heads, head_dim]; then when we apply the GridSpec, the 2nd last dimension is num_q_heads which is hard to be a multiple of 8. # If permute_dims turns out to be expensive, try jnp.swapaxes. The compiler # may optimize the copies away. - # Or consider unsqueeze a dimension at the 2nd last dimension and squeeze it - # out later. - # jevin: can we not do the permute_dims? - # Why the permute_dims is needed? Before permute, q.shape=[num_tokens, num_q_heads, head_dim]; then when we apply the GridSpec, the 2nd last dimension is num_q_heads which is hard to be a multiple of 8. + # Or consider unsqueeze a dimension at the 2nd last dimension and squeeze it + # out later so that num_q_heads doesn't have to be the 2nd last dimension and hence doesn't subject to the multiple of 8 constraint. q = jnp.permute_dims(q, (1, 0, 2)) # [num_q_heads, num_tokens, head_dim] num_kv_heads, total_num_pages, page_size, head_dim = k_pages.shape check_kernel_input(q, k_pages, v_pages,kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_block) num_q_heads_per_kv_head = num_q_heads // num_kv_heads - group_metadata, num_logical_q_tiles = make_group_metadata( + group_metadata, num_logical_q_tiles = make_sequence_metadata( cu_q_lens=cu_q_lens, m=num_tokens, tm=num_queries_per_block, - start_group=jnp.array([0]), - num_seqs=num_seqs, + start_sequence=jnp.array([0]), + num_sequences=num_seqs, ) - seq_ids, physical_q_tile_ids = group_metadata - pl.debug_print("xw32 line797 seq_ids={}, physical_q_tile_ids={}, num_logical_q_tiles={}", seq_ids, physical_q_tile_ids, num_logical_q_tiles) pages_per_sequence = page_indices.shape[1] num_kv_blks = pages_per_sequence // num_kv_pages_per_block - # num_logical_q_tiles has type jnp.ndarray. So we need the .item() below. grid = (num_kv_heads, num_logical_q_tiles, num_kv_blks) - print(f"xw32 line367 grid={grid}") # out_shape o_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) - # xw32: need to double check that the out_shape of l and m are correct. l = jax.ShapeDtypeStruct((num_q_heads, num_tokens, MIN_BLOCK_SIZE), dtype=jnp.float32) m = jax.ShapeDtypeStruct((num_q_heads, num_tokens, MIN_BLOCK_SIZE), @@ -842,11 +771,9 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) (num_q_heads_per_kv_head, num_queries_per_block, head_dim), qo_index_map, ) - q_dtype_for_kernel_launch = q.dtype in_specs = [ q_block_spec, # Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, q_scales_pages. - # TODO: consider to remove the k_scales_pages and v_scales_pages during cleaning up. pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), None, pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -854,20 +781,10 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) ] # out_spec - # jevin: create a qo spec and reuse it. - o_specs = pl.BlockSpec( # Should be the same as q_block_spec - (num_q_heads_per_kv_head, num_queries_per_block, head_dim), - qo_index_map, - ) - + # o_specs should be the same as q_block_spec + o_specs = q_block_spec # lm_index_map is same as qo_index_map - # TODO: think about reusing q_indx_map. - def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_): - seq_ids, physical_q_tile_ids = group_metadata - del seq_ids - physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] - return (kv_head_idx, physical_q_blk_idx, 0) - + lm_index_map = qo_index_map out_specs = [ o_specs, pl.BlockSpec( @@ -920,19 +837,19 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) paged_flash_attention_kernel, pages_per_sequence=pages_per_sequence, num_tokens=num_tokens, - num_seqs=num_seqs, # it they changes, need to recompile. + num_seqs=num_seqs, num_kv_pages_per_block=num_kv_pages_per_block, mask_value=mask_value, ), grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=6, # TODO(xw32): may need to adjust. + num_scalar_prefetch=6, in_specs=in_specs, out_specs=out_specs, grid=grid, scratch_shapes=scratch_shapes, ), compiler_params=pltpu.TPUCompilerParams( - # due to compute_block_indices, we loop batch, kv_head, q_blk, kv_blk, the order matters. + # due to compute_block_indices, we loop kv_head, q_blk, kv_blk, the order matters. dimension_semantics=( "arbitrary", "arbitrary", @@ -958,7 +875,7 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) # buffer_index, # step, # # kernel inputs - # q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + # q, # k_pages, # k_scales_pages, # v_pages, @@ -975,7 +892,7 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) # buffer_index, # step, # # kernel inputs - # q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + # q, # k_pages, # k_scales_pages, # v_pages, @@ -992,12 +909,11 @@ def lm_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) buffer_index, step, # kernel inputs - q.astype(q_dtype_for_kernel_launch), # TODO: do we need the `.astype`? Need to double check. + q, k_pages, k_scales_pages, v_pages, v_scales_pages, ) ret = outputs[0] - # print(f"xw32 line495 ret.shape={ret.shape}, {ret=}") return jnp.permute_dims(ret, (1, 0, 2)).astype(q.dtype) From ab69feb78407bfc69f203255ec453f61be836945 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 31 Jan 2025 06:03:11 +0000 Subject: [PATCH 6/9] linter --- test/test_ragged_paged_attention_kernel.py | 163 ++++++++------- .../ragged_paged_attention_kernel.py | 195 +++++++++++------- 2 files changed, 201 insertions(+), 157 deletions(-) diff --git a/test/test_ragged_paged_attention_kernel.py b/test/test_ragged_paged_attention_kernel.py index bd98e0b5e1c5..95f3d937da44 100644 --- a/test/test_ragged_paged_attention_kernel.py +++ b/test/test_ragged_paged_attention_kernel.py @@ -9,7 +9,6 @@ import jax.numpy as jnp import numpy as np - jax.config.parse_flags_with_absl() ATOL_FP32 = 2e-1 @@ -32,15 +31,20 @@ def _ref_ragged_paged_attention( start_idx = 0 outputs: List[jax.Array] = [] for i in range(num_seqs): - cur_q_len = cu_q_lens[i+1] - cu_q_lens[i] - q = queries[start_idx:start_idx+cur_q_len] # [cur_q_len, num_q_heads, head_dim] + cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i] + q = queries[start_idx:start_idx + + cur_q_len] # [cur_q_len, num_q_heads, head_dim] cur_kv_len = kv_lens[i] num_pages = (cur_kv_len + page_size - 1) // page_size page_indices_to_use = page_indices[i, :num_pages] - k = k_pages[:, page_indices_to_use, :, :] # [num_kv_heads, page_indices_to_use, page_size, head_dim] - k = jnp.permute_dims(k, (1, 2, 0, 3)) # [page_indices_to_use, page_size, num_kv_heads, head_dim] - k = jnp.reshape(k, (-1, num_kv_heads, head_dim)) # [kv_len, num_kv_heads, head_dim] + k = k_pages[:, + page_indices_to_use, :, :] # [num_kv_heads, page_indices_to_use, page_size, head_dim] + k = jnp.permute_dims( + k, (1, 2, 0, + 3)) # [page_indices_to_use, page_size, num_kv_heads, head_dim] + k = jnp.reshape( + k, (-1, num_kv_heads, head_dim)) # [kv_len, num_kv_heads, head_dim] k = k[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim] v = v_pages[:, page_indices_to_use, :, :] @@ -55,15 +59,15 @@ def _ref_ragged_paged_attention( attn = jnp.einsum("qhd,khd->hqk", q, k) attn = attn.astype('float32') q_span = (cur_kv_len - cur_q_len) + jax.lax.broadcasted_iota( - jnp.int32, (cur_q_len, cur_kv_len), 0 - ) + jnp.int32, (cur_q_len, cur_kv_len), 0) kv_span = jax.lax.broadcasted_iota(jnp.int32, (cur_q_len, cur_kv_len), 1) # Use the same DEFAULT_MASK_VALUE as in the kernel instead of float("-inf") so that the kernel can match the ref implement better. mask = jnp.where(q_span < kv_span, DEFAULT_MASK_VALUE, 0.) with jax.numpy_rank_promotion("allow"): attn = attn + mask attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) - out = jnp.einsum("hqk,khd->qhd", attn, v) # [cur_q_len, num_q_heads, head_dim] + out = jnp.einsum("hqk,khd->qhd", attn, + v) # [cur_q_len, num_q_heads, head_dim] outputs.append(out) start_idx += cur_q_len @@ -102,15 +106,12 @@ def _verify_ragged_paged_attention( prng_key = jax.random.key(0) k1, k2, k3, k4 = jax.random.split(prng_key, 4) - queries = jax.random.normal(k1, - (num_q_tokens, num_q_heads, head_dim), - dtype=dtype) - k_pages = jax.random.normal(k2, - (num_kv_heads, num_pages, page_size, head_dim), - dtype=dtype) - v_pages = jax.random.normal(k3, - (num_kv_heads, num_pages, page_size, head_dim), - dtype=dtype) + queries = jax.random.normal( + k1, (num_q_tokens, num_q_heads, head_dim), dtype=dtype) + k_pages = jax.random.normal( + k2, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) + v_pages = jax.random.normal( + k3, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) # Create a kv_lens: i32[num_tokens] kv_lens_with_paddings = [0] * num_q_tokens @@ -121,19 +122,24 @@ def _verify_ragged_paged_attention( # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] max_kv_len = max([seq_len[1] for seq_len in seq_lens]) max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size - # The reason why we need to pad max_num_pages_per_seq is that + # The reason why we need to pad max_num_pages_per_seq is that # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 - max_num_pages_per_seq = self._get_closest_power_of_two(max_num_pages_per_seq) + max_num_pages_per_seq = self._get_closest_power_of_two( + max_num_pages_per_seq) # The assert below mimics the reality that each page get a unique index. # But for testing, the assert could be omitted. # assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" - page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) + page_indices = jax.random.randint( + k4, (num_q_tokens, max_num_pages_per_seq), + 0, + num_pages, + dtype=jnp.int32) # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] q_lens_with_paddings = [0] * num_q_tokens for i in range(num_seqs): q_lens_with_paddings[i] = query_lens[i] - cu_q_lens = jnp.cumsum(jnp.array([0]+q_lens_with_paddings)) + cu_q_lens = jnp.cumsum(jnp.array([0] + q_lens_with_paddings)) err, actual_output = ragged_paged_attention( queries, @@ -161,8 +167,11 @@ def _verify_ragged_paged_attention( self.assertEqual(actual_output.shape, expected_output.shape) self.assertEqual(actual_output.dtype, expected_output.dtype) - print(f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}') - print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}') + print( + f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}') + print( + f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}' + ) if dtype == jnp.float32: atol = 2e-1 rtol = 1e-2 @@ -171,16 +180,15 @@ def _verify_ragged_paged_attention( rtol = 1e-1 else: self.fail(f'Unsupported dtype: {dtype}') - self.assertTrue(jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol)) + self.assertTrue( + jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol)) def _get_closest_power_of_two(self, x): if x <= 0: raise ValueError(f"x must be positive. Got {x}") - return 2 ** int(np.ceil(np.log2(x))) + return 2**int(np.ceil(np.log2(x))) - def test_paged_attention_basic( - self, - ): + def test_paged_attention_basic(self,): # Same setup as in the design doc. # assuming q_blk_size=128, page_size=16, num_kv_pages_per_compute_block=16 # Note one of the constraints of the kernel is that q.shape[0]%q_blk_size==0 as in _calculate_num_tiles. @@ -226,13 +234,22 @@ def test_paged_attention_varlen_comprehensive( dtype, num_pages, ) - - def test_paged_attention_mix_prefill_and_decode1( - self, - ): + + def test_paged_attention_mix_prefill_and_decode1(self,): # assuming q_blk_size=128 - seq_lens = [(1, 1328), (5, 18), (1, 129), (120, 229), (1, 122), # first physical q block - (1, 64), (32, 100), (250, 463), (1, 18), (1, 17), (99, 123)] # last 3 physical q blocks [(q_len, kv_len),...] + seq_lens = [ + (1, 1328), + (5, 18), + (1, 129), + (120, 229), + (1, 122), # first physical q block + (1, 64), + (32, 100), + (250, 463), + (1, 18), + (1, 17), + (99, 123) + ] # last 3 physical q blocks [(q_len, kv_len),...] num_heads = (4, 4) head_dim = 128 dtype = jnp.float32 @@ -248,11 +265,10 @@ def test_paged_attention_mix_prefill_and_decode1( num_pages, ) - def test_paged_attention_mix_prefill_and_decode2( - self, - ): + def test_paged_attention_mix_prefill_and_decode2(self,): # assuming q_blk_size=128 - seq_lens = [(1, 127), (120, 1328), (1, 64), (1, 64), (1, 64), (1, 64), (256, 256), (131, 463)] # [(q_len, kv_len),...] + seq_lens = [(1, 127), (120, 1328), (1, 64), (1, 64), (1, 64), (1, 64), + (256, 256), (131, 463)] # [(q_len, kv_len),...] num_heads = (1, 1) head_dim = 128 page_size = 16 @@ -268,9 +284,7 @@ def test_paged_attention_mix_prefill_and_decode2( num_pages, ) - def test_paged_attention_extreme_all_tokens_belong_to_one_sequence( - self, - ): + def test_paged_attention_extreme_all_tokens_belong_to_one_sequence(self,): # assuming q_blk_size=128 seq_lens = [(512, 1328)] # [(q_len, kv_len),...] num_heads = (1, 1) @@ -288,14 +302,12 @@ def test_paged_attention_extreme_all_tokens_belong_to_one_sequence( num_pages, ) - def test_paged_attention_extreme_one_tokens_per_sequence_min( - self, - ): + def test_paged_attention_extreme_one_tokens_per_sequence_min(self,): seq_lens = [] # [(q_len, kv_len),...] num_seqs = 64 - num_queries_per_block=16 + num_queries_per_block = 16 for i in range(num_seqs): - seq_lens.append((1, 256+i)) + seq_lens.append((1, 256 + i)) num_heads = (1, 1) head_dim = 128 page_size = 16 @@ -312,9 +324,7 @@ def test_paged_attention_extreme_one_tokens_per_sequence_min( num_queries_per_block=num_queries_per_block, ) - def test_paged_attention_q_len_should_be_no_longer_than_kv_len( - self, - ): + def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,): # assuming q_blk_size=128 seq_lens = [(1, 0), (511, 256)] # [(q_len, kv_len),...] num_heads = (1, 1) @@ -333,15 +343,12 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len( prng_key = jax.random.key(0) k1, k2, k3, k4 = jax.random.split(prng_key, 4) - queries = jax.random.normal(k1, - (num_q_tokens, num_q_heads, head_dim), - dtype=dtype) - k_pages = jax.random.normal(k2, - (num_kv_heads, num_pages, page_size, head_dim), - dtype=dtype) - v_pages = jax.random.normal(k3, - (num_kv_heads, num_pages, page_size, head_dim), - dtype=dtype) + queries = jax.random.normal( + k1, (num_q_tokens, num_q_heads, head_dim), dtype=dtype) + k_pages = jax.random.normal( + k2, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) + v_pages = jax.random.normal( + k3, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) # Create a kv_lens: i32[num_tokens] kv_lens_with_paddings = [0] * num_q_tokens @@ -352,22 +359,29 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len( # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] max_kv_len = max([seq_len[1] for seq_len in seq_lens]) max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size - # The reason why we need to pad max_num_pages_per_seq is that + # The reason why we need to pad max_num_pages_per_seq is that # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 - max_num_pages_per_seq = self._get_closest_power_of_two(max_num_pages_per_seq) + max_num_pages_per_seq = self._get_closest_power_of_two( + max_num_pages_per_seq) # The assert below mimics the reality that each page get a unique index. # But for testing, the assert could be omitted. - assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" - page_indices = jax.random.randint(k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) + assert max_num_pages_per_seq * num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" + page_indices = jax.random.randint( + k4, (num_q_tokens, max_num_pages_per_seq), + 0, + num_pages, + dtype=jnp.int32) # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] q_lens_with_paddings = [0] * num_q_tokens for i in range(num_seqs): q_lens_with_paddings[i] = query_lens[i] - cu_q_lens = jnp.cumsum(jnp.array([0]+q_lens_with_paddings)) + cu_q_lens = jnp.cumsum(jnp.array([0] + q_lens_with_paddings)) - with self.assertRaisesRegex(ValueError, "cur_q_len must be less or equal to cur_kv_len"): - err, _ = ragged_paged_attention(queries, + with self.assertRaisesRegex( + ValueError, "cur_q_len must be less or equal to cur_kv_len"): + err, _ = ragged_paged_attention( + queries, k_pages, v_pages, kv_lens_np, @@ -377,14 +391,12 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len( ) err.throw() - def test_paged_attention_extreme_one_tokens_per_sequence_large( - self, - ): + def test_paged_attention_extreme_one_tokens_per_sequence_large(self,): # assuming q_blk_size=128 seq_lens = [] # [(q_len, kv_len),...] num_seqs = 512 for i in range(num_seqs): - seq_lens.append((1, 128+i)) + seq_lens.append((1, 128 + i)) num_heads = (1, 1) head_dim = 128 page_size = 16 @@ -400,21 +412,18 @@ def test_paged_attention_extreme_one_tokens_per_sequence_large( num_pages, ) - def test_make_sequence_metadata( - self, - ): - cu_q_lens = jnp.array([0, 192, 448, 512] + [512]*(512-4)) + def test_make_sequence_metadata(self,): + cu_q_lens = jnp.array([0, 192, 448, 512] + [512] * (512 - 4)) num_q_tokens = 512 num_queries_per_compute_block = 128 start_group = jnp.array([0]) num_seqs = 3 - metadata, num_logical_q_tiles = make_sequence_metadata( + metadata, num_logical_q_tiles = make_sequence_metadata( cu_q_lens=cu_q_lens, m=num_q_tokens, tm=num_queries_per_compute_block, start_sequence=start_group, - num_sequences=num_seqs - ) + num_sequences=num_seqs) seq_ids, physical_q_tile_ids = metadata self.assertEqual(num_logical_q_tiles, 6) self.assertTrue(jnp.array_equal(seq_ids, [0, 0, 1, 1, 1, 2])) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py index 863ef1bf3c8b..bee0d3ca8d43 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -12,9 +12,9 @@ import jax.numpy as jnp import numpy as np - DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + class MultiPageAsyncCopyDescriptor: """Descriptor for async copy of multiple K/V pages from HBM.""" @@ -100,6 +100,7 @@ def _calculate_num_tiles(x: int, tx: int) -> int: raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") return tiles + # https://github.com/jax-ml/jax/blob/9fb29766a2130e74a85cba30420cf777d185ea5a/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L79 def make_sequence_metadata( *, @@ -152,13 +153,12 @@ def make_sequence_metadata( # # NOTE: This does not change sequence_offsets[num_sequences], which is m # (because we enforce m is divisible by tm). - rounded_sequence_ends = ((sequence_ends + tm - 1) // tm * tm).astype(jnp.int32) - + rounded_sequence_ends = ((sequence_ends + tm - 1) // tm * tm).astype( + jnp.int32) # (2) Round the sequence_starts down to the nearest multiple of 'tm'. sequence_starts = jnp.concatenate( - [jnp.zeros(1, dtype=jnp.int32), sequence_ends[:-1]] - ) + [jnp.zeros(1, dtype=jnp.int32), sequence_ends[:-1]]) rounded_sequence_starts = sequence_starts // tm * tm # (3) Calculate the number of rows in each sequence. @@ -216,14 +216,12 @@ def make_sequence_metadata( # partial_tile_mask = ((sequence_offsets[:-1] % tm) == 0) - partial_tile_ids = jnp.where( - partial_tile_mask, tiles_m, sequence_offsets[:-1] // tm - ) + partial_tile_ids = jnp.where(partial_tile_mask, tiles_m, + sequence_offsets[:-1] // tm) tile_visits = ( - jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] - + 1 - ) + jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + + 1) # Create the m-dimension tile ids for each grid index based on the visit # counts for each tile. @@ -246,10 +244,14 @@ def make_sequence_metadata( # # Remove tile visits that belong to a sequence not in our shard. iota = jnp.arange(num_sequences, dtype=jnp.int32) - active_sequence_mask = jnp.logical_and(iota <= end_sequence, iota >= start_sequence) - sequence_tiles = jnp.where(active_sequence_mask, sequence_tiles[:num_sequences], 0) + active_sequence_mask = jnp.logical_and(iota <= end_sequence, + iota >= start_sequence) + sequence_tiles = jnp.where(active_sequence_mask, + sequence_tiles[:num_sequences], 0) num_tiles = sequence_tiles.sum() - return (sequence_ids, m_tile_ids), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles + return (sequence_ids, m_tile_ids + ), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles + def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_block): @@ -268,17 +270,24 @@ def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, raise ValueError("kv_lens.shape[0] must be thet same as num_tokens. Got" f" {kv_lens.shape[0]} and {num_tokens}") if page_indices.shape[0] != num_tokens: - raise ValueError("page_indices.shape[0] must be thet same as num_tokens. Got" - f" {page_indices.shape[0]} and {num_tokens}") + raise ValueError( + "page_indices.shape[0] must be thet same as num_tokens. Got" + f" {page_indices.shape[0]} and {num_tokens}") if cu_q_lens.shape[0] != num_tokens + 1: - raise ValueError("cu_q_lens.shape[0] must be thet same as num_tokens + 1. Got" - f" {cu_q_lens.shape[0]} and {num_tokens + 1}") + raise ValueError( + "cu_q_lens.shape[0] must be thet same as num_tokens + 1. Got" + f" {cu_q_lens.shape[0]} and {num_tokens + 1}") for i in range(num_seqs): - cur_q_len = cu_q_lens[i+1] - cu_q_lens[i] + cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i] cur_kv_len = kv_lens[i] - checkify.check(cur_q_len <= cur_kv_len, "cur_q_len must be less or equal to cur_kv_len. Got {} and {}", cur_q_len, cur_kv_len) + checkify.check( + cur_q_len <= cur_kv_len, + "cur_q_len must be less or equal to cur_kv_len. Got {} and {}", + cur_q_len, cur_kv_len) if num_seqs > num_tokens: - raise ValueError(f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}") + raise ValueError( + f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}" + ) # int16: will pack. need to explicit cast to int32. int64 is not supported in Pallas. for smem 1d case. # 2d smem: int16 will be packed with an empty. So we didn't save any memory. # scalar: use i32 (1, N). int16 for (1, N) will be padding. Need to use (2, N). @@ -299,6 +308,7 @@ def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, "Number of Q heads must be divisible by number of KV heads. Got" f" {num_q_heads} and {num_kv_heads}.") + # https://github.com/jax-ml/jax/blob/e3b3b913f7bcec3767e1442ace08999413f8703d/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L269C1-L283C64 def _get_store_mask( *, @@ -317,6 +327,7 @@ def _get_store_mask( iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id return jnp.logical_and(iota >= group_start, iota < group_end) + def _flash_attention( q_head_idx_per_kv, # scalar, ranges from 0 to num_query_heads_per_kv_head group_metadata_ref, # (seq_ids, physical_q_tile_ids) @@ -344,7 +355,8 @@ def _flash_attention( head_dim: int, num_q_heads_per_kv_head: int, ): - assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_block, head_dim) + assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_block, + head_dim) kv_blk_size = page_size * num_kv_pages_per_block assert k.shape == (kv_blk_size, head_dim) assert v.shape == (kv_blk_size, head_dim) @@ -357,13 +369,17 @@ def _flash_attention( seq_ids, physical_q_tile_ids = group_metadata_ref # If the q-dim physical tile is changed (meaning it is a new physical q-dim tile that has not visited before), initialize the acc_scratch_ref, m_scratch_ref, and l_scratch_ref to run the flash attention v2 algorithm. - prev_logical_q_blk_idx = jnp.where(logical_q_blk_idx > 0, logical_q_blk_idx - 1, 0) + prev_logical_q_blk_idx = jnp.where(logical_q_blk_idx > 0, + logical_q_blk_idx - 1, 0) is_first_processed_logical_q_blk = logical_q_blk_idx == 0 - physical_q_blk_changed = (physical_q_tile_ids[logical_q_blk_idx] != physical_q_tile_ids[prev_logical_q_blk_idx]) - first_time_seeing_physical_q_blk = jnp.logical_or(is_first_processed_logical_q_blk, physical_q_blk_changed) + physical_q_blk_changed = ( + physical_q_tile_ids[logical_q_blk_idx] != + physical_q_tile_ids[prev_logical_q_blk_idx]) + first_time_seeing_physical_q_blk = jnp.logical_or( + is_first_processed_logical_q_blk, physical_q_blk_changed) is_first_kv_blk = (kv_blk_idx == 0) should_init_scratch_ref = jnp.logical_and(is_first_kv_blk, - first_time_seeing_physical_q_blk) + first_time_seeing_physical_q_blk) @pl.when(should_init_scratch_ref) def init_scratch_ref(): # pylint: disable=unused-variable @@ -374,8 +390,10 @@ def init_scratch_ref(): # pylint: disable=unused-variable acc_scratch_ref[q_head_idx_per_kv] = jnp.zeros( acc_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) - m_prev = m_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] - l_prev = l_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] + m_prev = m_scratch_ref[ + q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] + l_prev = l_scratch_ref[ + q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] # Load the whole q_block that belongs to the current physical q_blk and compute the attention. When we write, we only write the part that belongs to the current sequence. # Cannot just load only the part of q_block that belongs to the current sequence, because it results in dynamic shapes and then fails the JIT compilation. @@ -390,22 +408,19 @@ def init_scratch_ref(): # pylint: disable=unused-variable # Modify the mask accordingly: first form the mask. Then move the mask up/down to the right place. cur_seq_idx = seq_ids[logical_q_blk_idx] cur_seq_start = effective_cu_q_lens_ref[cur_seq_idx] - cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx+1] + cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx + 1] physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] - q_index = physical_q_blk_idx*num_queries_per_block-cur_seq_start + q_index = physical_q_blk_idx * num_queries_per_block - cur_seq_start kv_index = kv_blk_idx * kv_blk_size effective_kv_len = effective_kv_lens_ref[cur_seq_idx] effective_q_len = cur_seq_end - cur_seq_start - row_ids = ( - effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( - jnp.int32, - (num_queries_per_block, kv_blk_size), 0) + row_ids = (effective_kv_len - + effective_q_len) + q_index + jax.lax.broadcasted_iota( + jnp.int32, (num_queries_per_block, kv_blk_size), 0) col_ids = kv_index + jax.lax.broadcasted_iota( - jnp.int32, - (num_queries_per_block, kv_blk_size), 1) + jnp.int32, (num_queries_per_block, kv_blk_size), 1) causal_mask = jnp.where(row_ids < col_ids, mask_value, 0.) - assert causal_mask.shape == (num_queries_per_block, - kv_blk_size) + assert causal_mask.shape == (num_queries_per_block, kv_blk_size) s = s + causal_mask # [block_q, block_k] @@ -415,8 +430,7 @@ def init_scratch_ref(): # pylint: disable=unused-variable block_k_repeats, rem = divmod(kv_blk_size, MIN_BLOCK_SIZE) if rem: raise NotImplementedError( - f"{kv_blk_size=} should be a multiple of {MIN_BLOCK_SIZE}" - ) + f"{kv_blk_size=} should be a multiple of {MIN_BLOCK_SIZE}") p = jnp.exp( s - pltpu.repeat(m_next, block_k_repeats, 1)) # Shape [block_q, block_k] @@ -437,44 +451,60 @@ def init_scratch_ref(): # pylint: disable=unused-variable # Need to store these l_next and m_next which will relay to the output. # But only update the part that belongs to the current sequence we are working on. - lm_mask = _get_store_mask(grid_id=logical_q_blk_idx, - group_offsets=effective_cu_q_lens_ref, - group_ids=seq_ids, - m_tile_ids=physical_q_tile_ids, - tm=num_queries_per_block, - tn=MIN_BLOCK_SIZE, - ) + lm_mask = _get_store_mask( + grid_id=logical_q_blk_idx, + group_offsets=effective_cu_q_lens_ref, + group_ids=seq_ids, + m_tile_ids=physical_q_tile_ids, + tm=num_queries_per_block, + tn=MIN_BLOCK_SIZE, + ) # Either jax.lax.select or jnp.where works here. - l_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], l_next, l_scratch_ref[q_head_idx_per_kv]) - m_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], m_next, m_scratch_ref[q_head_idx_per_kv]) + l_scratch_ref[q_head_idx_per_kv] = jax.lax.select( + lm_mask[...], l_next, l_scratch_ref[q_head_idx_per_kv]) + m_scratch_ref[q_head_idx_per_kv] = jax.lax.select( + lm_mask[...], m_next, m_scratch_ref[q_head_idx_per_kv]) l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) # [block_q, 128] - temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast(l_corr * l_next_inv_safe) - acc_mask = _get_store_mask(grid_id=logical_q_blk_idx, - group_offsets=effective_cu_q_lens_ref, - group_ids=seq_ids, - m_tile_ids=physical_q_tile_ids, - tm=num_queries_per_block, - tn=head_dim, - ) - acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) + temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast( + l_corr * l_next_inv_safe) + acc_mask = _get_store_mask( + grid_id=logical_q_blk_idx, + group_offsets=effective_cu_q_lens_ref, + group_ids=seq_ids, + m_tile_ids=physical_q_tile_ids, + tm=num_queries_per_block, + tn=head_dim, + ) + acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select( + acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) o_curr = jax.lax.dot( p.astype(v.dtype), v, preferred_element_type=jnp.float32) # [block_q, 128] - temp = (acc_scratch_ref[q_head_idx_per_kv] + o_curr * l_broadcast(l_next_inv_safe)) - acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) + temp = ( + acc_scratch_ref[q_head_idx_per_kv] + + o_curr * l_broadcast(l_next_inv_safe)) + acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select( + acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) # Store the result from VMEM to HBM only when it is the last kv_block and the next q-dim logical tile belongs to a different q-dim physical tile. - is_last_kv_blk_idx = (kv_blk_idx == (pl.cdiv(effective_kv_len, kv_blk_size) - 1)) - num_logical_q_blks = pl.num_programs(1) # grid=(num_kv_heads, num_logical_q_tiles, num_kv_blks) - next_logical_q_blk_idx = jnp.where(logical_q_blk_idx == num_logical_q_blks - 1, - logical_q_blk_idx, - logical_q_blk_idx+1) - is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks-1) - physical_q_blk_will_change = (physical_q_tile_ids[logical_q_blk_idx] != physical_q_tile_ids[next_logical_q_blk_idx]) - last_time_seeing_cur_physical_q_blk = jnp.logical_or(is_last_logical_q_blk, physical_q_blk_will_change) - should_store_to_hbm = jnp.logical_and(is_last_kv_blk_idx, last_time_seeing_cur_physical_q_blk) + is_last_kv_blk_idx = ( + kv_blk_idx == (pl.cdiv(effective_kv_len, kv_blk_size) - 1)) + num_logical_q_blks = pl.num_programs( + 1) # grid=(num_kv_heads, num_logical_q_tiles, num_kv_blks) + next_logical_q_blk_idx = jnp.where( + logical_q_blk_idx == num_logical_q_blks - 1, logical_q_blk_idx, + logical_q_blk_idx + 1) + is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks - 1) + physical_q_blk_will_change = ( + physical_q_tile_ids[logical_q_blk_idx] != + physical_q_tile_ids[next_logical_q_blk_idx]) + last_time_seeing_cur_physical_q_blk = jnp.logical_or( + is_last_logical_q_blk, physical_q_blk_will_change) + should_store_to_hbm = jnp.logical_and(is_last_kv_blk_idx, + last_time_seeing_cur_physical_q_blk) + @pl.when(should_store_to_hbm) def store_to_hbm(): # pylint: disable=unused-variable o_ref[q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( @@ -484,6 +514,7 @@ def store_to_hbm(): # pylint: disable=unused-variable m_ref[q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( m_ref.dtype) + def paged_flash_attention_kernel( # prefetch refs group_metadata_ref, # (seq_ids, physical_q_tile_ids) @@ -563,7 +594,7 @@ def advance_logical_q_blk_idx(): cur_seq_idx = seq_ids[logical_q_blk_idx] effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] return lax.cond( - kv_blk_idx*kv_blk_size < effective_kv_len_cur_seq, + kv_blk_idx * kv_blk_size < effective_kv_len_cur_seq, lambda: (kv_head_idx, logical_q_blk_idx, kv_blk_idx), advance_logical_q_blk_idx, ) @@ -608,7 +639,8 @@ def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k.start() async_copy_v.start() - next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = compute_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx+1) + next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = compute_block_indices( + kv_head_idx, logical_q_blk_idx, kv_blk_idx + 1) @pl.when(next_kv_head_idx < num_kv_heads) def prefetch_next_block(): # pylint: disable=unused-variable @@ -625,8 +657,8 @@ def prefetch_next_block(): # pylint: disable=unused-variable k = async_copy_k.wait_and_get_loaded( ) # [pages_per_compute_block*page_size,head_dim] v = async_copy_v.wait_and_get_loaded() - assert k.shape == (num_kv_pages_per_block*page_size, head_dim) - assert v.shape == (num_kv_pages_per_block*page_size, head_dim) + assert k.shape == (num_kv_pages_per_block * page_size, head_dim) + assert v.shape == (num_kv_pages_per_block * page_size, head_dim) for q_head_idx in range(num_q_heads_per_kv_head): _flash_attention( @@ -661,6 +693,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable MIN_BLOCK_SIZE = 128 + @checkify.checkify @functools.partial( jax.jit, @@ -737,7 +770,8 @@ def ragged_paged_attention( # out later so that num_q_heads doesn't have to be the 2nd last dimension and hence doesn't subject to the multiple of 8 constraint. q = jnp.permute_dims(q, (1, 0, 2)) # [num_q_heads, num_tokens, head_dim] num_kv_heads, total_num_pages, page_size, head_dim = k_pages.shape - check_kernel_input(q, k_pages, v_pages,kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_block) + check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, + num_seqs, num_kv_pages_per_block) num_q_heads_per_kv_head = num_q_heads // num_kv_heads group_metadata, num_logical_q_tiles = make_sequence_metadata( @@ -762,14 +796,16 @@ def ragged_paged_attention( # in-spec. Note currently q.shape=[num_q_heads, num_tokens, head_dim] # Within the kernel, q.shape should be [num_q_heads_per_kv_head, q_block_size, head_dim] - def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_): + def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, + *_): seq_ids, physical_q_tile_ids = group_metadata del seq_ids physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] return (kv_head_idx, physical_q_blk_idx, 0) + q_block_spec = pl.BlockSpec( - (num_q_heads_per_kv_head, num_queries_per_block, head_dim), - qo_index_map, + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), + qo_index_map, ) in_specs = [ q_block_spec, @@ -803,8 +839,7 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), jnp.float32) acc_scratch = pltpu.VMEM( - (num_q_heads_per_kv_head, num_queries_per_block, head_dim), - jnp.float32) + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), jnp.float32) scratch_shapes = [ pltpu.VMEM( ( @@ -899,7 +934,7 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, *_) # v_scales_pages, # ) # debug compile ends - + outputs = kernel( # prefetch group_metadata, From cd65cc4b107d3153d10cfacd6cf7b02834b53c9e Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 31 Jan 2025 22:38:21 +0000 Subject: [PATCH 7/9] address pr comments --- .../ragged_paged_attention_kernel.py | 249 ++++++++---------- 1 file changed, 110 insertions(+), 139 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py index bee0d3ca8d43..1552100aabb3 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -14,7 +14,6 @@ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) - class MultiPageAsyncCopyDescriptor: """Descriptor for async copy of multiple K/V pages from HBM.""" @@ -100,7 +99,6 @@ def _calculate_num_tiles(x: int, tx: int) -> int: raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") return tiles - # https://github.com/jax-ml/jax/blob/9fb29766a2130e74a85cba30420cf777d185ea5a/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L79 def make_sequence_metadata( *, @@ -110,7 +108,7 @@ def make_sequence_metadata( start_sequence: jnp.ndarray, num_sequences: int, ): - """Create the metadata needed for grouped matmul computation. + """Create the metadata needed for ragged paged attention computation. Args: cu_q_lens: : A 1d, jnp.ndarray with shape [num_seqs+1] and jnp.int32 dtype. @@ -153,12 +151,13 @@ def make_sequence_metadata( # # NOTE: This does not change sequence_offsets[num_sequences], which is m # (because we enforce m is divisible by tm). - rounded_sequence_ends = ((sequence_ends + tm - 1) // tm * tm).astype( - jnp.int32) + rounded_sequence_ends = ((sequence_ends + tm - 1) // tm * tm).astype(jnp.int32) + # (2) Round the sequence_starts down to the nearest multiple of 'tm'. sequence_starts = jnp.concatenate( - [jnp.zeros(1, dtype=jnp.int32), sequence_ends[:-1]]) + [jnp.zeros(1, dtype=jnp.int32), sequence_ends[:-1]] + ) rounded_sequence_starts = sequence_starts // tm * tm # (3) Calculate the number of rows in each sequence. @@ -216,12 +215,14 @@ def make_sequence_metadata( # partial_tile_mask = ((sequence_offsets[:-1] % tm) == 0) - partial_tile_ids = jnp.where(partial_tile_mask, tiles_m, - sequence_offsets[:-1] // tm) + partial_tile_ids = jnp.where( + partial_tile_mask, tiles_m, sequence_offsets[:-1] // tm + ) tile_visits = ( - jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + - 1) + jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + + 1 + ) # Create the m-dimension tile ids for each grid index based on the visit # counts for each tile. @@ -244,14 +245,10 @@ def make_sequence_metadata( # # Remove tile visits that belong to a sequence not in our shard. iota = jnp.arange(num_sequences, dtype=jnp.int32) - active_sequence_mask = jnp.logical_and(iota <= end_sequence, - iota >= start_sequence) - sequence_tiles = jnp.where(active_sequence_mask, - sequence_tiles[:num_sequences], 0) + active_sequence_mask = jnp.logical_and(iota <= end_sequence, iota >= start_sequence) + sequence_tiles = jnp.where(active_sequence_mask, sequence_tiles[:num_sequences], 0) num_tiles = sequence_tiles.sum() - return (sequence_ids, m_tile_ids - ), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles - + return (sequence_ids, m_tile_ids), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_block): @@ -270,24 +267,17 @@ def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, raise ValueError("kv_lens.shape[0] must be thet same as num_tokens. Got" f" {kv_lens.shape[0]} and {num_tokens}") if page_indices.shape[0] != num_tokens: - raise ValueError( - "page_indices.shape[0] must be thet same as num_tokens. Got" - f" {page_indices.shape[0]} and {num_tokens}") + raise ValueError("page_indices.shape[0] must be thet same as num_tokens. Got" + f" {page_indices.shape[0]} and {num_tokens}") if cu_q_lens.shape[0] != num_tokens + 1: - raise ValueError( - "cu_q_lens.shape[0] must be thet same as num_tokens + 1. Got" - f" {cu_q_lens.shape[0]} and {num_tokens + 1}") + raise ValueError("cu_q_lens.shape[0] must be thet same as num_tokens + 1. Got" + f" {cu_q_lens.shape[0]} and {num_tokens + 1}") for i in range(num_seqs): - cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i] + cur_q_len = cu_q_lens[i+1] - cu_q_lens[i] cur_kv_len = kv_lens[i] - checkify.check( - cur_q_len <= cur_kv_len, - "cur_q_len must be less or equal to cur_kv_len. Got {} and {}", - cur_q_len, cur_kv_len) + checkify.check(cur_q_len <= cur_kv_len, "cur_q_len must be less or equal to cur_kv_len. Got {} and {}", cur_q_len, cur_kv_len) if num_seqs > num_tokens: - raise ValueError( - f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}" - ) + raise ValueError(f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}") # int16: will pack. need to explicit cast to int32. int64 is not supported in Pallas. for smem 1d case. # 2d smem: int16 will be packed with an empty. So we didn't save any memory. # scalar: use i32 (1, N). int16 for (1, N) will be padding. Need to use (2, N). @@ -308,29 +298,27 @@ def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, "Number of Q heads must be divisible by number of KV heads. Got" f" {num_q_heads} and {num_kv_heads}.") - # https://github.com/jax-ml/jax/blob/e3b3b913f7bcec3767e1442ace08999413f8703d/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L269C1-L283C64 def _get_store_mask( *, - grid_id: jnp.ndarray, - group_offsets: jnp.ndarray, - group_ids: jnp.ndarray, - m_tile_ids: jnp.ndarray, - tm: int, - tn: int, + logical_q_blk_idx: jnp.ndarray, + sequence_offsets: jnp.ndarray, + sequence_ids: jnp.ndarray, + physical_q_tile_ids: jnp.ndarray, + tq: int, + tk: int, ) -> jnp.ndarray: - """Mask for rows that belong to the current group in the current tile.""" - group_id = group_ids[grid_id] - group_start = group_offsets[group_id] - group_end = group_offsets[group_id + 1] - m_id = m_tile_ids[grid_id] * tm - iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id - return jnp.logical_and(iota >= group_start, iota < group_end) - + """Mask for rows that belong to the current sequence in the current physical q tile.""" + sequence_id = sequence_ids[logical_q_blk_idx] + sequence_start = sequence_offsets[sequence_id] + sequence_end = sequence_offsets[sequence_id + 1] + physical_q_tile_id = physical_q_tile_ids[logical_q_blk_idx] * tq + iota = jax.lax.broadcasted_iota(jnp.int32, (tq, tk), 0) + physical_q_tile_id + return jnp.logical_and(iota >= sequence_start, iota < sequence_end) def _flash_attention( q_head_idx_per_kv, # scalar, ranges from 0 to num_query_heads_per_kv_head - group_metadata_ref, # (seq_ids, physical_q_tile_ids) + sequence_metadata_ref, # (seq_ids, physical_q_tile_ids) effective_kv_lens_ref, # [num_tokens] effective_cu_q_lens_ref, # [num_tokens + 1] # kernel inputs @@ -355,8 +343,7 @@ def _flash_attention( head_dim: int, num_q_heads_per_kv_head: int, ): - assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_block, - head_dim) + assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_block, head_dim) kv_blk_size = page_size * num_kv_pages_per_block assert k.shape == (kv_blk_size, head_dim) assert v.shape == (kv_blk_size, head_dim) @@ -366,20 +353,16 @@ def _flash_attention( pl.program_id(1), pl.program_id(2), ) - seq_ids, physical_q_tile_ids = group_metadata_ref + seq_ids, physical_q_tile_ids = sequence_metadata_ref # If the q-dim physical tile is changed (meaning it is a new physical q-dim tile that has not visited before), initialize the acc_scratch_ref, m_scratch_ref, and l_scratch_ref to run the flash attention v2 algorithm. - prev_logical_q_blk_idx = jnp.where(logical_q_blk_idx > 0, - logical_q_blk_idx - 1, 0) + prev_logical_q_blk_idx = jnp.where(logical_q_blk_idx > 0, logical_q_blk_idx - 1, 0) is_first_processed_logical_q_blk = logical_q_blk_idx == 0 - physical_q_blk_changed = ( - physical_q_tile_ids[logical_q_blk_idx] != - physical_q_tile_ids[prev_logical_q_blk_idx]) - first_time_seeing_physical_q_blk = jnp.logical_or( - is_first_processed_logical_q_blk, physical_q_blk_changed) + physical_q_blk_changed = (physical_q_tile_ids[logical_q_blk_idx] != physical_q_tile_ids[prev_logical_q_blk_idx]) + first_time_seeing_physical_q_blk = jnp.logical_or(is_first_processed_logical_q_blk, physical_q_blk_changed) is_first_kv_blk = (kv_blk_idx == 0) should_init_scratch_ref = jnp.logical_and(is_first_kv_blk, - first_time_seeing_physical_q_blk) + first_time_seeing_physical_q_blk) @pl.when(should_init_scratch_ref) def init_scratch_ref(): # pylint: disable=unused-variable @@ -390,10 +373,8 @@ def init_scratch_ref(): # pylint: disable=unused-variable acc_scratch_ref[q_head_idx_per_kv] = jnp.zeros( acc_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) - m_prev = m_scratch_ref[ - q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] - l_prev = l_scratch_ref[ - q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] + m_prev = m_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] + l_prev = l_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] # Load the whole q_block that belongs to the current physical q_blk and compute the attention. When we write, we only write the part that belongs to the current sequence. # Cannot just load only the part of q_block that belongs to the current sequence, because it results in dynamic shapes and then fails the JIT compilation. @@ -408,19 +389,22 @@ def init_scratch_ref(): # pylint: disable=unused-variable # Modify the mask accordingly: first form the mask. Then move the mask up/down to the right place. cur_seq_idx = seq_ids[logical_q_blk_idx] cur_seq_start = effective_cu_q_lens_ref[cur_seq_idx] - cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx + 1] + cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx+1] physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] - q_index = physical_q_blk_idx * num_queries_per_block - cur_seq_start + q_index = physical_q_blk_idx*num_queries_per_block-cur_seq_start kv_index = kv_blk_idx * kv_blk_size effective_kv_len = effective_kv_lens_ref[cur_seq_idx] effective_q_len = cur_seq_end - cur_seq_start - row_ids = (effective_kv_len - - effective_q_len) + q_index + jax.lax.broadcasted_iota( - jnp.int32, (num_queries_per_block, kv_blk_size), 0) + row_ids = ( + effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( + jnp.int32, + (num_queries_per_block, kv_blk_size), 0) col_ids = kv_index + jax.lax.broadcasted_iota( - jnp.int32, (num_queries_per_block, kv_blk_size), 1) + jnp.int32, + (num_queries_per_block, kv_blk_size), 1) causal_mask = jnp.where(row_ids < col_ids, mask_value, 0.) - assert causal_mask.shape == (num_queries_per_block, kv_blk_size) + assert causal_mask.shape == (num_queries_per_block, + kv_blk_size) s = s + causal_mask # [block_q, block_k] @@ -430,7 +414,8 @@ def init_scratch_ref(): # pylint: disable=unused-variable block_k_repeats, rem = divmod(kv_blk_size, MIN_BLOCK_SIZE) if rem: raise NotImplementedError( - f"{kv_blk_size=} should be a multiple of {MIN_BLOCK_SIZE}") + f"{kv_blk_size=} should be a multiple of {MIN_BLOCK_SIZE}" + ) p = jnp.exp( s - pltpu.repeat(m_next, block_k_repeats, 1)) # Shape [block_q, block_k] @@ -451,60 +436,51 @@ def init_scratch_ref(): # pylint: disable=unused-variable # Need to store these l_next and m_next which will relay to the output. # But only update the part that belongs to the current sequence we are working on. - lm_mask = _get_store_mask( - grid_id=logical_q_blk_idx, - group_offsets=effective_cu_q_lens_ref, - group_ids=seq_ids, - m_tile_ids=physical_q_tile_ids, - tm=num_queries_per_block, - tn=MIN_BLOCK_SIZE, - ) + lm_mask = _get_store_mask(logical_q_blk_idx=logical_q_blk_idx, + sequence_offsets=effective_cu_q_lens_ref, + sequence_ids=seq_ids, + physical_q_tile_ids=physical_q_tile_ids, + tq=num_queries_per_block, + tk=MIN_BLOCK_SIZE, + ) # Either jax.lax.select or jnp.where works here. - l_scratch_ref[q_head_idx_per_kv] = jax.lax.select( - lm_mask[...], l_next, l_scratch_ref[q_head_idx_per_kv]) - m_scratch_ref[q_head_idx_per_kv] = jax.lax.select( - lm_mask[...], m_next, m_scratch_ref[q_head_idx_per_kv]) + pl.store( + l_scratch_ref, + # (q_head_idx_per_kv,), # not working + #(pl.ds(q_head_idx_per_kv, 1), slice(None), slice(None)), # not working + (q_head_idx_per_kv, slice(None), slice(None)), # not working + l_next[...], + mask=lm_mask[...], + ) + m_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], m_next, m_scratch_ref[q_head_idx_per_kv]) l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) # [block_q, 128] - temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast( - l_corr * l_next_inv_safe) - acc_mask = _get_store_mask( - grid_id=logical_q_blk_idx, - group_offsets=effective_cu_q_lens_ref, - group_ids=seq_ids, - m_tile_ids=physical_q_tile_ids, - tm=num_queries_per_block, - tn=head_dim, - ) - acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select( - acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) + temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast(l_corr * l_next_inv_safe) + acc_mask = _get_store_mask(logical_q_blk_idx=logical_q_blk_idx, + sequence_offsets=effective_cu_q_lens_ref, + sequence_ids=seq_ids, + physical_q_tile_ids=physical_q_tile_ids, + tq=num_queries_per_block, + tk=head_dim, + ) + acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) o_curr = jax.lax.dot( p.astype(v.dtype), v, preferred_element_type=jnp.float32) # [block_q, 128] - temp = ( - acc_scratch_ref[q_head_idx_per_kv] + - o_curr * l_broadcast(l_next_inv_safe)) - acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select( - acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) + temp = (acc_scratch_ref[q_head_idx_per_kv] + o_curr * l_broadcast(l_next_inv_safe)) + acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) # Store the result from VMEM to HBM only when it is the last kv_block and the next q-dim logical tile belongs to a different q-dim physical tile. - is_last_kv_blk_idx = ( - kv_blk_idx == (pl.cdiv(effective_kv_len, kv_blk_size) - 1)) - num_logical_q_blks = pl.num_programs( - 1) # grid=(num_kv_heads, num_logical_q_tiles, num_kv_blks) - next_logical_q_blk_idx = jnp.where( - logical_q_blk_idx == num_logical_q_blks - 1, logical_q_blk_idx, - logical_q_blk_idx + 1) - is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks - 1) - physical_q_blk_will_change = ( - physical_q_tile_ids[logical_q_blk_idx] != - physical_q_tile_ids[next_logical_q_blk_idx]) - last_time_seeing_cur_physical_q_blk = jnp.logical_or( - is_last_logical_q_blk, physical_q_blk_will_change) - should_store_to_hbm = jnp.logical_and(is_last_kv_blk_idx, - last_time_seeing_cur_physical_q_blk) - + is_last_kv_blk_idx = (kv_blk_idx == (pl.cdiv(effective_kv_len, kv_blk_size) - 1)) + num_logical_q_blks = pl.num_programs(1) # grid=(num_kv_heads, num_logical_q_tiles, num_kv_blks) + next_logical_q_blk_idx = jnp.where(logical_q_blk_idx == num_logical_q_blks - 1, + logical_q_blk_idx, + logical_q_blk_idx+1) + is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks-1) + physical_q_blk_will_change = (physical_q_tile_ids[logical_q_blk_idx] != physical_q_tile_ids[next_logical_q_blk_idx]) + last_time_seeing_cur_physical_q_blk = jnp.logical_or(is_last_logical_q_blk, physical_q_blk_will_change) + should_store_to_hbm = jnp.logical_and(is_last_kv_blk_idx, last_time_seeing_cur_physical_q_blk) @pl.when(should_store_to_hbm) def store_to_hbm(): # pylint: disable=unused-variable o_ref[q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( @@ -514,10 +490,9 @@ def store_to_hbm(): # pylint: disable=unused-variable m_ref[q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( m_ref.dtype) - def paged_flash_attention_kernel( # prefetch refs - group_metadata_ref, # (seq_ids, physical_q_tile_ids) + sequence_metadata_ref, # (seq_ids, physical_q_tile_ids) effective_kv_lens_ref, # [num_tokens] # 1d vector, results from page_indices.reshape(-1) where originally page_indices.shape=[num_tokens, pages_per_sequence] page_indices_1d_ref, @@ -563,7 +538,7 @@ def paged_flash_attention_kernel( num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape kv_blk_size = page_size * num_kv_pages_per_block - seq_ids, physical_q_tile_ids = group_metadata_ref + seq_ids, physical_q_tile_ids = sequence_metadata_ref cur_seq_idx = seq_ids[logical_q_blk_idx] effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] should_run = (kv_blk_idx * kv_blk_size < effective_kv_len_cur_seq) @@ -594,7 +569,7 @@ def advance_logical_q_blk_idx(): cur_seq_idx = seq_ids[logical_q_blk_idx] effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] return lax.cond( - kv_blk_idx * kv_blk_size < effective_kv_len_cur_seq, + kv_blk_idx*kv_blk_size < effective_kv_len_cur_seq, lambda: (kv_head_idx, logical_q_blk_idx, kv_blk_idx), advance_logical_q_blk_idx, ) @@ -639,8 +614,7 @@ def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k.start() async_copy_v.start() - next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = compute_block_indices( - kv_head_idx, logical_q_blk_idx, kv_blk_idx + 1) + next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = compute_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx+1) @pl.when(next_kv_head_idx < num_kv_heads) def prefetch_next_block(): # pylint: disable=unused-variable @@ -657,13 +631,13 @@ def prefetch_next_block(): # pylint: disable=unused-variable k = async_copy_k.wait_and_get_loaded( ) # [pages_per_compute_block*page_size,head_dim] v = async_copy_v.wait_and_get_loaded() - assert k.shape == (num_kv_pages_per_block * page_size, head_dim) - assert v.shape == (num_kv_pages_per_block * page_size, head_dim) + assert k.shape == (num_kv_pages_per_block*page_size, head_dim) + assert v.shape == (num_kv_pages_per_block*page_size, head_dim) for q_head_idx in range(num_q_heads_per_kv_head): _flash_attention( q_head_idx, - group_metadata_ref, + sequence_metadata_ref, effective_kv_lens_ref, effective_cu_q_lens_ref, # kernel inputs @@ -693,7 +667,6 @@ def prefetch_next_block(): # pylint: disable=unused-variable MIN_BLOCK_SIZE = 128 - @checkify.checkify @functools.partial( jax.jit, @@ -770,11 +743,10 @@ def ragged_paged_attention( # out later so that num_q_heads doesn't have to be the 2nd last dimension and hence doesn't subject to the multiple of 8 constraint. q = jnp.permute_dims(q, (1, 0, 2)) # [num_q_heads, num_tokens, head_dim] num_kv_heads, total_num_pages, page_size, head_dim = k_pages.shape - check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, - num_seqs, num_kv_pages_per_block) + check_kernel_input(q, k_pages, v_pages,kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_block) num_q_heads_per_kv_head = num_q_heads // num_kv_heads - group_metadata, num_logical_q_tiles = make_sequence_metadata( + sequence_metadata, num_logical_q_tiles = make_sequence_metadata( cu_q_lens=cu_q_lens, m=num_tokens, tm=num_queries_per_block, @@ -796,16 +768,14 @@ def ragged_paged_attention( # in-spec. Note currently q.shape=[num_q_heads, num_tokens, head_dim] # Within the kernel, q.shape should be [num_q_heads_per_kv_head, q_block_size, head_dim] - def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, - *_): - seq_ids, physical_q_tile_ids = group_metadata + def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, sequence_metadata, *_): + seq_ids, physical_q_tile_ids = sequence_metadata del seq_ids physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] return (kv_head_idx, physical_q_blk_idx, 0) - q_block_spec = pl.BlockSpec( - (num_q_heads_per_kv_head, num_queries_per_block, head_dim), - qo_index_map, + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), + qo_index_map, ) in_specs = [ q_block_spec, @@ -839,7 +809,8 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), jnp.float32) acc_scratch = pltpu.VMEM( - (num_q_heads_per_kv_head, num_queries_per_block, head_dim), jnp.float32) + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), + jnp.float32) scratch_shapes = [ pltpu.VMEM( ( @@ -903,7 +874,7 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, # jax.jit(kernel) # .lower( # # prefetch - # group_metadata, + # sequence_metadata, # kv_lens, # page_indices_1d, # cu_q_lens, @@ -920,7 +891,7 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, # ) # outputs = compiled_kernel( # # prefetch - # group_metadata, + # sequence_metadata, # kv_lens, # page_indices_1d, # cu_q_lens, @@ -934,10 +905,10 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, group_metadata, # v_scales_pages, # ) # debug compile ends - + outputs = kernel( # prefetch - group_metadata, + sequence_metadata, kv_lens, page_indices_1d, cu_q_lens, From 7fe50717c5826220172b5c08bf6566cffc0cee70 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Sat, 1 Feb 2025 00:31:11 +0000 Subject: [PATCH 8/9] fix the rest of comments --- .../ragged_paged_attention_kernel.py | 248 +++++++++--------- 1 file changed, 123 insertions(+), 125 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py index 1552100aabb3..cfcc04436969 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -14,6 +14,7 @@ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + class MultiPageAsyncCopyDescriptor: """Descriptor for async copy of multiple K/V pages from HBM.""" @@ -99,6 +100,7 @@ def _calculate_num_tiles(x: int, tx: int) -> int: raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") return tiles + # https://github.com/jax-ml/jax/blob/9fb29766a2130e74a85cba30420cf777d185ea5a/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L79 def make_sequence_metadata( *, @@ -151,13 +153,12 @@ def make_sequence_metadata( # # NOTE: This does not change sequence_offsets[num_sequences], which is m # (because we enforce m is divisible by tm). - rounded_sequence_ends = ((sequence_ends + tm - 1) // tm * tm).astype(jnp.int32) - + rounded_sequence_ends = ((sequence_ends + tm - 1) // tm * tm).astype( + jnp.int32) # (2) Round the sequence_starts down to the nearest multiple of 'tm'. sequence_starts = jnp.concatenate( - [jnp.zeros(1, dtype=jnp.int32), sequence_ends[:-1]] - ) + [jnp.zeros(1, dtype=jnp.int32), sequence_ends[:-1]]) rounded_sequence_starts = sequence_starts // tm * tm # (3) Calculate the number of rows in each sequence. @@ -215,14 +216,12 @@ def make_sequence_metadata( # partial_tile_mask = ((sequence_offsets[:-1] % tm) == 0) - partial_tile_ids = jnp.where( - partial_tile_mask, tiles_m, sequence_offsets[:-1] // tm - ) + partial_tile_ids = jnp.where(partial_tile_mask, tiles_m, + sequence_offsets[:-1] // tm) tile_visits = ( - jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] - + 1 - ) + jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + + 1) # Create the m-dimension tile ids for each grid index based on the visit # counts for each tile. @@ -245,10 +244,14 @@ def make_sequence_metadata( # # Remove tile visits that belong to a sequence not in our shard. iota = jnp.arange(num_sequences, dtype=jnp.int32) - active_sequence_mask = jnp.logical_and(iota <= end_sequence, iota >= start_sequence) - sequence_tiles = jnp.where(active_sequence_mask, sequence_tiles[:num_sequences], 0) + active_sequence_mask = jnp.logical_and(iota <= end_sequence, + iota >= start_sequence) + sequence_tiles = jnp.where(active_sequence_mask, + sequence_tiles[:num_sequences], 0) num_tiles = sequence_tiles.sum() - return (sequence_ids, m_tile_ids), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles + return (sequence_ids, m_tile_ids + ), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles + def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_block): @@ -264,23 +267,26 @@ def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, raise ValueError("head_dim of Q must be the same as that of K/V. Got" f" {head_dim} and {head_dim_k}.") if kv_lens.shape[0] != num_tokens: - raise ValueError("kv_lens.shape[0] must be thet same as num_tokens. Got" + raise ValueError("kv_lens.shape[0] must be the same as num_tokens. Got" f" {kv_lens.shape[0]} and {num_tokens}") if page_indices.shape[0] != num_tokens: - raise ValueError("page_indices.shape[0] must be thet same as num_tokens. Got" + raise ValueError("page_indices.shape[0] must be the same as num_tokens. Got" f" {page_indices.shape[0]} and {num_tokens}") if cu_q_lens.shape[0] != num_tokens + 1: - raise ValueError("cu_q_lens.shape[0] must be thet same as num_tokens + 1. Got" - f" {cu_q_lens.shape[0]} and {num_tokens + 1}") + raise ValueError( + "cu_q_lens.shape[0] must be the same as num_tokens + 1. Got" + f" {cu_q_lens.shape[0]} and {num_tokens + 1}") for i in range(num_seqs): - cur_q_len = cu_q_lens[i+1] - cu_q_lens[i] + cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i] cur_kv_len = kv_lens[i] - checkify.check(cur_q_len <= cur_kv_len, "cur_q_len must be less or equal to cur_kv_len. Got {} and {}", cur_q_len, cur_kv_len) + checkify.check( + cur_q_len <= cur_kv_len, + "cur_q_len must be less or equal to cur_kv_len. Got {} and {}", + cur_q_len, cur_kv_len) if num_seqs > num_tokens: - raise ValueError(f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}") - # int16: will pack. need to explicit cast to int32. int64 is not supported in Pallas. for smem 1d case. - # 2d smem: int16 will be packed with an empty. So we didn't save any memory. - # scalar: use i32 (1, N). int16 for (1, N) will be padding. Need to use (2, N). + raise ValueError( + f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}" + ) if kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or cu_q_lens.dtype != jnp.int32: raise ValueError( f"The dtype of `lengths` must be int32. Got {kv_lens.dtype=}, " @@ -298,6 +304,7 @@ def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, "Number of Q heads must be divisible by number of KV heads. Got" f" {num_q_heads} and {num_kv_heads}.") + # https://github.com/jax-ml/jax/blob/e3b3b913f7bcec3767e1442ace08999413f8703d/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L269C1-L283C64 def _get_store_mask( *, @@ -316,6 +323,7 @@ def _get_store_mask( iota = jax.lax.broadcasted_iota(jnp.int32, (tq, tk), 0) + physical_q_tile_id return jnp.logical_and(iota >= sequence_start, iota < sequence_end) + def _flash_attention( q_head_idx_per_kv, # scalar, ranges from 0 to num_query_heads_per_kv_head sequence_metadata_ref, # (seq_ids, physical_q_tile_ids) @@ -343,7 +351,8 @@ def _flash_attention( head_dim: int, num_q_heads_per_kv_head: int, ): - assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_block, head_dim) + assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_block, + head_dim) kv_blk_size = page_size * num_kv_pages_per_block assert k.shape == (kv_blk_size, head_dim) assert v.shape == (kv_blk_size, head_dim) @@ -356,13 +365,17 @@ def _flash_attention( seq_ids, physical_q_tile_ids = sequence_metadata_ref # If the q-dim physical tile is changed (meaning it is a new physical q-dim tile that has not visited before), initialize the acc_scratch_ref, m_scratch_ref, and l_scratch_ref to run the flash attention v2 algorithm. - prev_logical_q_blk_idx = jnp.where(logical_q_blk_idx > 0, logical_q_blk_idx - 1, 0) + prev_logical_q_blk_idx = jnp.where(logical_q_blk_idx > 0, + logical_q_blk_idx - 1, 0) is_first_processed_logical_q_blk = logical_q_blk_idx == 0 - physical_q_blk_changed = (physical_q_tile_ids[logical_q_blk_idx] != physical_q_tile_ids[prev_logical_q_blk_idx]) - first_time_seeing_physical_q_blk = jnp.logical_or(is_first_processed_logical_q_blk, physical_q_blk_changed) + physical_q_blk_changed = ( + physical_q_tile_ids[logical_q_blk_idx] != + physical_q_tile_ids[prev_logical_q_blk_idx]) + first_time_seeing_physical_q_blk = jnp.logical_or( + is_first_processed_logical_q_blk, physical_q_blk_changed) is_first_kv_blk = (kv_blk_idx == 0) should_init_scratch_ref = jnp.logical_and(is_first_kv_blk, - first_time_seeing_physical_q_blk) + first_time_seeing_physical_q_blk) @pl.when(should_init_scratch_ref) def init_scratch_ref(): # pylint: disable=unused-variable @@ -373,8 +386,10 @@ def init_scratch_ref(): # pylint: disable=unused-variable acc_scratch_ref[q_head_idx_per_kv] = jnp.zeros( acc_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) - m_prev = m_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] - l_prev = l_scratch_ref[q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] + m_prev = m_scratch_ref[ + q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] + l_prev = l_scratch_ref[ + q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] # Load the whole q_block that belongs to the current physical q_blk and compute the attention. When we write, we only write the part that belongs to the current sequence. # Cannot just load only the part of q_block that belongs to the current sequence, because it results in dynamic shapes and then fails the JIT compilation. @@ -389,22 +404,19 @@ def init_scratch_ref(): # pylint: disable=unused-variable # Modify the mask accordingly: first form the mask. Then move the mask up/down to the right place. cur_seq_idx = seq_ids[logical_q_blk_idx] cur_seq_start = effective_cu_q_lens_ref[cur_seq_idx] - cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx+1] + cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx + 1] physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] - q_index = physical_q_blk_idx*num_queries_per_block-cur_seq_start + q_index = physical_q_blk_idx * num_queries_per_block - cur_seq_start kv_index = kv_blk_idx * kv_blk_size effective_kv_len = effective_kv_lens_ref[cur_seq_idx] effective_q_len = cur_seq_end - cur_seq_start - row_ids = ( - effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( - jnp.int32, - (num_queries_per_block, kv_blk_size), 0) + row_ids = (effective_kv_len - + effective_q_len) + q_index + jax.lax.broadcasted_iota( + jnp.int32, (num_queries_per_block, kv_blk_size), 0) col_ids = kv_index + jax.lax.broadcasted_iota( - jnp.int32, - (num_queries_per_block, kv_blk_size), 1) + jnp.int32, (num_queries_per_block, kv_blk_size), 1) causal_mask = jnp.where(row_ids < col_ids, mask_value, 0.) - assert causal_mask.shape == (num_queries_per_block, - kv_blk_size) + assert causal_mask.shape == (num_queries_per_block, kv_blk_size) s = s + causal_mask # [block_q, block_k] @@ -414,8 +426,7 @@ def init_scratch_ref(): # pylint: disable=unused-variable block_k_repeats, rem = divmod(kv_blk_size, MIN_BLOCK_SIZE) if rem: raise NotImplementedError( - f"{kv_blk_size=} should be a multiple of {MIN_BLOCK_SIZE}" - ) + f"{kv_blk_size=} should be a multiple of {MIN_BLOCK_SIZE}") p = jnp.exp( s - pltpu.repeat(m_next, block_k_repeats, 1)) # Shape [block_q, block_k] @@ -436,53 +447,70 @@ def init_scratch_ref(): # pylint: disable=unused-variable # Need to store these l_next and m_next which will relay to the output. # But only update the part that belongs to the current sequence we are working on. - lm_mask = _get_store_mask(logical_q_blk_idx=logical_q_blk_idx, - sequence_offsets=effective_cu_q_lens_ref, - sequence_ids=seq_ids, - physical_q_tile_ids=physical_q_tile_ids, - tq=num_queries_per_block, - tk=MIN_BLOCK_SIZE, - ) + lm_mask = _get_store_mask( + logical_q_blk_idx=logical_q_blk_idx, + sequence_offsets=effective_cu_q_lens_ref, + sequence_ids=seq_ids, + physical_q_tile_ids=physical_q_tile_ids, + tq=num_queries_per_block, + tk=MIN_BLOCK_SIZE, + ) # Either jax.lax.select or jnp.where works here. pl.store( l_scratch_ref, - # (q_head_idx_per_kv,), # not working - #(pl.ds(q_head_idx_per_kv, 1), slice(None), slice(None)), # not working - (q_head_idx_per_kv, slice(None), slice(None)), # not working - l_next[...], - mask=lm_mask[...], + (pl.ds(q_head_idx_per_kv, 1), slice(None), slice(None)), + l_next.reshape(1, *l_next.shape), # no-op here. + mask=lm_mask.reshape(1, *lm_mask.shape), + ) + pl.store( + m_scratch_ref, + (pl.ds(q_head_idx_per_kv, 1), slice(None), slice(None)), + m_next.reshape(1, *m_next.shape), + mask=lm_mask.reshape(1, *lm_mask.shape), ) - m_scratch_ref[q_head_idx_per_kv] = jax.lax.select(lm_mask[...], m_next, m_scratch_ref[q_head_idx_per_kv]) l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) # [block_q, 128] - temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast(l_corr * l_next_inv_safe) - acc_mask = _get_store_mask(logical_q_blk_idx=logical_q_blk_idx, - sequence_offsets=effective_cu_q_lens_ref, - sequence_ids=seq_ids, - physical_q_tile_ids=physical_q_tile_ids, - tq=num_queries_per_block, - tk=head_dim, - ) - acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) + temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast( + l_corr * l_next_inv_safe) o_curr = jax.lax.dot( p.astype(v.dtype), v, preferred_element_type=jnp.float32) # [block_q, 128] - temp = (acc_scratch_ref[q_head_idx_per_kv] + o_curr * l_broadcast(l_next_inv_safe)) - acc_scratch_ref[q_head_idx_per_kv] = jax.lax.select(acc_mask[...], temp, acc_scratch_ref[q_head_idx_per_kv]) + temp += o_curr * l_broadcast(l_next_inv_safe) + acc_mask = _get_store_mask( + logical_q_blk_idx=logical_q_blk_idx, + sequence_offsets=effective_cu_q_lens_ref, + sequence_ids=seq_ids, + physical_q_tile_ids=physical_q_tile_ids, + tq=num_queries_per_block, + tk=head_dim, + ) + pl.store( + acc_scratch_ref, + (pl.ds(q_head_idx_per_kv, 1), slice(None), slice(None)), + temp.reshape(1, *temp.shape), + mask=acc_mask.reshape(1, *acc_mask.shape), + ) # Store the result from VMEM to HBM only when it is the last kv_block and the next q-dim logical tile belongs to a different q-dim physical tile. - is_last_kv_blk_idx = (kv_blk_idx == (pl.cdiv(effective_kv_len, kv_blk_size) - 1)) - num_logical_q_blks = pl.num_programs(1) # grid=(num_kv_heads, num_logical_q_tiles, num_kv_blks) - next_logical_q_blk_idx = jnp.where(logical_q_blk_idx == num_logical_q_blks - 1, - logical_q_blk_idx, - logical_q_blk_idx+1) - is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks-1) - physical_q_blk_will_change = (physical_q_tile_ids[logical_q_blk_idx] != physical_q_tile_ids[next_logical_q_blk_idx]) - last_time_seeing_cur_physical_q_blk = jnp.logical_or(is_last_logical_q_blk, physical_q_blk_will_change) - should_store_to_hbm = jnp.logical_and(is_last_kv_blk_idx, last_time_seeing_cur_physical_q_blk) - @pl.when(should_store_to_hbm) - def store_to_hbm(): # pylint: disable=unused-variable + is_last_kv_blk_idx = ( + kv_blk_idx == (pl.cdiv(effective_kv_len, kv_blk_size) - 1)) + num_logical_q_blks = pl.num_programs( + 1) # grid=(num_kv_heads, num_logical_q_tiles, num_kv_blks) + next_logical_q_blk_idx = jnp.where( + logical_q_blk_idx == num_logical_q_blks - 1, logical_q_blk_idx, + logical_q_blk_idx + 1) + is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks - 1) + physical_q_blk_will_change = ( + physical_q_tile_ids[logical_q_blk_idx] != + physical_q_tile_ids[next_logical_q_blk_idx]) + last_time_seeing_cur_physical_q_blk = jnp.logical_or( + is_last_logical_q_blk, physical_q_blk_will_change) + should_store_to_output = jnp.logical_and(is_last_kv_blk_idx, + last_time_seeing_cur_physical_q_blk) + + @pl.when(should_store_to_output) + def store_to_output(): # pylint: disable=unused-variable o_ref[q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( o_ref.dtype) l_ref[q_head_idx_per_kv] = l_scratch_ref[q_head_idx_per_kv].astype( @@ -490,6 +518,7 @@ def store_to_hbm(): # pylint: disable=unused-variable m_ref[q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( m_ref.dtype) + def paged_flash_attention_kernel( # prefetch refs sequence_metadata_ref, # (seq_ids, physical_q_tile_ids) @@ -569,7 +598,7 @@ def advance_logical_q_blk_idx(): cur_seq_idx = seq_ids[logical_q_blk_idx] effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] return lax.cond( - kv_blk_idx*kv_blk_size < effective_kv_len_cur_seq, + kv_blk_idx * kv_blk_size < effective_kv_len_cur_seq, lambda: (kv_head_idx, logical_q_blk_idx, kv_blk_idx), advance_logical_q_blk_idx, ) @@ -614,7 +643,8 @@ def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k.start() async_copy_v.start() - next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = compute_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx+1) + next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = compute_block_indices( + kv_head_idx, logical_q_blk_idx, kv_blk_idx + 1) @pl.when(next_kv_head_idx < num_kv_heads) def prefetch_next_block(): # pylint: disable=unused-variable @@ -631,8 +661,8 @@ def prefetch_next_block(): # pylint: disable=unused-variable k = async_copy_k.wait_and_get_loaded( ) # [pages_per_compute_block*page_size,head_dim] v = async_copy_v.wait_and_get_loaded() - assert k.shape == (num_kv_pages_per_block*page_size, head_dim) - assert v.shape == (num_kv_pages_per_block*page_size, head_dim) + assert k.shape == (num_kv_pages_per_block * page_size, head_dim) + assert v.shape == (num_kv_pages_per_block * page_size, head_dim) for q_head_idx in range(num_q_heads_per_kv_head): _flash_attention( @@ -667,6 +697,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable MIN_BLOCK_SIZE = 128 + @checkify.checkify @functools.partial( jax.jit, @@ -716,6 +747,9 @@ def ragged_paged_attention( num_queries_per_block: how many queries to be processes in one flash attention block in the pallas kernel. + The num_tokens, num_seqs, and pages_per_sequence are dynamic. If they are + very dynamic, then the overhead could be high due to the recompilation. + Returns: The output of attention([num_tokens, num_q_heads, head_dim]). """ @@ -743,7 +777,8 @@ def ragged_paged_attention( # out later so that num_q_heads doesn't have to be the 2nd last dimension and hence doesn't subject to the multiple of 8 constraint. q = jnp.permute_dims(q, (1, 0, 2)) # [num_q_heads, num_tokens, head_dim] num_kv_heads, total_num_pages, page_size, head_dim = k_pages.shape - check_kernel_input(q, k_pages, v_pages,kv_lens, page_indices, cu_q_lens, num_seqs, num_kv_pages_per_block) + check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, + num_seqs, num_kv_pages_per_block) num_q_heads_per_kv_head = num_q_heads // num_kv_heads sequence_metadata, num_logical_q_tiles = make_sequence_metadata( @@ -768,14 +803,16 @@ def ragged_paged_attention( # in-spec. Note currently q.shape=[num_q_heads, num_tokens, head_dim] # Within the kernel, q.shape should be [num_q_heads_per_kv_head, q_block_size, head_dim] - def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, sequence_metadata, *_): + def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, + sequence_metadata, *_): seq_ids, physical_q_tile_ids = sequence_metadata del seq_ids physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] return (kv_head_idx, physical_q_blk_idx, 0) + q_block_spec = pl.BlockSpec( - (num_q_heads_per_kv_head, num_queries_per_block, head_dim), - qo_index_map, + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), + qo_index_map, ) in_specs = [ q_block_spec, @@ -809,8 +846,7 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, sequence_metadata, (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), jnp.float32) acc_scratch = pltpu.VMEM( - (num_q_heads_per_kv_head, num_queries_per_block, head_dim), - jnp.float32) + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), jnp.float32) scratch_shapes = [ pltpu.VMEM( ( @@ -868,44 +904,6 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, sequence_metadata, buffer_index = jnp.zeros((1,), jnp.int32) step = jnp.zeros((1,), jnp.int32) - # debug compile begins - # To enable debug, uncomment this section, comment out the `kernel()` below and comment out the jax.jit above. - # compiled_kernel = ( - # jax.jit(kernel) - # .lower( - # # prefetch - # sequence_metadata, - # kv_lens, - # page_indices_1d, - # cu_q_lens, - # buffer_index, - # step, - # # kernel inputs - # q, - # k_pages, - # k_scales_pages, - # v_pages, - # v_scales_pages, - # ) - # .compile({'xla_tpu_enable_log_recorder': 'true'}) - # ) - # outputs = compiled_kernel( - # # prefetch - # sequence_metadata, - # kv_lens, - # page_indices_1d, - # cu_q_lens, - # buffer_index, - # step, - # # kernel inputs - # q, - # k_pages, - # k_scales_pages, - # v_pages, - # v_scales_pages, - # ) - # debug compile ends - outputs = kernel( # prefetch sequence_metadata, From 51704029be40599805c45358a9e1247bac65bce7 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 3 Feb 2025 17:45:14 +0000 Subject: [PATCH 9/9] Trigger CI