Skip to content

Commit

Permalink
[Core] is_cosxl_edit arg in SDXL ip2p. (#7650)
Browse files Browse the repository at this point in the history
* is_cosxl_edit arg in SDXL ip2p.

* Empty-Commit

Co-authored-by: Yiyi Xu <yixu310@gmail.com>

* doc

* remove redundant logic.

* reflect drhuv's comments.

---------

Co-authored-by: Yiyi Xu <yixu310@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
  • Loading branch information
3 people authored Apr 16, 2024
1 parent fda1531 commit 9d50f7e
Showing 1 changed file with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used.
is_cosxl_edit (`bool`, *optional*):
When set the image latents are scaled.
"""

model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
Expand All @@ -185,6 +187,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
is_cosxl_edit: Optional[bool] = False,
):
super().__init__()

Expand All @@ -201,6 +204,7 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.is_cosxl_edit = is_cosxl_edit

add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

Expand Down Expand Up @@ -551,6 +555,9 @@ def prepare_image_latents(
if image_latents.dtype != self.vae.dtype:
image_latents = image_latents.to(dtype=self.vae.dtype)

if self.is_cosxl_edit:
image_latents = image_latents * self.vae.config.scaling_factor

return image_latents

# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
Expand Down

0 comments on commit 9d50f7e

Please sign in to comment.