Skip to content

Commit

Permalink
Merge branch 'main' into dreambooth-lora-flux-exploration
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Oct 16, 2024
2 parents b1b2128 + cef4f65 commit bfb0741
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 40 deletions.
104 changes: 84 additions & 20 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,14 +1358,30 @@ def load_lora_into_transformer(
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

# Offload back.
if is_model_cpu_offload:
Expand Down Expand Up @@ -1932,14 +1948,30 @@ def load_lora_into_transformer(
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

# Offload back.
if is_model_cpu_offload:
Expand Down Expand Up @@ -2279,14 +2311,30 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

# Offload back.
if is_model_cpu_offload:
Expand Down Expand Up @@ -2717,14 +2765,30 @@ def load_lora_into_transformer(
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

# Offload back.
if is_model_cpu_offload:
Expand Down
26 changes: 21 additions & 5 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,30 @@ def _process_lora(
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

return is_model_cpu_offload, is_sequential_cpu_offload

Expand Down
11 changes: 6 additions & 5 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
if is_transformers_available():
import transformers

if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed

if is_peft_available():
from peft import set_peft_model_state_dict

Expand Down Expand Up @@ -442,15 +445,13 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
self.cur_decay_value = decay
one_minus_decay = 1 - decay

context_manager = contextlib.nullcontext
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed
context_manager = contextlib.nullcontext()

if self.foreach:
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)

with context_manager():
with context_manager:
params_grad = [param for param in parameters if param.requires_grad]
s_params_grad = [
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
Expand All @@ -472,7 +473,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)

with context_manager():
with context_manager:
if param.requires_grad:
s_param.sub_(one_minus_decay * (s_param - param))
else:
Expand Down
25 changes: 17 additions & 8 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from diffusers.utils.testing_utils import (
floats_tensor,
is_peft_available,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
slow,
Expand Down Expand Up @@ -166,7 +167,7 @@ def test_modify_padding_mode(self):
@slow
@require_torch_gpu
@require_peft_backend
@unittest.skip("We cannot run inference on this model with the current CI hardware")
# @unittest.skip("We cannot run inference on this model with the current CI hardware")
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
Expand Down Expand Up @@ -208,9 +209,11 @@ def test_flux_the_last_ben(self):
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090])
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert max_diff < 1e-3

def test_flux_kohya(self):
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
Expand All @@ -230,7 +233,9 @@ def test_flux_kohya(self):
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert max_diff < 1e-3

def test_flux_kohya_with_text_encoder(self):
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
Expand All @@ -248,9 +253,11 @@ def test_flux_kohya_with_text_encoder(self):
).images

out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219])
expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert max_diff < 1e-3

def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
Expand All @@ -268,6 +275,8 @@ def test_flux_xlabs(self):
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980])
expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980])

max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
assert max_diff < 1e-3
91 changes: 89 additions & 2 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
LCMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
CaptureLogger,
floats_tensor,
require_peft_backend,
require_peft_version_greater,
Expand Down Expand Up @@ -219,10 +221,18 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules

if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
if (
"text_encoder" in lora_loadable_modules
and hasattr(pipe, "text_encoder")
and getattr(pipe.text_encoder, "peft_config", None) is not None
):
modules_to_save["text_encoder"] = pipe.text_encoder

if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
if (
"text_encoder_2" in lora_loadable_modules
and hasattr(pipe, "text_encoder_2")
and getattr(pipe.text_encoder_2, "peft_config", None) is not None
):
modules_to_save["text_encoder_2"] = pipe.text_encoder_2

if has_denoiser:
Expand Down Expand Up @@ -1747,6 +1757,83 @@ def test_simple_inference_with_dora(self):
"DoRA lora should change the output",
)

def test_missing_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
pipe.unload_lora_weights()
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

# To make things dynamic since we cannot settle with a single key for all the models where we
# offer PEFT support.
missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key]

logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline")
)
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)

# Since the missing key won't contain the adapter name ("default_0").
# Also strip out the component prefix (such as "unet." from `missing_key`).
component = list({k.split(".")[0] for k in state_dict})[0]
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))

def test_unexpected_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
pipe.unload_lora_weights()
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)

logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline")
)
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)

self.assertTrue(".diffusers_cat" in cap_logger.out)

@unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
"""
Expand Down

0 comments on commit bfb0741

Please sign in to comment.