Skip to content

Commit

Permalink
Modification on the PAG community pipeline (re) (#7876)
Browse files Browse the repository at this point in the history
* edited_pag_implementation

* update

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
  • Loading branch information
HyoungwonCho and yiyixuxu authored May 8, 2024
1 parent 8edaf3b commit c221714
Showing 1 changed file with 42 additions and 53 deletions.
95 changes: 42 additions & 53 deletions examples/community/pipeline_stable_diffusion_pag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Implementation of StableDiffusionPAGPipeline
# Implementation of StableDiffusionPipeline with PAG
# https://ku-cvlab.github.io/Perturbed-Attention-Guidance

import inspect
from typing import Any, Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -134,8 +135,8 @@ def __call__(

value = attn.to_v(hidden_states_ptb)

hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
# hidden_states_ptb = value
# hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
hidden_states_ptb = value

hidden_states_ptb = hidden_states_ptb.to(query.dtype)

Expand Down Expand Up @@ -1045,7 +1046,7 @@ def pag_scale(self):
return self._pag_scale

@property
def do_adversarial_guidance(self):
def do_perturbed_attention_guidance(self):
return self._pag_scale > 0

@property
Expand All @@ -1056,14 +1057,6 @@ def pag_adaptive_scaling(self):
def do_pag_adaptive_scaling(self):
return self._pag_adaptive_scaling > 0

@property
def pag_drop_rate(self):
return self._pag_drop_rate

@property
def pag_applied_layers(self):
return self._pag_applied_layers

@property
def pag_applied_layers_index(self):
return self._pag_applied_layers_index
Expand All @@ -1080,8 +1073,6 @@ def __call__(
guidance_scale: float = 7.5,
pag_scale: float = 0.0,
pag_adaptive_scaling: float = 0.0,
pag_drop_rate: float = 0.5,
pag_applied_layers: List[str] = ["down"], # ['down', 'mid', 'up']
pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -1221,8 +1212,6 @@ def __call__(

self._pag_scale = pag_scale
self._pag_adaptive_scaling = pag_adaptive_scaling
self._pag_drop_rate = pag_drop_rate
self._pag_applied_layers = pag_applied_layers
self._pag_applied_layers_index = pag_applied_layers_index

# 2. Define call parameters
Expand Down Expand Up @@ -1257,13 +1246,13 @@ def __call__(
# to avoid doing two forward passes

# cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
# both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
Expand Down Expand Up @@ -1306,7 +1295,7 @@ def __call__(
).to(device=device, dtype=latents.dtype)

# 7. Denoising loop
if self.do_adversarial_guidance:
if self.do_perturbed_attention_guidance:
down_layers = []
mid_layers = []
up_layers = []
Expand All @@ -1322,6 +1311,29 @@ def __call__(
else:
raise ValueError(f"Invalid layer type: {layer_type}")

# change attention layer in UNet if use PAG
if self.do_perturbed_attention_guidance:
if self.do_classifier_free_guidance:
replace_processor = PAGCFGIdentitySelfAttnProcessor()
else:
replace_processor = PAGIdentitySelfAttnProcessor()

drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
if drop_layer[0] == "d":
down_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "m":
mid_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "u":
up_layers[int(drop_layer[1])].processor = replace_processor
else:
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
except IndexError:
raise ValueError(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)

num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -1330,41 +1342,18 @@ def __call__(
continue

# cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 2)
# pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 2)
# both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 3)
# no
else:
latent_model_input = latents

# change attention layer in UNet if use PAG
if self.do_adversarial_guidance:
if self.do_classifier_free_guidance:
replace_processor = PAGCFGIdentitySelfAttnProcessor()
else:
replace_processor = PAGIdentitySelfAttnProcessor()

drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
if drop_layer[0] == "d":
down_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "m":
mid_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "u":
up_layers[int(drop_layer[1])].processor = replace_processor
else:
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
except IndexError:
raise ValueError(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
Expand All @@ -1381,14 +1370,14 @@ def __call__(
# perform guidance

# cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

delta = noise_pred_text - noise_pred_uncond
noise_pred = noise_pred_uncond + self.guidance_scale * delta

# pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)

signal_scale = self.pag_scale
Expand All @@ -1400,7 +1389,7 @@ def __call__(
noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)

# both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)

signal_scale = self.pag_scale
Expand Down Expand Up @@ -1458,11 +1447,8 @@ def __call__(
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image, has_nsfw_concept)

# change attention layer in UNet if use PAG
if self.do_adversarial_guidance:
if self.do_perturbed_attention_guidance:
drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
Expand All @@ -1479,4 +1465,7 @@ def __call__(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)

if not return_dict:
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

0 comments on commit c221714

Please sign in to comment.