diff --git a/auto_fp8/__init__.py b/auto_fp8/__init__.py index ea4fbb6..d463cc8 100644 --- a/auto_fp8/__init__.py +++ b/auto_fp8/__init__.py @@ -1,5 +1,4 @@ -from .config import BaseQuantizeConfig -from .modeling import AutoFP8ForCausalLM +from .modeling import AutoFP8ForCausalLM, BaseQuantizeConfig __all__ = [ "AutoFP8ForCausalLM", diff --git a/auto_fp8/config.py b/auto_fp8/config.py deleted file mode 100644 index 24c6200..0000000 --- a/auto_fp8/config.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import List, Optional, Tuple - - -class BaseQuantizeConfig: - """Configuration for model quantization. - - Args: - quant_method: Type/precision of quantization method to use. - At the moment, this is just "fp8" which specifically means - the fp8_e4m3 format in pytorch. - activation_scheme: Choice of either "dynamic" or "static" quantization - of activtions. If "static", then calibration samples are required - during quantization to produce accurate per-tensor scales for - activations of Linear modules. - ignore_patterns: List of patterns used to ignore layers. If a string - starts with "re:", then everything afterwards is used as python - regex style matching i.e. re.search(), for each Linear layer. - By default, "re:.*lm_head" is included to ignore the embedding - Linear layer usually at the end of decoder LLMs - kv_cache_quant_targets: Tuple of Linear module names to target for - calibration of the output scales for KV cache quantization. - Usually, these should be `("k_proj", "v_proj")`. - """ - - def __init__( - self, - quant_method: str = "fp8", - activation_scheme: str = "static", - ignore_patterns: List[str] = ["re:.*lm_head"], - kv_cache_quant_targets: Optional[Tuple[str]] = None, - ): - if quant_method != "fp8": - raise ValueError("Only FP8 quantization is supported.") - if activation_scheme not in ["static", "dynamic"]: - raise ValueError( - "Invalid activation_scheme. Choose either 'static' or 'dynamic'." - ) - self.quant_method = quant_method - self.activation_scheme = activation_scheme - self.ignore_patterns = ignore_patterns - self.kv_cache_quant_targets = kv_cache_quant_targets - self.ignored_layers = [] diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index 04a9e71..eb4d2ba 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -1,42 +1,49 @@ -import re -from typing import List, Optional, Tuple +import os +from typing import List, Optional + +from transformers import AutoConfig, AutoTokenizer +from datasets import Dataset +from llmcompressor.transformers import SparseAutoModelForCausalLM +from llmcompressor.transformers import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + + +class BaseQuantizeConfig: + """Configuration for model quantization. + + Args: + quant_method: Type/precision of quantization method to use. + At the moment, this is just "fp8" which specifically means + the fp8_e4m3 format in pytorch. + activation_scheme: Choice of either "dynamic" or "static" quantization + of activtions. If "static", then calibration samples are required + during quantization to produce accurate per-tensor scales for + activations of Linear modules. + ignore_patterns: List of patterns used to ignore layers. If a string + starts with "re:", then everything afterwards is used as python + regex style matching i.e. re.search(), for each Linear layer. + By default, "lm_head" is included to ignore the embedding + Linear layer usually at the end of decoder LLMs + """ -import torch -from transformers import AutoModelForCausalLM - -from auto_fp8.config import BaseQuantizeConfig -from auto_fp8.quantize import ( - quantize_activations, - quantize_weights, - save_quantized_model, -) + def __init__( + self, + quant_method: str = "fp8", + activation_scheme: str = "static", + ignore_patterns: List[str] = ["lm_head"], + ): + self.quant_method = quant_method + self.activation_scheme = activation_scheme + self.ignore_patterns = ignore_patterns class AutoFP8ForCausalLM: def __init__( - self, - model: AutoModelForCausalLM, - quantize_config: BaseQuantizeConfig, + self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig ): self.model = model self.model_type = self.model.config.model_type self.config = self.model.config - - # Gather the Linear module names that we want to ignore - quantize_config.ignored_layers = get_layers_to_ignore( - self.model, quantize_config.ignore_patterns - ) - - if quantize_config.kv_cache_quant_targets: - kv_cache_quant_layers = get_kv_cache_quant_layers( - self.model, quantize_config.kv_cache_quant_targets - ) - if len(kv_cache_quant_layers) == 0: - raise ValueError( - f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument." - ) - quantize_config.kv_cache_quant_layers = kv_cache_quant_layers - self.quantize_config = quantize_config @classmethod @@ -44,130 +51,41 @@ def from_pretrained( cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, - **model_init_kwargs, + **kwargs, ): - """Load the un-quantized pretrained model""" - - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - - # Parameters related to loading from Hugging Face Hub - cache_dir = model_init_kwargs.pop("cache_dir", None) - force_download = model_init_kwargs.pop("force_download", False) - resume_download = model_init_kwargs.pop("resume_download", False) - proxies = model_init_kwargs.pop("proxies", None) - local_files_only = model_init_kwargs.pop("local_files_only", False) - use_auth_token = model_init_kwargs.pop("use_auth_token", None) - revision = model_init_kwargs.pop("revision", None) - subfolder = model_init_kwargs.pop("subfolder", "") - commit_hash = model_init_kwargs.pop("_commit_hash", None) - - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "use_auth_token": use_auth_token, - "revision": revision, - "subfolder": subfolder, - "_commit_hash": commit_hash, - } - - torch.cuda.empty_cache() - - # Important defaults - if "torch_dtype" not in model_init_kwargs: - model_init_kwargs["torch_dtype"] = "auto" - - if "device_map" not in model_init_kwargs: - model_init_kwargs["device_map"] = "auto" - - merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} - print("Loading model with the following kwargs:", merged_kwargs) - model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, **merged_kwargs + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + model = SparseAutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + config=config, + device_map="auto", + torch_dtype="auto", + **kwargs, ) - - model_config = model.config.to_dict() - seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] - if any(k in model_config for k in seq_len_keys): - for key in seq_len_keys: - if key in model_config: - model.seqlen = model_config[key] - break - else: - print("Can't get model's sequence length, setting to 2048.") - model.seqlen = 2048 - model.eval() - return cls(model, quantize_config) - def quantize(self, calibration_tokens: Optional[torch.Tensor] = None): - - # Always quantize the weights as they do not require calibration data - quantize_weights(self.model, self.quantize_config) - - if self.quantize_config.activation_scheme == "static": - assert ( - calibration_tokens is not None - ), "Calibration tokens required for activation quantization" - - - def _prepare_calibration_data(calibration_tokens): - if hasattr(calibration_tokens, "input_ids"): - return calibration_tokens.input_ids - return calibration_tokens - - quantize_activations( - self.model, - self.quantize_config, - _prepare_calibration_data(calibration_tokens), - ) + def quantize(self, dataset: Optional[Dataset] = None): + assert ( + self.quantize_config.activation_scheme == "static" + ), "Dynamic isn't supported yet" + assert ( + dataset is not None + ), "Calibration tokens required for static activation quantization" - def save_quantized(self, save_dir): - save_quantized_model( - self.model, - quant_config=self.quantize_config, - save_dir=save_dir, + recipe = QuantizationModifier( + targets="Linear", scheme="FP8", ignore=self.quantize_config.ignore_patterns ) + oneshot( + model=self.model, + dataset=dataset, + recipe=recipe, + ) -def get_layers_to_ignore(model, ignore_patterns) -> List[str]: - ignored_layers = set() - - for name, linear in model.named_modules(): - if not isinstance(linear, torch.nn.Linear): - continue - - for ignore_pattern in ignore_patterns: - regex_prefix = "re:" - if ignore_pattern.startswith(regex_prefix): - # check if name matches regex and add to set if true - regex_pattern = ignore_pattern[len(regex_prefix) :] - if re.search(regex_pattern, name): - ignored_layers.add(name) - else: - # else, exact match - if ignore_pattern == name: - ignored_layers.add(name) - - return list(ignored_layers) - - -def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]: - kv_cache_quant_layers = [] - - for name, linear in model.named_modules(): - if not isinstance(linear, torch.nn.Linear): - continue - - for output_quant_target in kv_cache_quant_targets: - if name.endswith(output_quant_target): - kv_cache_quant_layers.append(name) + def save_quantized(self, save_directory: str): + self.save_pretrained(save_directory, save_compressed=True) - return kv_cache_quant_layers + def save_pretrained(self, save_directory: str, save_compressed: bool = True): + self.model.save_pretrained(save_directory, save_compressed=save_compressed) + tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path) + tokenizer.save_pretrained(save_directory) + print(f"Saved final checkpoint to {os.path.abspath(save_directory)}") \ No newline at end of file diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 38a4de6..0237bc2 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -72,11 +72,25 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): # Deal with empty tensors (triggeted by empty MoE experts) return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device) +<<<<<<< HEAD +<<<<<<< HEAD +======= +>>>>>>> 959bdbc (Add comment) # TODO: Disable native fp8 gemm for now, always just dequantize # native_fp8_support = ( # torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) # ) native_fp8_support = False +<<<<<<< HEAD +======= + native_fp8_support = ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 9) + and False + ) +>>>>>>> 3ee9283 (Support calibrating kv cache scales) +======= +>>>>>>> 959bdbc (Add comment) if native_fp8_support: need_reshape = A.dim() == 3 if need_reshape: @@ -108,6 +122,7 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): # Class responsible for quantizing weights class FP8DynamicLinear(torch.nn.Module): +<<<<<<< HEAD def __init__( self, weight: torch.Tensor, @@ -125,10 +140,112 @@ def forward(self, x): A=qinput, A_scale=x_scale, B=self.weight, +======= + def __init__( + self, + weight: torch.Tensor, + weight_scale: torch.Tensor, + bias: torch.nn.Parameter, + ): + super().__init__() + self.weight = torch.nn.Parameter(weight, requires_grad=False) + self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + self.bias = bias + + def forward(self, x): + qinput, x_scale = per_tensor_quantize(x) + output = fp8_gemm( + A=qinput, + A_scale=x_scale, + B=self.weight, + B_scale=self.weight_scale, + bias=self.bias, + out_dtype=x.dtype, + ) + return output + + +# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer +class FP8StaticLinearQuantizer(torch.nn.Module): + def __init__( + self, + weight: torch.Tensor, + weight_scale: torch.Tensor, + bias: torch.nn.Parameter, + quantize_output: bool = False, + ): + super().__init__() + self.weight = torch.nn.Parameter(weight, requires_grad=False) + self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + self.bias = bias + self.input_scale = None + self.output_scale = None + self.quantize_output = quantize_output + + def forward(self, x): + qinput, x_input_scale = per_tensor_quantize(x) + if self.input_scale is None: + self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False) + elif x_input_scale > self.input_scale: + self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False) + output = fp8_gemm( + A=qinput, + A_scale=self.input_scale, + B=self.weight, + B_scale=self.weight_scale, + bias=self.bias, + out_dtype=x.dtype, + ) + + # Optionally, quantize output and record scale + if self.quantize_output: + qoutput, output_scale = per_tensor_quantize(output) + if self.output_scale is None: + self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False) + elif output_scale > self.output_scale: + self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False) + output = qoutput.to(output.dtype) * output_scale + + return output + + +# Module responsible for representing the final checkpoint representation +class FP8StaticLinear(torch.nn.Module): + def __init__( + self, + weight: torch.nn.Parameter, + weight_scale: torch.nn.Parameter, + bias: torch.nn.Parameter, + input_scale: torch.nn.Parameter, + output_scale: Optional[torch.nn.Parameter] = None, + ): + super().__init__() + self.weight = weight + self.weight_scale = weight_scale + self.bias = bias + self.input_scale = input_scale + self.output_scale = output_scale + + def forward(self, x): + qinput = static_per_tensor_quantize(x, self.input_scale) + output = fp8_gemm( + A=qinput, + A_scale=self.input_scale, +<<<<<<< HEAD + B=self.qweight, +>>>>>>> 3ee9283 (Support calibrating kv cache scales) B_scale=self.weight_scale, bias=self.bias, out_dtype=x.dtype, ) +<<<<<<< HEAD +======= + + if self.output_scale: + qoutput = static_per_tensor_quantize(output, self.output_scale) + output = qoutput.to(output.dtype) * self.output_scale + +>>>>>>> 3ee9283 (Support calibrating kv cache scales) return output @@ -198,6 +315,8 @@ def forward(self, x): output = fp8_gemm( A=qinput, A_scale=self.input_scale, +======= +>>>>>>> def2049 (Fix weight name) B=self.weight, B_scale=self.weight_scale, bias=self.bias, @@ -237,7 +356,15 @@ def quantize_weights( quant_weight, weight_scale = per_tensor_quantize(linear.weight) bias = copy.deepcopy(linear.bias) if linear.bias is not None else None quant_linear = FP8DynamicLinear( +<<<<<<< HEAD +<<<<<<< HEAD weight=quant_weight, weight_scale=weight_scale, bias=bias +======= + qweight=quant_weight, weight_scale=weight_scale, bias=bias +>>>>>>> 3ee9283 (Support calibrating kv cache scales) +======= + weight=quant_weight, weight_scale=weight_scale, bias=bias +>>>>>>> def2049 (Fix weight name) ) replace_module(model, name, quant_linear) del linear.weight @@ -259,7 +386,15 @@ def quantize_activations( ): continue quantizer = FP8StaticLinearQuantizer( +<<<<<<< HEAD +<<<<<<< HEAD + weight=dynamic_quant_linear.weight, +======= + qweight=dynamic_quant_linear.qweight, +>>>>>>> 3ee9283 (Support calibrating kv cache scales) +======= weight=dynamic_quant_linear.weight, +>>>>>>> def2049 (Fix weight name) weight_scale=dynamic_quant_linear.weight_scale, bias=dynamic_quant_linear.bias, quantize_output=( @@ -272,12 +407,36 @@ def quantize_activations( cleanup_memory() # Pass through calibration data to measure activation scales +<<<<<<< HEAD +<<<<<<< HEAD with torch.inference_mode(): with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: for row_idx in range(calibration_tokens.shape[0]): model(calibration_tokens[row_idx].reshape(1, -1)) cleanup_memory() pbar.update(1) +======= +======= +>>>>>>> 57c31bb (Use `torch.inference_mode()` for lower memory usage during calibration (#20)) + with tqdm.tqdm( + total=calibration_tokens.shape[0], desc="Calibrating activation scales" + ) as pbar: + for row_idx in range(calibration_tokens.shape[0]): + model(calibration_tokens[row_idx].reshape(1, -1)) + cleanup_memory() + pbar.update(1) +<<<<<<< HEAD +>>>>>>> 3ee9283 (Support calibrating kv cache scales) +======= +======= + with torch.inference_mode(): + with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: + for row_idx in range(calibration_tokens.shape[0]): + model(calibration_tokens[row_idx].reshape(1, -1)) + cleanup_memory() + pbar.update(1) +>>>>>>> b1c6ad6 (Use `torch.inference_mode()` for lower memory usage during calibration (#20)) +>>>>>>> 57c31bb (Use `torch.inference_mode()` for lower memory usage during calibration (#20)) # Replace dynamic quantizer observer with StaticLinear for export for name, quantizer in model.named_modules(): @@ -287,7 +446,15 @@ def quantize_activations( ): continue static_proj = FP8StaticLinear( +<<<<<<< HEAD +<<<<<<< HEAD + weight=quantizer.weight, +======= + qweight=quantizer.qweight, +>>>>>>> 3ee9283 (Support calibrating kv cache scales) +======= weight=quantizer.weight, +>>>>>>> def2049 (Fix weight name) weight_scale=quantizer.weight_scale, bias=quantizer.bias, input_scale=quantizer.input_scale, diff --git a/example_dataset.py b/example_dataset.py index 204345f..82d336e 100644 --- a/example_dataset.py +++ b/example_dataset.py @@ -3,20 +3,27 @@ from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig -pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" -quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8" +pretrained_model_dir = "facebook/opt-125m" +quantized_model_dir = "opt-125m-FP8" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) tokenizer.pad_token = tokenizer.eos_token -ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") -examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] -examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda") +MAX_SEQUENCE_LENGTH = 2048 +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512)) +def preprocess(example): + example = tokenizer.apply_chat_template(example["messages"], tokenize=False) + return tokenizer( + example, + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) +ds = ds.map(preprocess, remove_columns=ds.column_names) quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") -model = AutoFP8ForCausalLM.from_pretrained( - pretrained_model_dir, quantize_config=quantize_config -) -model.quantize(examples) +model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) +model.quantize(ds) model.save_quantized(quantized_model_dir) diff --git a/tests/test_auto_fp8.py b/tests/test_auto_fp8.py index 6045d84..bb852d9 100644 --- a/tests/test_auto_fp8.py +++ b/tests/test_auto_fp8.py @@ -1,20 +1,52 @@ import os import shutil +<<<<<<< HEAD +<<<<<<< HEAD import pytest +======= +>>>>>>> 3ee9283 (Support calibrating kv cache scales) +======= +import pytest +>>>>>>> 2739d61 (Add Qwen test) import safetensors.torch from transformers import AutoTokenizer from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig MODELS = [ +<<<<<<< HEAD +<<<<<<< HEAD + ("facebook/opt-125m", 160), + ("Qwen/Qwen2-0.5B-Instruct", 620), +] + +<<<<<<< HEAD +@pytest.mark.parametrize("model_id,target_size", MODELS) +def test_dynamic_quantization(model_id, target_size): + quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic" +======= +def test_dynamic_quantization(): + model_id = "facebook/opt-125m" + quantized_model_dir = "opt-125m-fp8-dynamic" +>>>>>>> 3ee9283 (Support calibrating kv cache scales) +======= + "facebook/opt-125m", + "Qwen/Qwen2-0.5B-Instruct", +======= ("facebook/opt-125m", 160), +<<<<<<< HEAD + ("Qwen/Qwen2-0.5B-Instruct", 600), +>>>>>>> 415c0b7 (Add fixed target sizes) +======= ("Qwen/Qwen2-0.5B-Instruct", 620), +>>>>>>> 93c0d54 (Fix proj linear count) ] @pytest.mark.parametrize("model_id,target_size", MODELS) def test_dynamic_quantization(model_id, target_size): quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic" +>>>>>>> 2739d61 (Add Qwen test) quantize_config = BaseQuantizeConfig( quant_method="fp8", activation_scheme="dynamic" @@ -30,6 +62,11 @@ def test_dynamic_quantization(model_id, target_size): model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") shutil.rmtree(quantized_model_dir) +<<<<<<< HEAD +<<<<<<< HEAD +<<<<<<< HEAD +======= +>>>>>>> c3acdee (Switch from output_scale to kv_scale) # We expect the quantized model to be a certain size target_size = target_size * (1024 * 1024) assert model_size < target_size @@ -38,6 +75,31 @@ def test_dynamic_quantization(model_id, target_size): @pytest.mark.parametrize("model_id,target_size", MODELS) def test_static_quantization(model_id, target_size): quantized_model_dir = model_id.split("/")[-1] + "-fp8-static" +======= + # We expect the model to be < 160MB + target_size = 160 * (1024 * 1024) + assert model_size < target_size + + +<<<<<<< HEAD +def test_static_quantization(): + model_id = "facebook/opt-125m" + quantized_model_dir = "opt-125m-fp8-static" +>>>>>>> 3ee9283 (Support calibrating kv cache scales) +======= +@pytest.mark.parametrize("model_id", MODELS) +def test_static_quantization(model_id): +======= + # We expect the model to be a certain size + target_size = target_size * (1024 * 1024) + assert model_size < target_size + + +@pytest.mark.parametrize("model_id,target_size", MODELS) +def test_static_quantization(model_id, target_size): +>>>>>>> 415c0b7 (Add fixed target sizes) + quantized_model_dir = model_id.split("/")[-1] + "-fp8-static" +>>>>>>> 2739d61 (Add Qwen test) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) examples = ["auto-fp8 is an easy-to-use model quantization library"] @@ -55,7 +117,53 @@ def test_static_quantization(model_id, target_size): model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") shutil.rmtree(quantized_model_dir) +<<<<<<< HEAD +<<<<<<< HEAD + # We expect the quantized model to be a certain size + target_size = target_size * (1024 * 1024) + assert model_size < target_size + +@pytest.mark.parametrize("model_id,target_size", MODELS) +def test_kv_cache_static_quantization(model_id, target_size): + quantized_model_dir = model_id.split("/")[-1] + "-fp8-static-kv" + + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) + examples = ["auto-fp8 is an easy-to-use model quantization library"] + examples = tokenizer(examples, return_tensors="pt") + + quantize_config = BaseQuantizeConfig( + quant_method="fp8", + activation_scheme="static", + kv_cache_quant_targets=("k_proj", "v_proj"), + ) + + model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) + model.model.to("cpu") + + model.quantize(examples) + model.save_quantized(quantized_model_dir) + + tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors") + proj_linear_count = 0 + kv_scale_count = 0 + for name, _ in tensors.items(): + if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"): + proj_linear_count += 1 + if name.endswith("kv_scale"): + kv_scale_count += 1 + assert proj_linear_count // 2 == kv_scale_count + + # Measure checkpoint size and cleanup + model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") + shutil.rmtree(quantized_model_dir) + + # We expect the quantized model to be a certain size +======= + # We expect the model to be < 160MB +>>>>>>> 415c0b7 (Add fixed target sizes) +======= # We expect the quantized model to be a certain size +>>>>>>> c3acdee (Switch from output_scale to kv_scale) target_size = target_size * (1024 * 1024) assert model_size < target_size