Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpu pallas flash attention for inference. Collobrate with tohaowu. #1292

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import common_types
from kernels.ragged_attention import ragged_gqa
from kernels.ragged_attention import ragged_mha
from jax.experimental.pallas.ops.gpu import attention as gpu_pallas_attention
from layers import embeddings
from layers import initializers
from layers import linears
Expand Down Expand Up @@ -237,17 +238,27 @@ def apply_attention(
):
return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode)
elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected":
if isinstance(key, KVTensor):
key = key.dequant()
if isinstance(value, KVTensor):
value = value.dequant()

if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
"""Decode not supported with flash attention.
Use `dot_product` instead."""
)
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
if jax.devices()[0].platform == "tpu":
if isinstance(key, KVTensor):
key = key.dequant()
if isinstance(value, KVTensor):
value = value.dequant()

if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
"""Decode not supported with flash attention.
Use `dot_product` instead."""
)
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
else:
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
# fallback to dot_product as pallas gpu flash attention doesn't support decode stage
return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode)
else:
key = jnp.repeat(key, self.num_query_heads // self.num_kv_heads, axis=2)
value = jnp.repeat(value, self.num_query_heads // self.num_kv_heads, axis=2)
out = gpu_pallas_attention.mha(query, key, value, decoder_segment_ids, sm_scale=1.0, causal=True)
return out, None, None
elif self.attention_kernel == "cudnn_flash_te":
if isinstance(key, KVTensor):
key = key.dequant()
Expand Down
Loading