Skip to content

Commit

Permalink
Fix bug with some LoRA variants when applied to a BnB NF4 quantized m…
Browse files Browse the repository at this point in the history
…odel. Note the previous commit which added a unit test to trigger this bug.
  • Loading branch information
RyanJDick authored and psychedelicious committed Jan 21, 2025
1 parent 7724bd5 commit fa7c36a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
5 changes: 3 additions & 2 deletions invokeai/backend/patches/layers/lora_layer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import invokeai.backend.util.logging as logger
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.util.calc_tensor_size import calc_tensors_size


Expand Down Expand Up @@ -67,8 +68,8 @@ def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float
# Reshape all params to match the original module's shape.
for param_name, param_weight in params.items():
orig_param = orig_parameters[param_name]
if param_weight.shape != orig_param.shape:
params[param_name] = param_weight.reshape(orig_param.shape)
if param_weight.shape != get_param_shape(orig_param):
params[param_name] = param_weight.reshape(get_param_shape(orig_param))

return params

Expand Down
19 changes: 19 additions & 0 deletions invokeai/backend/patches/layers/param_shape_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

try:
from bitsandbytes.nn.modules import Params4bit

bnb_available: bool = True
except ImportError:
bnb_available: bool = False


def get_param_shape(param: torch.Tensor) -> torch.Size:
"""A helper function to get the shape of a parameter that handles `bitsandbytes.nn.Params4Bit` correctly."""
# Accessing the `.shape` attribute of `bitsandbytes.nn.Params4Bit` will return an incorrect result. Instead, we must
# access the `.quant_state.shape` attribute.
if bnb_available and type(param) is Params4bit: # type: ignore
quant_state = param.quant_state
if quant_state is not None:
return quant_state.shape
return param.shape

0 comments on commit fa7c36a

Please sign in to comment.