@@ -43,6 +43,21 @@ def _get_coreml_inputs(sample_inputs, args):
4343        ) for  k , v  in  sample_inputs .items ()
4444    ]
4545
46+ # Simpler version of `DiagonalGaussianDistribution` with only needed calculations 
47+ # as implemented in vae.py as part of the AutoencoderKL class 
48+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L312 
49+ # coremltools-6.1 does not yet implement the randn operation with the option of setting a random seed 
50+ class  CoreMLDiagonalGaussianDistribution (object ):
51+     def  __init__ (self , parameters , noise ):
52+         self .parameters  =  parameters 
53+         self .noise  =  noise 
54+         self .mean , self .logvar  =  torch .chunk (parameters , 2 , dim = 1 )
55+         self .logvar  =  torch .clamp (self .logvar , - 30.0 , 20.0 )
56+         self .std  =  torch .exp (0.5  *  self .logvar )
57+ 
58+     def  sample (self ) ->  torch .FloatTensor :
59+         x  =  self .mean  +  self .std  *  self .noise 
60+         return  x 
4661
4762def  compute_psnr (a , b ):
4863    """ Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects 
@@ -140,7 +155,7 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
140155
141156def  quantize_weights_to_8bits (args ):
142157    for  model_name  in  [
143-             "text_encoder" , "vae_decoder" , "unet" , "unet_chunk1" ,
158+             "text_encoder" , "vae_decoder" , "vae_encoder"  ,  " unet" , "unet_chunk1" ,
144159            "unet_chunk2" , "safety_checker" 
145160    ]:
146161        out_path  =  _get_out_path (args , model_name )
@@ -190,6 +205,7 @@ def bundle_resources_for_swift_cli(args):
190205    # Compile model using coremlcompiler (Significantly reduces the load time for unet) 
191206    for  source_name , target_name  in  [("text_encoder" , "TextEncoder" ),
192207                                     ("vae_decoder" , "VAEDecoder" ),
208+                                      ("vae_encoder" , "VAEEncoder" ),
193209                                     ("unet" , "Unet" ),
194210                                     ("unet_chunk1" , "UnetChunk1" ),
195211                                     ("unet_chunk2" , "UnetChunk2" ),
@@ -453,6 +469,159 @@ def forward(self, z):
453469    gc .collect ()
454470
455471
472+ def  convert_vae_encoder (pipe , args ):
473+     """ Converts the VAE Encoder component of Stable Diffusion 
474+     """ 
475+     out_path  =  _get_out_path (args , "vae_encoder" )
476+     if  os .path .exists (out_path ):
477+         logger .info (
478+             f"`vae_encoder` already exists at { out_path }  , skipping conversion." 
479+         )
480+         return 
481+ 
482+     if  not  hasattr (pipe , "unet" ):
483+         raise  RuntimeError (
484+             "convert_unet() deletes pipe.unet to save RAM. " 
485+             "Please use convert_vae_encoder() before convert_unet()" )
486+ 
487+     sample_shape  =  (
488+         1 ,  # B 
489+         3 ,  # C (RGB range from -1 to 1) 
490+         (args .latent_h  or  pipe .unet .config .sample_size ) *  8 ,  # H 
491+         (args .latent_w  or  pipe .unet .config .sample_size ) *  8 ,  # w 
492+     )
493+     
494+     noise_shape  =  (
495+         1 ,  # B 
496+         4 ,  # C 
497+         pipe .unet .config .sample_size ,  # H 
498+         pipe .unet .config .sample_size ,  # w 
499+     )
500+ 
501+     float_value_shape  =  (
502+         1 ,
503+         1 ,
504+     )
505+ 
506+     sqrt_alphas_cumprod_torch_shape  =  torch .tensor ([[0.2 ,]])
507+     sqrt_one_minus_alphas_cumprod_torch_shape  =  torch .tensor ([[0.8 ,]])
508+ 
509+     sample_vae_encoder_inputs  =  {
510+         "sample" : torch .rand (* sample_shape , dtype = torch .float16 ),
511+         "diagonal_noise" : torch .rand (* noise_shape , dtype = torch .float16 ),
512+         "noise" : torch .rand (* noise_shape , dtype = torch .float16 ),
513+         "sqrt_alphas_cumprod" : torch .rand (* float_value_shape , dtype = torch .float16 ),
514+         "sqrt_one_minus_alphas_cumprod" : torch .rand (* float_value_shape , dtype = torch .float16 ),
515+     }
516+ 
517+     class  VAEEncoder (nn .Module ):
518+         """ Wrapper nn.Module wrapper for pipe.encode() method 
519+         """ 
520+ 
521+         def  __init__ (self ):
522+             super ().__init__ ()
523+             self .quant_conv  =  pipe .vae .quant_conv 
524+             self .alphas_cumprod  =  pipe .scheduler .alphas_cumprod 
525+             self .encoder  =  pipe .vae .encoder 
526+ 
527+         # Because CoreMLTools does not support the torch.randn op, we pass in both 
528+         # the diagonal Noise for the `DiagonalGaussianDistribution` operation and 
529+         # the noise tensor combined with precalculated `sqrt_alphas_cumprod` and `sqrt_one_minus_alphas_cumprod` 
530+         # for faster computation. 
531+         def  forward (self , sample , diagonal_noise , noise , sqrt_alphas_cumprod , sqrt_one_minus_alphas_cumprod ):
532+             h  =  self .encoder (sample )
533+             moments  =  self .quant_conv (h )
534+             posterior  =  CoreMLDiagonalGaussianDistribution (moments , diagonal_noise )
535+             posteriorSample  =  posterior .sample ()
536+             
537+             # Add the scaling operation and the latent noise for faster computation 
538+             init_latents  =  0.18215  *  posteriorSample 
539+             result  =  self .add_noise (init_latents , noise , sqrt_alphas_cumprod , sqrt_one_minus_alphas_cumprod )
540+             return  result 
541+         
542+         def  add_noise (
543+             self ,
544+             original_samples : torch .FloatTensor ,
545+             noise : torch .FloatTensor ,
546+             sqrt_alphas_cumprod : torch .FloatTensor ,
547+             sqrt_one_minus_alphas_cumprod : torch .FloatTensor 
548+         ) ->  torch .FloatTensor :
549+             noisy_samples  =  sqrt_alphas_cumprod  *  original_samples  +  sqrt_one_minus_alphas_cumprod  *  noise 
550+             return  noisy_samples 
551+          
552+ 
553+     baseline_encoder  =  VAEEncoder ().eval ()
554+ 
555+     # No optimization needed for the VAE Encoder as it is a pure ConvNet 
556+     traced_vae_encoder  =  torch .jit .trace (
557+         baseline_encoder , (
558+             sample_vae_encoder_inputs ["sample" ].to (torch .float32 ),
559+             sample_vae_encoder_inputs ["diagonal_noise" ].to (torch .float32 ),
560+             sample_vae_encoder_inputs ["noise" ].to (torch .float32 ),
561+             sqrt_alphas_cumprod_torch_shape .to (torch .float32 ),
562+             sqrt_one_minus_alphas_cumprod_torch_shape .to (torch .float32 )
563+         ))
564+ 
565+     modify_coremltools_torch_frontend_badbmm ()
566+     coreml_vae_encoder , out_path  =  _convert_to_coreml (
567+         "vae_encoder" , traced_vae_encoder , sample_vae_encoder_inputs ,
568+         ["latent_dist" ], args )
569+ 
570+     # Set model metadata 
571+     coreml_vae_encoder .author  =  f"Please refer to the Model Card available at huggingface.co/{ args .model_version }  " 
572+     coreml_vae_encoder .license  =  "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)" 
573+     coreml_vae_encoder .version  =  args .model_version 
574+     coreml_vae_encoder .short_description  =  \
575+         "Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. "  \
576+         "Please refer to https://arxiv.org/abs/2112.10752 for details." 
577+ 
578+     # Set the input descriptions 
579+     coreml_vae_encoder .input_description ["sample" ] =  \
580+         "An image of the correct size to create the latent space with, image2image and in-painting." 
581+     coreml_vae_encoder .input_description ["diagonal_noise" ] =  \
582+         "Latent noise for `DiagonalGaussianDistribution` operation." 
583+     coreml_vae_encoder .input_description ["noise" ] =  \
584+         "Latent noise for use with strength parameter of image2image" 
585+     coreml_vae_encoder .input_description ["sqrt_alphas_cumprod" ] =  \
586+         "Precalculated `sqrt_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values" 
587+     coreml_vae_encoder .input_description ["sqrt_one_minus_alphas_cumprod" ] =  \
588+         "Precalculated `sqrt_one_minus_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values" 
589+ 
590+     # Set the output descriptions 
591+     coreml_vae_encoder .output_description [
592+         "latent_dist" ] =  "The latent embeddings from the unet model from the input image." 
593+ 
594+     _save_mlpackage (coreml_vae_encoder , out_path )
595+ 
596+     logger .info (f"Saved vae_encoder into { out_path }  " )
597+ 
598+     # Parity check PyTorch vs CoreML 
599+     if  args .check_output_correctness :
600+         baseline_out  =  baseline_encoder (
601+             sample = sample_vae_encoder_inputs ["sample" ].to (torch .float32 ),
602+             diagonal_noise = sample_vae_encoder_inputs ["diagonal_noise" ].to (torch .float32 ),
603+             noise = sample_vae_encoder_inputs ["noise" ].to (torch .float32 ),
604+             sqrt_alphas_cumprod = sqrt_alphas_cumprod_torch_shape ,
605+             sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod_torch_shape ,
606+             ).numpy (),
607+ 
608+         coreml_out  =  list (
609+             coreml_vae_encoder .predict (
610+                 {
611+                     "sample" : sample_vae_encoder_inputs ["sample" ].numpy (),
612+                     "diagonal_noise" : sample_vae_encoder_inputs ["diagonal_noise" ].numpy (),
613+                     "noise" : sample_vae_encoder_inputs ["noise" ].numpy (),
614+                     "sqrt_alphas_cumprod" : sqrt_alphas_cumprod_torch_shape .numpy (),
615+                     "sqrt_one_minus_alphas_cumprod" : sqrt_one_minus_alphas_cumprod_torch_shape .numpy ()
616+                 }).values ())
617+ 
618+         report_correctness (baseline_out [0 ], coreml_out [0 ],
619+                            "vae_encoder baseline PyTorch to baseline CoreML" )
620+ 
621+     del  traced_vae_encoder , pipe .vae .encoder , coreml_vae_encoder 
622+     gc .collect ()
623+ 
624+ 
456625def  convert_unet (pipe , args ):
457626    """ Converts the UNet component of Stable Diffusion 
458627    """ 
@@ -801,7 +970,12 @@ def main(args):
801970        logger .info ("Converting vae_decoder" )
802971        convert_vae_decoder (pipe , args )
803972        logger .info ("Converted vae_decoder" )
804- 
973+         
974+     if  args .convert_vae_encoder :
975+         logger .info ("Converting vae_encoder" )
976+         convert_vae_encoder (pipe , args )
977+         logger .info ("Converted vae_encoder" )
978+         
805979    if  args .convert_unet :
806980        logger .info ("Converting unet" )
807981        convert_unet (pipe , args )
@@ -835,6 +1009,7 @@ def parser_spec():
8351009    # Select which models to export (All are needed for text-to-image pipeline to function) 
8361010    parser .add_argument ("--convert-text-encoder" , action = "store_true" )
8371011    parser .add_argument ("--convert-vae-decoder" , action = "store_true" )
1012+     parser .add_argument ("--convert-vae-encoder" , action = "store_true" )
8381013    parser .add_argument ("--convert-unet" , action = "store_true" )
8391014    parser .add_argument ("--convert-safety-checker" , action = "store_true" )
8401015    parser .add_argument (
0 commit comments