From b5c8b555d7d5bf76e424e616052d902b9de2fd64 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 19 Apr 2024 02:13:27 +0200 Subject: [PATCH] Move IP Adapter Face ID to core (#7186) * Switch to peft and multi proj layers * Move Face ID loading and inference to core --------- Co-authored-by: Sayak Paul --- docs/source/en/using-diffusers/ip_adapter.md | 56 ++- .../en/using-diffusers/loading_adapters.md | 37 ++ examples/community/README.md | 2 - examples/community/ip_adapter_face_id.py | 436 +++--------------- src/diffusers/loaders/ip_adapter.py | 13 + src/diffusers/loaders/unet.py | 146 ++++++ src/diffusers/models/embeddings.py | 139 +++++- .../unets/test_models_unet_2d_condition.py | 60 ++- .../test_ip_adapter_stable_diffusion.py | 30 ++ tests/pipelines/test_pipelines_common.py | 48 +- 10 files changed, 592 insertions(+), 375 deletions(-) diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md index 4ae403538d2b..dc64b2548529 100644 --- a/docs/source/en/using-diffusers/ip_adapter.md +++ b/docs/source/en/using-diffusers/ip_adapter.md @@ -362,14 +362,12 @@ IP-Adapter's image prompting and compatibility with other adapters and models ma ### Face model -Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces: +Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces from the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repository: * [ip-adapter-full-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-full-face_sd15.safetensors) is conditioned with images of cropped faces and removed backgrounds * [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces -> [!TIP] -> -> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters. +Additionally, Diffusers supports all IP-Adapter checkpoints trained with face embeddings extracted by `insightface` face models. Supported models are from the [h94/IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) repository. For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models. @@ -411,6 +409,56 @@ image +To use IP-Adapter FaceID models, first extract face embeddings with `insightface`. Then pass the list of tensors to the pipeline as `ip_adapter_image_embeds`. + +```py +import torch +from diffusers import StableDiffusionPipeline, DDIMScheduler +from diffusers.utils import load_image +from insightface.app import FaceAnalysis + +pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16, +).to("cuda") +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None) +pipeline.set_ip_adapter_scale(0.6) + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png") + +ref_images_embeds = [] +app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(640, 640)) +image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) +faces = app.get(image) +image = torch.from_numpy(faces[0].normed_embedding) +ref_images_embeds.append(image.unsqueeze(0)) +ref_images_embeds = torch.stack(ref_images_embeds, dim=0).unsqueeze(0) +neg_ref_images_embeds = torch.zeros_like(ref_images_embeds) +id_embeds = torch.cat([neg_ref_images_embeds, ref_images_embeds]).to(dtype=torch.float16, device="cuda")) + +generator = torch.Generator(device="cpu").manual_seed(42) + +images = pipeline( + prompt="A photo of a girl", + ip_adapter_image_embeds=[id_embeds], + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=20, num_images_per_prompt=1, + generator=generator +).images +``` + +Both IP-Adapter FaceID Plus and Plus v2 models require CLIP image embeddings. You can prepare face embeddings as shown previously, then you can extract and pass CLIP embeddings to the hidden image projection layers. + +```py +clip_embeds = pipeline.prepare_ip_adapter_image_embeds([ip_adapter_images], None, torch.device("cuda"), num_images, True)[0] + +pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = clip_embeds.to(dtype=torch.float16) +pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = False # True if Plus v2 +``` + + ### Multi IP-Adapter More than one IP-Adapter can be used at the same time to generate specific images in more diverse styles. For example, you can use IP-Adapter-Face to generate consistent faces and characters, and IP-Adapter Plus to generate those faces in a specific style. diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index b079d2165ece..5871823aefe0 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -320,3 +320,40 @@ pipeline = AutoPipelineForText2Image.from_pretrained( pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors") ``` + +### IP-Adapter Face ID models + +The IP-Adapter FaceID models are experimental IP Adapters that use image embeddings generated by `insightface` instead of CLIP image embeddings. Some of these models also use LoRA to improve ID consistency. +You need to install `insightface` and all its requirements to use these models. + + +As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and are not intended for commercial use. + + +```py +pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16 +).to("cuda") + +pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin", image_encoder_folder=None) +``` + +If you want to use one of the two IP-Adapter FaceID Plus models, you must also load the CLIP image encoder, as this models use both `insightface` and CLIP image embeddings to achieve better photorealism. + +```py +from transformers import CLIPVisionModelWithProjection + +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", + torch_dtype=torch.float16, +) + +pipeline = AutoPipelineForText2Image.from_pretrained( + "runwayml/stable-diffusion-v1-5", + image_encoder=image_encoder, + torch_dtype=torch.float16 +).to("cuda") + +pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plus_sd15.bin") +``` diff --git a/examples/community/README.md b/examples/community/README.md index cc471874ca02..5cebc4f9f049 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -3819,12 +3819,10 @@ export_to_gif(frames, "animation.gif") IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded. You need to install `insightface` and all its requirements to use this model. You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`. -You have to disable PEFT BACKEND in order to load weights. You can find more results [here](https://github.com/huggingface/diffusers/pull/6276). ```py import diffusers -diffusers.utils.USE_PEFT_BACKEND = False import torch from diffusers.utils import load_image import cv2 diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py index b4d2446b5ce9..f644c287be04 100644 --- a/examples/community/ip_adapter_face_id.py +++ b/examples/community/ip_adapter_face_id.py @@ -26,7 +26,14 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.models.lora import LoRALinearLayer, adjust_lora_scale_text_encoder +from diffusers.models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) +from diffusers.models.embeddings import MultiIPAdapterImageProjection +from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -45,300 +52,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class LoRAIPAdapterAttnProcessor(nn.Module): - r""" - Attention processor for IP-Adapater. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - rank (`int`, defaults to 4): - The dimension of the LoRA update matrices. - network_alpha (`int`, *optional*): - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. - lora_scale (`float`, defaults to 1.0): - the weight scale of LoRA. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__( - self, - hidden_size, - cross_attention_dim=None, - rank=4, - network_alpha=None, - lora_scale=1.0, - scale=1.0, - num_tokens=4, - ): - super().__init__() - - self.rank = rank - self.lora_scale = lora_scale - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - # separate ip_hidden_states from encoder_hidden_states - if encoder_hidden_states is not None: - if isinstance(encoder_hidden_states, tuple): - encoder_hidden_states, ip_hidden_states = encoder_hidden_states - else: - deprecation_message = ( - "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release." - " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning." - ) - deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) - end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - [encoder_hidden_states[:, end_pos:, :]], - ) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class LoRAIPAdapterAttnProcessor2_0(nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - rank (`int`, defaults to 4): - The dimension of the LoRA update matrices. - network_alpha (`int`, *optional*): - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. - lora_scale (`float`, defaults to 1.0): - the weight scale of LoRA. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__( - self, - hidden_size, - cross_attention_dim=None, - rank=4, - network_alpha=None, - lora_scale=1.0, - scale=1.0, - num_tokens=4, - ): - super().__init__() - - self.rank = rank - self.lora_scale = lora_scale - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - # separate ip_hidden_states from encoder_hidden_states - if encoder_hidden_states is not None: - if isinstance(encoder_hidden_states, tuple): - encoder_hidden_states, ip_hidden_states = encoder_hidden_states - else: - deprecation_message = ( - "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release." - " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning." - ) - deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) - end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - [encoder_hidden_states[:, end_pos:, :]], - ) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - class IPAdapterFullImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): super().__init__() @@ -615,17 +328,13 @@ def convert_ip_adapter_image_proj_to_diffusers(self, state_dict): return image_projection def _load_ip_adapter_weights(self, state_dict): - from diffusers.models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, - ) - num_image_text_embeds = 4 self.unet.encoder_hid_proj = None # set ip-adapter cross-attention processors & load state_dict attn_procs = {} + lora_dict = {} key_id = 0 for name in self.unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim @@ -642,94 +351,99 @@ def _load_ip_adapter_weights(self, state_dict): AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor ) attn_procs[name] = attn_processor_class() - rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] - attn_module = self.unet - for n in name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, - out_features=attn_module.to_q.out_features, - rank=rank, - ) + + lora_dict.update( + {f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]} ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, - out_features=attn_module.to_k.out_features, - rank=rank, - ) + lora_dict.update( + {f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]} ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, - out_features=attn_module.to_v.out_features, - rank=rank, - ) + lora_dict.update( + {f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]} ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=rank, - ) + lora_dict.update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dict.update( + {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]} ) - - value_dict = {} - for k, module in attn_module.named_children(): - index = "." - if not hasattr(module, "set_lora_layer"): - index = ".0." - module = module[0] - lora_layer = getattr(module, "lora_layer") - for lora_name, w in lora_layer.state_dict().items(): - value_dict.update( - { - f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][ - f"{key_id}.{k}_lora.{lora_name}" - ] - } - ) - - attn_module.load_state_dict(value_dict, strict=False) - attn_module.to(dtype=self.dtype, device=self.device) key_id += 1 else: - rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] attn_processor_class = ( - LoRAIPAdapterAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else LoRAIPAdapterAttnProcessor + IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, - rank=rank, num_tokens=num_image_text_embeds, ).to(dtype=self.dtype, device=self.device) - value_dict = {} - for k, w in attn_procs[name].state_dict().items(): - value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) + lora_dict.update( + {f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]} + ) + lora_dict.update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dict.update( + {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]} + ) + value_dict = {} + value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) + value_dict.update({"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) attn_procs[name].load_state_dict(value_dict) key_id += 1 self.unet.set_attn_processor(attn_procs) + self.load_lora_weights(lora_dict, adapter_name="faceid") + self.set_adapters(["faceid"], adapter_weights=[1.0]) + # convert IP-Adapter Image Projection layers to diffusers image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) + image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)] - self.unet.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) + self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) self.unet.config.encoder_hid_dim_type = "ip_image_proj" def set_ip_adapter_scale(self, scale): unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet for attn_processor in unet.attn_processors.values(): - if isinstance(attn_processor, (LoRAIPAdapterAttnProcessor, LoRAIPAdapterAttnProcessor2_0)): - attn_processor.scale = scale + if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + attn_processor.scale = [scale] def _encode_prompt( self, @@ -1298,7 +1012,7 @@ def __call__( negative_image_embeds = torch.zeros_like(image_embeds) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) - + image_embeds = [image_embeds] # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) @@ -1319,7 +1033,7 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else None + added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else {} # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index c531d5a519f2..fdddc382212f 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -21,6 +21,7 @@ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict from ..utils import ( + USE_PEFT_BACKEND, _get_model_file, is_accelerate_available, is_torch_version, @@ -228,6 +229,18 @@ def load_ip_adapter( unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + extra_loras = unet._load_ip_adapter_loras(state_dicts) + if extra_loras != {}: + if not USE_PEFT_BACKEND: + logger.warning("PEFT backend is required to load these weights.") + else: + # apply the IP Adapter Face ID LoRA weights + peft_config = getattr(unet, "peft_config", {}) + for k, lora in extra_loras.items(): + if f"faceid_{k}" not in peft_config: + self.load_lora_weights(lora, adapter_name=f"faceid_{k}") + self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) + def set_ip_adapter_scale(self, scale): """ Sets the conditioning scale between text and image. diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index a099267f40b4..294db44ee61d 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -27,6 +27,8 @@ from ..models.embeddings import ( ImageProjection, + IPAdapterFaceIDImageProjection, + IPAdapterFaceIDPlusImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, @@ -756,6 +758,90 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us diffusers_name = diffusers_name.replace("proj.3", "norm") updated_state_dict[diffusers_name] = value + elif "perceiver_resampler.proj_in.weight" in state_dict: + # IP-Adapter Face ID Plus + id_embeddings_dim = state_dict["proj.0.weight"].shape[1] + embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0] + hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1] + output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0] + heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64 + + with init_context(): + image_projection = IPAdapterFaceIDPlusImageProjection( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + id_embeddings_dim=id_embeddings_dim, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("perceiver_resampler.", "") + diffusers_name = diffusers_name.replace("0.to", "attn.to") + diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.") + diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.") + diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.") + diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.") + diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0") + diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1") + diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0") + diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1") + diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0") + diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1") + diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0") + diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1") + + if "norm1" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value + elif "norm2" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value + elif "to_kv" in diffusers_name: + v_chunk = value.chunk(2, dim=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in diffusers_name: + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + elif "proj.0.weight" == diffusers_name: + updated_state_dict["proj.net.0.proj.weight"] = value + elif "proj.0.bias" == diffusers_name: + updated_state_dict["proj.net.0.proj.bias"] = value + elif "proj.2.weight" == diffusers_name: + updated_state_dict["proj.net.2.weight"] = value + elif "proj.2.bias" == diffusers_name: + updated_state_dict["proj.net.2.bias"] = value + else: + updated_state_dict[diffusers_name] = value + + elif "norm.weight" in state_dict: + # IP-Adapter Face ID + id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = id_embeddings_dim_out // id_embeddings_dim_in + norm_layer = "norm.weight" + cross_attention_dim = state_dict[norm_layer].shape[0] + num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim + + with init_context(): + image_projection = IPAdapterFaceIDImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=id_embeddings_dim_in, + mult=multiplier, + num_tokens=num_tokens, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + updated_state_dict[diffusers_name] = value + else: # IP-Adapter Plus num_image_text_embeds = state_dict["latents"].shape[1] @@ -847,6 +933,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor ) attn_procs[name] = attn_processor_class() + else: attn_processor_class = ( IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor @@ -859,6 +946,12 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F elif "proj.3.weight" in state_dict["image_proj"]: # IP-Adapter Full Face num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token + elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]: + # IP-Adapter Face ID Plus + num_image_text_embeds += [4] + elif "norm.weight" in state_dict["image_proj"]: + # IP-Adapter Face ID + num_image_text_embeds += [4] else: # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] @@ -910,6 +1003,59 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): self.to(dtype=self.dtype, device=self.device) + def _load_ip_adapter_loras(self, state_dicts): + lora_dicts = {} + for key_id, name in enumerate(self.attn_processors.keys()): + for i, state_dict in enumerate(state_dicts): + if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: + if i not in lora_dicts: + lora_dicts[i] = {} + lora_dicts[i].update( + { + f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_k_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_q_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_v_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]} + ) + lora_dicts[i].update( + {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]} + ) + lora_dicts[i].update( + {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]} + ) + lora_dicts[i].update( + { + f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.up.weight" + ] + } + ) + return lora_dicts + class FromOriginalUNetMixin: """ diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 91bbd58fa025..ced520bb8204 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -472,6 +472,22 @@ def forward(self, image_embeds: torch.FloatTensor): return self.norm(self.ff(image_embeds)) +class IPAdapterFaceIDImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + super().__init__() + from .attention import FeedForward + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.FloatTensor): + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + + class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() @@ -794,13 +810,14 @@ class IPAdapterPlusImageProjection(nn.Module): """Resampler of IP-Adapter Plus. Args: - ---- embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, that is the same number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. - hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. - Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio + Defaults to 16. num_queries (int): + The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. """ @@ -851,11 +868,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Args: - ---- x (torch.Tensor): Input Tensor. - Returns: - ------- torch.Tensor: Output Tensor. """ latents = self.latents.repeat(x.size(0), 1, 1) @@ -875,6 +889,119 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.norm_out(latents) +class IPAdapterPlusImageProjectionBlock(nn.Module): + def __init__( + self, + embed_dims: int = 768, + dim_head: int = 64, + heads: int = 16, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.ln0 = nn.LayerNorm(embed_dims) + self.ln1 = nn.LayerNorm(embed_dims) + self.attn = Attention( + query_dim=embed_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ) + self.ff = nn.Sequential( + nn.LayerNorm(embed_dims), + FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ) + + def forward(self, x, latents, residual): + encoder_hidden_states = self.ln0(x) + latents = self.ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = self.attn(latents, encoder_hidden_states) + residual + latents = self.ff(latents) + latents + return latents + + +class IPAdapterFaceIDPlusImageProjection(nn.Module): + """FacePerceiverResampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + ffproj_ratio (float): The expansion ratio of feedforward network hidden + layer channels (for ID embeddings). Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 768, + hidden_dims: int = 1280, + id_embeddings_dim: int = 512, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_tokens: int = 4, + num_queries: int = 8, + ffn_ratio: float = 4, + ffproj_ratio: int = 2, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.num_tokens = num_tokens + self.embed_dim = embed_dims + self.clip_embeds = None + self.shortcut = False + self.shortcut_scale = 1.0 + + self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) + self.norm = nn.LayerNorm(embed_dims) + + self.proj_in = nn.Linear(hidden_dims, embed_dims) + + self.proj_out = nn.Linear(embed_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + id_embeds (torch.Tensor): Input Tensor (ID embeds). + Returns: + torch.Tensor: Output Tensor. + """ + id_embeds = id_embeds.to(self.clip_embeds.dtype) + id_embeds = self.proj(id_embeds) + id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) + id_embeds = self.norm(id_embeds) + latents = id_embeds + + clip_embeds = self.proj_in(self.clip_embeds) + x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + out = self.norm_out(latents) + if self.shortcut: + out = id_embeds + self.shortcut_scale * out + return out + + class MultiIPAdapterImageProjection(nn.Module): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): super().__init__() diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index bf6b7fe99b7f..1b8a998cfd66 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -30,7 +30,7 @@ IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) -from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection +from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -190,6 +190,64 @@ def create_ip_adapter_plus_state_dict(model): return ip_state_dict +def create_ip_adapter_faceid_state_dict(model): + # "ip_adapter" (cross-attention weights) + # no LoRA weights + ip_cross_attn_state_dict = {} + key_id = 1 + + for name in model.attn_processors.keys(): + cross_attention_dim = ( + None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim + ) + + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + if cross_attention_dim is not None: + sd = IPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + } + ) + + key_id += 2 + + # "image_proj" (ImageProjection layer weights) + cross_attention_dim = model.config["cross_attention_dim"] + image_projection = IPAdapterFaceIDImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, mult=2, num_tokens=4 + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.0.weight": sd["ff.net.0.proj.weight"], + "proj.0.bias": sd["ff.net.0.proj.bias"], + "proj.2.weight": sd["ff.net.2.weight"], + "proj.2.bias": sd["ff.net.2.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + def create_custom_diffusion_layers(model, mock_weights: bool = True): train_kv = True train_q_out = True diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index a4217e73a9a2..ef70baa05f19 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -37,6 +37,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, is_flaky, + load_pt, numpy_cosine_similarity_distance, require_torch_gpu, slow, @@ -306,6 +307,35 @@ def test_multi(self): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 + def test_text_to_image_face_id(self): + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter-FaceID", + subfolder=None, + weight_name="ip-adapter-faceid_sd15.bin", + image_encoder_folder=None, + ) + pipeline.set_ip_adapter_scale(0.7) + + inputs = self.get_dummy_inputs() + id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[ + 0 + ] + id_embeds = id_embeds.reshape((2, 1, 1, 512)) + inputs["ip_adapter_image_embeds"] = [id_embeds] + inputs["ip_adapter_image"] = None + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array( + [0.32714844, 0.3239746, 0.3466797, 0.31835938, 0.30004883, 0.3251953, 0.3215332, 0.3552246, 0.3251953] + ) + max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) + assert max_diff < 5e-4 + @slow @require_torch_gpu diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 2ff53374be56..b5dde78f11b2 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -48,7 +48,10 @@ get_autoencoder_tiny_config, get_consistency_vae_config, ) -from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict +from ..models.unets.test_models_unet_2d_condition import ( + create_ip_adapter_faceid_state_dict, + create_ip_adapter_state_dict, +) from ..others.test_utils import TOKEN, USER, is_staging_test @@ -239,6 +242,9 @@ def test_pipeline_signature(self): def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): return torch.randn((2, 1, cross_attention_dim), device=torch_device) + def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32): + return torch.randn((2, 1, 1, cross_attention_dim), device=torch_device) + def _get_dummy_masks(self, input_size: int = 64): _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device) _masks[0, :, :, : int(input_size / 2)] = 1 @@ -416,6 +422,46 @@ def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" ) + def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + output_without_adapter = pipe(**inputs)[0] + output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten() + + adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs)[0] + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs)[0] + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" + ) + class PipelineLatentTesterMixin: """