Skip to content

Commit

Permalink
Implement HooksMixin (#917)
Browse files Browse the repository at this point in the history
* Implement HooksMixin

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add docstring

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* integrate with smoothquant

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* integrate with QuantizationModifier

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* update hooks in tests

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* integrate with wanda

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* integrate with magnitude and constant

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* integrate with SparseGPTModifier

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add hooksmixin to modifier

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* nits

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
  • Loading branch information
kylesayrs and dsikka authored Dec 6, 2024
1 parent 1830382 commit 9f58887
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 172 deletions.
7 changes: 3 additions & 4 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Optional

from pydantic import BaseModel

from llmcompressor.core.events import Event, EventType
from llmcompressor.core.state import State
from llmcompressor.modifiers.interface import ModifierInterface
from llmcompressor.modifiers.utils.hooks import HooksMixin

__all__ = ["Modifier"]


class Modifier(BaseModel, ModifierInterface, ABC):
class Modifier(ModifierInterface, HooksMixin):
"""
A base class for all modifiers to inherit from.
Modifiers are used to modify the training process for a model.
Expand Down
72 changes: 35 additions & 37 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -130,7 +131,8 @@ def initialize_compression(
"Inferring layer-wise sparsities from "
f"{len(dataloader)} calibration samples..."
)
self.sparsity = self._infer_layer_sparsity(dataloader)
activations = self._get_activations(dataloader)
self.sparsity = self._infer_layer_sparsity(activations)
self._validate_layerwise_sparsity()

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
Expand Down Expand Up @@ -254,19 +256,17 @@ def _infer_mask_block_size(self):

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))

def _infer_layer_sparsity(self, calibration_dataloader):
acts = _get_activations(self.model, calibration_dataloader)
def _infer_layer_sparsity(self, activations):
sparsegpt_groups = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0)
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z])

acts = None
del acts
del activations
torch.cuda.empty_cache()

outlier_ratios = {}
Expand Down Expand Up @@ -300,36 +300,34 @@ def _infer_layer_sparsity(self, calibration_dataloader):
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

@torch.no_grad()
def _get_activations(self, data_loader, nsamples=128):
self.model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)
else:
acts[name] += (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)

for name, mod in self.model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")

device = next(self.model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
self.model(**batch)
batch = None
torch.cuda.empty_cache()

@torch.no_grad()
def _get_activations(model, data_loader, nsamples=128):
import functools

model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
else:
acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()

hooks = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
hooks.append(
mod.register_forward_pre_hook(functools.partial(save_acts, name=name))
)
device = next(model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
model(**batch)
batch = None
torch.cuda.empty_cache()

for h in hooks:
h.remove()
self.remove_hooks()

return acts
return acts
21 changes: 5 additions & 16 deletions src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import Dict

import torch
from pydantic import BaseModel
from torch.nn import Parameter
from torch.utils.hooks import RemovableHandle

from llmcompressor.core import ModelParameterizedLayer
from llmcompressor.modifiers.utils.hooks import HooksMixin

__all__ = ["LayerParamMasking", "param_mask_name"]

Expand Down Expand Up @@ -39,11 +38,9 @@ class ParameterizedLayerMaskSettings:
use_hooks: bool = False


class LayerParamMasking(BaseModel):
class LayerParamMasking(HooksMixin):
_mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {}
_masked_layer_params: Dict[str, ModelParameterizedLayer] = {}
_forward_hooks: Dict[str, RemovableHandle] = {}
_backward_hooks: Dict[str, RemovableHandle] = {}
enabled_: bool = False

def add_mask(
Expand Down Expand Up @@ -100,12 +97,8 @@ def _backward_hook_fn(gradients):

return gradients

self._forward_hooks[layer_param_name] = (
parameterized_layer.layer.register_forward_hook(_forward_hook_fn)
)
self._backward_hooks[layer_param_name] = (
parameterized_layer.param.register_hook(_backward_hook_fn)
)
self.register_hook(parameterized_layer.layer, _forward_hook_fn, "forward")
self.register_hook(parameterized_layer.param, _backward_hook_fn, "")

def update_mask(
self,
Expand All @@ -131,11 +124,7 @@ def remove_mask(self, layer_param_name: str):
del self._mask_settings[layer_param_name]

if mask_settings.use_hooks:
self._forward_hooks[layer_param_name].remove()
self._backward_hooks[layer_param_name].remove()

del self._forward_hooks[layer_param_name]
del self._backward_hooks[layer_param_name]
self.remove_hooks()

def apply_mask_weight(self, layer_param_name: str):
if not self.enabled_:
Expand Down
72 changes: 35 additions & 37 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -121,7 +122,8 @@ def initialize_compression(
"Inferring layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_layer_sparsity(dataloader)
activations = self._get_activations(dataloader)
self.sparsity = self._infer_layer_sparsity(activations)
self._validate_layerwise_sparsity()

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
Expand Down Expand Up @@ -224,19 +226,17 @@ def _infer_mask_block_size(self):

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))

def _infer_layer_sparsity(self, calibration_dataloader):
acts = _get_activations(self.model, calibration_dataloader)
def _infer_layer_sparsity(self, activations):
wanda = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0)
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
wanda[name] = torch.cat([item.flatten().cpu() for item in z])

acts = None
del acts
del activations
torch.cuda.empty_cache()

outlier_ratios = {}
Expand Down Expand Up @@ -268,36 +268,34 @@ def _infer_layer_sparsity(self, calibration_dataloader):
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

@torch.no_grad()
def _get_activations(self, data_loader, nsamples=128):
self.model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)
else:
acts[name] += (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)

for name, mod in self.model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")

device = next(self.model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
self.model(**batch)
batch = None
torch.cuda.empty_cache()

@torch.no_grad()
def _get_activations(model, data_loader, nsamples=128):
import functools

model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
else:
acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()

hooks = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
hooks.append(
mod.register_forward_pre_hook(functools.partial(save_acts, name=name))
)
device = next(model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
model(**batch)
batch = None
torch.cuda.empty_cache()

for h in hooks:
h.remove()
self.remove_hooks()

return acts
return acts
70 changes: 28 additions & 42 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Dict, Optional, Tuple

import torch
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
Expand Down Expand Up @@ -146,71 +146,57 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
)


def calibrate_input_hook():
def calibrate_input_hook(module: Module, args: Any):
"""
Hook to calibrate input activations.
Will call the observers to update the scales/zp before applying
input QDQ in the module's forward pass.
"""
args = args[0] if isinstance(args, tuple) else args
calibrate_activations(module, value=args, base_name="input")

def hook_fn(module: Module, inp):
inp = inp[0] if isinstance(inp, tuple) else inp
calibrate_activations(module, value=inp, base_name="input")

return hook_fn


def calibrate_output_hook():
def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
"""
Hook to calibrate output activations.
Will call the observers to update the scales/zp before applying
output QDQ.
"""

def hook_fn(module: Module, inp, output: torch.Tensor):
calibrate_activations(
module,
value=output,
base_name="output",
)
output = forward_quantize(
module=module,
value=output,
base_name="output",
args=module.quantization_scheme.output_activations,
)
return output

return hook_fn
calibrate_activations(
module,
value=output,
base_name="output",
)
output = forward_quantize(
module=module,
value=output,
base_name="output",
args=module.quantization_scheme.output_activations,
)
return output


def calibrate_kv_cache_input_hook():
def calibrate_kv_cache_input_hook(
module: Module, args: Any, kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
Hook to update inputs to attention layers when running
kv_cache quantization. Will update the passed in
kv_cache to singleton QuantizedKVParameterCache.
"""
kv_cache = getattr(module, "kv_cache")
kwargs["past_key_value"] = kv_cache
kwargs["use_cache"] = False
return args, kwargs

def hook_fn(module: Module, args, kwargs):
kv_cache = getattr(module, "kv_cache")
kwargs["past_key_value"] = kv_cache
kwargs["use_cache"] = False
return args, kwargs

return hook_fn


def calibrate_kv_cache_output_hook():
def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor):
"""
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
"""

def hook_fn(module: Module, inpt, output: torch.Tensor):
kv_cache = getattr(module, "kv_cache")
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")

return hook_fn
kv_cache = getattr(module, "kv_cache")
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")


def set_unset_kv_cache(module: Module):
Expand Down
Loading

0 comments on commit 9f58887

Please sign in to comment.