Skip to content

Commit

Permalink
Skip PEFT LoRA Scaling if the scale is 1.0 (#7576)
Browse files Browse the repository at this point in the history
* Skip scaling if scale is identity

* move check for weight one to scale and unscale lora

* fix code style/quality

* Empty-Commit

---------

Co-authored-by: Steven Munn <stevenjmunn@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Munn <5297082+stevenjlm@users.noreply.github.com>
  • Loading branch information
4 people authored Apr 11, 2024
1 parent 33c5d12 commit 42f25d6
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def recurse_remove_peft_layers(model):
module_replaced = False

if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
module.weight.device
)
new_module = torch.nn.Linear(
module.in_features,
module.out_features,
bias=module.bias is not None,
).to(module.weight.device)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
Expand Down Expand Up @@ -110,6 +112,9 @@ def scale_lora_layers(model, weight):
"""
from peft.tuners.tuners_utils import BaseTunerLayer

if weight == 1.0:
return

for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.scale_layer(weight)
Expand All @@ -129,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
"""
from peft.tuners.tuners_utils import BaseTunerLayer

if weight == 1.0:
return

for module in model.modules():
if isinstance(module, BaseTunerLayer):
if weight is not None and weight != 0:
Expand Down

0 comments on commit 42f25d6

Please sign in to comment.