Skip to content

Commit

Permalink
Cast height, width to int inside prepare latents (#7691)
Browse files Browse the repository at this point in the history
update
  • Loading branch information
DN6 authored Apr 19, 2024
1 parent e567401 commit 90250d9
Show file tree
Hide file tree
Showing 63 changed files with 378 additions and 68 deletions.
7 changes: 6 additions & 1 deletion examples/community/composable_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,12 @@ def check_inputs(self, prompt, height, width, callback_steps):
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
Expand Down
7 changes: 6 additions & 1 deletion examples/community/gluegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,12 @@ def check_inputs(
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/instaflow_one_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,12 @@ def check_inputs(
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/ip_adapter_face_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,12 @@ def check_inputs(
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/latent_consistency_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,12 @@ def prepare_latents(
latents=None,
generator=None,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)

if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
Expand Down
7 changes: 6 additions & 1 deletion examples/community/latent_consistency_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,12 @@ def run_safety_checker(self, image, device, dtype):

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/latent_consistency_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def run_safety_checker(self, image, device, dtype):
return image, has_nsfw_concept

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if latents is None:
latents = torch.randn(shape, dtype=dtype).to(device)
else:
Expand Down
7 changes: 6 additions & 1 deletion examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,12 @@ def prepare_latents(
):
if image is None:
batch_size = batch_size * num_images_per_prompt
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
14 changes: 12 additions & 2 deletions examples/community/lpw_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,12 @@ def prepare_latents(
batch_size *= num_images_per_prompt

if image is None:
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down Expand Up @@ -1140,7 +1145,12 @@ def prepare_latents(
return latents

else:
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/pipeline_demofusion_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,12 @@ def check_inputs(

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
14 changes: 12 additions & 2 deletions examples/community/pipeline_sdxl_style_aligned.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,12 @@ def prepare_latents(
batch_size *= num_images_per_prompt

if image is None:
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down Expand Up @@ -999,7 +1004,12 @@ def prepare_latents(
return latents

else:
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/pipeline_stable_diffusion_pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,12 @@ def check_inputs(
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,12 @@ def check_conditions(

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/pipeline_stable_diffusion_xl_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,12 @@ def check_inputs(

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/pipeline_zero1to3.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,12 @@ def check_inputs(self, image, height, width, callback_steps):
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/stable_diffusion_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,12 @@ def check_inputs(
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/stable_diffusion_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,12 @@ def check_inputs(
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/community/stable_diffusion_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,12 @@ def prepare_latents(
Returns:
torch.Tensor: The prepared latent vectors.
"""
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,12 @@ def prepare_image(

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
7 changes: 6 additions & 1 deletion examples/research_projects/rdm/pipeline_rdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ def _encode_image(self, retrieved_images, batch_size):
return image_embeddings

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/audioldm/pipeline_audioldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
self.vocoder.config.model_in_dim // self.vae_scale_factor,
int(height) // self.vae_scale_factor,
int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,8 +790,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
self.vocoder.config.model_in_dim // self.vae_scale_factor,
int(height) // self.vae_scale_factor,
int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,12 @@ def prepare_image(

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,12 @@ def prepare_latents(
return_noise=False,
return_image_latents=False,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,12 @@ def prepare_latents(
return_noise=False,
return_image_latents=False,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,12 @@ def prepare_image(

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down
Loading

0 comments on commit 90250d9

Please sign in to comment.