Skip to content

Commit

Permalink
[Pipeline] AnimateDiff SDXL (#6721)
Browse files Browse the repository at this point in the history
* update conversion script to handle motion adapter sdxl checkpoint

* add animatediff xl

* handle addition_embed_type

* fix output

* update

* add imports

* make fix-copies

* add decode latents

* update docstrings

* add animatediff sdxl to docs

* remove unnecessary lines

* update example

* add test

* revert conv_in conv_out kernel param

* remove unused param addition_embed_type_num_heads

* latest IPAdapter impl

* make fix-copies

* fix return

* add IPAdapterTesterMixin to tests

* fix return

* revert based on suggestion

* add freeinit

* fix test_to_dtype test

* use StableDiffusionMixin instead of different helper methods

* fix progress bar iterations

* apply suggestions from review

* hardcode flip_sin_to_cos and freq_shift

* make fix-copies

* fix ip adapter implementation

* fix last failing test

* make style

* Update docs/source/en/api/pipelines/animatediff.md

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* remove todo

* fix doc-builder errors

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
  • Loading branch information
a-r-r-o-w and DN6 authored May 8, 2024
1 parent f29b934 commit 818f760
Show file tree
Hide file tree
Showing 10 changed files with 1,740 additions and 9 deletions.
53 changes: 53 additions & 0 deletions docs/source/en/api/pipelines/animatediff.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,53 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you

</Tip>

### AnimateDiffSDXLPipeline

AnimateDiff can also be used with SDXL models. This is currently an experimental feature as only a beta release of the motion adapter checkpoint is available.

```python
import torch
from diffusers.models import MotionAdapter
from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler
from diffusers.utils import export_to_gif

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16)

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
clip_sample=False,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
)
pipe = AnimateDiffSDXLPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")

# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

output = pipe(
prompt="a panda surfing in the ocean, realistic, high quality",
negative_prompt="low quality, worst quality",
num_inference_steps=20,
guidance_scale=8,
width=1024,
height=1024,
num_frames=16,
)

frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```

### AnimateDiffVideoToVideoPipeline

AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
Expand Down Expand Up @@ -522,6 +569,12 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
- all
- __call__

## AnimateDiffSDXLPipeline

[[autodoc]] AnimateDiffSDXLPipeline
- all
- __call__

## AnimateDiffVideoToVideoPipeline

[[autodoc]] AnimateDiffVideoToVideoPipeline
Expand Down
7 changes: 5 additions & 2 deletions scripts/convert_animatediff_motion_module_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--use_motion_mid_block", action="store_true")
parser.add_argument("--motion_max_seq_length", type=int, default=32)
parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
parser.add_argument("--save_fp16", action="store_true")

return parser.parse_args()
Expand All @@ -49,11 +50,13 @@ def get_args():

conv_state_dict = convert_motion_module(state_dict)
adapter = MotionAdapter(
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
block_out_channels=args.block_out_channels,
use_motion_mid_block=args.use_motion_mid_block,
motion_max_seq_length=args.motion_max_seq_length,
)
# skip loading position embeddings
adapter.load_state_dict(conv_state_dict, strict=False)
adapter.save_pretrained(args.output_path)

if args.save_fp16:
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")
adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffVideoToVideoPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
Expand Down Expand Up @@ -595,6 +596,7 @@
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffVideoToVideoPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/unets/unet_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def get_down_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
return CrossAttnDownBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
Expand Down Expand Up @@ -255,6 +256,7 @@ def get_up_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
return CrossAttnUpBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
Expand Down
74 changes: 68 additions & 6 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,18 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
use_linear_projection: bool = False,
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
use_motion_mid_block: int = True,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
projection_class_embeddings_input_dim: Optional[int] = None,
time_cond_proj_dim: Optional[int] = None,
):
super().__init__()
Expand All @@ -240,6 +245,21 @@ def __init__(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)

if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)

if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)

if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")

# input
conv_in_kernel = 3
conv_out_kernel = 3
Expand All @@ -260,13 +280,26 @@ def __init__(
if encoder_hid_dim_type is None:
self.encoder_hid_proj = None

if addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

# class embedding
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])

if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)

if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)

if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
Expand All @@ -276,21 +309,22 @@ def __init__(

down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[i],
)
self.down_blocks.append(down_block)

Expand All @@ -302,13 +336,14 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[-1],
)

else:
Expand All @@ -318,11 +353,12 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
transformer_layers_per_block=transformer_layers_per_block[-1],
)

# count how many layers upsample the images
Expand All @@ -331,6 +367,9 @@ def __init__(
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))

output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
Expand All @@ -349,7 +388,7 @@ def __init__(

up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
Expand All @@ -358,13 +397,14 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
resolution_idx=i,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down Expand Up @@ -835,6 +875,28 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)

emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None

if self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)

text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)

emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffVideoToVideoPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
Expand Down Expand Up @@ -367,7 +368,7 @@
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline
from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import (
AudioLDM2Pipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/animatediff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -33,6 +34,7 @@

else:
from .pipeline_animatediff import AnimateDiffPipeline
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down
Loading

0 comments on commit 818f760

Please sign in to comment.