Skip to content

Commit

Permalink
Accelerate Utilities (#193)
Browse files Browse the repository at this point in the history
* wip

* add modify_offload_module

* update docs

* WIP

* cleanup functions, begin depreciation

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove extra space

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* revert get_offloaded_device

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* update to align_module_device

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add requires skip for accelerate

* fix per token initialization

* remove align_module_device

* respond to nits

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* Accelerate Utilities Follow-up (#224)

* rename

* implement recursive case

* remove print

* support OffloadedWeightsLoader

* add lifecycle docstring

* implement offload_to_weights_map with recursive definition

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add docstring

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* fix type hint

* add check_accelerate guard

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* make device used by  clearer

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* update update_prefix_dict

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* reuse fixture

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* use apply rather than recursion

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* clearer delete_from_weights_map

* add offload_device argument (#228)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs authored Dec 20, 2024
1 parent 975cb22 commit 85b473e
Show file tree
Hide file tree
Showing 5 changed files with 708 additions and 91 deletions.
61 changes: 17 additions & 44 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from compressed_tensors.utils import get_execution_device, is_module_offloaded
from compressed_tensors.utils import (
disable_hf_hook,
has_offloaded_params,
register_offload_parameter,
)
from torch.nn import Module, Parameter


Expand Down Expand Up @@ -112,43 +116,10 @@ def initialize_module_for_quantization(
module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED

offloaded = False
# What is this doing/why isn't this in the attn case?
if is_module_offloaded(module):
try:
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import PrefixedDataset
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Offloaded model detected. To use CPU offloading with "
"compressed-tensors the `accelerate` package must be installed, "
"run `pip install compressed-tensors[accelerate]`"
)

offloaded = True
hook = module._hf_hook
prefix_dict = module._hf_hook.weights_map
new_prefix = {}

# recreate the prefix dict (since it is immutable)
# and add quantization parameters
for key, data in module.named_parameters():
if key not in prefix_dict:
new_prefix[f"{prefix_dict.prefix}{key}"] = data
else:
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
remove_hook_from_module(module)

# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)

if offloaded:
# we need to re-add the hook for offloading now that we've wrapped forward
add_hook_to_module(module, hook)
if prefix_dict is not None:
module._hf_hook.weights_map = new_prefix_dict
with disable_hf_hook(module):
# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)


def is_attention_module(module: Module):
Expand All @@ -169,9 +140,11 @@ def _initialize_scale_zero_point(
if quantization_args.dynamic:
return

device = next(module.parameters()).device
if is_module_offloaded(module):
device = get_execution_device(module)
# begin on the same device as other parameters or cpu if offloaded.
# in the offloaded case, there's no point moving tensors to the execution device
# if they're going to be immediately offloaded by `register_offload_parameter`
params_device = next(module.parameters()).device
device = "cpu" if has_offloaded_params(module) else params_device

# infer expected scale/zero point shape
if quantization_args.strategy == QuantizationStrategy.TOKEN:
Expand All @@ -196,15 +169,15 @@ def _initialize_scale_zero_point(
torch.empty(expected_shape, dtype=scale_dtype, device=device),
requires_grad=False,
)
module.register_parameter(f"{base_name}_scale", init_scale)
register_offload_parameter(module, f"{base_name}_scale", init_scale)

if force_zero_point or not quantization_args.symmetric:
zp_dtype = quantization_args.pytorch_dtype()
init_zero_point = Parameter(
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
requires_grad=False,
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)

# only grouped activation ordering has g_idx
if quantization_args.actorder == ActivationOrdering.GROUP:
Expand All @@ -214,7 +187,7 @@ def _initialize_scale_zero_point(
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
requires_grad=False,
)
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)


def _initialize_attn_scales(module: Module) -> None:
Expand Down
65 changes: 64 additions & 1 deletion src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional

import torch
from transformers import AutoConfig
Expand All @@ -24,6 +26,8 @@
"tensor_follows_mask_structure",
"replace_module",
"is_compressed_tensors_config",
"getattr_chain",
"deprecated",
"Aliasable",
]

Expand Down Expand Up @@ -122,6 +126,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
return False


def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
"""
Chain multiple getattr calls, separated by `.`
:param obj: base object whose attributes are being retrieved
:param chain_str: attribute names separated by `.`
:param default: default value, throw error otherwise
"""
if len(args) >= 1:
has_default = True
default = args[0]
elif "default" in kwargs:
has_default = True
default = kwargs["default"]
else:
has_default = False

attr_names = chain_str.split(".")

res = obj
for attr_name in attr_names:
if not hasattr(res, attr_name):
if has_default:
return default
else:
raise AttributeError(f"{res} object has no attribute {attr_name}")
res = getattr(res, attr_name)

return res


def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
"""
Decorator to mark functions as deprecated
:param new_function: Function called in place of depreciated function
:param message: Depreciation message, replaces default depreciation message
"""

def decorator(func: Callable[[Any], Any]):
nonlocal message

if message is None:
message = (
f"{func.__name__} is deprecated and will be removed in a future release"
)
if future_name is not None:
message += f". Please use {future_name} instead."

@wraps(func)
def wrapped(*args, **kwargs):
warnings.warn(message, DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)

return wrapped

return decorator


class Aliasable:
"""
A mixin for enums to allow aliasing of enum members
Expand Down
Loading

0 comments on commit 85b473e

Please sign in to comment.