File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed
Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -422,6 +422,22 @@ def test_iterable_dataset_using_none_batch_size(self):
422422 for d in dataloader :
423423 assert isinstance (d , torch .Tensor )
424424
425+ def test_iterable_dataset_with_non_tensor_samples (self ):
426+ dataset = SimpleIterableDataset (10 )
427+
428+ def collate_fn (features ):
429+ return {
430+ "tensor" : torch .stack (features ),
431+ "non_tensor" : "non_tensor_value" ,
432+ }
433+
434+ dataloader = DataLoader (dataset , batch_size = 4 , collate_fn = collate_fn )
435+ accelerator = Accelerator ()
436+ dataloader = accelerator .prepare_data_loader (dataloader )
437+ for d in dataloader :
438+ assert isinstance (d ["tensor" ], torch .Tensor )
439+ assert d ["non_tensor" ] == "non_tensor_value"
440+
425441 @parameterized .expand ([1 , 2 ], name_func = parameterized_custom_name_func )
426442 def test_reproducibility (self , num_processes ):
427443 set_seed (21 )
You can’t perform that action at this time.
0 commit comments