Skip to content
This repository was archived by the owner on Feb 25, 2022. It is now read-only.

Commit d741ddf

Browse files
authored
Merge pull request #230 from nostalgebraist/tfrecords-prepend-fix
Fix trailing token bug in create_tfrecords
2 parents afe5e69 + b12f6b2 commit d741ddf

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

data/create_tfrecords.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def split_list(l, n):
106106
return [l[i:i + n] for i in range(0, len(l), n)]
107107

108108

109-
def archive_to_tokens(f, encoder, args):
109+
def archive_to_tokens(f, encoder, args, prefix=[]):
110110
# Generator that yields the contents of the files in an archive
111111
# if data_to_prepend is not None, prepend data_to_prepend + a EOS separator to the encoded data
112112
reader = Reader(f)
@@ -116,7 +116,8 @@ def archive_to_tokens(f, encoder, args):
116116
if args.wikitext_detokenize:
117117
doc = wikitext_detokenizer(doc)
118118
doc = encoder.encode(doc) + args.separator # read document from lmd and append separator token
119-
yield split_list(doc, args.chunk_size) # split into n_ctx + 1 size chunks
119+
yield split_list(prefix + doc, args.chunk_size) # split into n_ctx + 1 size chunks
120+
prefix = []
120121

121122

122123
def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None):
@@ -189,24 +190,21 @@ def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_c
189190
tokenized_files_array = []
190191

191192
for f in files:
192-
for tokenized_files in archive_to_tokens(f, enc, args):
193+
for tokenized_files in archive_to_tokens(f, enc, args, prefix=data_to_prepend):
193194
files_processed += 1
194195
if files_processed < resume_files_processed:
195196
continue # resume from checkpoint
196197

197198
# if the last chunk < chunk size, but > minimum_size, take it and append it to the beginning of the next file
199+
data_to_prepend = []
198200
n_tokens = len(tokenized_files[-1])
199201
if n_tokens < args.chunk_size:
200202
data = tokenized_files.pop(-1)
201203
if n_tokens >= args.minimum_size:
202-
data_to_prepend.extend(data)
204+
data_to_prepend = data
203205
else:
204206
discarded_files += 1
205207

206-
if len(data_to_prepend) >= args.chunk_size:
207-
# if length of data_to_prepend becomes greater than chunk size, add concatted files to tokenized files
208-
tokenized_files_array.append(data_to_prepend[:args.chunk_size])
209-
data_to_prepend = data_to_prepend[args.chunk_size:]
210208
# add tokenized files > chunk size to main array
211209
tokenized_files_array.extend(tokenized_files)
212210

0 commit comments

Comments
 (0)