Skip to content

Commit

Permalink
Merge pull request #1288 from AI-Hypercomputer:agagik-mla-config
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729271690
  • Loading branch information
maxtext authors committed Feb 20, 2025
2 parents e2a0367 + 4a82219 commit f4484a0
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 12 deletions.
16 changes: 15 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ final_logits_soft_cap: 0.0
use_post_attn_norm: False
use_post_ffw_norm: False

# MLA parameters
q_lora_rank: 0
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128

# Combine matmuls for QKV and MLP
fused_qkv: False
Expand Down Expand Up @@ -461,11 +467,19 @@ use_iota_embed: False
# use positional embedding
use_untrainable_positional_embedding: False
trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size
# Rope parameters
# RoPE parameters
rope_type: "default" # one of "default", "llama3.1" or "yarn"
rope_min_timescale: 1
rope_max_timescale: 10_000

# yarn RoPE parameters
original_seq_len: 4096
rope_theta: 10000.0
rope_factor: 40
beta_fast: 32
beta_slow: 1
mscale: 1.0

# Ahead of time Compilation (aka AOT)
# Only set these arguments if you are running train_compile or loading a compiled train step.
compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g. compiled_train_v5e-256.pickle
Expand Down
10 changes: 8 additions & 2 deletions MaxText/configs/models/deepseek3-671b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ base_mlp_dim: 18432
base_moe_mlp_dim: 2048
base_num_decoder_layers: 61
first_num_dense_layers: 3
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32000 # TODO(b/394635939): update after adding tokenizer
enable_dropout: False
Expand All @@ -34,5 +33,12 @@ shared_experts: 1
routed_scaling_factor: 2.5
routed_score_func: "sigmoid"
routed_bias: True
rope_max_timescale: 10_000
# MLA
attention_type: "mla"
q_lora_rank: 1536
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
rope_type: "yarn"
decoder_block: "deepseek"
13 changes: 10 additions & 3 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def kv_cache(
two tuples of (k, v, decoder_segments) -- either can be Nones
"""
if key.shape != value.shape:
if key.shape != value.shape and self.config.attention_type != AttentionType.MLA.value:
raise ValueError(f"Can't KV cache with mismatched shapes {key.shape=}, {value.shape=}")

if model_mode == common_types.MODEL_MODE_TRAIN:
Expand Down Expand Up @@ -1272,6 +1272,11 @@ def apply_rotary_embedding(self, inputs: Array, inputs_positions: Array, name: s
elif rope_type.startswith("yarn"):
rotary_embedding = YarnRotaryEmbedding(
max_seq_len=self.config.max_target_length,
original_seq_len=self.config.original_seq_len,
beta_fast=self.config.beta_fast,
beta_slow=self.config.beta_slow,
rope_theta=self.config.rope_theta,
rope_factor=self.config.rope_factor,
embedding_dims=rope_embedding_dims,
fprop_dtype=self.dtype,
name=name,
Expand Down Expand Up @@ -1370,7 +1375,7 @@ class MLA(Attention):
max_seq_len: int = 4096 * 4
original_seq_len: int = 4096
mscale: float = 1.0 # scaling factor for softmax
rope_factor: float = 10000.0 # rotary embedding factor
rope_factor: float = 40.0 # rotary embedding factor

@property
def qk_head_dim(self) -> int:
Expand All @@ -1381,7 +1386,9 @@ def setup(self):
super().setup()

# Assert required configuration parameters for MLA attention.
assert self.config.attention_type == AttentionType.MLA.value, "MLA requires MLA attention type"
assert (
self.config.attention_type == AttentionType.MLA.value
), f"MLA requires MLA attention type {AttentionType.MLA.value}"
assert self.kv_lora_rank > 0, "KV LoRA rank must be > 0"
assert self.qk_nope_head_dim > 0, "QK NoPe head dim must be > 0"
assert self.qk_rope_head_dim > 0, "QK RoPE head dim must be > 0"
Expand Down
12 changes: 10 additions & 2 deletions MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, deco
lnx = lnx_rms(inputs)
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

# TODO: Update self-attention to MLA
attention_layer = Attention(
attention_layer = attentions.MLA(
config=cfg,
num_query_heads=cfg.num_query_heads,
num_kv_heads=cfg.num_kv_heads,
Expand All @@ -80,6 +79,15 @@ def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, deco
name="self_attention",
quant=quant,
kv_quant=quantizations.configure_kv_quant(cfg),
q_lora_rank=cfg.q_lora_rank,
kv_lora_rank=cfg.kv_lora_rank,
qk_nope_head_dim=cfg.qk_nope_head_dim,
qk_rope_head_dim=cfg.qk_rope_head_dim,
v_head_dim=cfg.v_head_dim,
max_seq_len=cfg.max_target_length,
original_seq_len=cfg.original_seq_len,
mscale=cfg.mscale,
rope_factor=cfg.rope_factor,
)

attention_lnx = attention_layer(
Expand Down
8 changes: 4 additions & 4 deletions MaxText/tests/train_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,9 @@ def test_moe_deepseek_scanned_bf16(self):
"model_name=deepseek3-671b",
"sparse_matmul=True",
"megablox=False",
"per_device_batch_size=4",
"per_device_batch_size=2",
"max_target_length=1024",
"attention=flash",
"attention=dot_product", # Change to flush attention once it works for MLA
"dtype=bfloat16",
"weight_dtype=bfloat16",
"scan_layers=True",
Expand All @@ -472,9 +472,9 @@ def test_moe_deepseek_unscanned_bf16(self):
"model_name=deepseek3-671b",
"sparse_matmul=True",
"megablox=False",
"per_device_batch_size=4",
"per_device_batch_size=2",
"max_target_length=1024",
"attention=flash",
"attention=dot_product", # Change to flush attention once it works for MLA
"dtype=bfloat16",
"weight_dtype=bfloat16",
"scan_layers=False",
Expand Down

0 comments on commit f4484a0

Please sign in to comment.