Skip to content

Commit

Permalink
[core] Support VideoToVideo with CogVideoX (#9333)
Browse files Browse the repository at this point in the history
* add vid2vid pipeline for cogvideox

* make fix-copies

* update docs

* fake context parallel cache, vae encode tiling

* add test for cog vid2vid

* use video link from HF docs repo

* add copied from comments; correctly rename test class
  • Loading branch information
a-r-r-o-w authored Sep 2, 2024
1 parent af6c0fb commit 0e6a840
Show file tree
Hide file tree
Showing 9 changed files with 1,190 additions and 20 deletions.
8 changes: 7 additions & 1 deletion docs/source/en/api/pipelines/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc
- all
- __call__

## CogVideoXVideoToVideoPipeline

[[autodoc]] CogVideoXVideoToVideoPipeline
- all
- __call__

## CogVideoXPipelineOutput

[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput
[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CycleDiffusionPipeline",
"FluxControlNetPipeline",
"FluxPipeline",
Expand Down Expand Up @@ -699,6 +700,7 @@
AuraFlowPipeline,
CLIPImageProjection,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CycleDiffusionPipeline,
FluxControlNetPipeline,
FluxPipeline,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["cogvideo"] = ["CogVideoXPipeline"]
_import_structure["cogvideo"] = ["CogVideoXPipeline", "CogVideoXVideoToVideoPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
Expand Down Expand Up @@ -454,7 +454,7 @@
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXPipeline
from .cogvideo import CogVideoXPipeline, CogVideoXVideoToVideoPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/cogvideo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -33,6 +34,7 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cogvideox import CogVideoXPipeline
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline

else:
import sys
Expand Down
19 changes: 2 additions & 17 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import inspect
import math
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -26,9 +25,10 @@
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -136,21 +136,6 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


@dataclass
class CogVideoXPipelineOutput(BaseOutput):
r"""
Output class for CogVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""

frames: torch.Tensor


class CogVideoXPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using CogVideoX.
Expand Down
Loading

0 comments on commit 0e6a840

Please sign in to comment.