Skip to content

Commit e451080

Browse files
committed
update test
Signed-off-by: sudipto baral <[email protected]>
1 parent eb3a763 commit e451080

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def create_queue() -> Queue:
4747
return q
4848

4949

50-
def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> Trainer:
50+
def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> None:
5151
dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None, persistent_workers=True)
5252
trainer = Trainer(
5353
max_epochs=max_epochs,
@@ -61,21 +61,17 @@ def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> Trainer:
6161
else:
6262
trainer.fit(BoringModel(), dataloader)
6363
trainer.save_checkpoint(str(ckpt_path))
64-
return trainer
6564

6665

6766
def test_resume_training_with(tmp_path):
6867
"""Test resuming training from checkpoint file using a IterableDataset."""
6968
queue = create_queue()
7069
max_epoch = 2
7170
ckpt_path = tmp_path / "model.ckpt"
72-
trainer = train_model(queue, max_epoch, ckpt_path)
73-
assert trainer is not None
71+
train_model(queue, max_epoch, ckpt_path)
7472

7573
assert os.path.exists(ckpt_path), f"Checkpoint file '{ckpt_path}' wasn't created"
76-
7774
ckpt_size = os.path.getsize(ckpt_path)
7875
assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)"
7976

80-
trainer = train_model(queue, max_epoch + 2, ckpt_path)
81-
assert trainer is not None
77+
train_model(queue, max_epoch + 2, ckpt_path)

0 commit comments

Comments
 (0)