Skip to content

Commit

Permalink
Separate kv_scale into key_scale and value_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jul 3, 2024
1 parent 4b2092c commit 966052f
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,8 @@ def quantize_activations(
cleanup_memory()

# Post-process step for kv cache scales to take the k/v module
# `output_scale` parameters, take the max of them, and store them in
# the parent attention module as `kv_scale`
# NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
# `output_scale` parameters, and store them in the parent attention
# module as `key_scale` and `value_scale`
if hasattr(quantize_config, "kv_cache_quant_layers"):
# Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
# so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
Expand All @@ -313,8 +312,8 @@ def quantize_activations(
k_proj = dict(model.named_modules())[k_proj_name]
v_proj = dict(model.named_modules())[v_proj_name]

kv_scale = max(k_proj.output_scale, v_proj.output_scale)
parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False)
parent_module.key_scale = torch.nn.Parameter(k_proj.output_scale, requires_grad=False)
parent_module.value_scale = torch.nn.Parameter(v_proj.output_scale, requires_grad=False)

# Remove output_scale from k_proj and v_proj
k_proj.output_scale = None
Expand Down

0 comments on commit 966052f

Please sign in to comment.