diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md
index bfd6ab973d5e..7cacad87d78c 100644
--- a/docs/source/en/api/pipelines/animatediff.md
+++ b/docs/source/en/api/pipelines/animatediff.md
@@ -29,6 +29,7 @@ The abstract of the paper is the following:
| [AnimateDiffSparseControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py) | *Controlled Video-to-Video Generation with AnimateDiff using SparseCtrl* |
| [AnimateDiffSDXLPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py) | *Video-to-Video Generation with AnimateDiff* |
| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
+| [AnimateDiffVideoToVideoControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py) | *Video-to-Video Generation with AnimateDiff using ControlNet* |
## Available checkpoints
@@ -518,6 +519,97 @@ Here are some sample outputs:
+
+
+### AnimateDiffVideoToVideoControlNetPipeline
+
+AnimateDiff can be used together with ControlNets to enhance video-to-video generation by allowing for precise control over the output. ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala, and allows you to condition Stable Diffusion with an additional control image to ensure that the spatial information is preserved throughout the video.
+
+This pipeline allows you to condition your generation both on the original video and on a sequence of control images.
+
+```python
+import torch
+from PIL import Image
+from tqdm.auto import tqdm
+
+from controlnet_aux.processor import OpenposeDetector
+from diffusers import AnimateDiffVideoToVideoControlNetPipeline
+from diffusers.utils import export_to_gif, load_video
+from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler
+
+# Load the ControlNet
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16)
+# Load the motion adapter
+motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
+# Load SD 1.5 based finetuned model
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
+pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained(
+ "SG161222/Realistic_Vision_V5.1_noVAE",
+ motion_adapter=motion_adapter,
+ controlnet=controlnet,
+ vae=vae,
+).to(device="cuda", dtype=torch.float16)
+
+# Enable LCM to speed up inference
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
+pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
+pipe.set_adapters(["lcm-lora"], [0.8])
+
+video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif")
+video = [frame.convert("RGB") for frame in video]
+
+prompt = "astronaut in space, dancing"
+negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
+
+# Create controlnet preprocessor
+open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
+
+# Preprocess controlnet images
+conditioning_frames = []
+for frame in tqdm(video):
+ conditioning_frames.append(open_pose(frame))
+
+strength = 0.8
+with torch.inference_mode():
+ video = pipe(
+ video=video,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=10,
+ guidance_scale=2.0,
+ controlnet_conditioning_scale=0.75,
+ conditioning_frames=conditioning_frames,
+ strength=strength,
+ generator=torch.Generator().manual_seed(42),
+ ).frames[0]
+
+video = [frame.resize(conditioning_frames[0].size) for frame in video]
+export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8)
+```
+
+Here are some sample outputs:
+
+
+
+ Source Video |
+ Output Video |
+
+
+
+ anime girl, dancing
+
+
+ |
+
+ astronaut in space, dancing
+
+
+ |
+
+
+
+**The lights and composition were transferred from the Source Video.**
+
## Using Motion LoRAs
Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations.
@@ -866,6 +958,12 @@ pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapt
- all
- __call__
+## AnimateDiffVideoToVideoControlNetPipeline
+
+[[autodoc]] AnimateDiffVideoToVideoControlNetPipeline
+ - all
+ - __call__
+
## AnimateDiffPipelineOutput
[[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index af28b383b563..5b505b6a1f3a 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -245,6 +245,7 @@
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
+ "AnimateDiffVideoToVideoControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
@@ -694,6 +695,7 @@
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
+ AnimateDiffVideoToVideoControlNetPipeline,
AnimateDiffVideoToVideoPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index ad7ea2872ac5..e4d37a905b86 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -123,6 +123,7 @@
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
+ "AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["flux"] = [
"FluxControlNetPipeline",
@@ -449,6 +450,7 @@
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
+ AnimateDiffVideoToVideoControlNetPipeline,
AnimateDiffVideoToVideoPipeline,
)
from .audioldm import AudioLDMPipeline
diff --git a/src/diffusers/pipelines/animatediff/__init__.py b/src/diffusers/pipelines/animatediff/__init__.py
index 3ee72bc44003..d916abf2d85d 100644
--- a/src/diffusers/pipelines/animatediff/__init__.py
+++ b/src/diffusers/pipelines/animatediff/__init__.py
@@ -26,6 +26,7 @@
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
_import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
+ _import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -40,6 +41,7 @@
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
+ from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline
from .pipeline_output import AnimateDiffPipelineOutput
else:
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
new file mode 100644
index 000000000000..1d26f95a2f58
--- /dev/null
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
@@ -0,0 +1,1341 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from ...image_processor import PipelineImageInput
+from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...models.unets.unet_motion_model import MotionAdapter
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...video_processor import VideoProcessor
+from ..controlnet.multicontrolnet import MultiControlNetModel
+from ..free_init_utils import FreeInitMixin
+from ..free_noise_utils import AnimateDiffFreeNoiseMixin
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from .pipeline_output import AnimateDiffPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from PIL import Image
+ >>> from tqdm.auto import tqdm
+
+ >>> from diffusers import AnimateDiffVideoToVideoControlNetPipeline
+ >>> from diffusers.utils import export_to_gif, load_video
+ >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler
+
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
+ ... )
+ >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
+ >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
+
+ >>> pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained(
+ ... "SG161222/Realistic_Vision_V5.1_noVAE",
+ ... motion_adapter=motion_adapter,
+ ... controlnet=controlnet,
+ ... vae=vae,
+ ... ).to(device="cuda", dtype=torch.float16)
+
+ >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
+ >>> pipe.load_lora_weights(
+ ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora"
+ ... )
+ >>> pipe.set_adapters(["lcm-lora"], [0.8])
+
+ >>> video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif"
+ ... )
+ >>> video = [frame.convert("RGB") for frame in video]
+
+ >>> from controlnet_aux.processor import OpenposeDetector
+
+ >>> open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
+ >>> for frame in tqdm(video):
+ ... conditioning_frames.append(open_pose(frame))
+
+ >>> prompt = "astronaut in space, dancing"
+ >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
+
+ >>> strength = 0.8
+ >>> with torch.inference_mode():
+ ... video = pipe(
+ ... video=video,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... num_inference_steps=10,
+ ... guidance_scale=2.0,
+ ... controlnet_conditioning_scale=0.75,
+ ... conditioning_frames=conditioning_frames,
+ ... strength=strength,
+ ... generator=torch.Generator().manual_seed(42),
+ ... ).frames[0]
+
+ >>> video = [frame.resize(conditioning_frames[0].size) for frame in video]
+ >>> export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class AnimateDiffVideoToVideoControlNetPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ FreeInitMixin,
+ AnimateDiffFreeNoiseMixin,
+):
+ r"""
+ Pipeline for video-to-video generation with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]` or `Tuple[ControlNetModel]` or `MultiControlNetModel`):
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
+ additional conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ motion_adapter: MotionAdapter,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ ):
+ super().__init__()
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.control_video_processor = VideoProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, (str, dict)):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_video
+ def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
+ latents = []
+ for i in range(0, len(video), decode_chunk_size):
+ batch_video = video[i : i + decode_chunk_size]
+ batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
+ latents.append(batch_video)
+ return torch.cat(latents)
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
+ def decode_latents(self, latents, decode_chunk_size: int = 16):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ video = []
+ for i in range(0, latents.shape[0], decode_chunk_size):
+ batch_latents = latents[i : i + decode_chunk_size]
+ batch_latents = self.vae.decode(batch_latents).sample
+ video.append(batch_latents)
+
+ video = torch.cat(video)
+ video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ height,
+ width,
+ video=None,
+ conditioning_frames=None,
+ latents=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and not isinstance(prompt, (str, list, dict)):
+ raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` should be provided")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+
+ num_frames = len(video) if latents is None else latents.shape[2]
+
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(conditioning_frames, list):
+ raise TypeError(
+ f"For single controlnet, `image` must be of type `list` but got {type(conditioning_frames)}"
+ )
+ if len(conditioning_frames) != num_frames:
+ raise ValueError(f"Excepted image to have length {num_frames} but got {len(conditioning_frames)=}")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(conditioning_frames, list) or not isinstance(conditioning_frames[0], list):
+ raise TypeError(
+ f"For multiple controlnets: `image` must be type list of lists but got {type(conditioning_frames)=}"
+ )
+ if len(conditioning_frames[0]) != num_frames:
+ raise ValueError(
+ f"Expected length of image sublist as {num_frames} but got {len(conditioning_frames)=}"
+ )
+ if any(len(img) != len(conditioning_frames[0]) for img in conditioning_frames):
+ raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ video: Optional[torch.Tensor] = None,
+ height: int = 64,
+ width: int = 64,
+ num_channels_latents: int = 4,
+ batch_size: int = 1,
+ timestep: Optional[int] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ decode_chunk_size: int = 16,
+ add_noise: bool = False,
+ ) -> torch.Tensor:
+ num_frames = video.shape[1] if latents is None else latents.shape[2]
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ video = video.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
+ for i in range(batch_size)
+ ]
+ else:
+ init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
+
+ init_latents = torch.cat(init_latents, dim=0)
+
+ # restore vae to original dtype
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ error_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
+ )
+ raise ValueError(error_message)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
+ else:
+ if shape != latents.shape:
+ # [B, C, F, H, W]
+ raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
+
+ latents = latents.to(device, dtype=dtype)
+
+ if add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep)
+
+ return latents
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet.AnimateDiffControlNetPipeline.prepare_video
+ def prepare_conditioning_frames(
+ self,
+ video,
+ width,
+ height,
+ batch_size,
+ num_videos_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ video = self.control_video_processor.preprocess_video(video, height=height, width=width).to(
+ dtype=torch.float32
+ )
+ video = video.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ video_batch_size = video.shape[0]
+
+ if video_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_videos_per_prompt
+
+ video = video.repeat_interleave(repeat_by, dim=0)
+ video = video.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ video = torch.cat([video] * 2)
+
+ return video
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ video: List[List[PipelineImageInput]] = None,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ enforce_inference_steps: bool = False,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.5,
+ strength: float = 0.8,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ conditioning_frames: Optional[List[PipelineImageInput]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ decode_chunk_size: int = 16,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ video (`List[PipelineImageInput]`):
+ The input video to condition the generation on. Must be a list of images/frames of the video.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ strength (`float`, *optional*, defaults to 0.8):
+ Higher strength leads to more differences between original video and generated video.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ conditioning_frames (`List[PipelineImageInput]`, *optional*):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If multiple
+ ControlNets are specified, images must be passed as a list such that each element of the list can be
+ correctly batched for input to a single ControlNet.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ decode_chunk_size (`int`, defaults to `16`):
+ The number of frames to decode at a time when calling `decode_latents` method.
+
+ Examples:
+
+ Returns:
+ [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ strength=strength,
+ height=height,
+ width=width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ video=video,
+ conditioning_frames=conditioning_frames,
+ latents=latents,
+ ip_adapter_image=ip_adapter_image,
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
+ control_guidance_start=control_guidance_start,
+ control_guidance_end=control_guidance_end,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, (str, dict)):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ dtype = self.dtype
+
+ # 3. Prepare timesteps
+ if not enforce_inference_steps:
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+ else:
+ denoising_inference_steps = int(num_inference_steps / strength)
+ timesteps, denoising_inference_steps = retrieve_timesteps(
+ self.scheduler, denoising_inference_steps, device, timesteps, sigmas
+ )
+ timesteps = timesteps[-num_inference_steps:]
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+
+ # 4. Prepare latent variables
+ if latents is None:
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
+ # Move the number of frames before the number of channels.
+ video = video.permute(0, 2, 1, 3, 4)
+ video = video.to(device=device, dtype=dtype)
+
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ video=video,
+ height=height,
+ width=width,
+ num_channels_latents=num_channels_latents,
+ batch_size=batch_size * num_videos_per_prompt,
+ timestep=latent_timestep,
+ dtype=dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ decode_chunk_size=decode_chunk_size,
+ add_noise=enforce_inference_steps,
+ )
+
+ # 5. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ num_frames = latents.shape[2]
+ if self.free_noise_enabled:
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
+ prompt=prompt,
+ num_frames=num_frames,
+ device=device,
+ num_videos_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ else:
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
+
+ # 6. Prepare IP-Adapter embeddings
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 7. Prepare ControlNet conditions
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ if isinstance(controlnet, ControlNetModel):
+ conditioning_frames = self.prepare_conditioning_frames(
+ video=conditioning_frames,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ cond_prepared_videos = []
+ for frame_ in conditioning_frames:
+ prepared_video = self.prepare_conditioning_frames(
+ video=frame_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ cond_prepared_videos.append(prepared_video)
+ conditioning_frames = cond_prepared_videos
+ else:
+ assert False
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
+ for free_init_iter in range(num_free_init_iters):
+ if self.free_init_enabled:
+ latents, timesteps = self._apply_free_init(
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
+ )
+ num_inference_steps = len(timesteps)
+ # make sure to readjust timesteps based on strength
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
+
+ self._num_timesteps = len(timesteps)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 10. Denoising loop
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ control_model_input = torch.transpose(control_model_input, 1, 2)
+ control_model_input = control_model_input.reshape(
+ (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
+ )
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=conditioning_frames,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 11. Post-processing
+ if output_type == "latent":
+ video = latents
+ else:
+ video_tensor = self.decode_latents(latents, decode_chunk_size)
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
+
+ # 12. Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index ff1f38d7318b..732488721598 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -152,6 +152,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class AnimateDiffVideoToVideoControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
new file mode 100644
index 000000000000..5e598e67ec11
--- /dev/null
+++ b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
@@ -0,0 +1,535 @@
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AnimateDiffVideoToVideoControlNetPipeline,
+ AutoencoderKL,
+ ControlNetModel,
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ LCMScheduler,
+ MotionAdapter,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+ UNetMotionModel,
+)
+from diffusers.models.attention import FreeNoiseTransformerBlock
+from diffusers.utils import is_xformers_available, logging
+from diffusers.utils.testing_utils import torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
+from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+class AnimateDiffVideoToVideoControlNetPipelineFastTests(
+ IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
+):
+ pipeline_class = AnimateDiffVideoToVideoControlNetPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS.union({"conditioning_frames"})
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ def get_dummy_components(self):
+ cross_attention_dim = 8
+ block_out_channels = (8, 8)
+
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=block_out_channels,
+ layers_per_block=2,
+ sample_size=8,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=2,
+ )
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="linear",
+ clip_sample=False,
+ )
+ torch.manual_seed(0)
+ controlnet = ControlNetModel(
+ block_out_channels=block_out_channels,
+ layers_per_block=2,
+ in_channels=4,
+ down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
+ cross_attention_dim=cross_attention_dim,
+ conditioning_embedding_out_channels=(8, 8),
+ norm_num_groups=1,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=block_out_channels,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ norm_num_groups=2,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=cross_attention_dim,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ torch.manual_seed(0)
+ motion_adapter = MotionAdapter(
+ block_out_channels=block_out_channels,
+ motion_layers_per_block=2,
+ motion_norm_num_groups=2,
+ motion_num_attention_heads=4,
+ )
+
+ components = {
+ "unet": unet,
+ "controlnet": controlnet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "motion_adapter": motion_adapter,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "feature_extractor": None,
+ "image_encoder": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0, num_frames: int = 2):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ video_height = 32
+ video_width = 32
+ video = [Image.new("RGB", (video_width, video_height))] * num_frames
+
+ video_height = 32
+ video_width = 32
+ conditioning_frames = [Image.new("RGB", (video_width, video_height))] * num_frames
+
+ inputs = {
+ "video": video,
+ "conditioning_frames": conditioning_frames,
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 7.5,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_from_pipe_consistent_config(self):
+ assert self.original_pipeline_class == StableDiffusionPipeline
+ original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
+ original_kwargs = {"requires_safety_checker": False}
+
+ # create original_pipeline_class(sd)
+ pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
+
+ # original_pipeline_class(sd) -> pipeline_class
+ pipe_components = self.get_dummy_components()
+ pipe_additional_components = {}
+ for name, component in pipe_components.items():
+ if name not in pipe_original.components:
+ pipe_additional_components[name] = component
+
+ pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
+
+ # pipeline_class -> original_pipeline_class(sd)
+ original_pipe_additional_components = {}
+ for name, component in pipe_original.components.items():
+ if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
+ original_pipe_additional_components[name] = component
+
+ pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
+
+ # compare the config
+ original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
+ original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
+ assert original_config_2 == original_config
+
+ def test_motion_unet_loading(self):
+ components = self.get_dummy_components()
+ pipe = AnimateDiffVideoToVideoControlNetPipeline(**components)
+
+ assert isinstance(pipe.unet, UNetMotionModel)
+
+ @unittest.skip("Attention slicing is not enabled in this pipeline")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_ip_adapter(self):
+ expected_pipe_slice = None
+ if torch_device == "cpu":
+ expected_pipe_slice = np.array(
+ [
+ 0.5569,
+ 0.6250,
+ 0.4144,
+ 0.5613,
+ 0.5563,
+ 0.5213,
+ 0.5091,
+ 0.4950,
+ 0.4950,
+ 0.5684,
+ 0.3858,
+ 0.4863,
+ 0.6457,
+ 0.4311,
+ 0.5517,
+ 0.5608,
+ 0.4417,
+ 0.5377,
+ ]
+ )
+ return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
+
+ def test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ additional_params_copy_to_batched_inputs=["num_inference_steps"],
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for components in pipe.components.values():
+ if hasattr(components, "set_default_attn_processor"):
+ components.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ if name == "prompt":
+ len_prompt = len(value)
+ batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
+ batched_inputs[name][-1] = 100 * "very long"
+
+ else:
+ batched_inputs[name] = batch_size * [value]
+
+ if "generator" in inputs:
+ batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_inputs["batch_size"] = batch_size
+
+ for arg in additional_params_copy_to_batched_inputs:
+ batched_inputs[arg] = inputs[arg]
+
+ output = pipe(**inputs)
+ output_batch = pipe(**batched_inputs)
+
+ assert output_batch[0].shape[0] == batch_size
+
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
+ assert max_diff < expected_max_diff
+
+ @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ def test_to_device(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cpu" for device in model_devices))
+
+ output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
+ self.assertTrue(np.isnan(output_cpu).sum() == 0)
+
+ pipe.to("cuda")
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cuda" for device in model_devices))
+
+ output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+ def test_prompt_embeds(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("prompt")
+ inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
+ pipe(**inputs)
+
+ def test_latent_inputs(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ sample_size = pipe.unet.config.sample_size
+ num_frames = len(inputs["conditioning_frames"])
+ inputs["latents"] = torch.randn((1, 4, num_frames, sample_size, sample_size), device=torch_device)
+ inputs.pop("video")
+ pipe(**inputs)
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_without_offload = pipe(**inputs).frames[0]
+ output_without_offload = (
+ output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
+ )
+
+ pipe.enable_xformers_memory_efficient_attention()
+ inputs = self.get_dummy_inputs(torch_device)
+ output_with_offload = pipe(**inputs).frames[0]
+ output_with_offload = (
+ output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
+ )
+
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
+
+ def test_free_init(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs_normal = self.get_dummy_inputs(torch_device)
+ frames_normal = pipe(**inputs_normal).frames[0]
+
+ pipe.enable_free_init(
+ num_iters=2,
+ use_fast_sampling=True,
+ method="butterworth",
+ order=4,
+ spatial_stop_frequency=0.25,
+ temporal_stop_frequency=0.25,
+ )
+ inputs_enable_free_init = self.get_dummy_inputs(torch_device)
+ frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
+
+ pipe.disable_free_init()
+ inputs_disable_free_init = self.get_dummy_inputs(torch_device)
+ frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
+
+ sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
+ max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
+ self.assertGreater(
+ sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
+ )
+ self.assertLess(
+ max_diff_disabled,
+ 1e-4,
+ "Disabling of FreeInit should lead to results similar to the default pipeline results",
+ )
+
+ def test_free_init_with_schedulers(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs_normal = self.get_dummy_inputs(torch_device)
+ frames_normal = pipe(**inputs_normal).frames[0]
+
+ schedulers_to_test = [
+ DPMSolverMultistepScheduler.from_config(
+ components["scheduler"].config,
+ timestep_spacing="linspace",
+ beta_schedule="linear",
+ algorithm_type="dpmsolver++",
+ steps_offset=1,
+ clip_sample=False,
+ ),
+ LCMScheduler.from_config(
+ components["scheduler"].config,
+ timestep_spacing="linspace",
+ beta_schedule="linear",
+ steps_offset=1,
+ clip_sample=False,
+ ),
+ ]
+ components.pop("scheduler")
+
+ for scheduler in schedulers_to_test:
+ components["scheduler"] = scheduler
+ pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ pipe.enable_free_init(num_iters=2, use_fast_sampling=False)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ frames_enable_free_init = pipe(**inputs).frames[0]
+ sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
+
+ self.assertGreater(
+ sum_enabled,
+ 1e1,
+ "Enabling of FreeInit should lead to results different from the default pipeline results",
+ )
+
+ def test_free_noise_blocks(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ pipe.enable_free_noise()
+ for block in pipe.unet.down_blocks:
+ for motion_module in block.motion_modules:
+ for transformer_block in motion_module.transformer_blocks:
+ self.assertTrue(
+ isinstance(transformer_block, FreeNoiseTransformerBlock),
+ "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
+ )
+
+ pipe.disable_free_noise()
+ for block in pipe.unet.down_blocks:
+ for motion_module in block.motion_modules:
+ for transformer_block in motion_module.transformer_blocks:
+ self.assertFalse(
+ isinstance(transformer_block, FreeNoiseTransformerBlock),
+ "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
+ )
+
+ def test_free_noise(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16)
+ inputs_normal["num_inference_steps"] = 2
+ inputs_normal["strength"] = 0.5
+ frames_normal = pipe(**inputs_normal).frames[0]
+
+ for context_length in [8, 9]:
+ for context_stride in [4, 6]:
+ pipe.enable_free_noise(context_length, context_stride)
+
+ inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
+ inputs_enable_free_noise["num_inference_steps"] = 2
+ inputs_enable_free_noise["strength"] = 0.5
+ frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
+
+ pipe.disable_free_noise()
+ inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
+ inputs_disable_free_noise["num_inference_steps"] = 2
+ inputs_disable_free_noise["strength"] = 0.5
+ frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
+
+ sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
+ max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
+ self.assertGreater(
+ sum_enabled,
+ 1e1,
+ "Enabling of FreeNoise should lead to results different from the default pipeline results",
+ )
+ self.assertLess(
+ max_diff_disabled,
+ 1e-4,
+ "Disabling of FreeNoise should lead to results similar to the default pipeline results",
+ )
+
+ def test_free_noise_multi_prompt(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ context_length = 8
+ context_stride = 4
+ pipe.enable_free_noise(context_length, context_stride)
+
+ # Make sure that pipeline works when prompt indices are within num_frames bounds
+ inputs = self.get_dummy_inputs(torch_device, num_frames=16)
+ inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"}
+ inputs["num_inference_steps"] = 2
+ inputs["strength"] = 0.5
+ pipe(**inputs).frames[0]
+
+ with self.assertRaises(ValueError):
+ # Ensure that prompt indices are within bounds
+ inputs = self.get_dummy_inputs(torch_device, num_frames=16)
+ inputs["num_inference_steps"] = 2
+ inputs["strength"] = 0.5
+ inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
+ pipe(**inputs).frames[0]