Skip to content

Commit

Permalink
add PAG support for SD Img2Img (#9463)
Browse files Browse the repository at this point in the history
* added pag to sd img2img pipeline


---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
  • Loading branch information
SahilCarterr and yiyixuxu authored Oct 9, 2024
1 parent 31058cd commit af28ae2
Show file tree
Hide file tree
Showing 8 changed files with 1,401 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- all
- __call__

## StableDiffusionPAGImg2ImgPipeline
[[autodoc]] StableDiffusionPAGImg2ImgPipeline
- all
- __call__

## StableDiffusionControlNetPAGPipeline
[[autodoc]] StableDiffusionControlNetPAGPipeline

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@
"StableDiffusionLatentUpscalePipeline",
"StableDiffusionLDM3DPipeline",
"StableDiffusionModelEditingPipeline",
"StableDiffusionPAGImg2ImgPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPanoramaPipeline",
"StableDiffusionParadigmsPipeline",
Expand Down Expand Up @@ -795,6 +796,7 @@
StableDiffusionLatentUpscalePipeline,
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPAGImg2ImgPipeline,
StableDiffusionPAGPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
"HunyuanDiTPAGPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPAGImg2ImgPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPAGInpaintPipeline",
Expand Down Expand Up @@ -569,6 +570,7 @@
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGImg2ImgPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionXLControlNetPAGPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGImg2ImgPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionXLControlNetPAGPipeline,
Expand Down Expand Up @@ -131,6 +132,7 @@
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
("kandinsky3", Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
Expand All @@ -54,6 +55,7 @@
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
Expand Down
1,091 changes: 1,091 additions & 0 deletions src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_torch_and_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusionPAGImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusionPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

Expand Down
282 changes: 282 additions & 0 deletions tests/pipelines/pag/test_pag_sd_img2img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import inspect
import random
import unittest

import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

from diffusers import (
AutoencoderKL,
AutoencoderTiny,
AutoPipelineForImage2Image,
EulerDiscreteScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionPAGImg2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
require_torch_gpu,
slow,
torch_device,
)

from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)


enable_full_determinism()


class StableDiffusionPAGImg2ImgPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionPAGImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS

def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components

def get_dummy_tiny_autoencoder(self):
return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)

def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"pag_scale": 0.9,
"output_type": "np",
}
return inputs

def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

# base pipeline (expect same output when pag is disabled)
pipe_sd = StableDiffusionImg2ImgPipeline(**components)
pipe_sd = pipe_sd.to(device)
pipe_sd.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
assert (
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]

# pag disabled with pag_scale=0.0
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 0.0
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]

# pag enabled
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]

assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3

def test_pag_inference(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = pipe_pag(**inputs).images
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (
1,
32,
32,
3,
), f"the shape of the output image should be (1, 32, 32, 3) but got {image.shape}"

expected_slice = np.array(
[0.44203848, 0.49598145, 0.42248967, 0.6707724, 0.5683791, 0.43603387, 0.58316565, 0.60077155, 0.5174199]
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)


@slow
@require_torch_gpu
class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusionPAGImg2ImgPipeline
repo_id = "Jiali/stable-diffusion-1.5"

def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()

def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"image": init_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"pag_scale": 3.0,
"output_type": "np",
}
return inputs

def test_pag_cfg(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=None)

inputs = self.get_inputs(torch_device)
image = pipeline(**inputs).images

image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
print(image_slice.flatten())
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"

def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=None)

inputs = self.get_inputs(torch_device, guidance_scale=0.0)
image = pipeline(**inputs).images

image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
print(image_slice.flatten())
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"

0 comments on commit af28ae2

Please sign in to comment.