From acd6d2c42f0fa4fade262e8814279748a544b0ce Mon Sep 17 00:00:00 2001 From: sanaka <50254737+HorizonWind2004@users.noreply.github.com> Date: Wed, 9 Oct 2024 05:25:48 +0800 Subject: [PATCH] Fix the bug that `joint_attention_kwargs` is not passed to the FLUX's transformer attention processors (#9517) * Update transformer_flux.py --- src/diffusers/models/transformers/transformer_flux.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index e38efe668c6c..6238ab8044bb 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -83,14 +83,16 @@ def forward( hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, + joint_attention_kwargs=None, ): residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - + joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) @@ -161,18 +163,20 @@ def forward( encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, + joint_attention_kwargs=None, ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - + joint_attention_kwargs = joint_attention_kwargs or {} # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, ) # Process attention outputs for the `hidden_states`. @@ -497,6 +501,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual @@ -533,6 +538,7 @@ def custom_forward(*inputs): hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual