File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -209,6 +209,10 @@ def load(self):
209209 print ("%sLoading checkpoint from: %s...%s" % (Colors .MAGENTABG , self ._path , Colors .ENDC ))
210210 self ._data = torch .load (self ._path )
211211
212+ def was_provided (self ) -> bool :
213+ """Returns true if user has provided a checkpoint to be loaded"""
214+ return self ._path is not None
215+
212216 @property
213217 def model (self ) -> Optional [StateDict ]:
214218 """return the model weights"""
@@ -385,7 +389,8 @@ def train(device:Device, checkpoint:Checkpoint):
385389 data_loader_eval = DataLoader (eval_data , batch_size = BATCH_SIZE , shuffle = True )
386390
387391 history = History (HISTORY_FILE )
388- history .load ()
392+ if checkpoint .was_provided ():
393+ history .load ()
389394
390395 loss_fn = nn .KLDivLoss (reduction = "batchmean" )
391396 lr = 0.1
@@ -502,7 +507,7 @@ def get_checkpoint_arg() -> Optional[str]:
502507 checkpoint path.
503508 """
504509 parser = argparse .ArgumentParser ()
505- parser .add_argument ("--checkpoint" , type = str )
510+ parser .add_argument ("--checkpoint" , type = str , default = "" )
506511 args = parser .parse_args (sys .argv [1 :])
507512 checkpoint_path = args .checkpoint if args .checkpoint else None
508513 return checkpoint_path
You can’t perform that action at this time.
0 commit comments