@@ -43,6 +43,21 @@ def _get_coreml_inputs(sample_inputs, args):
43
43
) for k , v in sample_inputs .items ()
44
44
]
45
45
46
+ # Simpler version of `DiagonalGaussianDistribution` with only needed calculations
47
+ # as implemented in vae.py as part of the AutoencoderKL class
48
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L312
49
+ # coremltools-6.1 does not yet implement the randn operation with the option of setting a random seed
50
+ class CoreMLDiagonalGaussianDistribution (object ):
51
+ def __init__ (self , parameters , noise ):
52
+ self .parameters = parameters
53
+ self .noise = noise
54
+ self .mean , self .logvar = torch .chunk (parameters , 2 , dim = 1 )
55
+ self .logvar = torch .clamp (self .logvar , - 30.0 , 20.0 )
56
+ self .std = torch .exp (0.5 * self .logvar )
57
+
58
+ def sample (self ) -> torch .FloatTensor :
59
+ x = self .mean + self .std * self .noise
60
+ return x
46
61
47
62
def compute_psnr (a , b ):
48
63
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
@@ -140,7 +155,7 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
140
155
141
156
def quantize_weights_to_8bits (args ):
142
157
for model_name in [
143
- "text_encoder" , "vae_decoder" , "unet" , "unet_chunk1" ,
158
+ "text_encoder" , "vae_decoder" , "vae_encoder" , " unet" , "unet_chunk1" ,
144
159
"unet_chunk2" , "safety_checker"
145
160
]:
146
161
out_path = _get_out_path (args , model_name )
@@ -190,6 +205,7 @@ def bundle_resources_for_swift_cli(args):
190
205
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
191
206
for source_name , target_name in [("text_encoder" , "TextEncoder" ),
192
207
("vae_decoder" , "VAEDecoder" ),
208
+ ("vae_encoder" , "VAEEncoder" ),
193
209
("unet" , "Unet" ),
194
210
("unet_chunk1" , "UnetChunk1" ),
195
211
("unet_chunk2" , "UnetChunk2" ),
@@ -453,6 +469,159 @@ def forward(self, z):
453
469
gc .collect ()
454
470
455
471
472
+ def convert_vae_encoder (pipe , args ):
473
+ """ Converts the VAE Encoder component of Stable Diffusion
474
+ """
475
+ out_path = _get_out_path (args , "vae_encoder" )
476
+ if os .path .exists (out_path ):
477
+ logger .info (
478
+ f"`vae_encoder` already exists at { out_path } , skipping conversion."
479
+ )
480
+ return
481
+
482
+ if not hasattr (pipe , "unet" ):
483
+ raise RuntimeError (
484
+ "convert_unet() deletes pipe.unet to save RAM. "
485
+ "Please use convert_vae_encoder() before convert_unet()" )
486
+
487
+ sample_shape = (
488
+ 1 , # B
489
+ 3 , # C (RGB range from -1 to 1)
490
+ (args .latent_h or pipe .unet .config .sample_size ) * 8 , # H
491
+ (args .latent_w or pipe .unet .config .sample_size ) * 8 , # w
492
+ )
493
+
494
+ noise_shape = (
495
+ 1 , # B
496
+ 4 , # C
497
+ pipe .unet .config .sample_size , # H
498
+ pipe .unet .config .sample_size , # w
499
+ )
500
+
501
+ float_value_shape = (
502
+ 1 ,
503
+ 1 ,
504
+ )
505
+
506
+ sqrt_alphas_cumprod_torch_shape = torch .tensor ([[0.2 ,]])
507
+ sqrt_one_minus_alphas_cumprod_torch_shape = torch .tensor ([[0.8 ,]])
508
+
509
+ sample_vae_encoder_inputs = {
510
+ "sample" : torch .rand (* sample_shape , dtype = torch .float16 ),
511
+ "diagonal_noise" : torch .rand (* noise_shape , dtype = torch .float16 ),
512
+ "noise" : torch .rand (* noise_shape , dtype = torch .float16 ),
513
+ "sqrt_alphas_cumprod" : torch .rand (* float_value_shape , dtype = torch .float16 ),
514
+ "sqrt_one_minus_alphas_cumprod" : torch .rand (* float_value_shape , dtype = torch .float16 ),
515
+ }
516
+
517
+ class VAEEncoder (nn .Module ):
518
+ """ Wrapper nn.Module wrapper for pipe.encode() method
519
+ """
520
+
521
+ def __init__ (self ):
522
+ super ().__init__ ()
523
+ self .quant_conv = pipe .vae .quant_conv
524
+ self .alphas_cumprod = pipe .scheduler .alphas_cumprod
525
+ self .encoder = pipe .vae .encoder
526
+
527
+ # Because CoreMLTools does not support the torch.randn op, we pass in both
528
+ # the diagonal Noise for the `DiagonalGaussianDistribution` operation and
529
+ # the noise tensor combined with precalculated `sqrt_alphas_cumprod` and `sqrt_one_minus_alphas_cumprod`
530
+ # for faster computation.
531
+ def forward (self , sample , diagonal_noise , noise , sqrt_alphas_cumprod , sqrt_one_minus_alphas_cumprod ):
532
+ h = self .encoder (sample )
533
+ moments = self .quant_conv (h )
534
+ posterior = CoreMLDiagonalGaussianDistribution (moments , diagonal_noise )
535
+ posteriorSample = posterior .sample ()
536
+
537
+ # Add the scaling operation and the latent noise for faster computation
538
+ init_latents = 0.18215 * posteriorSample
539
+ result = self .add_noise (init_latents , noise , sqrt_alphas_cumprod , sqrt_one_minus_alphas_cumprod )
540
+ return result
541
+
542
+ def add_noise (
543
+ self ,
544
+ original_samples : torch .FloatTensor ,
545
+ noise : torch .FloatTensor ,
546
+ sqrt_alphas_cumprod : torch .FloatTensor ,
547
+ sqrt_one_minus_alphas_cumprod : torch .FloatTensor
548
+ ) -> torch .FloatTensor :
549
+ noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
550
+ return noisy_samples
551
+
552
+
553
+ baseline_encoder = VAEEncoder ().eval ()
554
+
555
+ # No optimization needed for the VAE Encoder as it is a pure ConvNet
556
+ traced_vae_encoder = torch .jit .trace (
557
+ baseline_encoder , (
558
+ sample_vae_encoder_inputs ["sample" ].to (torch .float32 ),
559
+ sample_vae_encoder_inputs ["diagonal_noise" ].to (torch .float32 ),
560
+ sample_vae_encoder_inputs ["noise" ].to (torch .float32 ),
561
+ sqrt_alphas_cumprod_torch_shape .to (torch .float32 ),
562
+ sqrt_one_minus_alphas_cumprod_torch_shape .to (torch .float32 )
563
+ ))
564
+
565
+ modify_coremltools_torch_frontend_badbmm ()
566
+ coreml_vae_encoder , out_path = _convert_to_coreml (
567
+ "vae_encoder" , traced_vae_encoder , sample_vae_encoder_inputs ,
568
+ ["latent_dist" ], args )
569
+
570
+ # Set model metadata
571
+ coreml_vae_encoder .author = f"Please refer to the Model Card available at huggingface.co/{ args .model_version } "
572
+ coreml_vae_encoder .license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
573
+ coreml_vae_encoder .version = args .model_version
574
+ coreml_vae_encoder .short_description = \
575
+ "Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
576
+ "Please refer to https://arxiv.org/abs/2112.10752 for details."
577
+
578
+ # Set the input descriptions
579
+ coreml_vae_encoder .input_description ["sample" ] = \
580
+ "An image of the correct size to create the latent space with, image2image and in-painting."
581
+ coreml_vae_encoder .input_description ["diagonal_noise" ] = \
582
+ "Latent noise for `DiagonalGaussianDistribution` operation."
583
+ coreml_vae_encoder .input_description ["noise" ] = \
584
+ "Latent noise for use with strength parameter of image2image"
585
+ coreml_vae_encoder .input_description ["sqrt_alphas_cumprod" ] = \
586
+ "Precalculated `sqrt_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
587
+ coreml_vae_encoder .input_description ["sqrt_one_minus_alphas_cumprod" ] = \
588
+ "Precalculated `sqrt_one_minus_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
589
+
590
+ # Set the output descriptions
591
+ coreml_vae_encoder .output_description [
592
+ "latent_dist" ] = "The latent embeddings from the unet model from the input image."
593
+
594
+ _save_mlpackage (coreml_vae_encoder , out_path )
595
+
596
+ logger .info (f"Saved vae_encoder into { out_path } " )
597
+
598
+ # Parity check PyTorch vs CoreML
599
+ if args .check_output_correctness :
600
+ baseline_out = baseline_encoder (
601
+ sample = sample_vae_encoder_inputs ["sample" ].to (torch .float32 ),
602
+ diagonal_noise = sample_vae_encoder_inputs ["diagonal_noise" ].to (torch .float32 ),
603
+ noise = sample_vae_encoder_inputs ["noise" ].to (torch .float32 ),
604
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod_torch_shape ,
605
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod_torch_shape ,
606
+ ).numpy (),
607
+
608
+ coreml_out = list (
609
+ coreml_vae_encoder .predict (
610
+ {
611
+ "sample" : sample_vae_encoder_inputs ["sample" ].numpy (),
612
+ "diagonal_noise" : sample_vae_encoder_inputs ["diagonal_noise" ].numpy (),
613
+ "noise" : sample_vae_encoder_inputs ["noise" ].numpy (),
614
+ "sqrt_alphas_cumprod" : sqrt_alphas_cumprod_torch_shape .numpy (),
615
+ "sqrt_one_minus_alphas_cumprod" : sqrt_one_minus_alphas_cumprod_torch_shape .numpy ()
616
+ }).values ())
617
+
618
+ report_correctness (baseline_out [0 ], coreml_out [0 ],
619
+ "vae_encoder baseline PyTorch to baseline CoreML" )
620
+
621
+ del traced_vae_encoder , pipe .vae .encoder , coreml_vae_encoder
622
+ gc .collect ()
623
+
624
+
456
625
def convert_unet (pipe , args ):
457
626
""" Converts the UNet component of Stable Diffusion
458
627
"""
@@ -801,7 +970,12 @@ def main(args):
801
970
logger .info ("Converting vae_decoder" )
802
971
convert_vae_decoder (pipe , args )
803
972
logger .info ("Converted vae_decoder" )
804
-
973
+
974
+ if args .convert_vae_encoder :
975
+ logger .info ("Converting vae_encoder" )
976
+ convert_vae_encoder (pipe , args )
977
+ logger .info ("Converted vae_encoder" )
978
+
805
979
if args .convert_unet :
806
980
logger .info ("Converting unet" )
807
981
convert_unet (pipe , args )
@@ -835,6 +1009,7 @@ def parser_spec():
835
1009
# Select which models to export (All are needed for text-to-image pipeline to function)
836
1010
parser .add_argument ("--convert-text-encoder" , action = "store_true" )
837
1011
parser .add_argument ("--convert-vae-decoder" , action = "store_true" )
1012
+ parser .add_argument ("--convert-vae-encoder" , action = "store_true" )
838
1013
parser .add_argument ("--convert-unet" , action = "store_true" )
839
1014
parser .add_argument ("--convert-safety-checker" , action = "store_true" )
840
1015
parser .add_argument (
0 commit comments