@@ -301,14 +301,21 @@ def __getitem__(self, idx):
301301 self .doc_idx [doc_index_l ],
302302 length = offset_l + 1 ))
303303 sample = np .concatenate (sample_list )
304+
305+ text_name = 'text'
306+ if args .use_dataset_only :
307+ text_name = 'input_ids'
308+ sample_dict = {text_name : np .array (sample , dtype = np .int64 )}
304309 if args .return_data_index :
305- return {'text' : np .array (sample , dtype = np .int64 ),
306- 'index' : np .array ([orig_idx ], dtype = np .int64 )}
307- elif self .return_doc_ids : # for retro preprocessing
308- return {'text' : np .array (sample , dtype = np .int64 ),
309- 'doc_ids' : np .array (doc_ids , dtype = np .int64 )}
310- else :
311- return {'text' : np .array (sample , dtype = np .int64 )}
310+ sample_dict .update ({'index' : np .array ([orig_idx ], dtype = np .int64 )})
311+
312+ if self .return_doc_ids : # for retro preprocessing
313+ sample_dict .update ({'doc_ids' : np .array (doc_ids , dtype = np .int64 )})
314+
315+ if args .use_dataset_only :
316+ sample_dict .update ({'labels' : np .array (sample , dtype = np .int64 )})
317+
318+ return sample_dict
312319
313320
314321def _build_index_mappings (name , data_prefix , documents , sizes ,
0 commit comments