Skip to content

Commit 788f0a8

Browse files
committed
Implement custom collate_fn that skips nones
This allows to skip dataloader errors
1 parent 2bba276 commit 788f0a8

1 file changed

Lines changed: 19 additions & 0 deletions

File tree

src/data/discotube_text_audio.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,30 @@ def setup(self, stage: str):
178178
metadata_id_map=self.metadata_id_map,
179179
)
180180

181+
def collate_fn(self, batch):
182+
"""Custom collate function to skip data loader errors"""
183+
184+
audio_in, text_in = zip(*batch)
185+
186+
audio = [i for i in audio_in if i is not None]
187+
text = [i for i in text_in if i is not None]
188+
189+
audio = torch.stack(audio)
190+
191+
if len(audio) < len(audio_in) // 2:
192+
warnings.warn(
193+
f"Skipping {len(audio_in) - len(audio)} samples out if {len(audio_in)} in collate_fn "
194+
)
195+
196+
return audio, text
197+
181198
def train_dataloader(self):
182199
return DataLoader(
183200
self.dataset_train,
184201
batch_size=self.batch_size,
185202
num_workers=self.num_workers,
186203
pin_memory=True,
204+
collate_fn=self.collate_fn,
187205
)
188206

189207
def val_dataloader(self):
@@ -192,4 +210,5 @@ def val_dataloader(self):
192210
batch_size=self.batch_size,
193211
num_workers=self.num_workers,
194212
pin_memory=True,
213+
collate_fn=self.collate_fn,
195214
)

0 commit comments

Comments
 (0)