Skip to content

Commit 8760390

Browse files
authored
Change the text name of sample for compatible with Huggingface trainer (bigscience-workshop#289)
Signed-off-by: yuanwu <[email protected]>
1 parent e7f0201 commit 8760390

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

megatron/data/gpt_dataset.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,21 @@ def __getitem__(self, idx):
301301
self.doc_idx[doc_index_l],
302302
length=offset_l + 1))
303303
sample = np.concatenate(sample_list)
304+
305+
text_name = 'text'
306+
if args.use_dataset_only:
307+
text_name = 'input_ids'
308+
sample_dict = {text_name: np.array(sample, dtype=np.int64)}
304309
if args.return_data_index:
305-
return {'text': np.array(sample, dtype=np.int64),
306-
'index': np.array([orig_idx], dtype=np.int64)}
307-
elif self.return_doc_ids: # for retro preprocessing
308-
return {'text': np.array(sample, dtype=np.int64),
309-
'doc_ids': np.array(doc_ids, dtype=np.int64)}
310-
else:
311-
return {'text': np.array(sample, dtype=np.int64)}
310+
sample_dict.update({'index': np.array([orig_idx], dtype=np.int64)})
311+
312+
if self.return_doc_ids: # for retro preprocessing
313+
sample_dict.update({'doc_ids': np.array(doc_ids, dtype=np.int64)})
314+
315+
if args.use_dataset_only:
316+
sample_dict.update({'labels': np.array(sample, dtype=np.int64)})
317+
318+
return sample_dict
312319

313320

314321
def _build_index_mappings(name, data_prefix, documents, sizes,

0 commit comments

Comments
 (0)