diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index bfd6ab973d5e..7cacad87d78c 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -29,6 +29,7 @@ The abstract of the paper is the following: | [AnimateDiffSparseControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py) | *Controlled Video-to-Video Generation with AnimateDiff using SparseCtrl* | | [AnimateDiffSDXLPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py) | *Video-to-Video Generation with AnimateDiff* | | [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* | +| [AnimateDiffVideoToVideoControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py) | *Video-to-Video Generation with AnimateDiff using ControlNet* | ## Available checkpoints @@ -518,6 +519,97 @@ Here are some sample outputs: + + +### AnimateDiffVideoToVideoControlNetPipeline + +AnimateDiff can be used together with ControlNets to enhance video-to-video generation by allowing for precise control over the output. ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala, and allows you to condition Stable Diffusion with an additional control image to ensure that the spatial information is preserved throughout the video. + +This pipeline allows you to condition your generation both on the original video and on a sequence of control images. + +```python +import torch +from PIL import Image +from tqdm.auto import tqdm + +from controlnet_aux.processor import OpenposeDetector +from diffusers import AnimateDiffVideoToVideoControlNetPipeline +from diffusers.utils import export_to_gif, load_video +from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler + +# Load the ControlNet +controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16) +# Load the motion adapter +motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM") +# Load SD 1.5 based finetuned model +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) +pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained( + "SG161222/Realistic_Vision_V5.1_noVAE", + motion_adapter=motion_adapter, + controlnet=controlnet, + vae=vae, +).to(device="cuda", dtype=torch.float16) + +# Enable LCM to speed up inference +pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") +pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora") +pipe.set_adapters(["lcm-lora"], [0.8]) + +video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif") +video = [frame.convert("RGB") for frame in video] + +prompt = "astronaut in space, dancing" +negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly" + +# Create controlnet preprocessor +open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda") + +# Preprocess controlnet images +conditioning_frames = [] +for frame in tqdm(video): + conditioning_frames.append(open_pose(frame)) + +strength = 0.8 +with torch.inference_mode(): + video = pipe( + video=video, + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=10, + guidance_scale=2.0, + controlnet_conditioning_scale=0.75, + conditioning_frames=conditioning_frames, + strength=strength, + generator=torch.Generator().manual_seed(42), + ).frames[0] + +video = [frame.resize(conditioning_frames[0].size) for frame in video] +export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8) +``` + +Here are some sample outputs: + + + + + + + + + + +
Source VideoOutput Video
+ anime girl, dancing +
+ anime girl, dancing +
+ astronaut in space, dancing +
+ astronaut in space, dancing +
+ +**The lights and composition were transferred from the Source Video.** + ## Using Motion LoRAs Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations. @@ -866,6 +958,12 @@ pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapt - all - __call__ +## AnimateDiffVideoToVideoControlNetPipeline + +[[autodoc]] AnimateDiffVideoToVideoControlNetPipeline + - all + - __call__ + ## AnimateDiffPipelineOutput [[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index af28b383b563..5b505b6a1f3a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -245,6 +245,7 @@ "AnimateDiffPipeline", "AnimateDiffSDXLPipeline", "AnimateDiffSparseControlNetPipeline", + "AnimateDiffVideoToVideoControlNetPipeline", "AnimateDiffVideoToVideoPipeline", "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", @@ -694,6 +695,7 @@ AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffSparseControlNetPipeline, + AnimateDiffVideoToVideoControlNetPipeline, AnimateDiffVideoToVideoPipeline, AudioLDM2Pipeline, AudioLDM2ProjectionModel, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ad7ea2872ac5..e4d37a905b86 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -123,6 +123,7 @@ "AnimateDiffSDXLPipeline", "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", + "AnimateDiffVideoToVideoControlNetPipeline", ] _import_structure["flux"] = [ "FluxControlNetPipeline", @@ -449,6 +450,7 @@ AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffSparseControlNetPipeline, + AnimateDiffVideoToVideoControlNetPipeline, AnimateDiffVideoToVideoPipeline, ) from .audioldm import AudioLDMPipeline diff --git a/src/diffusers/pipelines/animatediff/__init__.py b/src/diffusers/pipelines/animatediff/__init__.py index 3ee72bc44003..d916abf2d85d 100644 --- a/src/diffusers/pipelines/animatediff/__init__.py +++ b/src/diffusers/pipelines/animatediff/__init__.py @@ -26,6 +26,7 @@ _import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"] _import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"] _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"] + _import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -40,6 +41,7 @@ from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline + from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline from .pipeline_output import AnimateDiffPipelineOutput else: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py new file mode 100644 index 000000000000..1d26f95a2f58 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py @@ -0,0 +1,1341 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..controlnet.multicontrolnet import MultiControlNetModel +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from tqdm.auto import tqdm + + >>> from diffusers import AnimateDiffVideoToVideoControlNetPipeline + >>> from diffusers.utils import export_to_gif, load_video + >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16 + ... ) + >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM") + >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + + >>> pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V5.1_noVAE", + ... motion_adapter=motion_adapter, + ... controlnet=controlnet, + ... vae=vae, + ... ).to(device="cuda", dtype=torch.float16) + + >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") + >>> pipe.load_lora_weights( + ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora" + ... ) + >>> pipe.set_adapters(["lcm-lora"], [0.8]) + + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif" + ... ) + >>> video = [frame.convert("RGB") for frame in video] + + >>> from controlnet_aux.processor import OpenposeDetector + + >>> open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda") + >>> for frame in tqdm(video): + ... conditioning_frames.append(open_pose(frame)) + + >>> prompt = "astronaut in space, dancing" + >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly" + + >>> strength = 0.8 + >>> with torch.inference_mode(): + ... video = pipe( + ... video=video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=10, + ... guidance_scale=2.0, + ... controlnet_conditioning_scale=0.75, + ... conditioning_frames=conditioning_frames, + ... strength=strength, + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] + + >>> video = [frame.resize(conditioning_frames[0].size) for frame in video] + >>> export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffVideoToVideoControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, +): + r""" + Pipeline for video-to-video generation with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]` or `Tuple[ControlNetModel]` or `MultiControlNetModel`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_video_processor = VideoProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_video + def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor: + latents = [] + for i in range(0, len(video), decode_chunk_size): + batch_video = video[i : i + decode_chunk_size] + batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator) + latents.append(batch_video) + return torch.cat(latents) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + height, + width, + video=None, + conditioning_frames=None, + latents=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + num_frames = len(video) if latents is None else latents.shape[2] + + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(conditioning_frames, list): + raise TypeError( + f"For single controlnet, `image` must be of type `list` but got {type(conditioning_frames)}" + ) + if len(conditioning_frames) != num_frames: + raise ValueError(f"Excepted image to have length {num_frames} but got {len(conditioning_frames)=}") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(conditioning_frames, list) or not isinstance(conditioning_frames[0], list): + raise TypeError( + f"For multiple controlnets: `image` must be type list of lists but got {type(conditioning_frames)=}" + ) + if len(conditioning_frames[0]) != num_frames: + raise ValueError( + f"Expected length of image sublist as {num_frames} but got {len(conditioning_frames)=}" + ) + if any(len(img) != len(conditioning_frames[0]) for img in conditioning_frames): + raise ValueError("All conditioning frame batches for multicontrolnet must be same size") + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.prepare_latents + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + height: int = 64, + width: int = 64, + num_channels_latents: int = 4, + batch_size: int = 1, + timestep: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + decode_chunk_size: int = 16, + add_noise: bool = False, + ) -> torch.Tensor: + num_frames = video.shape[1] if latents is None else latents.shape[2] + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + video = video.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0) + for i in range(batch_size) + ] + else: + init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video] + + init_latents = torch.cat(init_latents, dim=0) + + # restore vae to original dtype + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + error_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Please make sure to update your script to pass as many initial images as text prompts" + ) + raise ValueError(error_message) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4) + else: + if shape != latents.shape: + # [B, C, F, H, W] + raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + + latents = latents.to(device, dtype=dtype) + + if add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(latents, noise, timestep) + + return latents + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet.AnimateDiffControlNetPipeline.prepare_video + def prepare_conditioning_frames( + self, + video, + width, + height, + batch_size, + num_videos_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + video = self.control_video_processor.preprocess_video(video, height=height, width=width).to( + dtype=torch.float32 + ) + video = video.permute(0, 2, 1, 3, 4).flatten(0, 1) + video_batch_size = video.shape[0] + + if video_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_videos_per_prompt + + video = video.repeat_interleave(repeat_by, dim=0) + video = video.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + video = torch.cat([video] * 2) + + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + video: List[List[PipelineImageInput]] = None, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + enforce_inference_steps: bool = False, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.5, + strength: float = 0.8, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + conditioning_frames: Optional[List[PipelineImageInput]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_chunk_size: int = 16, + ): + r""" + The call function to the pipeline for generation. + + Args: + video (`List[PipelineImageInput]`): + The input video to condition the generation on. Must be a list of images/frames of the video. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + strength (`float`, *optional*, defaults to 0.8): + Higher strength leads to more differences between original video and generated video. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + conditioning_frames (`List[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If multiple + ControlNets are specified, images must be passed as a list such that each element of the list can be + correctly batched for input to a single ControlNet. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + decode_chunk_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. + + Examples: + + Returns: + [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + strength=strength, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + video=video, + conditioning_frames=conditioning_frames, + latents=latents, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.dtype + + # 3. Prepare timesteps + if not enforce_inference_steps: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + else: + denoising_inference_steps = int(num_inference_steps / strength) + timesteps, denoising_inference_steps = retrieve_timesteps( + self.scheduler, denoising_inference_steps, device, timesteps, sigmas + ) + timesteps = timesteps[-num_inference_steps:] + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + + # 4. Prepare latent variables + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + # Move the number of frames before the number of channels. + video = video.permute(0, 2, 1, 3, 4) + video = video.to(device=device, dtype=dtype) + + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + video=video, + height=height, + width=width, + num_channels_latents=num_channels_latents, + batch_size=batch_size * num_videos_per_prompt, + timestep=latent_timestep, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + decode_chunk_size=decode_chunk_size, + add_noise=enforce_inference_steps, + ) + + # 5. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + num_frames = latents.shape[2] + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 6. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7. Prepare ControlNet conditions + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + if isinstance(controlnet, ControlNetModel): + conditioning_frames = self.prepare_conditioning_frames( + video=conditioning_frames, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + cond_prepared_videos = [] + for frame_ in conditioning_frames: + prepared_video = self.prepare_conditioning_frames( + video=frame_, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + cond_prepared_videos.append(prepared_video) + conditioning_frames = cond_prepared_videos + else: + assert False + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + num_inference_steps = len(timesteps) + # make sure to readjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 10. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + control_model_input = torch.transpose(control_model_input, 1, 2) + control_model_input = control_model_input.reshape( + (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4]) + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=conditioning_frames, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 11. Post-processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 12. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ff1f38d7318b..732488721598 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -152,6 +152,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AnimateDiffVideoToVideoControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py new file mode 100644 index 000000000000..5e598e67ec11 --- /dev/null +++ b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py @@ -0,0 +1,535 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import ( + AnimateDiffVideoToVideoControlNetPipeline, + AutoencoderKL, + ControlNetModel, + DDIMScheduler, + DPMSolverMultistepScheduler, + LCMScheduler, + MotionAdapter, + StableDiffusionPipeline, + UNet2DConditionModel, + UNetMotionModel, +) +from diffusers.models.attention import FreeNoiseTransformerBlock +from diffusers.utils import is_xformers_available, logging +from diffusers.utils.testing_utils import torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin + + +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + +class AnimateDiffVideoToVideoControlNetPipelineFastTests( + IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase +): + pipeline_class = AnimateDiffVideoToVideoControlNetPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS.union({"conditioning_frames"}) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + def get_dummy_components(self): + cross_attention_dim = 8 + block_out_channels = (8, 8) + + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=block_out_channels, + layers_per_block=2, + sample_size=8, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=cross_attention_dim, + norm_num_groups=2, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + clip_sample=False, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=block_out_channels, + layers_per_block=2, + in_channels=4, + down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), + cross_attention_dim=cross_attention_dim, + conditioning_embedding_out_channels=(8, 8), + norm_num_groups=1, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=block_out_channels, + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=2, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=cross_attention_dim, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + torch.manual_seed(0) + motion_adapter = MotionAdapter( + block_out_channels=block_out_channels, + motion_layers_per_block=2, + motion_norm_num_groups=2, + motion_num_attention_heads=4, + ) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "motion_adapter": motion_adapter, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def get_dummy_inputs(self, device, seed=0, num_frames: int = 2): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + video_height = 32 + video_width = 32 + video = [Image.new("RGB", (video_width, video_height))] * num_frames + + video_height = 32 + video_width = 32 + conditioning_frames = [Image.new("RGB", (video_width, video_height))] * num_frames + + inputs = { + "video": video, + "conditioning_frames": conditioning_frames, + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 7.5, + "output_type": "pt", + } + return inputs + + def test_from_pipe_consistent_config(self): + assert self.original_pipeline_class == StableDiffusionPipeline + original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe" + original_kwargs = {"requires_safety_checker": False} + + # create original_pipeline_class(sd) + pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs) + + # original_pipeline_class(sd) -> pipeline_class + pipe_components = self.get_dummy_components() + pipe_additional_components = {} + for name, component in pipe_components.items(): + if name not in pipe_original.components: + pipe_additional_components[name] = component + + pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components) + + # pipeline_class -> original_pipeline_class(sd) + original_pipe_additional_components = {} + for name, component in pipe_original.components.items(): + if name not in pipe.components or not isinstance(component, pipe.components[name].__class__): + original_pipe_additional_components[name] = component + + pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components) + + # compare the config + original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")} + original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")} + assert original_config_2 == original_config + + def test_motion_unet_loading(self): + components = self.get_dummy_components() + pipe = AnimateDiffVideoToVideoControlNetPipeline(**components) + + assert isinstance(pipe.unet, UNetMotionModel) + + @unittest.skip("Attention slicing is not enabled in this pipeline") + def test_attention_slicing_forward_pass(self): + pass + + def test_ip_adapter(self): + expected_pipe_slice = None + if torch_device == "cpu": + expected_pipe_slice = np.array( + [ + 0.5569, + 0.6250, + 0.4144, + 0.5613, + 0.5563, + 0.5213, + 0.5091, + 0.4950, + 0.4950, + 0.5684, + 0.3858, + 0.4863, + 0.6457, + 0.4311, + 0.5517, + 0.5608, + 0.4417, + 0.5377, + ] + ) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) + + def test_inference_batch_single_identical( + self, + batch_size=2, + expected_max_diff=1e-4, + additional_params_copy_to_batched_inputs=["num_inference_steps"], + ): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for components in pipe.components.values(): + if hasattr(components, "set_default_attn_processor"): + components.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is has been used in self.get_dummy_inputs + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batched_inputs.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + if name == "prompt": + len_prompt = len(value) + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + batched_inputs[name][-1] = 100 * "very long" + + else: + batched_inputs[name] = batch_size * [value] + + if "generator" in inputs: + batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_inputs["batch_size"] = batch_size + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + output = pipe(**inputs) + output_batch = pipe(**batched_inputs) + + assert output_batch[0].shape[0] == batch_size + + max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() + assert max_diff < expected_max_diff + + @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + def test_to_device(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cpu" for device in model_devices)) + + output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] + self.assertTrue(np.isnan(output_cpu).sum() == 0) + + pipe.to("cuda") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cuda" for device in model_devices)) + + output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] + self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + def test_prompt_embeds(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("prompt") + inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device) + pipe(**inputs) + + def test_latent_inputs(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + sample_size = pipe.unet.config.sample_size + num_frames = len(inputs["conditioning_frames"]) + inputs["latents"] = torch.randn((1, 4, num_frames, sample_size, sample_size), device=torch_device) + inputs.pop("video") + pipe(**inputs) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_without_offload = pipe(**inputs).frames[0] + output_without_offload = ( + output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload + ) + + pipe.enable_xformers_memory_efficient_attention() + inputs = self.get_dummy_inputs(torch_device) + output_with_offload = pipe(**inputs).frames[0] + output_with_offload = ( + output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload + ) + + max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() + self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") + + def test_free_init(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + pipe.enable_free_init( + num_iters=2, + use_fast_sampling=True, + method="butterworth", + order=4, + spatial_stop_frequency=0.25, + temporal_stop_frequency=0.25, + ) + inputs_enable_free_init = self.get_dummy_inputs(torch_device) + frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0] + + pipe.disable_free_init() + inputs_disable_free_init = self.get_dummy_inputs(torch_device) + frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max() + self.assertGreater( + sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results" + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeInit should lead to results similar to the default pipeline results", + ) + + def test_free_init_with_schedulers(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + schedulers_to_test = [ + DPMSolverMultistepScheduler.from_config( + components["scheduler"].config, + timestep_spacing="linspace", + beta_schedule="linear", + algorithm_type="dpmsolver++", + steps_offset=1, + clip_sample=False, + ), + LCMScheduler.from_config( + components["scheduler"].config, + timestep_spacing="linspace", + beta_schedule="linear", + steps_offset=1, + clip_sample=False, + ), + ] + components.pop("scheduler") + + for scheduler in schedulers_to_test: + components["scheduler"] = scheduler + pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_init(num_iters=2, use_fast_sampling=False) + + inputs = self.get_dummy_inputs(torch_device) + frames_enable_free_init = pipe(**inputs).frames[0] + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum() + + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeInit should lead to results different from the default pipeline results", + ) + + def test_free_noise_blocks(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertTrue( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + pipe.disable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertFalse( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.", + ) + + def test_free_noise(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_normal["num_inference_steps"] = 2 + inputs_normal["strength"] = 0.5 + frames_normal = pipe(**inputs_normal).frames[0] + + for context_length in [8, 9]: + for context_stride in [4, 6]: + pipe.enable_free_noise(context_length, context_stride) + + inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_enable_free_noise["num_inference_steps"] = 2 + inputs_enable_free_noise["strength"] = 0.5 + frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] + + pipe.disable_free_noise() + inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_disable_free_noise["num_inference_steps"] = 2 + inputs_disable_free_noise["strength"] = 0.5 + frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max() + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeNoise should lead to results different from the default pipeline results", + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeNoise should lead to results similar to the default pipeline results", + ) + + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that pipeline works when prompt indices are within num_frames bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + inputs["num_inference_steps"] = 2 + inputs["strength"] = 0.5 + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["num_inference_steps"] = 2 + inputs["strength"] = 0.5 + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0]