Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions examples/tokenize_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,49 @@ def pack_sequences(
['▁toys', '▁.', '</s>', '<s>', '▁but', '▁just', '▁one', '▁look']
"""
packed_sequences = []
packed_position_ids = []
buffer = []
position_buffer = []

for input_ids in batch["input_ids"]:
# Add the current sequence to the buffer
buffer.extend(input_ids)
buffer.append(eos_token_id) # Add EOS at the end of each sequence
# Truncate sequences that individually exceed max_seq_len (including EOS token).
seq_with_eos = (input_ids + [eos_token_id])[:max_seq_len]
# Position IDs reset to 0 at the start of each sub-sequence.
seq_positions = list(range(len(seq_with_eos)))

# Check if buffer needs to be split into chunks
while len(buffer) > max_seq_len:
# Take a full chunk from the buffer and append it to packed_sequences
packed_sequences.append(buffer[:max_seq_len])
# Remove the processed chunk from the buffer
buffer = buffer[max_seq_len:]
# If adding this sequence would overflow, flush the current buffer first.
# This ensures every chunk starts at a sequence boundary (position_ids[0] == 0).
if buffer and len(buffer) + len(seq_with_eos) > max_seq_len:
padding_length = max_seq_len - len(buffer)
packed_sequences.append(buffer + [pad_token_id] * padding_length)
packed_position_ids.append(position_buffer + [0] * padding_length)
buffer = []
position_buffer = []

buffer.extend(seq_with_eos)
position_buffer.extend(seq_positions)

# Flush immediately if exactly full (no padding needed).
if len(buffer) == max_seq_len:
packed_sequences.append(buffer)
packed_position_ids.append(position_buffer)
buffer = []
position_buffer = []

# Add the last buffer if it's exactly chunk_size
if len(buffer) == max_seq_len:
packed_sequences.append(buffer)
packed_position_ids.append(position_buffer)
elif len(buffer) > cutoff_size:
# if the buffer is larger than the cutoff size, pad it to the chunk_size
# if not, we do not include in the packed_sequences
buffer.extend([pad_token_id] * (max_seq_len - len(buffer)))
padding_length = max_seq_len - len(buffer)
buffer.extend([pad_token_id] * padding_length)
position_buffer.extend([0] * padding_length)
packed_sequences.append(buffer)
packed_position_ids.append(position_buffer)

output = {"input_ids": packed_sequences}
output = {"input_ids": packed_sequences, "position_ids": packed_position_ids}
if add_labels:
output["labels"] = [
[
Expand All @@ -109,7 +128,6 @@ def pack_sequences(
for example in output["input_ids"]
]

# mask attention for padding tokens, a better version would also mask cross-sequence dependencies
output["attention_mask"] = [
[0 if token_id == pad_token_id else 1 for token_id in example]
for example in output["input_ids"]
Expand Down
Loading