@@ -29,16 +29,16 @@ def __init__(
2929 decoder_warmup_epochs : int = 0 ,
3030 cosine_period_ratio : float = 1 ,
3131 compile_mode : str = None ,
32- weights : str = None ,
32+ weights : dict = None ,
3333 load_decoder : bool = True ,
34- repeat_stem_weights : bool = True ,
3534 optimizer : str = "SGD" ,
3635 train_transforms : Optional [transforms .Compose ] = None ,
3736 test_transforms : Optional [transforms .Compose ] = None ,
3837 val_transforms : Optional [transforms .Compose ] = None ,
3938 weight_decay : float = 3e-5 ,
4039 nesterov : bool = True ,
4140 momentum : float = 0.99 ,
41+ repeat_stem_weights : bool = True ,
4242 ):
4343 super ().__init__ ()
4444 self .learning_rate = learning_rate
@@ -60,11 +60,11 @@ def __init__(
6060 self .repeat_stem_weights = repeat_stem_weights
6161 assert 0 < cosine_period_ratio <= 1
6262
63- self .save_hyperparameters (ignore = ["model" , "train_transforms" , "val_transforms" , "test_transforms" ])
63+ self .save_hyperparameters (ignore = ["model" , "weights" , " train_transforms" , "val_transforms" , "test_transforms" ])
6464 self .model = model
6565
6666 if weights is not None :
67- self .load_weights (weights , load_decoder = load_decoder )
67+ self .load_state_dict (weights , load_decoder = load_decoder , strict = False )
6868
6969 self .model = torch .compile (model , mode = compile_mode ) if compile_mode is not None else model
7070
@@ -143,11 +143,6 @@ def configure_optimizers(self):
143143
144144 return [optimizer ], [scheduler_config ]
145145
146- def load_weights (self , weights , load_decoder = True ):
147- ckpt = torch .load (weights , map_location = "cpu" , weights_only = False )
148- print (f"Loading weights trained for { ckpt ['global_step' ]} steps / { ckpt ['epoch' ]} epochs." )
149- self .load_state_dict (ckpt ["state_dict" ], load_decoder = load_decoder , strict = False )
150-
151146 def load_state_dict (self , state_dict , load_decoder = True , * args , ** kwargs ):
152147 old_params = copy .deepcopy (self .state_dict ())
153148
@@ -161,10 +156,10 @@ def load_state_dict(self, state_dict, load_decoder=True, *args, **kwargs):
161156 state_dict = {k .replace ("_orig_mod." , "" ): v for k , v in state_dict .items ()}
162157
163158 # Repeat stem weights when state_dict num_channels is smaller than new_state_dict num_channels
164- if self .model .stem_weight_name is not None and self .repeat_stem_weights :
159+ if hasattr ( self . model , "stem_weight_name" ) and self .model .stem_weight_name is not None and self .repeat_stem_weights :
165160 prefix = "model._orig_mod." if "_orig_mod" in list (state_dict .keys ())[0 ] else "model."
166161 stem_name = f"{ prefix } { self .model .stem_weight_name } "
167- pt_input_channels = state_dict [stem_name ].shape [1 ] # (N, C, H, W, Z) where N is num tokens.
162+ pt_input_channels = state_dict [stem_name ].shape [1 ]
168163 ft_input_channels = old_params [stem_name ].shape [1 ]
169164 if pt_input_channels < ft_input_channels :
170165 assert pt_input_channels == 1 , (
0 commit comments