diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 5e9863ab0d0a..db8f4fd17297 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -342,15 +342,58 @@ def __init__( embed_dim: int = 1920, text_embed_dim: int = 4096, bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_positional_embeddings: bool = True, ) -> None: super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) self.text_proj = nn.Linear(text_embed_dim, embed_dim) + if use_positional_embeddings: + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=False) + + def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.embed_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + ) + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + joint_pos_embedding = torch.zeros( + 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) + + return joint_pos_embedding + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): r""" Args: @@ -371,6 +414,21 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): embeds = torch.cat( [text_embeds, image_embeds], dim=1 ).contiguous() # [batch, seq_length + num_frames x height x width, channels] + + if self.use_positional_embeddings: + pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + if ( + self.sample_height != height + or self.sample_width != width + or self.sample_frames != pre_time_compression_frames + ): + pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) + pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + else: + pos_embedding = self.pos_embedding + + embeds = embeds + pos_embedding + return embeds diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index c8d4b1896346..b6ba407104d5 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -23,7 +23,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -239,33 +239,29 @@ def __init__( super().__init__() inner_dim = num_attention_heads * attention_head_dim - post_patch_height = sample_height // patch_size - post_patch_width = sample_width // patch_size - post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 - self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames - # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) - self.embedding_dropout = nn.Dropout(dropout) - - # 2. 3D positional embeddings - spatial_pos_embedding = get_3d_sincos_pos_embed( - inner_dim, - (post_patch_width, post_patch_height), - post_time_compression_frames, - spatial_interpolation_scale, - temporal_interpolation_scale, + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, ) - spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) - pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) - pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) - self.register_buffer("pos_embedding", pos_embedding, persistent=False) + self.embedding_dropout = nn.Dropout(dropout) - # 3. Time embeddings + # 2. Time embeddings self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - # 4. Define spatio-temporal transformers blocks + # 3. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( [ CogVideoXBlock( @@ -284,7 +280,7 @@ def __init__( ) self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) - # 5. Output blocks + # 4. Output blocks self.norm_out = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=2 * inner_dim, @@ -422,20 +418,13 @@ def forward( # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) - # 3. Position embedding text_seq_length = encoder_hidden_states.shape[1] - if not self.config.use_rotary_positional_embeddings: - seq_length = height * width * num_frames // (self.config.patch_size**2) - - pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] - hidden_states = hidden_states + pos_embeds - hidden_states = self.embedding_dropout(hidden_states) - encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] - # 4. Transformer blocks + # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: @@ -471,11 +460,11 @@ def custom_forward(*inputs): hidden_states = self.norm_final(hidden_states) hidden_states = hidden_states[:, text_seq_length:] - # 5. Final block + # 4. Final block hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.proj_out(hidden_states) - # 6. Unpatchify + # 5. Unpatchify p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)