Skip to content

Commit c289334

Browse files
committed
Add a test case that previously failed with a TypeError
1 parent e8e0667 commit c289334

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/test_data_loader.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,22 @@ def test_iterable_dataset_using_none_batch_size(self):
422422
for d in dataloader:
423423
assert isinstance(d, torch.Tensor)
424424

425+
def test_iterable_dataset_with_non_tensor_samples(self):
426+
dataset = SimpleIterableDataset(10)
427+
428+
def collate_fn(features):
429+
return {
430+
"tensor": torch.stack(features),
431+
"non_tensor": "non_tensor_value",
432+
}
433+
434+
dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
435+
accelerator = Accelerator()
436+
dataloader = accelerator.prepare_data_loader(dataloader)
437+
for d in dataloader:
438+
assert isinstance(d["tensor"], torch.Tensor)
439+
assert d["non_tensor"] == "non_tensor_value"
440+
425441
@parameterized.expand([1, 2], name_func=parameterized_custom_name_func)
426442
def test_reproducibility(self, num_processes):
427443
set_seed(21)

0 commit comments

Comments
 (0)