diff --git a/perceptor/models/guided_diffusion/create_models.py b/perceptor/models/guided_diffusion/create_models.py new file mode 100644 index 0000000..d90ac21 --- /dev/null +++ b/perceptor/models/guided_diffusion/create_models.py @@ -0,0 +1,119 @@ +from .unet import UNetModel +from .script_util import ( + create_model_and_diffusion, + model_and_diffusion_defaults, +) + + +def create_openimages_model(): + model_config = model_and_diffusion_defaults() + model_config.update( + { + "attention_resolutions": "32, 16, 8", + "class_cond": False, + # 'diffusion_steps': 1000, #No need to edit this, it is taken care of later. + # 'rescale_timesteps': True, + # 'timestep_respacing': 250, #No need to edit this, it is taken care of later. + "image_size": 512, + "learn_sigma": True, + "noise_schedule": "linear", + "num_channels": 256, + "num_head_channels": 64, + "num_res_blocks": 2, + "resblock_updown": True, + "use_checkpoint": True, + "use_fp16": True, + "use_scale_shift_norm": True, + } + ) + + model, diffusion = create_model_and_diffusion(**model_config) + if model_config["use_fp16"]: + model.convert_to_fp16() + return model, diffusion + + +def create_pixelart_model(): + model_config = model_and_diffusion_defaults() + model_config.update( + dict( + image_size=256, + learn_sigma=True, + num_channels=128, + num_res_blocks=2, + num_heads=1, + num_heads_upsample=-1, + num_head_channels=-1, + attention_resolutions="16", + channel_mult="", + dropout=0.0, + class_cond=False, + use_checkpoint=False, + use_scale_shift_norm=False, + resblock_updown=False, + use_fp16=True, + use_new_attention_order=False, + ) + ) + + model, diffusion = create_model_and_diffusion(**model_config) + if model_config["use_fp16"]: + model.convert_to_fp16() + return model, diffusion + + +def create_model( + image_size, + num_channels, + num_res_blocks, + channel_mult="", + learn_sigma=False, + class_cond=False, + use_checkpoint=False, + attention_resolutions="16", + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + dropout=0, + resblock_updown=False, + use_fp16=False, + use_new_attention_order=False, +): + if channel_mult == "": + if image_size == 512: + channel_mult = (0.5, 1, 1, 2, 2, 4, 4) + elif image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return UNetModel( + image_size=image_size, + in_channels=3, + model_channels=num_channels, + out_channels=(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=None, + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + ) diff --git a/perceptor/models/guided_diffusion/guided_diffusion.py b/perceptor/models/guided_diffusion/guided_diffusion.py index 36f6461..a57783c 100644 --- a/perceptor/models/guided_diffusion/guided_diffusion.py +++ b/perceptor/models/guided_diffusion/guided_diffusion.py @@ -1,40 +1,28 @@ +from typing import Optional +from contextlib import contextmanager +import copy import torch -from torch import nn +import lantern +from transformers import CLIPTokenizer, CLIPTextModel, logging from basicsr.utils.download_util import load_file_from_url from perceptor.utils import cache from . import diffusion_space -from .unet import UNetModel -from .script_util import ( - create_model_and_diffusion, - model_and_diffusion_defaults, -) +from .create_models import create_openimages_model, create_pixelart_model +from .predictions import Predictions -@cache -class GuidedDiffusion(nn.Module): +class Model(torch.nn.Module): def __init__(self, name="standard"): """ Args: - name: The name of the model. - - Usage: - - diffusion = models.GuidedDiffusion("pixelart").to(device) - - diffused_image = torch.randn((1, 3, 256, 256)) - - for from_index, to_index in model.schedule_indices(): - eps = diffusion.eps(diffused_image, from_index) - denoised_image = diffusion.denoise(diffused_image, from_index, eps) - diffused_image = diffusion.step(diffused_image, eps, from_index, to_index) - denoised_image = diffusion.denoise(diffused_image, to_index) + name: The name of the model. Available models are "standard" and "pixelart" """ super().__init__() self.name = name if name == "standard": - self.model, self.diffusion = create_openimages_model() + self.model, self.scheduler = create_openimages_model() checkpoint_path = load_file_from_url( "https://huggingface.co/lowlevelware/512x512_diffusion_unconditional_ImageNet/resolve/main/512x512_diffusion_uncond_finetune_008100.pt", # alternative: "https://set.zlkj.in/models/diffusion/512x512_diffusion_uncond_openimages_epoch28_withfilter.pt", @@ -42,7 +30,7 @@ def __init__(self, name="standard"): ) self.shape = (3, 512, 512) elif name == "pixelart": - self.model, self.diffusion = create_pixelart_model() + self.model, self.scheduler = create_pixelart_model() checkpoint_path = load_file_from_url( "https://huggingface.co/KaliYuga/PADexpanded/resolve/main/PADexpanded.pt", "models", @@ -54,254 +42,204 @@ def __init__(self, name="standard"): self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) self.model.requires_grad_(False).eval() + self.schedule_alphas = torch.nn.Parameter( + torch.from_numpy(self.scheduler.alphas_cumprod).sqrt().float(), + requires_grad=False, + ) + + self.schedule_sigmas = torch.nn.Parameter( + (1 - torch.from_numpy(self.scheduler.alphas_cumprod)).sqrt().float(), + requires_grad=False, + ) + @property def device(self): return next(iter(self.parameters())).device - def schedule_indices(self, from_index=999, to_index=20, n_steps=None): + def schedule_indices(self, n_steps=500, from_index=999, to_index=0, rho=7.0): if from_index < to_index: raise ValueError("from_index must be greater than to_index") - if n_steps is None: - n_steps = (from_index - to_index) // 2 - schedule_indices = torch.linspace(from_index, to_index, n_steps).long() - from_indices = schedule_indices[:-1] - to_indices = schedule_indices[1:] - if (from_indices == to_indices).any(): - raise ValueError("Schedule indices must be unique") - return torch.stack([from_indices, to_indices], dim=1) - def random_diffused(self, shape): - return diffusion_space.decode(torch.randn(shape)).to(self.device) + from_alpha, from_sigma = self.alphas(from_index), self.sigmas(from_index) + to_alpha, to_sigma = self.alphas(to_index), self.sigmas(to_index) - def forward(self, diffused, from_index): - return self.denoise(diffused, from_index) + from_log_snr = torch.log(from_alpha**2 / from_sigma**2) + to_log_snr = torch.log(to_alpha**2 / to_sigma**2) - def denoise(self, diffused, from_index, eps=None): - x = diffusion_space.encode(diffused) - if isinstance(from_index, int) or from_index.ndim == 0: - from_index = torch.full((x.shape[0],), from_index).to(x.device) + elucidated_from_sigma = (1 / from_log_snr.exp()).sqrt().clamp(max=150) + elucidated_to_sigma = (1 / to_log_snr.exp()).sqrt().clamp(min=1e-3) - if eps is None: - eps = self.eps(diffused, from_index) + ramp = torch.linspace(0, 1, n_steps + 1).to(self.device) + min_inv_rho = elucidated_to_sigma ** (1 / rho) + max_inv_rho = elucidated_from_sigma ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + target_log_snr = torch.log(torch.ones_like(sigmas) ** 2 / sigmas**2) - return diffusion_space.decode( - self.diffusion._predict_xstart_from_eps(x, from_index, eps) + schedule_log_snr = torch.log( + self.schedule_alphas**2 / self.schedule_sigmas**2 ) - def diffuse(self, images, to_index, noise=None): - x0 = diffusion_space.encode(images) - if isinstance(to_index, int) or to_index.ndim == 0: - to_index = torch.full((x0.shape[0],), to_index).to(x0.device) - if noise is None: - noise = torch.randn_like(x0) - assert noise.shape == x0.shape - return diffusion_space.decode(self.diffusion.q_sample(x0, to_index, noise)) + assert target_log_snr.squeeze().ndim == 1 + assert schedule_log_snr.squeeze().ndim == 1 + schedule_indices = ( + (target_log_snr.squeeze()[:, None] - schedule_log_snr.squeeze()[None, :]) + .abs() + .argmin(dim=1) + .unique() + .sort(descending=True)[0] + ) - def eps(self, diffused, from_index): - x = diffusion_space.encode(diffused) - if isinstance(from_index, int) or from_index.ndim == 0: - from_index = torch.full((x.shape[0],), from_index).to(x.device) - return self.model(x, from_index)[:, :3] + assert len(schedule_indices) >= n_steps * 0.9 - def guided_eps(self, eps, grad, from_index, guidance_scale=1000): - """ - Guided eps by differentiating through the diffusion model. + assert (schedule_indices[:-1] != schedule_indices[1:]).all() + return torch.stack([schedule_indices[:-1], schedule_indices[1:]], dim=1) - Usage: - - diffused_image.requires_grad_(True) - with torch.enable_grad(): - eps = diffusion.eps(diffused_image, from_index) - denoised_image = diffusion.denoise(diffused_image, from_index, eps) + def random_diffused(self, shape): + n, c, h, w = shape + if h % 8 != 0: + raise ValueError("Height must be divisible by 32") + if w % 8 != 0: + raise ValueError("Width must be divisible by 32") + return diffusion_space.decode(torch.randn(shape).to(self.device)) + + def indices(self, indices): + if isinstance(indices, float) or isinstance(indices, int): + indices = torch.as_tensor(indices) + if indices.ndim == 0: + indices = indices[None] + if indices.ndim != 1: + raise ValueError("indices must be a scalar or a 1-dimensional tensor") + return indices.long().to(self.device) + + def alphas(self, indices): + return self.schedule_alphas[self.indices(indices)][:, None, None, None].to( + self.device + ) - augmentations_ = torch.cat([augmentations(denoised_image) for _ in range(4)]) - torch.stack( - [ - (text_loss(augmentations_) / len(text_losses)) - for text_loss in text_losses - ] - ).mean().backward() + def sigmas(self, indices): + return self.schedule_sigmas[self.indices(indices)][:, None, None, None].to( + self.device + ) - eps = diffusion.guided_eps(eps, diffused_image.grad, from_index) - """ - return ( - eps + guidance_scale * (1 - self.alphas_cumprod(from_index)).sqrt() * grad + @torch.cuda.amp.autocast() + def predicted_noise( + self, + diffused_images, + from_indices, + ): + return self.model( + diffusion_space.encode(diffused_images), self.indices(from_indices) + )[:, :3].float() + + def predictions(self, diffused_images, indices): + indices = self.indices(indices) + return Predictions( + from_diffused_images=diffused_images, + from_indices=indices, + predicted_noise=self.predicted_noise(diffused_images, indices), + schedule_alphas=self.schedule_alphas, + schedule_sigmas=self.schedule_sigmas, ) - def ts(self, index): - return torch.tensor([index], device=self.device, dtype=torch.float32) + def forward(self, diffused_images, indices): + return self.predictions(diffused_images, indices) - def alphas_cumprod(self, index): - return ( - torch.as_tensor(self.diffusion.alphas_cumprod[index]) - .float() - .to(self.device)[None, None, None, None] + def diffuse_images(self, denoised_images, indices, noise=None): + indices = self.indices(indices) + if noise is None: + noise = torch.randn_like(denoised_images) + alphas, sigmas = self.alphas(indices), self.sigmas(indices) + return diffusion_space.decode( + diffusion_space.encode(denoised_images) * alphas + noise * sigmas ) - def sqrt_one_minus_alphas_cumprod(self, index): - return ( - torch.as_tensor(self.diffusion.sqrt_one_minus_alphas_cumprod[index]) - .float() - .to(self.device)[None, None, None, None] - ) - def step(self, from_diffused, eps, from_index, to_index, noise=None, eta=1.0): - if to_index > from_index: - raise ValueError("to_index must be smaller than from_index") - if noise is None: - noise = torch.randn_like(eps) +GuidedDiffusion: Model = cache(Model) - pred = diffusion_space.encode(self.denoise(from_diffused, from_index, eps)) - from_alphas_cumprod = self.alphas_cumprod(from_index) - to_alphas_cumprod = self.alphas_cumprod(to_index) +def test_guided_diffusion(): + from tqdm import tqdm + from perceptor import utils - to_sigmas = eta * torch.sqrt( - (1 - to_alphas_cumprod) - / (1 - from_alphas_cumprod) - * (1 - from_alphas_cumprod / to_alphas_cumprod) - ) + torch.set_grad_enabled(False) + device = torch.device("cuda") - dir_xt = (1.0 - to_alphas_cumprod - to_sigmas**2).sqrt() * eps - to_x = to_alphas_cumprod.sqrt() * pred + dir_xt + to_sigmas * noise - return diffusion_space.decode(to_x) - - -def create_openimages_model(): - model_config = model_and_diffusion_defaults() - model_config.update( - { - "attention_resolutions": "32, 16, 8", - "class_cond": False, - # 'diffusion_steps': 1000, #No need to edit this, it is taken care of later. - # 'rescale_timesteps': True, - # 'timestep_respacing': 250, #No need to edit this, it is taken care of later. - "image_size": 512, - "learn_sigma": True, - "noise_schedule": "linear", - "num_channels": 256, - "num_head_channels": 64, - "num_res_blocks": 2, - "resblock_updown": True, - "use_checkpoint": True, - "use_fp16": True, - "use_scale_shift_norm": True, - } + diffusion_model = GuidedDiffusion().to(device) + diffused_images = diffusion_model.random_diffused((1, 3, 512, 512)).to(device) + + progress_bar = tqdm( + diffusion_model.schedule_indices(to_index=0, n_steps=50, rho=3.0) ) + for from_indices, to_indices in progress_bar: + step_predictions = diffusion_model.predictions( + diffused_images, + from_indices, + ) + diffused_images = step_predictions.step(to_indices) - model, diffusion = create_model_and_diffusion(**model_config) - if model_config["use_fp16"]: - model.convert_to_fp16() - return model, diffusion - - -def create_pixelart_model(): - model_config = model_and_diffusion_defaults() - model_config.update( - dict( - image_size=256, - learn_sigma=True, - num_channels=128, - num_res_blocks=2, - num_heads=1, - num_heads_upsample=-1, - num_head_channels=-1, - attention_resolutions="16", - channel_mult="", - dropout=0.0, - class_cond=False, - use_checkpoint=False, - use_scale_shift_norm=False, - resblock_updown=False, - use_fp16=True, - use_new_attention_order=False, + utils.pil_image(step_predictions.denoised_images.clamp(0, 1)).save( + "tests/guided_diffusion.png" + ) + + progress_bar.set_postfix( + dict( + from_indices=from_indices.item(), + to_indices=to_indices.item(), + ) ) + + predictions = diffusion_model.predictions( + diffused_images, + to_indices, ) - model, diffusion = create_model_and_diffusion(**model_config) - if model_config["use_fp16"]: - model.convert_to_fp16() - return model, diffusion - - -def create_model( - image_size, - num_channels, - num_res_blocks, - channel_mult="", - learn_sigma=False, - class_cond=False, - use_checkpoint=False, - attention_resolutions="16", - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - dropout=0, - resblock_updown=False, - use_fp16=False, - use_new_attention_order=False, -): - if channel_mult == "": - if image_size == 512: - channel_mult = (0.5, 1, 1, 2, 2, 4, 4) - elif image_size == 256: - channel_mult = (1, 1, 2, 2, 4, 4) - elif image_size == 128: - channel_mult = (1, 1, 2, 3, 4) - elif image_size == 64: - channel_mult = (1, 2, 3, 4) - else: - raise ValueError(f"unsupported image size: {image_size}") - else: - channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) - - attention_ds = [] - for res in attention_resolutions.split(","): - attention_ds.append(image_size // int(res)) - - return UNetModel( - image_size=image_size, - in_channels=3, - model_channels=num_channels, - out_channels=(3 if not learn_sigma else 6), - num_res_blocks=num_res_blocks, - attention_resolutions=tuple(attention_ds), - dropout=dropout, - channel_mult=channel_mult, - num_classes=None, - use_checkpoint=use_checkpoint, - use_fp16=use_fp16, - num_heads=num_heads, - num_head_channels=num_head_channels, - num_heads_upsample=num_heads_upsample, - use_scale_shift_norm=use_scale_shift_norm, - resblock_updown=resblock_updown, - use_new_attention_order=use_new_attention_order, + utils.pil_image(predictions.denoised_images.clamp(0, 1)).save( + "tests/guided_diffusion.png" ) -def test_pixelart_diffusion(): +def test_guided_diffusion_init_image(): + import requests + from PIL import Image + import torch + import torchvision.transforms.functional as TF + from tqdm import tqdm from perceptor import utils - model = GuidedDiffusion("pixelart").cuda() - diffused = model.random_diffused((1, 3, 256, 256)) + image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + init_image = TF.to_tensor( + Image.open(requests.get(image_url, stream=True).raw).resize((512, 512)) + )[None].cuda() - for from_index, to_index in model.schedule_indices(n_steps=50): - eps = model.eps(diffused, from_index) - diffused = model.step(diffused, eps, from_index, to_index) - denoised = model.denoise(diffused, to_index) - utils.pil_image(denoised).save("tests/pixelart.png") + torch.set_grad_enabled(False) + device = torch.device("cuda") + from_index = 400 -def test_guided_diffusion_diffusion(): - from perceptor import utils + diffusion_model = GuidedDiffusion().to(device) + diffused_images = diffusion_model.diffuse_images(init_image, from_index) - model = GuidedDiffusion("standard").cuda() - diffused = model.random_diffused((1, 3, 512, 512)) + for from_indices, to_indices in tqdm( + diffusion_model.schedule_indices(from_index=from_index, to_index=0, n_steps=50) + ): + for _ in range(4): + predictions = diffusion_model.predictions( + diffused_images, + from_indices, + ) + diffused_images = predictions.resample(to_indices) - for from_index, to_index in model.schedule_indices(n_steps=50): - eps = model.eps(diffused, from_index) - diffused = model.step(diffused, eps, from_index, to_index) - denoised = model.denoise(diffused, to_index) - utils.pil_image(denoised).save("tests/guided_diffusion.png") + predictions = diffusion_model.predictions( + diffused_images, + from_indices, + ) + diffused_images = predictions.step(to_indices) + + predictions = diffusion_model.predictions( + diffused_images, + to_indices, + ) + + utils.pil_image(predictions.denoised_images.clamp(0, 1)).save( + "tests/guided_diffusion_init_image.png" + ) diff --git a/perceptor/models/guided_diffusion/predictions.py b/perceptor/models/guided_diffusion/predictions.py new file mode 100644 index 0000000..b5b6ca6 --- /dev/null +++ b/perceptor/models/guided_diffusion/predictions.py @@ -0,0 +1,200 @@ +import torch +import lantern + +from perceptor.transforms.clamp_with_grad import clamp_with_grad +from . import diffusion_space + + +class Predictions(lantern.FunctionalBase): + from_diffused_images: lantern.Tensor.dims("NCHW") + from_indices: lantern.Tensor.dims("N") + predicted_noise: lantern.Tensor.dims("NCHW") + schedule_alphas: lantern.Tensor.dims("N") + schedule_sigmas: lantern.Tensor.dims("N") + + @property + def device(self): + return self.predicted_noise.device + + def indices(self, indices): + if isinstance(indices, float) or isinstance(indices, int): + indices = torch.as_tensor(indices) + if indices.ndim == 0: + indices = indices[None] + if indices.ndim != 1: + raise ValueError("indices must be a scalar or a 1D tensor") + return indices.long().to(self.device) + + def alphas(self, indices): + return self.schedule_alphas[self.indices(indices)][:, None, None, None].to( + self.device + ) + + def sigmas(self, indices): + return self.schedule_sigmas[self.indices(indices)][:, None, None, None].to( + self.device + ) + + @property + def from_alphas(self): + return self.alphas(self.from_indices) + + @property + def from_sigmas(self): + return self.sigmas(self.from_indices) + + @property + def from_diffused_xs(self): + return diffusion_space.encode(self.from_diffused_images) + + @property + def denoised_xs(self): + return ( + self.from_diffused_xs - self.from_sigmas * self.predicted_noise + ) / self.from_alphas.clamp(min=1e-7) + + @property + def denoised_images(self): + return diffusion_space.decode(self.denoised_xs) + + def step(self, to_indices, eta=0.0): + """ + Reduce noise level to `to_indices` + + Args: + to_indices: Union[Tensor, Tensor.shape("N"), float] + eta: float + + Returns: + diffused_images: torch.Tensor.shape("NCHW") + """ + to_alphas, to_sigmas = self.alphas(to_indices), self.sigmas(to_indices) + + if eta > 0.0: + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + ddim_sigma = ( + eta + * (to_sigmas**2 / self.from_sigmas**2).sqrt() + * (1 - self.from_alphas**2 / to_alphas**2).sqrt() + ) + adjusted_sigma = (to_sigmas**2 - ddim_sigma**2).sqrt() + + # Recombine the predicted noise and predicted denoised image in the + # correct proportions for the next step + to_diffused_xs = ( + self.denoised_xs * to_alphas + self.predicted_noise * adjusted_sigma + ) + + # Add the correct amount of fresh noise + noise = torch.randn_like(to_diffused_xs) + to_diffused_xs += noise * ddim_sigma + else: + to_diffused_xs = ( + self.denoised_xs * to_alphas + self.predicted_noise * to_sigmas + ) + + return diffusion_space.decode(to_diffused_xs) + # TODO: do not need to calculate denoised images? this could introduce errors? + + def correction(self, previous: "Predictions") -> "Predictions": + return previous.forced_denoised( + (self.denoised_images + previous.denoised_images) / 2 + ) + + def reverse_step(self, to_indices): + if (torch.as_tensor(self.from_indices) > torch.as_tensor(to_indices)).any(): + raise ValueError("from_indices must be less than to_indices") + + to_alphas, to_sigmas = self.alphas(to_indices), self.sigmas(to_indices) + return diffusion_space.decode( + self.denoised_xs * to_alphas + self.predicted_noise * to_sigmas + ) + # TODO: do not need to calculate denoised images? this could introduce errors? + + def resample(self, resample_indices): + """ + Harmonizing resampling from https://github.com/andreas128/RePaint + """ + return diffusion_space.decode( + self.denoised_xs * self.from_alphas + + self.resample_noise(resample_indices) * self.from_sigmas + ) + + def resample_noise(self, resample_indices): + if ( + torch.as_tensor(self.from_indices) < torch.as_tensor(resample_indices) + ).any(): + raise ValueError("from_indices must be greater than resample_indices") + resampled_noise_sigma = ( + self.sigmas(resample_indices) * self.predicted_noise + + ( + self.from_sigmas**2 - self.sigmas(resample_indices) ** 2 + ).sqrt() * torch.randn_like(self.predicted_noise) + ) # fmt: skip + return resampled_noise_sigma / self.from_sigmas + + def noisy_reverse_step(self, to_indices): + to_alphas, to_sigmas = self.alphas(to_indices), self.sigmas(to_indices) + + noise_sigma = self.from_sigmas * self.predicted_noise + ( + to_sigmas**2 - self.from_sigmas**2 + ).sqrt() * torch.randn_like(self.predicted_noise) + + return diffusion_space.decode(self.denoised_xs * to_alphas + noise_sigma) + + def guided(self, guiding, guidance_scale=0.5, clamp_value=1e-6) -> "Predictions": + return self.replace( + predicted_noise=self.predicted_noise + + guidance_scale + * self.from_sigmas + * guiding.clamp(-clamp_value, clamp_value) + / clamp_value + ) + + def dynamic_threshold(self, quantile=0.95) -> "Predictions": + """ + Thresholding heuristic from imagen paper + """ + dynamic_threshold = torch.quantile( + self.denoised_xs.flatten(start_dim=1).abs(), quantile, dim=1 + ).clamp(min=1.0) + denoised_xs = ( + clamp_with_grad( + self.denoised_xs, + -dynamic_threshold, + dynamic_threshold, + ) + # / dynamic_threshold + # imagen's dynamic thresholding divides by threshold but this makes the images gray + ) + return self.forced_denoised_images(diffusion_space.decode(denoised_xs)) + + def forced_denoised_images(self, denoised_images) -> "Predictions": + denoised_xs = diffusion_space.encode(denoised_images) + if (self.from_sigmas >= 1e-3).all(): + predicted_noise = ( + self.from_diffused_xs - denoised_xs * self.from_alphas + ) / self.from_sigmas + else: + predicted_noise = self.predicted_noise + return self.replace(predicted_noise=predicted_noise) + + def forced_predicted_noise(self, predicted_noise) -> "Predictions": + return self.replace(predicted_noise=predicted_noise) + + def wasserstein_distance(self): + sorted_noise = self.predicted_noise.flatten(start_dim=1).sort(dim=1)[0] + n = sorted_noise.shape[1] + margin = 0.5 / n + points = torch.linspace(margin, 1 - margin, sorted_noise.shape[1]) + expected_noise = torch.distributions.Normal(0, 1).icdf(points) + return (sorted_noise - expected_noise[None].to(sorted_noise)).abs().mean() + + def wasserstein_square_distance(self): + sorted_noise = self.predicted_noise.flatten(start_dim=1).sort(dim=1)[0] + n = sorted_noise.shape[1] + margin = 0.5 / n + points = torch.linspace(margin, 1 - margin, sorted_noise.shape[1]) + expected_noise = torch.distributions.Normal(0, 1).icdf(points) + return (sorted_noise - expected_noise[None].to(sorted_noise)).square().mean() diff --git a/perceptor/models/stable_diffusion/stable_diffusion.py b/perceptor/models/stable_diffusion/stable_diffusion.py index 74a8d57..9d07f80 100644 --- a/perceptor/models/stable_diffusion/stable_diffusion.py +++ b/perceptor/models/stable_diffusion/stable_diffusion.py @@ -211,8 +211,8 @@ def forward(self, diffused_latents, indices, conditioning=None): schedule_sigmas=self.schedule_sigmas, ) - def predictions(self, diffused_latents, ts, conditioning=None): - return self.forward(diffused_latents, ts, conditioning) + def predictions(self, diffused_latents, indices, conditioning=None): + return self.forward(diffused_latents, indices, conditioning) def conditioning(self, texts=None, images=None, encodings=None): """ diff --git a/perceptor/models/vgg/get_features.py b/perceptor/models/vgg/get_features.py new file mode 100644 index 0000000..61b83f9 --- /dev/null +++ b/perceptor/models/vgg/get_features.py @@ -0,0 +1,20 @@ +def get_features(image, model, layers=None): + + if layers is None: + layers = { + "0": "conv1_1", + "5": "conv2_1", + "10": "conv3_1", + "19": "conv4_1", + "21": "conv4_2", + "28": "conv5_1", + "31": "conv5_2", + } + features = {} + x = image + for name, layer in model._modules.items(): + x = layer(x) + if name in layers: + features[layers[name]] = x + + return features diff --git a/perceptor/models/vgg/load_image2.py b/perceptor/models/vgg/load_image2.py new file mode 100644 index 0000000..0ac4618 --- /dev/null +++ b/perceptor/models/vgg/load_image2.py @@ -0,0 +1,17 @@ +def load_image2(img_path, img_height=None, img_width=None): + + image = Image.open(img_path) + if img_width is not None: + image = image.resize( + (img_width, img_height) + ) # change image size to (3, img_size, img_size) + + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + + image = transform(image)[:3, :, :].unsqueeze(0) + + return image diff --git a/perceptor/models/vgg/normalize.py b/perceptor/models/vgg/normalize.py new file mode 100644 index 0000000..71d8a77 --- /dev/null +++ b/perceptor/models/vgg/normalize.py @@ -0,0 +1,11 @@ +import torch + + +def normalize(image): + mean = torch.tensor([0.485, 0.456, 0.406]).to(image.device) + std = torch.tensor([0.229, 0.224, 0.225]).to(image.device) + mean = mean.view(1, -1, 1, 1) + std = std.view(1, -1, 1, 1) + + image = (image - mean) / std + return image diff --git a/perceptor/models/vgg/vgg.py b/perceptor/models/vgg/vgg.py new file mode 100644 index 0000000..3dde554 --- /dev/null +++ b/perceptor/models/vgg/vgg.py @@ -0,0 +1,33 @@ +import torch.nn as nn +from torchvision import models + + +class VGG19(nn.Module): + def __init__(self, name="squeeze", linear_layers=True, spatial=False): + super().__init__() + self.model = models.vgg19(pretrained=True).features + self.model.eval() + self.model.requires_grad_(False) + + def forward(self, images): + """ + Args: + images_a: images of shape (batch_size, 3, height, width) between 0 and 1 + images_b: images of shape (batch_size, 3, height, width) between 0 and 1 + """ + + _, _, height, width = images.shape + if width % 8 != 0: + raise ValueError("Width must be divisible by 8") + if height % 8 != 0: + raise ValueError("Height must be divisible by 8") + + return self.model(images) + + +def test_vgg19(): + import torch + + model = VGG19() + images_a = torch.randn((1, 3, 256, 256)) + model(images_a) diff --git a/pyproject.toml b/pyproject.toml index 2304f64..c523487 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "perceptor" -version = "0.5.15" +version = "0.6.0" description = "Modular image generation library" authors = ["Richard Löwenström ", "dribnet"] readme = "README.md"