Skip to content

Commit

Permalink
Fix the bug that joint_attention_kwargs is not passed to the FLUX's…
Browse files Browse the repository at this point in the history
… transformer attention processors (#9517)

* Update transformer_flux.py
  • Loading branch information
HorizonWind2004 authored Oct 8, 2024
1 parent 86bd991 commit acd6d2c
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ def forward(
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
):
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))

joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)

hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
Expand Down Expand Up @@ -161,18 +163,20 @@ def forward(
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)

joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)

# Process attention outputs for the `hidden_states`.
Expand Down Expand Up @@ -497,6 +501,7 @@ def custom_forward(*inputs):
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

# controlnet residual
Expand Down Expand Up @@ -533,6 +538,7 @@ def custom_forward(*inputs):
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

# controlnet residual
Expand Down

0 comments on commit acd6d2c

Please sign in to comment.