-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][LoRA] Implement hot-swapping of LoRA #9453
base: main
Are you sure you want to change the base?
Conversation
This PR adds the possibility to hot-swap LoRA adapters. It is WIP. Description As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter. Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See huggingface#9279 for more context. Caveats To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise. Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature. I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type: > input name: arg861_1. data pointer changed from 139647332027392 to 139647331054592 I don't know enough about compilation to determine whether this is problematic or not. Current state This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature). Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT. Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature. Finally, it should be properly documented. I would like to collect feedback on the current state of the PR before putting more time into finalizing it.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this. I left some comments.
cc @apolinario |
does most lora have same scaling? |
So I played around a little bit, I have two main question: Do we support hotswap with different lora ranks? the rank config is not checked in the I think we should also look into supporting hot-swap with different scaling, I checked some popular loras on our hub, I think most of them have different ranks/alphas so this feature will be a lot more impactful if we are able to support different rank & scaling - based on this thread #9279, I understand that the change in the "scaling" dict would trigger a recompilation. But maybe there are ways to avoid it? for example, if this trigger recompile import torch
scaling = {}
def fn(x, key):
return x * scaling[key]
opt_fn = torch.compile(fn, backend="eager")
x = torch.rand(4)
scaling["first"] = 1.0
opt_fn(x, "first")
print(f" finish first run, updating scaling")
scaling["first"] = 2.0
opt_fn(x, "first") this won't import torch
scaling = {}
def fn(x, key):
return x * scaling[key]
opt_fn = torch.compile(fn, backend="eager")
x = torch.rand(4)
scaling["first"] = torch.tensor(1.0)
opt_fn(x, "first")
print(f" finish first run, updating scaling")
scaling["first"] = torch.tensor(2.0)
opt_fn(x, "first") I'm very excited about having this in diffusers ! think would be a super nice feature, especially for production use case :) |
I agree with your point on supporting LoRAs with different scaling in this context. With backend="eager", we may not get the full benefits of A good way to verify it would be to measure the performance of a pipeline with eager Cc: @anijain2305.
I will let @BenjaminBossan comment further but this might require a lot of changes within the tuner modules inside |
Thanks for all the feedback. I haven't forgotten about this PR, I was just occupied with other things. I'll come back to this as soon as I have a bit of time on my hands. The idea of using a tensor instead of float for scaling is intriguing, thanks for testing it. It might just work OOTB, as torch broadcasts 0-dim tensors automatically. Another possibility would be to multiply the scaling directly into one of the weights, so that the original alpha can be retained, but that is probably very error prone. Regarding different ranks, I have yet to test that. |
Yes, |
If different ranks become a problem, then https://huggingface.co/sayakpaul/lower-rank-flux-lora could provide a meaningful direction. |
Indeed, although avoiding recompilation altogether with different ranks would be even greater for real time swap applications |
yep can be a nice feature indeed! |
Indeed. For different ranks, things that come to mind:
|
A reverse direction of what I showed in #9453 is also possible (increase the rank of a LoRA): |
hi @BenjaminBossan and they work for the 4 loras I tested (all with different ranks and scaling) - I'm not as familiar with peft and just made enough changes for the purpose of the experiment & provide a reference point, so the code is very hacky there. sorry for that! to test , # testing hotswap PR
# TORCH_LOGS="guards,recompiles" TORCH_COMPILE_DEBUG=1 TORCH_LOGS_OUT=traces.txt python yiyi_test_3.py
from diffusers import DiffusionPipeline
import torch
import time
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
branch = "test-hotswap"
loras = [
"Norod78/sd15-megaphone-lora", # rank 16, scaling 0.5
"artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5", # rank 64, scaling 1.0
"Norod78/SD15-Rubber-Duck-LoRA", # rank 16, scaling 0.5
"wooyvern/sd-1.5-dark-fantasy-1.1", # rank 128, scaling 1.0
]
prompts =[
"Marge Simpson holding a megaphone in her hand with her town in the background",
"A lion, minimalist, Coloring Book, ColoringBookAF",
"The girl with a pearl earring Rubber duck",
"<lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art",
]
def print_rank_scaling(pipe):
print(f" rank: {pipe.unet.peft_config['default_0'].r}")
print(f" scaling: {pipe.unet.down_blocks[0].attentions[0].proj_in.scaling}")
# pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
for i, (lora_repo, prompt) in enumerate(zip(loras, prompts)):
hotswap = False if i == 0 else True
print(f"\nProcessing LoRA {i}: {lora_repo}")
print(f" prompt: {prompt}")
print(f" hotswap: {hotswap}")
# Start timing for the entire iteration
start_time = time.time()
# Load LoRA weights
pipe.load_lora_weights(lora_repo, hotswap=hotswap, adapter_name = "default_0")
print_rank_scaling(pipe)
# Time image generation
generator = torch.Generator(device="cuda").manual_seed(42)
generate_start_time = time.time()
image = pipe(prompt, num_inference_steps=50, generator=generator).images[0]
generate_time = time.time() - generate_start_time
# Save the image
image.save(f"yiyi_test_3_out_{branch}_lora{i}.png")
# Unload LoRA weights
pipe.unload_lora_weights()
# Calculate and print total time for this iteration
total_time = time.time() - start_time
print(f"Image generation time: {generate_time:.2f} seconds")
print(f"Total time for LoRA {i}: {total_time:.2f} seconds")
mem_bytes = torch.cuda.max_memory_allocated()
print(f"total Memory: {mem_bytes/(1024*1024):.3f} MB") output
confirm outputs are same as in main |
Very cool! Could you also try logging the traces just to confirm it does not trigger any recompilation? TORCH_LOGS="guards,recompiles" TORCH_LOGS_OUT=traces.txt python my_code.py |
I did and it doesn't |
also, I think, from the user experience perspective, it might be more convenient to have a "hotswap" mode that, once it's on, everything will be hot-swapped by default. I think, it is not something you use on and off, no? maybe be a question for @apolinario |
I think that is the case, yes! I also agree that the ability to hot-swap LoRAs (with But just in case it becomes a memory problem, users can explore the LoRA resizing path to have everything to a small unified rank (if it doesn't lead too much quality degradation). |
See also huggingface/diffusers#9453 The idea of hotswapping an adapter is the following: We can already load multiple adapters, e.g. two LoRAs, at the same time. But sometimes, we want to load one LoRA and then replace its weights in-place with the LoRA weights of another adapter. This is now possible the hotswap_adapter function. In general, this should be faster than deleting one adapter and loading the adapter in its place, which would be the current way to achieve the same final outcome. Another advantage of hotswapping is that it prevents re-compilation in case the PEFT model is already compiled. This can save quite a lot of time. There are some caveats for hotswapping: - It only works for the same PEFT method, so no swapping LoRA and LoHa. - Right now, only LoRA is properly supported. - The adapters must be compatible (e.g. same LoRA alpha, same target modules).
e537a04
to
c7378ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I addressed your points, could you please check again @sayakpaul?
The script for the test now runs successfully without recompilation. When hotswapping is disabled, I can confirm that we do get recompilation, but that is not explicitly tested. Do you think that's necessary?
@sayakpaul I don't know if the new test is being run as it's a "slow" test. Can we trigger this manually in case it's not run? The reason is that I found that in the PEFT CI, Edit: I found a fix for the error on PEFT, I pushed the same change to this PR, even though I don't know if the error would have happened here or not. But I think it's better to have it either way. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Benjamin! I left some comments on the tests. LMK if they make sense.
raise ValueError( | ||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." | ||
) | ||
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're using hotswap
in the condition here, should the error message also include any info about it?
@@ -77,7 +77,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): | |||
text_encoder_name = TEXT_ENCODER_NAME | |||
|
|||
def load_lora_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note to the reviewers that we're currently only brainstorming the changes through unet
. Those changes will be propagated to lora_pipeline.py
, too once we agree on the initial design.
if val0 != val1: | ||
raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") | ||
|
||
def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this will go away once huggingface/peft#2120 is merged?
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
|
||
def get_small_unet(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could go to tests/lora
.
lora_config = get_unet_lora_config() | ||
unet.add_adapter(lora_config) | ||
torch.manual_seed(42) | ||
out_base = unet(**dummy_input)["sample"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_base
gives me the impression that the output is from the base original model whereas the model we're testing here has already being injected with an adapter config. So, out_with_lora_config
?
unet = torch.compile(unet, mode="reduce-overhead") | ||
|
||
torch.manual_seed(42) | ||
out0 = unet(**dummy_input)["sample"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe out_with_lora_loaded_compiled
?
assert torch.allclose(out_base, out0, atol=atol, rtol=rtol) | ||
|
||
if do_hotswap: | ||
unet.load_attn_procs(file_name, adapter_name="default_0", hotswap=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point, the LoRA within unet
would already be present no? We're not unloading the LoRA above, not sure if it'd affect something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah probably that's what we want with hotswapping, nice!
I won't mind adding a test for that as well.
For now, we can run it manually. Got some errors while running https://github.com/huggingface/diffusers/actions/runs/11377670971/job/31652211734, checking it now. Edit: #9696. Able to run https://github.com/huggingface/diffusers/actions/runs/11384275265/job/31671631119. |
@BenjaminBossan I modified the Here's my updated file: Code# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is a standalone script that checks that we can hotswap a LoRA adapter on a compiles model
By itself, this script is not super interesting but when we collect the compile logs, we can check that hotswapping
does not trigger recompilation. This is done in the TestLoraHotSwapping class in test_pipelines.py.
Running this script with `check_hotswap(False)` will load the LoRA adapter without hotswapping, which will result in
recompilation.
"""
import os
import tempfile
import torch
import numpy as np
from peft import LoraConfig, get_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
from diffusers import StableDiffusionPipeline
from diffusers.utils.testing_utils import torch_device
def get_unet_lora_config(rank=4):
# from test_models_unet_2d_condition.py
unet_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config
def get_dummy_input(with_generator=False):
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 5,
"guidance_scale": 6.0,
"output_type": "np",
"return_dict": False
}
if with_generator:
generator = torch.manual_seed(0)
pipeline_inputs.update({"generator": generator})
return pipeline_inputs
def get_lora_state_dicts(modules_to_save):
state_dicts = {}
for module_name, module in modules_to_save.items():
if module is not None:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts
def set_lora_device(model, adapter_names, device):
# copied from LoraBaseMixin.set_lora_device
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names:
module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
if adapter_name in module.lora_magnitude_vector:
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[adapter_name].to(
device
)
def check_hotswap(do_hotswap):
dummy_input = get_dummy_input()
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
lora_config = get_unet_lora_config()
pipeline.unet.add_adapter(lora_config)
out_base = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
# sanity check
assert not (out_base == 0).all()
with tempfile.TemporaryDirectory() as tmp_dirname:
lora_state_dicts = get_lora_state_dicts({"unet": pipeline.unet})
StableDiffusionPipeline.save_lora_weights(
save_directory=tmp_dirname, safe_serialization=True, **lora_state_dicts
)
del pipeline
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
file_name = os.path.join(tmp_dirname, "pytorch_lora_weights.safetensors")
pipeline.load_lora_weights(file_name)
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead")
out0 = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
# sanity check: still same result
atol, rtol = 1e-3, 1e-3
assert np.allclose(out_base, out0, atol=atol, rtol=rtol)
if do_hotswap:
pipeline.load_lora_weights(file_name, adapter_name="default_0", hotswap=True)
else:
# offloading the old and loading the new adapter will result in recompilation
set_lora_device(pipeline.unet, adapter_names=["default_0"], device="cpu")
pipeline.load_lora_weights(file_name, adapter_name="other_name", hotswap=False)
out1 = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
# sanity check: since it's the same LoRA, the results should be identical
assert np.allclose(out0, out1, atol=atol, rtol=rtol)
if __name__ == "__main__":
# check_hotswap(False) will trigger recompilation
check_hotswap(True) When I am running ...
input name: arg349_1. data pointer changed from 140373425360384 to 140373425359360. input stack trace: File "/home/sayak/collabs/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1281, in forward
sample = upsample_block(
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sayak/collabs/diffusers/src/diffusers/models/unets/unet_2d_blocks.py", line 2551, in forward
hidden_states = attn(
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sayak/collabs/diffusers/src/diffusers/models/transformers/transformer_2d.py", line 442, in forward
hidden_states = block(
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sayak/collabs/diffusers/src/diffusers/models/attention.py", line 504, in forward
attn_output = self.attn2(
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sayak/collabs/diffusers/src/diffusers/models/attention_processor.py", line 491, in forward
return self.processor(
File "/home/sayak/collabs/diffusers/src/diffusers/models/attention_processor.py", line 2376, in __call__
hidden_states = attn.to_out[0](hidden_states)
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 584, in forward
result = result + lora_B(lora_A(dropout(x))) * scaling Full trace is here: https://pastebin.com/QY5HgrUU. Is this expected? I will help cover other scenarios in the test after we're through with this one. |
@sayakpaul I added the reverse test for compilation without hotswapping.
This is interesting, I could reproduce this error. However, I'm not sure where this comes from. The main difference in your script compared to the test are that:
right? To isolate the error, I changed the code a bit to avoid generation and only call forward on the unet: def get_dummy_input2():
# from UNet2DConditionModelTests
from diffusers.utils.testing_utils import floats_tensor
batch_size = 4
num_channels = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
def check_hotswap2(do_hotswap):
dummy_input = get_dummy_input2()
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
lora_config = get_unet_lora_config()
pipeline.unet.add_adapter(lora_config)
with tempfile.TemporaryDirectory() as tmp_dirname:
lora_state_dicts = get_lora_state_dicts({"unet": pipeline.unet})
StableDiffusionPipeline.save_lora_weights(
save_directory=tmp_dirname, safe_serialization=True, **lora_state_dicts
)
del pipeline
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
file_name = os.path.join(tmp_dirname, "pytorch_lora_weights.safetensors")
pipeline.load_lora_weights(file_name)
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead")
pipeline.unet(**dummy_input)
if do_hotswap:
pipeline.load_lora_weights(file_name, adapter_name="default_0", hotswap=True)
else:
# offloading the old and loading the new adapter will result in recompilation
set_lora_device(pipeline.unet, adapter_names=["default_0"], device="cpu")
pipeline.load_lora_weights(file_name, adapter_name="other_name", hotswap=False)
pipeline.unet(**dummy_input) and this passed. So my guess it that it has to do with generation or some other factor that I've missed. But when it comes to the error itself, I have no idea, probably this requires a |
I will give it a look, thanks! |
@BenjaminBossan while I give the a look, I think this PR could directly use the utilities from huggingface/peft#2120 once it's merged. We could also look into incorporating Yiyi's PoC on incorporating different scales and ranks. WDYT? |
Thanks!
Do we want to make this feature dependent on the installed PEFT version? I think we could copy the function over and only switch to PEFT once older PEFT versions are no longer supported. |
Yeah I think making this feature dependent on a |
Okay, then I'll switch it when the feature is released in PEFT. |
for key, new_val in state_dict.items(): | ||
# no need to account for potential _orig_mod in key here, as torch handles that | ||
old_val = attrgetter(key)(model) | ||
old_val.data = new_val.data.to(device=old_val.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
old_val.data = new_val.data.to(device=old_val.device) | |
old_val.data.copy_ (new_val.data.to(device=old_val.device)) |
@BenjaminBossan When i tested on this PR it did not work either, even for loras with same rank and same scaling factor, I think the main difference was this line #9453 (comment) (this is from the memory though, so maybe you can test it out) I do think we should support different ranks + scaling factor, otherwise this feature isn't very meaningful |
This PR adds the possibility to hot-swap LoRA adapters. It is WIP.
Description
As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter.
Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See #9279 for more context.
Caveats
To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise.
Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature.
I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type:
I don't know enough about compilation to determine whether this is problematic or not.
Current state
This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature).
Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT.
Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature.
Finally, it should be properly documented.
I would like to collect feedback on the current state of the PR before putting more time into finalizing it.
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.