Skip to content

Commit

Permalink
FIX: Small fixes to hotswapping (#2366)
Browse files Browse the repository at this point in the history
A couple of smaller issues that surfaced when working on the diffusers
integration are not fixed.

- Better detection if model is compiled in
  prepare_model_for_compiled_hotswap
- Fix handling of models that are compiled but where compilation is not
  detected (from "inside" the model)
- Handle device of swapped in adapter weights.
- Wrong adapter name in compiled diffusion model test
- Add hotswap test for different alphas and ranks but model not being
  compiled (linear and conv2d)
- Make _check_hotswap_configs_compatible "public"
- Don't import diffusers in test root
- Add support for compiled Conv2d
  • Loading branch information
BenjaminBossan authored Feb 12, 2025
1 parent 363c14e commit 6d03360
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 24 deletions.
5 changes: 5 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,8 @@ def is_xpu_available(check_device=False):
except RuntimeError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()


@lru_cache
def is_diffusers_available():
return importlib.util.find_spec("diffusers") is not None
32 changes: 24 additions & 8 deletions src/peft/utils/hotswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def prepare_model_for_compiled_hotswap(
# do inference with adapter 1
```
"""
is_compiled = hasattr(model, "_orig_mod")
is_compiled = hasattr(model, "_orig_mod") or getattr(model, "_compiled_call_impl", False)
if is_compiled:
raise ValueError("Call prepare_model_for_compiled_hotswap *before* compiling the model")

Expand Down Expand Up @@ -416,18 +416,34 @@ def hotswap_adapter_from_state_dict(
# swap actual weights
# no need to account for potential _orig_mod in key here, as torch handles that
old_val = attrgetter(key)(model)
new_val = new_val.to(old_val.data.device)

# We try to detect if the model is compiled but it does not always work, e.g. if hotswapping is called from
# within the model itself. In this case, swap_tensors raises RuntimeError and should continue without
# swap_tensors.
if not is_compiled and not is_compiled_inplace:
torch.utils.swap_tensors(old_val, new_val)
continue
try:
torch.utils.swap_tensors(old_val, new_val)
continue
except RuntimeError:
is_compiled = True

# Compiled models don't work with swap_tensors because there are weakrefs for the tensor. It is unclear if
# this workaround could not cause trouble but the tests indicate that it works.
if old_val.shape == new_val.shape:
# either
# - adapters had the same rank
# - adapters were padded with prepare_model_for_compiled_hotswap and 2nd adapter was larger
old_val.data = new_val.data
else:
if old_val.dim() != 2:
# TODO conv2d
raise NotImplementedError
# if 2nd adapter was smaller, ensure to fill up to adapter dimension and set the rest to zeros
if old_val.dim() not in (2, 4):
raise NotImplementedError(
f"Trying to hotswap an adapter whose weight has {old_val.dim()} dimensions, but only Conv2d and "
"Linear are supported"
)

# Linear or Conv2d: the check for dim 0 or 1 works for both of these layer types
if old_val.shape[0] > new_val.shape[0]:
old_val.data.fill_(0)
old_val.data[: new_val.shape[0]] = new_val.data
Expand All @@ -442,7 +458,7 @@ def hotswap_adapter_from_state_dict(
)


def _check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None:
def check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None:
"""
Check if two configs are compatible for hot-swapping.
Expand Down Expand Up @@ -548,7 +564,7 @@ def hotswap_adapter(model, model_name_or_path, adapter_name, torch_device=None,
]
config = config_cls.from_pretrained(model_name_or_path, **kwargs)
# config keys that could affect the model output besides what is determined by the state_dict
_check_hotswap_configs_compatible(model.active_peft_config, config)
check_hotswap_configs_compatible(model.active_peft_config, config)

state_dict = load_peft_weights(model_name_or_path, device=torch_device, **kwargs)

Expand Down
45 changes: 29 additions & 16 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from accelerate.test_utils.testing import run_command
from accelerate.utils import patch_environment
from datasets import Audio, Dataset, DatasetDict, load_dataset
from diffusers import UNet2DConditionModel
from diffusers.utils.testing_utils import floats_tensor
from packaging import version
from parameterized import parameterized
from torch.distributed import init_process_group
Expand Down Expand Up @@ -71,7 +69,7 @@
replace_lora_weights_loftq,
set_peft_model_state_dict,
)
from peft.import_utils import is_xpu_available
from peft.import_utils import is_diffusers_available, is_xpu_available
from peft.tuners import boft
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
Expand Down Expand Up @@ -4237,7 +4235,7 @@ def check_hotswap(self, do_hotswap, ranks, alpha_scalings):
assert torch.allclose(output1, output_after1, atol=tol, rtol=tol)

# it is important to check hotswapping small to large ranks and large to small ranks
@pytest.mark.parametrize("ranks", [(7, 13), (13, 7)])
@pytest.mark.parametrize("ranks", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_does_not_trigger_recompilation(self, ranks):
with torch._dynamo.config.patch(error_on_recompile=True): # raise an error on recompilation
self.check_hotswap(do_hotswap=True, ranks=ranks, alpha_scalings=ranks)
Expand All @@ -4255,8 +4253,8 @@ def test_no_hotswapping_compiled_model_triggers_recompilation(self):

def get_small_unet(self):
# from diffusers UNet2DConditionModelTests
# TODO: This appears not to work yet in full pipeline context, see:
# https://github.com/huggingface/diffusers/pull/9453#issuecomment-2418508871
from diffusers import UNet2DConditionModel

torch.manual_seed(0)
init_dict = {
"block_out_channels": (4, 8),
Expand All @@ -4273,19 +4271,22 @@ def get_small_unet(self):
model = UNet2DConditionModel(**init_dict)
return model.to(self.torch_device)

def get_unet_lora_config(self, lora_rank, lora_alpha):
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py
# note that this only targets linear layers by default
unet_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config

def get_dummy_input(self):
# from UNet2DConditionModelTests
from diffusers.utils.testing_utils import floats_tensor

batch_size = 4
num_channels = 4
sizes = (16, 16)
Expand All @@ -4310,13 +4311,13 @@ def set_lora_device(self, model, adapter_names, device):
device
)

def check_hotswap_diffusion(self, do_hotswap, ranks, alpha_scalings):
def check_hotswap_diffusion(self, do_hotswap, ranks, alpha_scalings, target_modules):
dummy_input = self.get_dummy_input()
unet = self.get_small_unet()
rank0, rank1 = ranks
alpha0, alpha1 = alpha_scalings
lora_config0 = self.get_unet_lora_config(rank0, alpha0)
lora_config1 = self.get_unet_lora_config(rank1, alpha1)
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules=target_modules)
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules=target_modules)
unet.add_adapter(lora_config0, adapter_name="adapter0")
unet.add_adapter(lora_config1, adapter_name="adapter1")

Expand All @@ -4337,19 +4338,31 @@ def check_hotswap_diffusion(self, do_hotswap, ranks, alpha_scalings):
unet(**dummy_input)["sample"]

if do_hotswap:
unet.load_lora_adapter(file_name1, adapter_name="default_0", hotswap=True)
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True)
else:
# offloading the old and loading the new adapter will result in recompilation
self.set_lora_device(unet, adapter_names=["default_0"], device="cpu")
self.set_lora_device(unet, adapter_names=["adapter0"], device="cpu")
unet.load_lora_adapter(file_name1, adapter_name="other_name", hotswap=False)

# we need to call forward to potentially trigger recompilation
unet(**dummy_input)["sample"]

@pytest.mark.skipif(not is_diffusers_available(), reason="Test requires diffusers to be installed")
@pytest.mark.xfail(
strict=True, reason="Requires hotswap to be implemented in diffusers", raises=torch._dynamo.exc.RecompileError
)
def test_hotswapping_compiled_diffusers_model_does_not_trigger_recompilation(self):
ranks = 7, 13
# it is important to check hotswapping small to large ranks and large to small ranks
@pytest.mark.parametrize("ranks", [(11, 11), (7, 13), (13, 7)])
@pytest.mark.parametrize(
"target_modules",
[
["to_q", "to_k", "to_v", "to_out.0"], # Linear layers
["conv", "conv1", "conv2"], # Conv2d layers
["to_q", "conv"], # mix of Linear and Conv2d
],
)
def test_hotswapping_compiled_diffusers_model_does_not_trigger_recompilation(self, ranks, target_modules):
with torch._dynamo.config.patch(error_on_recompile=True): # raise an error on recompilation
self.check_hotswap_diffusion(do_hotswap=True, ranks=ranks, alpha_scalings=ranks)
self.check_hotswap_diffusion(
do_hotswap=True, ranks=ranks, alpha_scalings=ranks, target_modules=target_modules
)
118 changes: 118 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,6 +3027,124 @@ def test_hotswap_extra_key_raises(self, tmp_path):
with pytest.raises(RuntimeError, match=msg):
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")

@pytest.mark.parametrize("ranks", [(7, 13), (13, 7)])
def test_hotswap_works_different_ranks_alphas(self, ranks, tmp_path):
# same as test_hotswap_works but different rank and alpha
# Load 2 different adapters and check that we can hotswap between them, with the model optionally being
# compiled.
atol, rtol = 1e-4, 1e-4
inputs = torch.rand(3, 10).to(self.torch_device)

# create adapter 0
config0 = LoraConfig(target_modules=["lin0", "lin1"], r=ranks[0], lora_alpha=ranks[0], init_lora_weights=False)
model = self.get_model()
torch.manual_seed(0)
model = get_peft_model(model, config0)
model.eval()
with torch.inference_mode():
output0 = model(inputs)
model.save_pretrained(tmp_path / "adapter0")

del model

# create adapter 1
config1 = LoraConfig(target_modules=["lin0"], r=ranks[1], lora_alpha=ranks[1], init_lora_weights=False)
model = self.get_model()
torch.manual_seed(1)
model = get_peft_model(model, config1)
model.eval()
with torch.inference_mode():
output1 = model(inputs)
model.save_pretrained(tmp_path / "adapter1")

# sanity check: they're not the same
assert not torch.allclose(output0, output1, atol=atol, rtol=rtol)

del model

# load adapter 0
model = self.get_model()
model = PeftModel.from_pretrained(model, tmp_path / "adapter0")
with torch.inference_mode():
output_loaded0 = model(inputs)

# sanity check: same output after loading for adapter 0
assert torch.allclose(output0, output_loaded0, atol=atol, rtol=rtol)

# hotswap with adapter 1
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")
with torch.inference_mode():
output_loaded1 = model(inputs)

# real check: model now behaves like adapter 1
assert torch.allclose(output1, output_loaded1, atol=atol, rtol=rtol)

# hotswap back to adapter 0
hotswap_adapter(model, tmp_path / "adapter0", adapter_name="default")
with torch.inference_mode():
output_loaded_back0 = model(inputs)

# real check: model now behaves again like adapter 0
assert torch.allclose(output0, output_loaded_back0, atol=atol, rtol=rtol)

@pytest.mark.parametrize("ranks", [(7, 13), (13, 7)])
def test_hotswap_works_different_ranks_alphas_conv2d(self, ranks, tmp_path):
# same as previous test, but for a Conv2d model
atol, rtol = 1e-4, 1e-4
inputs = torch.rand(3, 3, 10, 10).to(self.torch_device)

# create adapter 0
config0 = LoraConfig(target_modules=["conv"], r=ranks[0], init_lora_weights=False)
model = self.get_model_conv2d()
torch.manual_seed(0)
model = get_peft_model(model, config0)
model.eval()
with torch.inference_mode():
output0 = model(inputs)
model.save_pretrained(tmp_path / "adapter0")

del model

# create adapter 1
config1 = LoraConfig(target_modules=["conv"], r=ranks[1], init_lora_weights=False)
model = self.get_model_conv2d()
torch.manual_seed(1)
model = get_peft_model(model, config1)
model.eval()
with torch.inference_mode():
output1 = model(inputs)
model.save_pretrained(tmp_path / "adapter1")

# sanity check: they're not the same
assert not torch.allclose(output0, output1, atol=atol, rtol=rtol)

del model

# load adapter 0
model = self.get_model_conv2d()
model = PeftModel.from_pretrained(model, tmp_path / "adapter0")
with torch.inference_mode():
output_loaded0 = model(inputs)

# sanity check: same output after loading for adapter 0
assert torch.allclose(output0, output_loaded0, atol=atol, rtol=rtol)

# hotswap with adapter 1
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")
with torch.inference_mode():
output_loaded1 = model(inputs)

# real check: model now behaves like adapter 1
assert torch.allclose(output1, output_loaded1, atol=atol, rtol=rtol)

# hotswap back to adapter 0
hotswap_adapter(model, tmp_path / "adapter0", adapter_name="default")
with torch.inference_mode():
output_loaded_back0 = model(inputs)

# real check: model now behaves again like adapter 0
assert torch.allclose(output0, output_loaded_back0, atol=atol, rtol=rtol)

def test_prepare_model_for_compiled_hotswap_scalings_are_tensors(self):
config = LoraConfig(target_modules=["lin0", "lin1"])
model = self.get_model()
Expand Down

0 comments on commit 6d03360

Please sign in to comment.