@@ -47,7 +47,7 @@ def create_queue() -> Queue:
47
47
return q
48
48
49
49
50
- def train_model (queue : Queue , max_epochs : int , ckpt_path : Path ) -> Trainer :
50
+ def train_model (queue : Queue , max_epochs : int , ckpt_path : Path ) -> None :
51
51
dataloader = DataLoader (QueueDataset (queue ), num_workers = 1 , batch_size = None , persistent_workers = True )
52
52
trainer = Trainer (
53
53
max_epochs = max_epochs ,
@@ -61,21 +61,17 @@ def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> Trainer:
61
61
else :
62
62
trainer .fit (BoringModel (), dataloader )
63
63
trainer .save_checkpoint (str (ckpt_path ))
64
- return trainer
65
64
66
65
67
66
def test_resume_training_with (tmp_path ):
68
67
"""Test resuming training from checkpoint file using a IterableDataset."""
69
68
queue = create_queue ()
70
69
max_epoch = 2
71
70
ckpt_path = tmp_path / "model.ckpt"
72
- trainer = train_model (queue , max_epoch , ckpt_path )
73
- assert trainer is not None
71
+ train_model (queue , max_epoch , ckpt_path )
74
72
75
73
assert os .path .exists (ckpt_path ), f"Checkpoint file '{ ckpt_path } ' wasn't created"
76
-
77
74
ckpt_size = os .path .getsize (ckpt_path )
78
75
assert ckpt_size > 0 , f"Checkpoint file is empty (size: { ckpt_size } bytes)"
79
76
80
- trainer = train_model (queue , max_epoch + 2 , ckpt_path )
81
- assert trainer is not None
77
+ train_model (queue , max_epoch + 2 , ckpt_path )
0 commit comments