Skip to content

FSDP2 integration: torch.chunks(Params4bit) #1612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 51 additions & 18 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
Expand Down Expand Up @@ -206,6 +207,17 @@ def forward(self, input: Tensor) -> Tensor:
return emb


import torch
import copy
from typing import Optional, Union, Any, TypeVar, Dict, Tuple
from torch import device, dtype, Tensor

# Assuming these are imported from bitsandbytes
import bitsandbytes as bnb
from bitsandbytes.functional import QuantState

T = TypeVar('T', bound='Params4bit')

class Params4bit(torch.nn.Parameter):
def __new__(
cls,
Expand Down Expand Up @@ -294,8 +306,45 @@ def from_prequantized(
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

# Special handling for operations that need to preserve the Params4bit subclass
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
result = func(*args, **kwargs)

# For operations that return tensors or tuple of tensors,
# wrap the result back as Params4bit objects
if func in [torch.chunk, torch.split, torch.Tensor.chunk, torch.Tensor.split]:
# For functions that return tuple of tensors
return tuple(
cls(
data=chunk,
requires_grad=args[0].requires_grad if hasattr(args[0], 'requires_grad') else False,
quant_state=args[0].quant_state if hasattr(args[0], 'quant_state') else None,
blocksize=args[0].blocksize if hasattr(args[0], 'blocksize') else 64,
compress_statistics=args[0].compress_statistics if hasattr(args[0], 'compress_statistics') else True,
quant_type=args[0].quant_type if hasattr(args[0], 'quant_type') else "fp4",
quant_storage=args[0].quant_storage if hasattr(args[0], 'quant_storage') else torch.uint8,
module=args[0].module if hasattr(args[0], 'module') else None,
bnb_quantized=args[0].bnb_quantized if hasattr(args[0], 'bnb_quantized') else False,
)
for chunk in result
)
elif isinstance(result, torch.Tensor) and not isinstance(result, cls):
# For functions that return a single tensor
return cls(
data=result,
requires_grad=args[0].requires_grad if hasattr(args[0], 'requires_grad') else False,
quant_state=args[0].quant_state if hasattr(args[0], 'quant_state') else None,
blocksize=args[0].blocksize if hasattr(args[0], 'blocksize') else 64,
compress_statistics=args[0].compress_statistics if hasattr(args[0], 'compress_statistics') else True,
quant_type=args[0].quant_type if hasattr(args[0], 'quant_type') else "fp4",
quant_storage=args[0].quant_storage if hasattr(args[0], 'quant_storage') else torch.uint8,
module=args[0].module if hasattr(args[0], 'module') else None,
bnb_quantized=args[0].bnb_quantized if hasattr(args[0], 'bnb_quantized') else False,
)
else:
# For other operations, return the result as is
return result

def _quantize(self, device):
w = self.data.contiguous().to(device)
Expand All @@ -322,20 +371,6 @@ def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: b
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)

@overload
def to(
self: T,
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T: ...

@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...

@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...

def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

Expand All @@ -356,8 +391,6 @@ def to(self, *args, **kwargs):
)

return new_param


def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None:
return
Expand Down Expand Up @@ -1078,4 +1111,4 @@ def forward(self, x):
if self.weight.CB is not None:
self.init_8bit_state()

out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias