Skip to content

Commit

Permalink
Move IP Adapter Face ID to core (#7186)
Browse files Browse the repository at this point in the history
* Switch to peft and multi proj layers

* Move Face ID loading and inference to core

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
fabiorigano and sayakpaul authored Apr 19, 2024
1 parent e23c27e commit b5c8b55
Show file tree
Hide file tree
Showing 10 changed files with 592 additions and 375 deletions.
56 changes: 52 additions & 4 deletions docs/source/en/using-diffusers/ip_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,12 @@ IP-Adapter's image prompting and compatibility with other adapters and models ma

### Face model

Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces:
Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces from the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repository:

* [ip-adapter-full-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-full-face_sd15.safetensors) is conditioned with images of cropped faces and removed backgrounds
* [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces

> [!TIP]
>
> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters.
Additionally, Diffusers supports all IP-Adapter checkpoints trained with face embeddings extracted by `insightface` face models. Supported models are from the [h94/IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) repository.

For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models.

Expand Down Expand Up @@ -411,6 +409,56 @@ image
</div>
</div>

To use IP-Adapter FaceID models, first extract face embeddings with `insightface`. Then pass the list of tensors to the pipeline as `ip_adapter_image_embeds`.

```py
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils import load_image
from insightface.app import FaceAnalysis

pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
).to("cuda")
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None)
pipeline.set_ip_adapter_scale(0.6)

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png")

ref_images_embeds = []
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
faces = app.get(image)
image = torch.from_numpy(faces[0].normed_embedding)
ref_images_embeds.append(image.unsqueeze(0))
ref_images_embeds = torch.stack(ref_images_embeds, dim=0).unsqueeze(0)
neg_ref_images_embeds = torch.zeros_like(ref_images_embeds)
id_embeds = torch.cat([neg_ref_images_embeds, ref_images_embeds]).to(dtype=torch.float16, device="cuda"))

generator = torch.Generator(device="cpu").manual_seed(42)

images = pipeline(
prompt="A photo of a girl",
ip_adapter_image_embeds=[id_embeds],
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=20, num_images_per_prompt=1,
generator=generator
).images
```

Both IP-Adapter FaceID Plus and Plus v2 models require CLIP image embeddings. You can prepare face embeddings as shown previously, then you can extract and pass CLIP embeddings to the hidden image projection layers.

```py
clip_embeds = pipeline.prepare_ip_adapter_image_embeds([ip_adapter_images], None, torch.device("cuda"), num_images, True)[0]

pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = clip_embeds.to(dtype=torch.float16)
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = False # True if Plus v2
```


### Multi IP-Adapter

More than one IP-Adapter can be used at the same time to generate specific images in more diverse styles. For example, you can use IP-Adapter-Face to generate consistent faces and characters, and IP-Adapter Plus to generate those faces in a specific style.
Expand Down
37 changes: 37 additions & 0 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,40 @@ pipeline = AutoPipelineForText2Image.from_pretrained(

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
```

### IP-Adapter Face ID models

The IP-Adapter FaceID models are experimental IP Adapters that use image embeddings generated by `insightface` instead of CLIP image embeddings. Some of these models also use LoRA to improve ID consistency.
You need to install `insightface` and all its requirements to use these models.

<Tip warning={true}>
As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and are not intended for commercial use.
</Tip>

```py
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin", image_encoder_folder=None)
```

If you want to use one of the two IP-Adapter FaceID Plus models, you must also load the CLIP image encoder, as this models use both `insightface` and CLIP image embeddings to achieve better photorealism.

```py
from transformers import CLIPVisionModelWithProjection

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5",
image_encoder=image_encoder,
torch_dtype=torch.float16
).to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plus_sd15.bin")
```
2 changes: 0 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3819,12 +3819,10 @@ export_to_gif(frames, "animation.gif")
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
You need to install `insightface` and all its requirements to use this model.
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
You have to disable PEFT BACKEND in order to load weights.
You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).

```py
import diffusers
diffusers.utils.USE_PEFT_BACKEND = False
import torch
from diffusers.utils import load_image
import cv2
Expand Down
Loading

0 comments on commit b5c8b55

Please sign in to comment.