Skip to content

Commit 086cc5e

Browse files
authored
Image2image - python (#115)
* Add Encoder model to torch2coreml for image2image and later for in-paining * diagonal test with randn * Revert "diagonal test with randn" This reverts commit 270afe1. * readme updates for encoder * pr comments
1 parent 6cd5c7a commit 086cc5e

File tree

2 files changed

+182
-2
lines changed

2 files changed

+182
-2
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ Both of these products require the Core ML models and tokenization resources to
220220
- `vocab.json` (tokenizer vocabulary file)
221221
- `merges.text` (merges for byte pair encoding file)
222222

223+
Optionally, for image2image, in-painting, or similar:
224+
225+
- `VAEEncoder.mlmodelc` (image encoder model)
226+
223227
Optionally, it may also include the safety checker model that some versions of Stable Diffusion include:
224228

225229
- `SafetyChecker.mlmodelc`
@@ -321,6 +325,7 @@ Differences may be less or more pronounced for different inputs. Please see the
321325
<b> A3: </b> In order to minimize the memory impact of the model conversion process, please execute the following command instead:
322326

323327
```bash
328+
python -m python_coreml_stable_diffusion.torch2coreml --convert-vae-encoder -o <output-mlpackages-directory> && \
324329
python -m python_coreml_stable_diffusion.torch2coreml --convert-vae-decoder -o <output-mlpackages-directory> && \
325330
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet -o <output-mlpackages-directory> && \
326331
python -m python_coreml_stable_diffusion.torch2coreml --convert-text-encoder -o <output-mlpackages-directory> && \

python_coreml_stable_diffusion/torch2coreml.py

+177-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ def _get_coreml_inputs(sample_inputs, args):
4343
) for k, v in sample_inputs.items()
4444
]
4545

46+
# Simpler version of `DiagonalGaussianDistribution` with only needed calculations
47+
# as implemented in vae.py as part of the AutoencoderKL class
48+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L312
49+
# coremltools-6.1 does not yet implement the randn operation with the option of setting a random seed
50+
class CoreMLDiagonalGaussianDistribution(object):
51+
def __init__(self, parameters, noise):
52+
self.parameters = parameters
53+
self.noise = noise
54+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
55+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
56+
self.std = torch.exp(0.5 * self.logvar)
57+
58+
def sample(self) -> torch.FloatTensor:
59+
x = self.mean + self.std * self.noise
60+
return x
4661

4762
def compute_psnr(a, b):
4863
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
@@ -140,7 +155,7 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
140155

141156
def quantize_weights_to_8bits(args):
142157
for model_name in [
143-
"text_encoder", "vae_decoder", "unet", "unet_chunk1",
158+
"text_encoder", "vae_decoder", "vae_encoder", "unet", "unet_chunk1",
144159
"unet_chunk2", "safety_checker"
145160
]:
146161
out_path = _get_out_path(args, model_name)
@@ -190,6 +205,7 @@ def bundle_resources_for_swift_cli(args):
190205
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
191206
for source_name, target_name in [("text_encoder", "TextEncoder"),
192207
("vae_decoder", "VAEDecoder"),
208+
("vae_encoder", "VAEEncoder"),
193209
("unet", "Unet"),
194210
("unet_chunk1", "UnetChunk1"),
195211
("unet_chunk2", "UnetChunk2"),
@@ -453,6 +469,159 @@ def forward(self, z):
453469
gc.collect()
454470

455471

472+
def convert_vae_encoder(pipe, args):
473+
""" Converts the VAE Encoder component of Stable Diffusion
474+
"""
475+
out_path = _get_out_path(args, "vae_encoder")
476+
if os.path.exists(out_path):
477+
logger.info(
478+
f"`vae_encoder` already exists at {out_path}, skipping conversion."
479+
)
480+
return
481+
482+
if not hasattr(pipe, "unet"):
483+
raise RuntimeError(
484+
"convert_unet() deletes pipe.unet to save RAM. "
485+
"Please use convert_vae_encoder() before convert_unet()")
486+
487+
sample_shape = (
488+
1, # B
489+
3, # C (RGB range from -1 to 1)
490+
(args.latent_h or pipe.unet.config.sample_size) * 8, # H
491+
(args.latent_w or pipe.unet.config.sample_size) * 8, # w
492+
)
493+
494+
noise_shape = (
495+
1, # B
496+
4, # C
497+
pipe.unet.config.sample_size, # H
498+
pipe.unet.config.sample_size, # w
499+
)
500+
501+
float_value_shape = (
502+
1,
503+
1,
504+
)
505+
506+
sqrt_alphas_cumprod_torch_shape = torch.tensor([[0.2,]])
507+
sqrt_one_minus_alphas_cumprod_torch_shape = torch.tensor([[0.8,]])
508+
509+
sample_vae_encoder_inputs = {
510+
"sample": torch.rand(*sample_shape, dtype=torch.float16),
511+
"diagonal_noise": torch.rand(*noise_shape, dtype=torch.float16),
512+
"noise": torch.rand(*noise_shape, dtype=torch.float16),
513+
"sqrt_alphas_cumprod": torch.rand(*float_value_shape, dtype=torch.float16),
514+
"sqrt_one_minus_alphas_cumprod": torch.rand(*float_value_shape, dtype=torch.float16),
515+
}
516+
517+
class VAEEncoder(nn.Module):
518+
""" Wrapper nn.Module wrapper for pipe.encode() method
519+
"""
520+
521+
def __init__(self):
522+
super().__init__()
523+
self.quant_conv = pipe.vae.quant_conv
524+
self.alphas_cumprod = pipe.scheduler.alphas_cumprod
525+
self.encoder = pipe.vae.encoder
526+
527+
# Because CoreMLTools does not support the torch.randn op, we pass in both
528+
# the diagonal Noise for the `DiagonalGaussianDistribution` operation and
529+
# the noise tensor combined with precalculated `sqrt_alphas_cumprod` and `sqrt_one_minus_alphas_cumprod`
530+
# for faster computation.
531+
def forward(self, sample, diagonal_noise, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
532+
h = self.encoder(sample)
533+
moments = self.quant_conv(h)
534+
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise)
535+
posteriorSample = posterior.sample()
536+
537+
# Add the scaling operation and the latent noise for faster computation
538+
init_latents = 0.18215 * posteriorSample
539+
result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
540+
return result
541+
542+
def add_noise(
543+
self,
544+
original_samples: torch.FloatTensor,
545+
noise: torch.FloatTensor,
546+
sqrt_alphas_cumprod: torch.FloatTensor,
547+
sqrt_one_minus_alphas_cumprod: torch.FloatTensor
548+
) -> torch.FloatTensor:
549+
noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
550+
return noisy_samples
551+
552+
553+
baseline_encoder = VAEEncoder().eval()
554+
555+
# No optimization needed for the VAE Encoder as it is a pure ConvNet
556+
traced_vae_encoder = torch.jit.trace(
557+
baseline_encoder, (
558+
sample_vae_encoder_inputs["sample"].to(torch.float32),
559+
sample_vae_encoder_inputs["diagonal_noise"].to(torch.float32),
560+
sample_vae_encoder_inputs["noise"].to(torch.float32),
561+
sqrt_alphas_cumprod_torch_shape.to(torch.float32),
562+
sqrt_one_minus_alphas_cumprod_torch_shape.to(torch.float32)
563+
))
564+
565+
modify_coremltools_torch_frontend_badbmm()
566+
coreml_vae_encoder, out_path = _convert_to_coreml(
567+
"vae_encoder", traced_vae_encoder, sample_vae_encoder_inputs,
568+
["latent_dist"], args)
569+
570+
# Set model metadata
571+
coreml_vae_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
572+
coreml_vae_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
573+
coreml_vae_encoder.version = args.model_version
574+
coreml_vae_encoder.short_description = \
575+
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
576+
"Please refer to https://arxiv.org/abs/2112.10752 for details."
577+
578+
# Set the input descriptions
579+
coreml_vae_encoder.input_description["sample"] = \
580+
"An image of the correct size to create the latent space with, image2image and in-painting."
581+
coreml_vae_encoder.input_description["diagonal_noise"] = \
582+
"Latent noise for `DiagonalGaussianDistribution` operation."
583+
coreml_vae_encoder.input_description["noise"] = \
584+
"Latent noise for use with strength parameter of image2image"
585+
coreml_vae_encoder.input_description["sqrt_alphas_cumprod"] = \
586+
"Precalculated `sqrt_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
587+
coreml_vae_encoder.input_description["sqrt_one_minus_alphas_cumprod"] = \
588+
"Precalculated `sqrt_one_minus_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
589+
590+
# Set the output descriptions
591+
coreml_vae_encoder.output_description[
592+
"latent_dist"] = "The latent embeddings from the unet model from the input image."
593+
594+
_save_mlpackage(coreml_vae_encoder, out_path)
595+
596+
logger.info(f"Saved vae_encoder into {out_path}")
597+
598+
# Parity check PyTorch vs CoreML
599+
if args.check_output_correctness:
600+
baseline_out = baseline_encoder(
601+
sample=sample_vae_encoder_inputs["sample"].to(torch.float32),
602+
diagonal_noise=sample_vae_encoder_inputs["diagonal_noise"].to(torch.float32),
603+
noise=sample_vae_encoder_inputs["noise"].to(torch.float32),
604+
sqrt_alphas_cumprod=sqrt_alphas_cumprod_torch_shape,
605+
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod_torch_shape,
606+
).numpy(),
607+
608+
coreml_out = list(
609+
coreml_vae_encoder.predict(
610+
{
611+
"sample": sample_vae_encoder_inputs["sample"].numpy(),
612+
"diagonal_noise": sample_vae_encoder_inputs["diagonal_noise"].numpy(),
613+
"noise": sample_vae_encoder_inputs["noise"].numpy(),
614+
"sqrt_alphas_cumprod": sqrt_alphas_cumprod_torch_shape.numpy(),
615+
"sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod_torch_shape.numpy()
616+
}).values())
617+
618+
report_correctness(baseline_out[0], coreml_out[0],
619+
"vae_encoder baseline PyTorch to baseline CoreML")
620+
621+
del traced_vae_encoder, pipe.vae.encoder, coreml_vae_encoder
622+
gc.collect()
623+
624+
456625
def convert_unet(pipe, args):
457626
""" Converts the UNet component of Stable Diffusion
458627
"""
@@ -801,7 +970,12 @@ def main(args):
801970
logger.info("Converting vae_decoder")
802971
convert_vae_decoder(pipe, args)
803972
logger.info("Converted vae_decoder")
804-
973+
974+
if args.convert_vae_encoder:
975+
logger.info("Converting vae_encoder")
976+
convert_vae_encoder(pipe, args)
977+
logger.info("Converted vae_encoder")
978+
805979
if args.convert_unet:
806980
logger.info("Converting unet")
807981
convert_unet(pipe, args)
@@ -835,6 +1009,7 @@ def parser_spec():
8351009
# Select which models to export (All are needed for text-to-image pipeline to function)
8361010
parser.add_argument("--convert-text-encoder", action="store_true")
8371011
parser.add_argument("--convert-vae-decoder", action="store_true")
1012+
parser.add_argument("--convert-vae-encoder", action="store_true")
8381013
parser.add_argument("--convert-unet", action="store_true")
8391014
parser.add_argument("--convert-safety-checker", action="store_true")
8401015
parser.add_argument(

0 commit comments

Comments
 (0)