Skip to content

Commit

Permalink
Fixed noise_pred_text referenced before assignment. (#9537)
Browse files Browse the repository at this point in the history
* Fixed local variable noise_pred_text referenced before assignment when using PAG with guidance scale and guidance rescale at the same time.

* Fixed style.

* Made returning text pred noise an argument.
  • Loading branch information
LagPixelLOL authored Oct 8, 2024
1 parent 02eeb8e commit 86bd991
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 12 deletions.
10 changes: 8 additions & 2 deletions src/diffusers/pipelines/pag/pag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def _get_pag_scale(self, t):
else:
return self.pag_scale

def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
def _apply_perturbed_attention_guidance(
self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
):
r"""
Apply perturbed attention guidance to the noise prediction.
Expand All @@ -107,9 +109,11 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
guidance_scale (float): The scale factor for the guidance term.
t (int): The current time step.
return_pred_text (bool): Whether to return the text noise prediction.
Returns:
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
perturbed attention guidance and the text noise prediction.
"""
pag_scale = self._get_pag_scale(t)
if do_classifier_free_guidance:
Expand All @@ -122,6 +126,8 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui
else:
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
if return_pred_text:
return noise_pred, noise_pred_text
return noise_pred

def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,8 +893,8 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pag/pipeline_pag_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,8 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)

elif self.do_classifier_free_guidance:
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,8 +1237,8 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)

elif self.do_classifier_free_guidance:
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,8 +1437,8 @@ def denoising_value_valid(dnv):

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,8 +1649,8 @@ def denoising_value_valid(dnv):

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
Expand Down

0 comments on commit 86bd991

Please sign in to comment.