diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 38a4de6..d327d3a 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -298,9 +298,8 @@ def quantize_activations( cleanup_memory() # Post-process step for kv cache scales to take the k/v module - # `output_scale` parameters, take the max of them, and store them in - # the parent attention module as `kv_scale` - # NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block + # `output_scale` parameters, and store them in the parent attention + # module as `k_scale` and `v_scale` if hasattr(quantize_config, "kv_cache_quant_layers"): # Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...] # so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...] @@ -313,8 +312,8 @@ def quantize_activations( k_proj = dict(model.named_modules())[k_proj_name] v_proj = dict(model.named_modules())[v_proj_name] - kv_scale = max(k_proj.output_scale, v_proj.output_scale) - parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False) + parent_module.k_scale = torch.nn.Parameter(k_proj.output_scale, requires_grad=False) + parent_module.v_scale = torch.nn.Parameter(v_proj.output_scale, requires_grad=False) # Remove output_scale from k_proj and v_proj k_proj.output_scale = None diff --git a/tests/test_auto_fp8.py b/tests/test_auto_fp8.py index 6045d84..7d79ff1 100644 --- a/tests/test_auto_fp8.py +++ b/tests/test_auto_fp8.py @@ -80,14 +80,21 @@ def test_kv_cache_static_quantization(model_id, target_size): model.save_quantized(quantized_model_dir) tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors") - proj_linear_count = 0 - kv_scale_count = 0 + k_proj_count = 0 + v_proj_count = 0 + k_scale_count = 0 + v_scale_count = 0 for name, _ in tensors.items(): - if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"): - proj_linear_count += 1 - if name.endswith("kv_scale"): - kv_scale_count += 1 - assert proj_linear_count // 2 == kv_scale_count + if name.endswith(".k_proj.weight"): + k_proj_count += 1 + if name.endswith(".v_proj.weight"): + v_proj_count += 1 + if name.endswith(".k_scale"): + k_scale_count += 1 + if name.endswith(".v_scale"): + v_scale_count += 1 + assert k_proj_count == k_scale_count + assert v_proj_count == v_scale_count # Measure checkpoint size and cleanup model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")