Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][LoRA] Implement hot-swapping of LoRA #9453

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
86 changes: 69 additions & 17 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_name = TEXT_ENCODER_NAME

def load_lora_weights(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note to the reviewers that we're currently only brainstorming the changes through unet. Those changes will be propagated to lora_pipeline.py, too once we agree on the initial design.

self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
Expand All @@ -103,6 +107,7 @@ def load_lora_weights(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap TODO
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
Expand Down Expand Up @@ -133,6 +138,7 @@ def load_lora_weights(
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
Expand Down Expand Up @@ -263,7 +269,14 @@ def lora_state_dict(

@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
unet,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
Expand All @@ -282,7 +295,10 @@ def load_lora_into_unet(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap TODO
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand All @@ -306,6 +322,7 @@ def load_lora_into_unet(
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
Expand Down Expand Up @@ -341,7 +358,9 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand Down Expand Up @@ -601,7 +620,9 @@ def load_lora_weights(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
Expand Down Expand Up @@ -786,7 +807,14 @@ def lora_state_dict(
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
unet,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
Expand All @@ -805,7 +833,10 @@ def load_lora_into_unet(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap TODO
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand All @@ -829,6 +860,7 @@ def load_lora_into_unet(
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
Expand Down Expand Up @@ -865,7 +897,9 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand Down Expand Up @@ -1226,7 +1260,9 @@ def load_lora_weights(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
Expand Down Expand Up @@ -1301,7 +1337,9 @@ def load_lora_into_transformer(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
Expand Down Expand Up @@ -1424,7 +1462,9 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand Down Expand Up @@ -1819,7 +1859,9 @@ def load_lora_weights(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand Down Expand Up @@ -1886,7 +1928,9 @@ def load_lora_into_transformer(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
Expand Down Expand Up @@ -2014,7 +2058,9 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand Down Expand Up @@ -2377,7 +2423,9 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand Down Expand Up @@ -2658,7 +2706,9 @@ def load_lora_weights(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
Expand Down Expand Up @@ -2708,7 +2758,9 @@ def load_lora_into_transformer(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`boo`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
Expand Down
Loading
Loading