diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 38a4de6..7adeb16 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -298,9 +298,8 @@ def quantize_activations( cleanup_memory() # Post-process step for kv cache scales to take the k/v module - # `output_scale` parameters, take the max of them, and store them in - # the parent attention module as `kv_scale` - # NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block + # `output_scale` parameters, and store them in the parent attention + # module as `key_scale` and `value_scale` if hasattr(quantize_config, "kv_cache_quant_layers"): # Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...] # so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...] @@ -313,8 +312,8 @@ def quantize_activations( k_proj = dict(model.named_modules())[k_proj_name] v_proj = dict(model.named_modules())[v_proj_name] - kv_scale = max(k_proj.output_scale, v_proj.output_scale) - parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False) + parent_module.key_scale = torch.nn.Parameter(k_proj.output_scale, requires_grad=False) + parent_module.value_scale = torch.nn.Parameter(v_proj.output_scale, requires_grad=False) # Remove output_scale from k_proj and v_proj k_proj.output_scale = None