Skip to content

Commit

Permalink
improve: support fp16 in stable diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Aug 29, 2022
1 parent 4992292 commit e0bdd38
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions perceptor/models/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@


class Model(torch.nn.Module):
def __init__(self, name="CompVis/stable-diffusion-v1-4", auth_token=True):
# TODO: fp16
def __init__(
self, name="CompVis/stable-diffusion-v1-4", fp16=False, auth_token=True
):
"""
Args:
name: The name of the model. Available models are:
Expand All @@ -28,8 +29,17 @@ def __init__(self, name="CompVis/stable-diffusion-v1-4", auth_token=True):
scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)

pipeline = StableDiffusionPipeline.from_pretrained(
name, scheduler=scheduler, use_auth_token=auth_token
name,
scheduler=scheduler,
use_auth_token=auth_token,
**dict(
revision="fp16",
torch_dtype=torch.float16,
)
if fp16
else dict(),
)

self.vae = pipeline.vae
Expand Down Expand Up @@ -134,15 +144,17 @@ def finetuneable_vae(self):
self.vae.load_state_dict(state_dict)
self.vae.requires_grad_(False)

@torch.cuda.amp.autocast()
def latents(
self, images: lantern.Tensor.dims("NCHW").float()
) -> lantern.Tensor.dims("NCHW"):
return self.encode(images)
return self.encode(images).float()

@torch.cuda.amp.autocast()
def images(
self, latents: lantern.Tensor.dims("NCHW").float()
) -> lantern.Tensor.dims("NCHW"):
return self.decode(latents)
return self.decode(latents).float()

def random_diffused_latents(self, shape):
n, c, h, w = shape
Expand Down

0 comments on commit e0bdd38

Please sign in to comment.