From 42f25d601a910dceadaee6c44345896b4cfa9928 Mon Sep 17 00:00:00 2001 From: Steven Munn Date: Wed, 10 Apr 2024 22:32:31 -0700 Subject: [PATCH] Skip PEFT LoRA Scaling if the scale is 1.0 (#7576) * Skip scaling if scale is identity * move check for weight one to scale and unscale lora * fix code style/quality * Empty-Commit --------- Co-authored-by: Steven Munn Co-authored-by: Sayak Paul Co-authored-by: Steven Munn <5297082+stevenjlm@users.noreply.github.com> --- src/diffusers/utils/peft_utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index feececc56966..8ea12e2e3b3f 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -64,9 +64,11 @@ def recurse_remove_peft_layers(model): module_replaced = False if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): - new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to( - module.weight.device - ) + new_module = torch.nn.Linear( + module.in_features, + module.out_features, + bias=module.bias is not None, + ).to(module.weight.device) new_module.weight = module.weight if module.bias is not None: new_module.bias = module.bias @@ -110,6 +112,9 @@ def scale_lora_layers(model, weight): """ from peft.tuners.tuners_utils import BaseTunerLayer + if weight == 1.0: + return + for module in model.modules(): if isinstance(module, BaseTunerLayer): module.scale_layer(weight) @@ -129,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None): """ from peft.tuners.tuners_utils import BaseTunerLayer + if weight == 1.0: + return + for module in model.modules(): if isinstance(module, BaseTunerLayer): if weight is not None and weight != 0: