diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 937084cf1..a319e6372 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1,3 +1,4 @@ + # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the @@ -206,6 +207,17 @@ def forward(self, input: Tensor) -> Tensor: return emb +import torch +import copy +from typing import Optional, Union, Any, TypeVar, Dict, Tuple +from torch import device, dtype, Tensor + +# Assuming these are imported from bitsandbytes +import bitsandbytes as bnb +from bitsandbytes.functional import QuantState + +T = TypeVar('T', bound='Params4bit') + class Params4bit(torch.nn.Parameter): def __new__( cls, @@ -294,8 +306,45 @@ def from_prequantized( def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + + # Special handling for operations that need to preserve the Params4bit subclass with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) + result = func(*args, **kwargs) + + # For operations that return tensors or tuple of tensors, + # wrap the result back as Params4bit objects + if func in [torch.chunk, torch.split, torch.Tensor.chunk, torch.Tensor.split]: + # For functions that return tuple of tensors + return tuple( + cls( + data=chunk, + requires_grad=args[0].requires_grad if hasattr(args[0], 'requires_grad') else False, + quant_state=args[0].quant_state if hasattr(args[0], 'quant_state') else None, + blocksize=args[0].blocksize if hasattr(args[0], 'blocksize') else 64, + compress_statistics=args[0].compress_statistics if hasattr(args[0], 'compress_statistics') else True, + quant_type=args[0].quant_type if hasattr(args[0], 'quant_type') else "fp4", + quant_storage=args[0].quant_storage if hasattr(args[0], 'quant_storage') else torch.uint8, + module=args[0].module if hasattr(args[0], 'module') else None, + bnb_quantized=args[0].bnb_quantized if hasattr(args[0], 'bnb_quantized') else False, + ) + for chunk in result + ) + elif isinstance(result, torch.Tensor) and not isinstance(result, cls): + # For functions that return a single tensor + return cls( + data=result, + requires_grad=args[0].requires_grad if hasattr(args[0], 'requires_grad') else False, + quant_state=args[0].quant_state if hasattr(args[0], 'quant_state') else None, + blocksize=args[0].blocksize if hasattr(args[0], 'blocksize') else 64, + compress_statistics=args[0].compress_statistics if hasattr(args[0], 'compress_statistics') else True, + quant_type=args[0].quant_type if hasattr(args[0], 'quant_type') else "fp4", + quant_storage=args[0].quant_storage if hasattr(args[0], 'quant_storage') else torch.uint8, + module=args[0].module if hasattr(args[0], 'module') else None, + bnb_quantized=args[0].bnb_quantized if hasattr(args[0], 'bnb_quantized') else False, + ) + else: + # For other operations, return the result as is + return result def _quantize(self, device): w = self.data.contiguous().to(device) @@ -322,20 +371,6 @@ def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: b def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): return self.to(device="xpu" if device is None else device, non_blocking=non_blocking) - @overload - def to( - self: T, - device: Optional[Union[int, device]] = ..., - dtype: Optional[Union[dtype, str]] = ..., - non_blocking: bool = ..., - ) -> T: ... - - @overload - def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ... - - @overload - def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... - def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) @@ -356,8 +391,6 @@ def to(self, *args, **kwargs): ) return new_param - - def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]): if getattr(module.weight, "quant_state", None) is not None: return @@ -1078,4 +1111,4 @@ def forward(self, x): if self.weight.CB is not None: self.init_8bit_state() - out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias + out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias \ No newline at end of file