From 02eeb8e77e5067eaf9fc1953c5e832ec894424aa Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 8 Oct 2024 21:47:44 +0530 Subject: [PATCH] [LoRA] Handle DoRA better (#9547) * handle dora. * print test * debug * fix * fix-copies * update logits * add warning in the test. * make is_dora check consistent. * fix-copies --- src/diffusers/loaders/lora_pipeline.py | 39 ++++++++++++++++++++++---- tests/lora/test_lora_layers_sdxl.py | 20 ++++++++----- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ba1435a8cbdc..8c8f2dfa84f8 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -99,7 +99,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -211,6 +211,11 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} network_alphas = None # TODO: replace it with a method from `state_dict_utils` @@ -562,7 +567,8 @@ def load_lora_weights( unet_config=self.unet.config, **kwargs, ) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -684,6 +690,11 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} network_alphas = None # TODO: replace it with a method from `state_dict_utils` @@ -1089,6 +1100,12 @@ def lora_state_dict( allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + return state_dict def load_lora_weights( @@ -1125,7 +1142,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -1587,9 +1604,13 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. - is_kohya = any(".lora_down.weight" in k for k in state_dict) if is_kohya: state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) @@ -1659,7 +1680,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -2374,6 +2395,12 @@ def lora_state_dict( allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + return state_dict def load_lora_weights( @@ -2405,7 +2432,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 4ec7ef897485..8deecd770c31 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -33,8 +33,10 @@ StableDiffusionXLPipeline, T2IAdapter, ) +from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.testing_utils import ( + CaptureLogger, load_image, nightly, numpy_cosine_similarity_distance, @@ -620,14 +622,18 @@ def test_integration_logits_for_dora_lora(self): pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya") pipeline.enable_model_cpu_offload() - images = pipeline( - "photo of ohwx dog", - num_inference_steps=10, - generator=torch.manual_seed(0), - output_type="np", - ).images + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + images = pipeline( + "photo of ohwx dog", + num_inference_steps=10, + generator=torch.manual_seed(0), + output_type="np", + ).images + assert "It seems like you are using a DoRA checkpoint" in cap_logger.out predicted_slice = images[0, -3:, -3:, -1].flatten() - expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516]) + expected_slice_scale = np.array([0.1817, 0.0697, 0.2346, 0.0900, 0.1261, 0.2279, 0.1767, 0.1991, 0.2886]) max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice) assert max_diff < 1e-3