Skip to content

Commit

Permalink
[Pipeline] animatediff + vid2vid + controlnet (#9337)
Browse files Browse the repository at this point in the history
* add animatediff + vid2vide + controlnet

* post tests fixes

* PR discussion fixes

* update docs

* change input video to links on HF + update an example

* make quality fix

* fix ip adapter test

* fix ip adapter test input

* update ip adapter test
  • Loading branch information
reallyigor authored Sep 9, 2024
1 parent 485b8bb commit a7361dc
Show file tree
Hide file tree
Showing 7 changed files with 1,995 additions and 0 deletions.
98 changes: 98 additions & 0 deletions docs/source/en/api/pipelines/animatediff.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The abstract of the paper is the following:
| [AnimateDiffSparseControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py) | *Controlled Video-to-Video Generation with AnimateDiff using SparseCtrl* |
| [AnimateDiffSDXLPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py) | *Video-to-Video Generation with AnimateDiff* |
| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
| [AnimateDiffVideoToVideoControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py) | *Video-to-Video Generation with AnimateDiff using ControlNet* |

## Available checkpoints

Expand Down Expand Up @@ -518,6 +519,97 @@ Here are some sample outputs:
</tr>
</table>



### AnimateDiffVideoToVideoControlNetPipeline

AnimateDiff can be used together with ControlNets to enhance video-to-video generation by allowing for precise control over the output. ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala, and allows you to condition Stable Diffusion with an additional control image to ensure that the spatial information is preserved throughout the video.

This pipeline allows you to condition your generation both on the original video and on a sequence of control images.

```python
import torch
from PIL import Image
from tqdm.auto import tqdm

from controlnet_aux.processor import OpenposeDetector
from diffusers import AnimateDiffVideoToVideoControlNetPipeline
from diffusers.utils import export_to_gif, load_video
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler

# Load the ControlNet
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16)
# Load the motion adapter
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
# Load SD 1.5 based finetuned model
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained(
"SG161222/Realistic_Vision_V5.1_noVAE",
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
).to(device="cuda", dtype=torch.float16)

# Enable LCM to speed up inference
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])

video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif")
video = [frame.convert("RGB") for frame in video]

prompt = "astronaut in space, dancing"
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"

# Create controlnet preprocessor
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")

# Preprocess controlnet images
conditioning_frames = []
for frame in tqdm(video):
conditioning_frames.append(open_pose(frame))

strength = 0.8
with torch.inference_mode():
video = pipe(
video=video,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=10,
guidance_scale=2.0,
controlnet_conditioning_scale=0.75,
conditioning_frames=conditioning_frames,
strength=strength,
generator=torch.Generator().manual_seed(42),
).frames[0]

video = [frame.resize(conditioning_frames[0].size) for frame in video]
export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8)
```

Here are some sample outputs:

<table align="center">
<tr>
<th align="center">Source Video</th>
<th align="center">Output Video</th>
</tr>
<tr>
<td align="center">
anime girl, dancing
<br />
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif" alt="anime girl, dancing" />
</td>
<td align="center">
astronaut in space, dancing
<br/>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff_vid2vid_controlnet.gif" alt="astronaut in space, dancing" />
</td>
</tr>
</table>

**The lights and composition were transferred from the Source Video.**

## Using Motion LoRAs

Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations.
Expand Down Expand Up @@ -866,6 +958,12 @@ pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapt
- all
- __call__

## AnimateDiffVideoToVideoControlNetPipeline

[[autodoc]] AnimateDiffVideoToVideoControlNetPipeline
- all
- __call__

## AnimateDiffPipelineOutput

[[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
Expand Down Expand Up @@ -694,6 +695,7 @@
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
AnimateDiffVideoToVideoControlNetPipeline,
AnimateDiffVideoToVideoPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["flux"] = [
"FluxControlNetPipeline",
Expand Down Expand Up @@ -449,6 +450,7 @@
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
AnimateDiffVideoToVideoControlNetPipeline,
AnimateDiffVideoToVideoPipeline,
)
from .audioldm import AudioLDMPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/animatediff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
_import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
_import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -40,6 +41,7 @@
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline
from .pipeline_output import AnimateDiffPipelineOutput

else:
Expand Down
Loading

0 comments on commit a7361dc

Please sign in to comment.