Skip to content

Commit

Permalink
fix end of batch cache slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
sixiang-google committed Feb 20, 2025
1 parent bea1cef commit 32171f8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
6 changes: 3 additions & 3 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,10 @@ def decode():
else:
assert False, "no generate fn"
result_tokens_l = []
for i in range(5):
for i in range(10):
self.decode_state, result_tokens = gen_fn(self.params, self.decode_state, None)
result_tokens_l.append(result_tokens)
for i in range(5):
for i in range(10):
# result_tokens.copy_to_host_async()
result_tokens = result_tokens_l[i].convert_to_numpy()
self.detokenize_backlog.put((result_tokens, False, 0, 0), block=True)
Expand Down Expand Up @@ -414,7 +414,7 @@ def detokenize():
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
continue

if len(self.prefill_buckets[padded_len // 2]) == 0:
if len(self.prefill_buckets[padded_len // 2]) != 0:
prefill_batch(self.prefill_buckets[padded_len // 2], padded_len // 2)
self.prefill_buckets[padded_len // 2] = []
if padded_len == self.max_prefill_length:
Expand Down
9 changes: 8 additions & 1 deletion MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,10 @@ def copy(path, partial_cache, full_cache, annotations):
zeros = jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32)
## zero out in case prefill cache is too small to cover
full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, zeros, slot, batch_idx)
# In case partial_cache is too small to slice at the given index, pad it with an extra seqlen
if i == num_prompts - 1:
pad = jnp.zeros((1, seq_len), dtype=int)
partial_cache = jnp.concatenate([partial_cache, pad], axis=1)
## copy prefill cache
partial_cache = jax.lax.dynamic_slice(partial_cache, (0, start_idx), (1, seq_len))
partial_cache = (partial_cache == partial_cache[0, 0]).astype(int)
Expand All @@ -749,8 +753,11 @@ def copy(path, partial_cache, full_cache, annotations):
slice_size[seqlen_index] = seq_len

slice_size = tuple(slice_size)
# Same as in prefill_segment_id processing
if i == num_prompts - 1:
pad = jnp.zeros(slice_size, dtype=partial_cache.dtype)
partial_cache = jnp.concatenate([partial_cache, pad], axis=seqlen_index)
partial_cache = jax.lax.dynamic_slice(partial_cache, start_indices, slice_size)
# jax.debug.print("start_indices: {}, slice_size: {}", start_indices, slice_size)

return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
else:
Expand Down

0 comments on commit 32171f8

Please sign in to comment.