diff --git a/.github/workflows/test-check.yaml b/.github/workflows/test-check.yaml index 65fc6908ca3..6ea07f3ac5e 100644 --- a/.github/workflows/test-check.yaml +++ b/.github/workflows/test-check.yaml @@ -62,7 +62,7 @@ jobs: steps: - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - uses: actions/checkout@v2 - uses: actions/checkout@v2 with: @@ -86,7 +86,7 @@ jobs: steps: - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - uses: actions/checkout@v2 - uses: actions/checkout@v2 with: @@ -110,7 +110,7 @@ jobs: steps: - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - uses: actions/checkout@v2 - uses: actions/checkout@v2 with: @@ -134,7 +134,7 @@ jobs: steps: - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - uses: actions/checkout@v2 - uses: actions/checkout@v2 with: diff --git a/DEVELOPING.md b/DEVELOPING.md index 8a815c4b37c..a1d3105d1b6 100644 --- a/DEVELOPING.md +++ b/DEVELOPING.md @@ -16,7 +16,7 @@ limitations under the License. # Developing SparseML -SparseML is developed and tested using Python 3.8-3.10. +SparseML is developed and tested using Python 3.8-3.11. To develop SparseML, you will also need the development dependencies and to follow the styling guidelines. Here are some details to get started. diff --git a/README.md b/README.md index 9db2f9e1b16..f4098ce13a4 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ SparseML enables you to create a sparse model trained on your dataset in two way ## Installation -This repository is tested on Python 3.8-3.10, and Linux/Debian systems. +This repository is tested on Python 3.8-3.11, and Linux/Debian systems. It is recommended to install in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep your system in order. Currently supported ML Frameworks are the following: `torch>=1.1.0,<=2.0`, `tensorflow>=1.8.0,<2.0.0`, `tensorflow.keras >= 2.2.0`. diff --git a/docs/source/installation.md b/docs/source/installation.md index a10a727f967..bc524f4827a 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -16,7 +16,7 @@ limitations under the License. # Installation -This repository is tested on Python 3.8-3.10, and Linux/Debian systems. +This repository is tested on Python 3.8-3.11, and Linux/Debian systems. It is recommended to install in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep your system in order. Currently supported ML Frameworks are the following: `torch>=1.1.0,<1.14`, `tensorflow>=1.8.0,<=2.0.0`, `tensorflow.keras >= 2.2.0`. diff --git a/setup.py b/setup.py index 29c9a903dae..cae8c87358c 100644 --- a/setup.py +++ b/setup.py @@ -293,7 +293,7 @@ def _setup_long_description() -> Tuple[str, str]: install_requires=_setup_install_requires(), extras_require=_setup_extras(), entry_points=_setup_entry_points(), - python_requires=">=3.8.0,<3.11", + python_requires=">=3.8.0,<3.12", classifiers=[ "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3", diff --git a/src/sparseml/core/lifecycle/session.py b/src/sparseml/core/lifecycle/session.py index 80f535b3c16..62b065ed603 100644 --- a/src/sparseml/core/lifecycle/session.py +++ b/src/sparseml/core/lifecycle/session.py @@ -31,7 +31,7 @@ @dataclass class SparsificationLifecycle: state: Optional[State] = None - recipe_container: RecipeContainer = RecipeContainer() + recipe_container: RecipeContainer = field(default_factory=RecipeContainer) modifiers: List[ModifierInterface] = field(default_factory=list) event_lifecycle: Optional[EventLifecycle] = None diff --git a/src/sparseml/experimental/__init__.py b/src/sparseml/experimental/__init__.py deleted file mode 100644 index 0c44f887a47..00000000000 --- a/src/sparseml/experimental/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/sparseml/experimental/sparsegpt/__init__.py b/src/sparseml/experimental/sparsegpt/__init__.py deleted file mode 100644 index 0c44f887a47..00000000000 --- a/src/sparseml/experimental/sparsegpt/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/sparseml/experimental/sparsegpt/dispatch.py b/src/sparseml/experimental/sparsegpt/dispatch.py index 4c1c80eeff2..78c98ecb036 100644 --- a/src/sparseml/experimental/sparsegpt/dispatch.py +++ b/src/sparseml/experimental/sparsegpt/dispatch.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SUPPORTED_MODELS = ["opt", "mpt", "llama"] +SUPPORTED_MODELS = ["opt", "mpt", "llama-2"] def load_model(args, model_key: str = None, *gargs, **kwargs): @@ -21,7 +21,7 @@ def load_model(args, model_key: str = None, *gargs, **kwargs): from sparseml.experimental.sparsegpt.opt import load_model as _load_model elif model_key == "mpt": from sparseml.experimental.sparsegpt.mpt import load_model as _load_model - elif model_key == "llama": + elif model_key == "llama-2": from sparseml.experimental.sparsegpt.llama2 import load_model as _load_model else: raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}") @@ -34,7 +34,7 @@ def load_data(args, model_key: str = None, *gargs, **kwargs): from sparseml.experimental.sparsegpt.opt import load_data as _load_data elif model_key == "mpt": from sparseml.experimental.sparsegpt.mpt import load_data as _load_data - elif model_key == "llama": + elif model_key == "llama-2": from sparseml.experimental.sparsegpt.llama2 import load_data as _load_data else: raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}") @@ -47,7 +47,7 @@ def evaluate_perplexity( model_key = _get_model_key(args) if model_key is None else model_key if model_key == "opt": from sparseml.experimental.sparsegpt.opt import ppl_eval as _ppl_eval - elif model_key == "llama": + elif model_key == "llama-2": from sparseml.experimental.sparsegpt.llama2 import ppl_eval as _ppl_eval else: raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}") @@ -64,7 +64,7 @@ def prepare_sparsegpt(model, dataloader, args, model_key: str = None, **kwargs): from sparseml.experimental.sparsegpt.mpt import ( prepare_sparsegpt as _prepare_sparsegpt, ) - elif model_key == "llama": + elif model_key == "llama-2": from sparseml.experimental.sparsegpt.llama2 import ( prepare_sparsegpt as _prepare_sparsegpt, ) diff --git a/src/sparseml/experimental/sparsegpt/examples/llama2/compare_obcq.py b/src/sparseml/experimental/sparsegpt/examples/llama2/compare_obcq.py index cb821781628..d2ca7439a1a 100644 --- a/src/sparseml/experimental/sparsegpt/examples/llama2/compare_obcq.py +++ b/src/sparseml/experimental/sparsegpt/examples/llama2/compare_obcq.py @@ -23,7 +23,7 @@ dataset = "open_platypus" model_name = "/home/sadkins/ml-experiments/nlg-text_generation/" -model_name += "llama_chat-llama_7b_chat-base/dense/training" +model_name += "llama_chat-llama_7b_chat-base/dense_llama-2/training" sparsity = 0.5 nbits = 8 smooth_quant = 0 @@ -73,7 +73,7 @@ class ProdArgs: def run_experimental_obcq(experimental_args): - model = load_model(experimental_args) + model, _ = load_model(experimental_args) calibration_data, _, _ = load_data(experimental_args, data_sequence_length) sequential(model, calibration_data, device, experimental_args) diff --git a/src/sparseml/experimental/sparsegpt/examples/opt/compare_obcq.py b/src/sparseml/experimental/sparsegpt/examples/opt/compare_obcq.py index 5f773435856..6cfea6d5fc7 100644 --- a/src/sparseml/experimental/sparsegpt/examples/opt/compare_obcq.py +++ b/src/sparseml/experimental/sparsegpt/examples/opt/compare_obcq.py @@ -72,7 +72,7 @@ class ProdArgs: def run_experimental_obcq(experimental_args): - model = load_model(experimental_args) + model, _ = load_model(experimental_args) calibration_data, _, _ = load_data(experimental_args, data_sequence_length) sequential(model, calibration_data, device, experimental_args) diff --git a/src/sparseml/experimental/sparsegpt/examples/opt/scripts/prune_quantize.opt.0.sh b/src/sparseml/experimental/sparsegpt/examples/opt/scripts/prune_quantize.opt.0.sh index e476e055d27..30c67e8c40e 100755 --- a/src/sparseml/experimental/sparsegpt/examples/opt/scripts/prune_quantize.opt.0.sh +++ b/src/sparseml/experimental/sparsegpt/examples/opt/scripts/prune_quantize.opt.0.sh @@ -2,11 +2,11 @@ export CUDA_VISIBLE_DEVICES=0 -ROOT=$HOME/sparseml/src/sparseml/experimental/sparsegpt +ROOT=$HOME/src/neuralmagic/sparseml/src/sparseml/experimental/sparsegpt DATASET=c4 -RECIPE_DIR=$ROOT/examples/opt/recipes +RECIPE_DIR=$ROOT/recipes RECIPE_NAME=opt-1.3b-opt_pretrain-pruned50_quantW8A8 SRC_MODEL_ORG=facebook diff --git a/src/sparseml/experimental/sparsegpt/layer_compressor.py b/src/sparseml/experimental/sparsegpt/layer_compressor.py index df70ed75b00..9340dc786b2 100644 --- a/src/sparseml/experimental/sparsegpt/layer_compressor.py +++ b/src/sparseml/experimental/sparsegpt/layer_compressor.py @@ -22,9 +22,6 @@ from sparseml.experimental.sparsegpt.sparsegpt import SparseGPT -DEFAULT_WBITS = 16 - - class BaseCompressor: def __init__(self, model): self.model = model diff --git a/src/sparseml/experimental/sparsegpt/llama.py b/src/sparseml/experimental/sparsegpt/llama.py deleted file mode 100644 index bcc8f18140d..00000000000 --- a/src/sparseml/experimental/sparsegpt/llama.py +++ /dev/null @@ -1,381 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import math -import warnings -from typing import Dict, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - apply_rotary_pos_emb, - repeat_kv, -) - -from llmfoundry import ( - COMPOSER_MODEL_REGISTRY, - build_finetuning_dataloader, - build_text_denoising_dataloader, -) -from llmfoundry.data.text_data import build_text_dataloader -from llmfoundry.utils.builders import build_tokenizer -from model_preprocessor import QuantizationModelPreprocessor -from omegaconf import OmegaConf as om -from sparseml.experimental.sparsegpt.layer_compressor import ( - BaseCompressor, - LayerCompressor, -) -from sparseml.experimental.sparsegpt.quant import ( - MatMulLeftInput_PV, - MatMulLeftInput_QK, - MatMulOutput_PV, - MatMulOutput_QK, - MatMulRightInput_PV, - MatMulRightInput_QK, - QuantizableMatMul, -) -from sparseml.experimental.sparsegpt.sequential import SequentialSparseGPT - - -class SequentialSparseGPT_LLAMA(SequentialSparseGPT): - def compressible_layers(self): - return self.model.model.model.layers - - -class LLAMABottomCompressor(BaseCompressor): - def compress( - self, dataloader=None, nsamples: int = None, dev: str = "cuda:0", **kwargs - ): - args = kwargs["args"] - data_seq_len = args.data_sequence_length - - model = self.model - layers = self.model.model.transformer.blocks - - use_cache = model.config.use_cache - model.config.use_cache = False - layers = model.model.model.layers - - model.model.model.embed_tokens = model.model.model.embed_tokens.to(dev) - layers[0] = layers[0].to(dev) - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (nsamples, data_seq_len, model.config.hidden_size), dtype=dtype, device=dev - ) - cache = {"i": 0, "attention_mask": None, "position_ids": None} - - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, inp, **kwargs): - inps[cache["i"]] = inp - cache["i"] += 1 - cache["attention_mask"] = kwargs["attention_mask"] - cache["position_ids"] = kwargs["position_ids"] - raise ValueError - - layers[0] = Catcher(layers[0]) - i = 0 - for batch in dataloader: - try: - tmp = {k: v.to(dev) for k, v in batch.items()} - # cache_attn_mask.append(tmp["attention_mask"]) - model(tmp) - except ValueError: - pass - i += 1 - if i == nsamples: - break - layers[0] = layers[0].module - - layers[0] = layers[0].cpu() - model.model.model.embed_tokens = model.model.model.embed_tokens.cpu() - torch.cuda.empty_cache() - - outs = torch.zeros_like(inps) - attention_mask = cache["attention_mask"] - position_ids = cache["position_ids"] - extras = { - "use_cache": use_cache, - "outputs": outs, - "attention_mask": attention_mask, - "position_ids": position_ids, - } - self.model = model - return model, extras - - -class LLAMADecoderLayerCompressor(LayerCompressor): - ... - - -class LLAMAHeadCompressor(BaseCompressor): - ... - - -def llama2_get_attn_with_quantized_matmuls(attn_weights_matmul, attn_output_matmul): - def forward_with_quantized_bmms( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = attn_weights_matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size " - f"{(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_output = attn_output_matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.config.pretraining_tp, dim=2 - ) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1 - ) - attn_output = sum( - [ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ] - ) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - return forward_with_quantized_bmms - - -def MatMulQuantizationPreprocessor(ModelPreprocessor): - def __call__(self, dev: str = "cuda:0", **kwargs) -> Tuple[nn.Module, Dict]: - for name, mod in self.model.named_modules(): - if isinstance(mod, LlamaAttention): - print( - f"Overriding attention for {name} with quantization-aware matmuls" - ) - attn_weights_matmul = QuantizableMatMul( - MatMulLeftInput_QK, MatMulRightInput_QK, MatMulOutput_QK - ) - attn_output_matmul = QuantizableMatMul( - MatMulLeftInput_PV, MatMulRightInput_PV, MatMulOutput_PV - ) - mod.attn_weights_matmul = attn_weights_matmul - mod.attn_output_matmul = attn_output_matmul - - # we are overriding forward of an instance, and not a class - # we should change this to a class method when we own the - # model implementation - bound_method = llama2_get_attn_with_quantized_matmuls( - mod.attn_weights_matmul, mod.attn_output_matmul - ).__get__(mod, mod.__class__) - setattr(mod, "forward", bound_method) - return self.model, {} - - -def prepare_sparsegpt(model, dataloader, args, **kwargs) -> SequentialSparseGPT: - # TODO: Check with Eldar on additional preprocessing (e.g., weight untying) - model_preprocessors = [] - if args.recipe: - model_preprocessors.append(MatMulQuantizationPreprocessor(model)) - model_preprocessors.append( - QuantizationModelPreprocessor( - args.recipe, dataloader, args.observer_batches - ) - ) - bottom_compressor = LLAMABottomCompressor(model) - sequential_sparsegpt = SequentialSparseGPT_LLAMA( - model, - recipe=args.recipe, - model_preprocessors=model_preprocessors, - bottom_compressor=bottom_compressor, - ) - - return sequential_sparsegpt - - -def load_model(args): - cfg = _build_cfg(args) - tokenizer = build_tokenizer(cfg.tokenizer) - - print("Initializing model...") - init_context = contextlib.nullcontext() - cfg.model.init_device = "cpu" - with init_context: - model = build_composer_model(cfg.model, tokenizer) - return model, {"cfg": cfg, "tokenizer": tokenizer} - - -def load_data(args): - cfg = _build_cfg(args) - tokenizer = build_tokenizer(cfg.tokenizer) - train_loader = build_dataloader( - cfg.train_loader, - tokenizer, - cfg.device_train_batch_size, - ) - test_loader = build_dataloader( - cfg.eval_loader, tokenizer, cfg.device_eval_batch_size - ) - - return train_loader, test_loader, tokenizer - - -def build_composer_model(model_cfg, tokenizer): - warnings.filterwarnings( - action="ignore", - message="Torchmetrics v0.9 introduced a new argument class property", - ) - if model_cfg.name not in COMPOSER_MODEL_REGISTRY: - raise ValueError(f"Not sure how to build model with name={model_cfg.name}") - return COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg, tokenizer) - - -def _build_cfg(args): - yaml_path = args.yaml_path - args_list = args.args_list - - with open(yaml_path) as f: - yaml_cfg = om.load(f) - cli_cfg = om.from_cli(args_list) - cfg = om.merge(yaml_cfg, cli_cfg) - return cfg - - -def build_dataloader(cfg, tokenizer, device_batch_size): - if cfg.name == "text": - return build_text_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == "text_denoising": - return build_text_denoising_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == "finetuning": - return build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - else: - raise ValueError(f"Not sure how to build dataloader with config: {cfg}") diff --git a/src/sparseml/experimental/sparsegpt/llama2.py b/src/sparseml/experimental/sparsegpt/llama2.py index b6ced4e7e9a..a26231fb482 100644 --- a/src/sparseml/experimental/sparsegpt/llama2.py +++ b/src/sparseml/experimental/sparsegpt/llama2.py @@ -85,8 +85,7 @@ def load_model(args): model = LlamaForCausalLM.from_pretrained(model, torch_dtype="auto") model.eval() seqlen = model.config.max_position_embeddings - model.seqlen = seqlen - return model + return model, seqlen def load_data(args, seqlen, split=0.1): diff --git a/src/sparseml/experimental/sparsegpt/main.py b/src/sparseml/experimental/sparsegpt/main.py index 1fe27a8ddcd..d64c80c5805 100644 --- a/src/sparseml/experimental/sparsegpt/main.py +++ b/src/sparseml/experimental/sparsegpt/main.py @@ -158,8 +158,7 @@ def _save(model, tokenizer, save_path): wandb.init(config=args) print("Load model", flush=True) - model = load_model(args) - seqlen = model.seqlen + model, seqlen = load_model(args) print("Load data", flush=True) dataloader, testloader, tokenizer = load_data(args, None, seqlen) diff --git a/src/sparseml/experimental/sparsegpt/mpt.py b/src/sparseml/experimental/sparsegpt/mpt.py index a109b9c9573..9721bf5b04a 100644 --- a/src/sparseml/experimental/sparsegpt/mpt.py +++ b/src/sparseml/experimental/sparsegpt/mpt.py @@ -28,16 +28,12 @@ ) from llmfoundry.data.text_data import build_text_dataloader from llmfoundry.utils.builders import build_tokenizer +from model_preprocessor import ModelPreprocessor, QuantizationModelPreprocessor from omegaconf import OmegaConf as om -from sequential import SequentialSparseGPT from sparseml.experimental.sparsegpt.layer_compressor import ( BaseCompressor, LayerCompressor, ) -from sparseml.experimental.sparsegpt.model_preprocessor import ( - ModelPreprocessor, - QuantizationModelPreprocessor, -) from sparseml.experimental.sparsegpt.quant import ( MatMulLeftInput_PV, MatMulLeftInput_QK, @@ -47,6 +43,7 @@ MatMulRightInput_QK, QuantizableMatMul, ) +from sparseml.experimental.sparsegpt.sequential import SequentialSparseGPT class SequentialSparseGPT_MPT(SequentialSparseGPT): diff --git a/src/sparseml/experimental/sparsegpt/opt.py b/src/sparseml/experimental/sparsegpt/opt.py index a78429f14a0..9b1f0e83aeb 100644 --- a/src/sparseml/experimental/sparsegpt/opt.py +++ b/src/sparseml/experimental/sparsegpt/opt.py @@ -163,8 +163,7 @@ def skip(*args, **kwargs): model = OPTForCausalLM.from_pretrained(model, torch_dtype="auto") seqlen = model.config.max_position_embeddings - model.seqlen = seqlen - return model + return model, seqlen def load_data(args, seqlen, split=0.1): diff --git a/src/sparseml/experimental/sparsegpt/sequential.py b/src/sparseml/experimental/sparsegpt/sequential.py index 685ddfbcdcd..ab0aa69b8f3 100644 --- a/src/sparseml/experimental/sparsegpt/sequential.py +++ b/src/sparseml/experimental/sparsegpt/sequential.py @@ -28,11 +28,13 @@ def __init__( recipe: Optional[str] = None, model_preprocessors: Optional[List[ModelPreprocessor]] = None, bottom_compressor: Optional[LayerCompressor] = None, + head_compressor: Optional[LayerCompressor] = None, args=None, ): self.model = model self.model_preprocessors = model_preprocessors self.bottom_compressor = bottom_compressor + self.head_compressor = head_compressor self.recipe = recipe self.manager = None self.compressible_layers = self.compressible_layers() @@ -103,6 +105,10 @@ def compress(self, dev: str = "cuda:0", **kwargs): ) accum_kwargs.update(layer_kwargs) + # Step 2: Prune/quantize head + if self.head_compressor is not None: + self.model, extras = self.head_compressor.compress(dev=dev, **accum_kwargs) + return self.model, {} def post_compress(self, dev: str = "cuda:0", **kwargs): diff --git a/src/sparseml/experimental/sparsegpt/utils.py b/src/sparseml/experimental/sparsegpt/utils.py index 6b5d9f78a93..191c97d3369 100644 --- a/src/sparseml/experimental/sparsegpt/utils.py +++ b/src/sparseml/experimental/sparsegpt/utils.py @@ -232,7 +232,6 @@ def ppl_eval_general( ppl = torch.exp(neg_log_likelihood / number_tokens) print(f"Perplexity: {ppl.item():3f}") - return ppl.item() def get_wikitext2(nsamples, seed, seqlen, model): diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 1e9e9eb79b1..0ecc8a90b7f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -48,6 +48,7 @@ "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", + "initialize_channel_wise_scale_zp", "QConfigProperties", "LINEAR_ACTIVATION_NAMES", "CONV_ACTIVATION_NAMES", @@ -710,6 +711,58 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, submodule_qconfig) +def initialize_channel_wise_scale_zp(module: Module): + """ + On torch channel-wise quantization, zero points and scales are + initialized to a default size of (1,) instead of their true size + of (num_output_channels,). This can cause issues on reloading + of saved checkpoints due to shape mismatch. This function expands + these initial scales and zero points to match the true expected + shape + + :param module: qat ready, uncalibrated model + """ + for name, submodule in module.named_modules(): + weight_fake_quant = getattr(submodule, "weight_fake_quant", None) + if not weight_fake_quant or ( + getattr(weight_fake_quant, "qscheme", None) + not in [torch.per_channel_affine, torch.per_channel_symmetric] + ): + # only consider modules with channel-wise quantized weights + continue + num_channels = None + if hasattr(submodule, "out_features"): + # matmul layers + num_channels = submodule.out_features + elif hasattr(submodule, "out_channels"): + num_channels = submodule.out_channels + + if not num_channels: + # unable to infer num_channels or num_channels is 0 + continue + + # update scale and zero point if they are initialized to a size of 1 + scale = weight_fake_quant.scale + if scale.numel() == 1: + weight_fake_quant.scale = torch.ones(num_channels, dtype=scale.dtype) + + zero_point = weight_fake_quant.zero_point + if zero_point.numel() == 1: + weight_fake_quant.zero_point = torch.ones( + num_channels, dtype=zero_point.dtype + ) + + # update the observer min and max vals + if weight_fake_quant.activation_post_process.min_val.numel() == 0: + weight_fake_quant.activation_post_process.min_val = torch.empty_like( + weight_fake_quant.scale + ) + if weight_fake_quant.activation_post_process.max_val.numel() == 0: + weight_fake_quant.activation_post_process.max_val = torch.empty_like( + weight_fake_quant.scale + ) + + def _delete_get_block_hooks( module: Module, fuse_blocks: List[List[str]], diff --git a/src/sparseml/pytorch/sparsification/quantization/quantize.py b/src/sparseml/pytorch/sparsification/quantization/quantize.py index ef109b9f05d..9f41ed05964 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantize.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantize.py @@ -357,21 +357,26 @@ def _match_submodule_name_or_type( submodule: Module, submodule_name: str, names_or_types: List[str] ) -> Optional[str]: # match preferences: - # 1. match module type name - # 2. match the submodule prefix (longest first) + # 1. match the submodule prefix (longest first) + # 2. match module type name submodule_match = "" for name_or_type in names_or_types: name_to_compare = submodule_name[:] if name_to_compare.startswith("module."): name_to_compare = name_to_compare[7:] - if name_or_type == submodule.__class__.__name__: - # type match, return type name - return name_or_type if name_to_compare.startswith(name_or_type) and ( len(name_or_type) > len(submodule_match) ): # match to most specific submodule name submodule_match = name_or_type + + # If didn't find prefix, try to match to match type + if not submodule_match: + for name_or_type in names_or_types: + if name_or_type == submodule.__class__.__name__: + # type match, return type name + return name_or_type + return submodule_match or None # return None if no match diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index f41f9b07c4a..51cb18b4750 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -42,6 +42,10 @@ from sparseml.core.logger import LoggerManager, TensorBoardLogger, WANDBLogger from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer from sparseml.pytorch.utils import ModuleSparsificationInfo + +from sparseml.pytorch.sparsification.quantization.helpers import ( + initialize_channel_wise_scale_zp, +) from sparseml.transformers.utils import SparseAutoModel from sparseml.transformers.utils.helpers import RECIPE_NAME @@ -667,6 +671,13 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]): ) return False + # PerChannel quantization observers initialize variables + # to dummy shapes that do not match the ones saved in + # state_dict. + # Need to reshape these variables in order to load state_dict + # properly. + initialize_channel_wise_scale_zp(self.model) + current_state_dict = self.model.state_dict() if set(orig_state_dict.keys()) == set(current_state_dict): diff --git a/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py b/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py index 1b499ea13fd..31e9232d349 100644 --- a/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py +++ b/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py @@ -25,6 +25,7 @@ QuantizationScheme, ) from sparseml.pytorch.sparsification.quantization.quantize import ( + _match_submodule_name_or_type, is_qat_helper_module, is_quantizable_module, ) @@ -66,7 +67,7 @@ def _assert_observers_eq(observer_1, observer_2): _assert_observers_eq(qconfig_1.weight, qconfig_2.weight) -def _test_quantized_module(base_model, modifier, module, name): +def _test_quantized_module(base_model, modifier, module, name, override_key): # check quant scheme and configs are set quantization_scheme = getattr(module, "quantization_scheme", None) qconfig = getattr(module, "qconfig", None) @@ -74,9 +75,8 @@ def _test_quantized_module(base_model, modifier, module, name): assert qconfig is not None # if module type is overwritten in by scheme_overrides, check scheme set correctly - module_type_name = module.__class__.__name__ - if module_type_name in modifier.scheme_overrides: - expected_scheme = modifier.scheme_overrides[module_type_name] + if override_key is not None: + expected_scheme = modifier.scheme_overrides[override_key] assert quantization_scheme == expected_scheme is_quant_wrapper = isinstance(module, torch_quantization.QuantWrapper) @@ -148,7 +148,12 @@ def _test_qat_applied(modifier, model): _test_qat_wrapped_module(model, name) elif is_quantizable: # check each target module is quantized - _test_quantized_module(model, modifier, module, name) + override_key = _match_submodule_name_or_type( + module, + name, + list(modifier.scheme_overrides.keys()), + ) + _test_quantized_module(model, modifier, module, name, override_key) else: # check all non-target modules are not quantized assert not hasattr(module, "quantization_scheme")