File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -178,12 +178,30 @@ def setup(self, stage: str):
178178 metadata_id_map = self .metadata_id_map ,
179179 )
180180
181+ def collate_fn (self , batch ):
182+ """Custom collate function to skip data loader errors"""
183+
184+ audio_in , text_in = zip (* batch )
185+
186+ audio = [i for i in audio_in if i is not None ]
187+ text = [i for i in text_in if i is not None ]
188+
189+ audio = torch .stack (audio )
190+
191+ if len (audio ) < len (audio_in ) // 2 :
192+ warnings .warn (
193+ f"Skipping { len (audio_in ) - len (audio )} samples out if { len (audio_in )} in collate_fn "
194+ )
195+
196+ return audio , text
197+
181198 def train_dataloader (self ):
182199 return DataLoader (
183200 self .dataset_train ,
184201 batch_size = self .batch_size ,
185202 num_workers = self .num_workers ,
186203 pin_memory = True ,
204+ collate_fn = self .collate_fn ,
187205 )
188206
189207 def val_dataloader (self ):
@@ -192,4 +210,5 @@ def val_dataloader(self):
192210 batch_size = self .batch_size ,
193211 num_workers = self .num_workers ,
194212 pin_memory = True ,
213+ collate_fn = self .collate_fn ,
195214 )
You can’t perform that action at this time.
0 commit comments