@@ -969,32 +969,80 @@ def load_networks(self, epoch):
969969 state_dict [new_key ] = state_dict [key ].clone ()
970970 del state_dict [key ]
971971
972- state1 = list (state_dict .keys ())
973- state2 = list (net .state_dict ().keys ())
974- state1 .sort ()
975- state2 .sort ()
976-
977- for key1 , key2 in zip (state1 , state2 ):
978- if key1 != key2 :
979- print (key1 == key2 , key1 , key2 )
980-
981- if hasattr (state_dict , "_ema" ):
982- net .load_state_dict (
983- state_dict ["_ema" ], strict = self .opt .model_load_no_strictness
972+ if self .opt .alg_diffusion_ddpm_cm_ft :
973+ model_dict = net .state_dict ()
974+ filtered = {}
975+
976+ for k , v in state_dict .items ():
977+ if "denoise_fn.model.cond_embed" in k :
978+ new_k = k .replace (
979+ "denoise_fn.model.cond_embed" ,
980+ "cm_cond_embed.projection" ,
981+ )
982+ elif k .startswith ("cond_embed." ):
983+ new_k = k .replace ("cond_embed" , "cm_cond_embed.projection" )
984+ elif "denoise_fn.model." in k :
985+ new_k = k .replace ("denoise_fn.model." , "cm_model." )
986+ else :
987+ new_k = k
988+
989+ if new_k in model_dict and v .shape == model_dict [new_k ].shape :
990+ filtered [new_k ] = v
991+ else :
992+ if "cond_embed" in k :
993+ print (f"⚠️ unmatched cond_embed key { k } → { new_k } " )
994+ else :
995+ print (
996+ f"⚠️ skipping { new_k } : shape { v .shape if hasattr (v , 'shape' ) else 'N/A' } "
997+ )
998+
999+ missing = set (model_dict .keys ()) - set (filtered .keys ())
1000+ extra = set (state_dict .keys ()) - set (model_dict .keys ())
1001+
1002+ print (
1003+ f"Loaded { len (filtered )} /{ len (model_dict )} params; { len (missing )} missing." ,
1004+ flush = True ,
9841005 )
985- else :
986- if (
987- name == "G_A"
988- and hasattr (net , "unet" )
989- and hasattr (net , "vae" )
990- and any ("lora" in n for n , _ in net .unet .named_parameters ())
991- ):
992- net .load_lora_config (load_path )
993- print ("loading the lora" )
994- else :
1006+
1007+ if missing :
1008+ print ("\n ⚠️ Missing keys:" )
1009+ for k in sorted (missing ):
1010+ print (" " , k )
1011+
1012+ net .load_state_dict (filtered , strict = False )
1013+
1014+ print (
1015+ "✅ Loaded pretrained DDPM weights (with partial embedding transfer)." ,
1016+ flush = True ,
1017+ )
1018+
1019+ if not self .opt .alg_diffusion_ddpm_cm_ft :
1020+ state1 = list (state_dict .keys ())
1021+ state2 = list (net .state_dict ().keys ())
1022+ state1 .sort ()
1023+ state2 .sort ()
1024+
1025+ for key1 , key2 in zip (state1 , state2 ):
1026+ if key1 != key2 :
1027+ print (key1 == key2 , key1 , key2 )
1028+
1029+ if hasattr (state_dict , "_ema" ):
9951030 net .load_state_dict (
996- state_dict , strict = self .opt .model_load_no_strictness
1031+ state_dict [ "_ema" ] , strict = self .opt .model_load_no_strictness
9971032 )
1033+ else :
1034+ if (
1035+ name == "G_A"
1036+ and hasattr (net , "unet" )
1037+ and hasattr (net , "vae" )
1038+ and any ("lora" in n for n , _ in net .unet .named_parameters ())
1039+ ):
1040+ net .load_lora_config (load_path )
1041+ print ("loading the lora" )
1042+ else :
1043+ net .load_state_dict (
1044+ state_dict , strict = self .opt .model_load_no_strictness
1045+ )
9981046
9991047 def get_nets (self ):
10001048 return_nets = {}
0 commit comments