@@ -196,21 +196,20 @@ def load_diffusers_checkpoint(self):
196196 precision = precision ,
197197 )
198198
199- if len (self .config .unet_checkpoint ) > 0 :
200- unet , unet_params = FlaxUNet2DConditionModel .from_pretrained (
201- self .config .unet_checkpoint ,
202- split_head_dim = self .config .split_head_dim ,
203- norm_num_groups = self .config .norm_num_groups ,
204- attention_kernel = self .config .attention ,
205- flash_block_sizes = flash_block_sizes ,
206- dtype = self .activations_dtype ,
207- weights_dtype = self .weights_dtype ,
208- mesh = self .mesh ,
209- )
210- params ["unet" ] = unet_params
211- pipeline .unet = unet
212- params = jax .tree_util .tree_map (lambda x : x .astype (self .config .weights_dtype ), params )
213-
199+ if len (self .config .unet_checkpoint ) > 0 :
200+ unet , unet_params = FlaxUNet2DConditionModel .from_pretrained (
201+ self .config .unet_checkpoint ,
202+ split_head_dim = self .config .split_head_dim ,
203+ norm_num_groups = self .config .norm_num_groups ,
204+ attention_kernel = self .config .attention ,
205+ flash_block_sizes = flash_block_sizes ,
206+ dtype = self .activations_dtype ,
207+ weights_dtype = self .weights_dtype ,
208+ mesh = self .mesh ,
209+ )
210+ params ["unet" ] = unet_params
211+ pipeline .unet = unet
212+ params = jax .tree_util .tree_map (lambda x : x .astype (self .config .weights_dtype ), params )
214213 return pipeline , params
215214
216215 def save_checkpoint (self , train_step , pipeline , params , train_states ):
0 commit comments