Skip to content

Commit 542c78e

Browse files
authored
pretrain bugfix + add demo data (#2824)
1 parent 5e11ea0 commit 542c78e

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

examples/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ save_to_hf: false
4545
...
4646
```
4747

48+
为了方便测试,我们也提供了[demo 数据集](https://paddleformers.bj.bcebos.com/datasets/pt_data.tar.gz)可以直接使用:
49+
50+
```shell
51+
wget https://paddleformers.bj.bcebos.com/datasets/pt_data.tar.gz
52+
mkdir -p data/pt && tar -xf pt_data.tar.gz -C data/sft/
53+
```
54+
4855
#### 1.1.2. 离线数据流
4956

5057
我们也可以选择使用离线的比特预训练数据流,更节省内存。离线数据流制作方法如下:

paddleformers/datasets/finetuning.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,30 @@ def __iter_func(self):
485485
self.estimate = False
486486
yield []
487487

488-
if len(batch_sequence) > 0:
488+
# If the entire dataset has been fully traversed, return the remaining data.
489+
if len(all_tokenized_tokens) > 0:
490+
cut_tokens = all_tokenized_tokens
491+
cut_tokens = cut_tokens + [self.tokenizer.eos_token_id]
492+
res_tokens = cut_tokens[:-1]
493+
res_labels = cut_tokens[1:]
494+
loss_mask = [1] * len(res_tokens)
495+
pos_ids = list(range(len(res_tokens)))
496+
sequence = Sequence(
497+
token_ids=res_tokens,
498+
position_ids=pos_ids,
499+
labels=res_labels,
500+
loss_mask=loss_mask,
501+
num_examples=actual_example_num,
502+
)
503+
batch_sequence = [sequence]
489504
yield batch_sequence
505+
if self.estimate:
506+
self.used_estimate_samples += actual_example_num
507+
if self.used_estimate_samples >= self.max_estimate_samples:
508+
self.used_estimate_samples = 0
509+
# Set flag to False and yield empty list to signal the end of estimation
510+
self.estimate = False
511+
yield []
490512
else:
491513
if not self.packing:
492514
for _ in range(len(self.mix_datasets)):

0 commit comments

Comments
 (0)