Skip to content

Commit 978ff97

Browse files
committed
updated history load
1 parent f00cc1c commit 978ff97

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ def load(self):
209209
print("%sLoading checkpoint from: %s...%s"%(Colors.MAGENTABG, self._path, Colors.ENDC))
210210
self._data = torch.load(self._path)
211211

212+
def was_provided(self) -> bool:
213+
"""Returns true if user has provided a checkpoint to be loaded"""
214+
return self._path is not None
215+
212216
@property
213217
def model(self) -> Optional[StateDict]:
214218
"""return the model weights"""
@@ -385,7 +389,8 @@ def train(device:Device, checkpoint:Checkpoint):
385389
data_loader_eval = DataLoader(eval_data, batch_size=BATCH_SIZE, shuffle=True)
386390

387391
history = History(HISTORY_FILE)
388-
history.load()
392+
if checkpoint.was_provided():
393+
history.load()
389394

390395
loss_fn = nn.KLDivLoss(reduction="batchmean")
391396
lr = 0.1
@@ -502,7 +507,7 @@ def get_checkpoint_arg() -> Optional[str]:
502507
checkpoint path.
503508
"""
504509
parser = argparse.ArgumentParser()
505-
parser.add_argument("--checkpoint", type=str)
510+
parser.add_argument("--checkpoint", type=str, default="")
506511
args = parser.parse_args(sys.argv[1:])
507512
checkpoint_path = args.checkpoint if args.checkpoint else None
508513
return checkpoint_path

0 commit comments

Comments
 (0)