Skip to content

Commit

Permalink
[refactor] move positional embeddings to patch embed layer for CogVid…
Browse files Browse the repository at this point in the history
…eoX (#9263)

* remove frame limit in cogvideox

* remove debug prints

* Update src/diffusers/models/transformers/cogvideox_transformer_3d.py

* revert pipeline; remove frame limitation

* revert transformer changes

* address review comments

* add error message

* apply suggestions from review
  • Loading branch information
a-r-r-o-w authored Sep 3, 2024
1 parent 81da2e1 commit 9d49b45
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 34 deletions.
58 changes: 58 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,58 @@ def __init__(
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_positional_embeddings: bool = True,
) -> None:
super().__init__()

self.patch_size = patch_size
self.embed_dim = embed_dim
self.sample_height = sample_height
self.sample_width = sample_width
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings

self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)

if use_positional_embeddings:
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)

def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames

pos_embedding = get_3d_sincos_pos_embed(
self.embed_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)

return joint_pos_embedding

def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
Expand All @@ -371,6 +414,21 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]

if self.use_positional_embeddings:
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if (
self.sample_height != height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
else:
pos_embedding = self.pos_embedding

embeds = embeds + pos_embedding

return embeds


Expand Down
57 changes: 23 additions & 34 deletions src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
Expand Down Expand Up @@ -239,33 +239,29 @@ def __init__(
super().__init__()
inner_dim = num_attention_heads * attention_head_dim

post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames

# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)

# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
bias=True,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
self.embedding_dropout = nn.Dropout(dropout)

# 3. Time embeddings
# 2. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)

# 4. Define spatio-temporal transformers blocks
# 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
Expand All @@ -284,7 +280,7 @@ def __init__(
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)

# 5. Output blocks
# 4. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
Expand Down Expand Up @@ -422,20 +418,13 @@ def forward(

# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)

# 3. Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)

pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)

encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]

# 4. Transformer blocks
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:

Expand Down Expand Up @@ -471,11 +460,11 @@ def custom_forward(*inputs):
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]

# 5. Final block
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)

# 6. Unpatchify
# 5. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
Expand Down

0 comments on commit 9d49b45

Please sign in to comment.