Skip to content

Commit

Permalink
[Core] fix img2img pipeline for Playground (#7627)
Browse files Browse the repository at this point in the history
* playground vae encoding should use std and mean of the vae.

* style.

* fix-copies.
  • Loading branch information
sayakpaul authored Apr 11, 2024
1 parent aa1f00f commit 33c5d12
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,12 @@ def prepare_latents(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)

latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)

# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
Expand Down Expand Up @@ -935,7 +941,12 @@ def prepare_latents(
self.vae.to(dtype)

init_latents = init_latents.to(dtype)
init_latents = self.vae.config.scaling_factor * init_latents
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
latents_std = latents_std.to(device=self.device, dtype=dtype)
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
init_latents = self.vae.config.scaling_factor * init_latents

if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,12 @@ def prepare_latents(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)

latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)

# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
Expand Down Expand Up @@ -702,7 +708,12 @@ def prepare_latents(
self.vae.to(dtype)

init_latents = init_latents.to(dtype)
init_latents = self.vae.config.scaling_factor * init_latents
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
latents_std = latents_std.to(device=self.device, dtype=dtype)
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
init_latents = self.vae.config.scaling_factor * init_latents

if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
Expand Down

0 comments on commit 33c5d12

Please sign in to comment.