Skip to content

feat(trainer): Support multi-role & consecutive turns in DataCollatorForCompletionOnlyLM (#3223) #3224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 331 additions & 0 deletions tests/test_data_collator_completion_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,124 @@
from trl import DataCollatorForCompletionOnlyLM


# Define samples globally for reuse
CHATML_SAMPLE_BASIC_MULTI_TURN = """<|im_start|>system
system prompt system ptompt system prompt

<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
T T T T T<|im_end|>
<|im_start|>user
T T T T T<|im_end|>
<|im_start|>user
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>""" # 4 assistant turns

CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN = """<|im_start|>system
system prompt system ptompt system prompt

<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>""" # 8 assistant turns

CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE = """<|im_start|>system
system prompt system ptompt system prompt

<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>""" # 8 assistant turns

CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_SIMPLE = """<|im_start|>system
Prompt.
<|im_end|>
<|im_start|>user
User query.<|im_end|>
<|im_start|>assistant
Assistant response 1.<|im_end|>
<|im_start|>assistant
Assistant response 2.<|im_end|>
<|im_start|>user
Another user query.<|im_end|>
<|im_start|>assistant
Assistant response 3.<|im_end|>""" # 3 assistant turns total, 2 consecutive

# Expected decoded output for a single assistant turn based on the samples above
EXPECTED_DECODED_ASSISTANT_CHUNK = "A A A A A<|im_end|>\n"


class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
def test_data_collator_finds_response_template_llama2_tokenizer(self):
# this should ideally be tested with meta-llama/Llama-2-7b-hf
Expand Down Expand Up @@ -167,3 +285,216 @@ def test_data_collator_for_completion_only_lm(self):
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [[0, 6, 13]]) # idem
self.assertEqual(batch["max_length_k"], torch.tensor([7])) # max length in batch, here 7 (second sequence)
self.assertEqual(batch["max_length_q"], torch.tensor([7])) # idem

def test_masking_basic_multi_turn(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

instruction_template = "<|im_start|>user\n"
response_template = "<|im_start|>assistant\n"

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_template,
response_template=response_template,
tokenizer=tokenizer,
mlm=False,
)

conversations = [
CHATML_SAMPLE_BASIC_MULTI_TURN,
CHATML_SAMPLE_BASIC_MULTI_TURN,
] # Batch of 2 identical samples
tokenized = tokenizer(conversations, add_special_tokens=False)

# Prepare input for collator in the typical dictionary format
batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected output: 4 assistant turns per sample
expected_decoded_output = EXPECTED_DECODED_ASSISTANT_CHUNK * 4

# Check labels for each sample in the batch
for i in range(len(collated_batch["labels"])):
valid_indices = collated_batch["labels"][i] != -100
valid_labels = collated_batch["labels"][i][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
# strip potential leading/trailing whitespace artefacts from decode
self.assertEqual(
decoded_text.strip(), expected_decoded_output.strip(), f"Mismatch in decoded labels for sample {i}"
)

def test_masking_multi_role_multi_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

# Use a list for multiple instruction templates
instruction_template = ["<|im_start|>tool\n", "<|im_start|>user\n"]
response_template = "<|im_start|>assistant\n"

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_template,
response_template=response_template,
tokenizer=tokenizer,
mlm=False,
)

conversations = [CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN, CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE]
tokenized = tokenizer(conversations, add_special_tokens=False)

batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected outputs based on the number of assistant turns
expected_outputs = [
EXPECTED_DECODED_ASSISTANT_CHUNK * 8, # CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN has 8 assistant turns
EXPECTED_DECODED_ASSISTANT_CHUNK
* 8, # CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE has 8 assistant turns
]

# Check labels for each sample in the batch
self.assertEqual(len(collated_batch["labels"]), len(expected_outputs), "Batch size mismatch")

for i in range(len(collated_batch["labels"])):
valid_indices = collated_batch["labels"][i] != -100
valid_labels = collated_batch["labels"][i][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
self.assertEqual(
decoded_text.strip(), expected_outputs[i].strip(), f"Mismatch in decoded labels for sample {i}"
)

def test_masking_consecutive_assistant(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

instruction_template = ["<|im_start|>tool\n", "<|im_start|>user\n"]
response_template = "<|im_start|>assistant\n"

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_template,
response_template=response_template,
tokenizer=tokenizer,
mlm=False,
)

tokenized = tokenizer([CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_SIMPLE], add_special_tokens=False)
batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected: Only the content *after* the response_template should be unmasked for all assistant turns.
# The logic correctly handles consecutive turns by masking up to the *next* instruction or the end.
expected_decoded_output = (
"Assistant response 1.<|im_end|>\nAssistant response 2.<|im_end|>\nAssistant response 3.<|im_end|>\n"
)

valid_indices = collated_batch["labels"][0] != -100
valid_labels = collated_batch["labels"][0][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
self.assertEqual(
decoded_text.strip(),
expected_decoded_output.strip(),
"Mismatch in decoded labels for consecutive assistant test",
)

def test_masking_left_padding(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
# Explicitly set left padding
tokenizer.padding_side = "left"
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

instruction_template = ["<|im_start|>tool\n", "<|im_start|>user\n"]
response_template = "<|im_start|>assistant\n"

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_template,
response_template=response_template,
tokenizer=tokenizer,
mlm=False,
)

conversations = [CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN, CHATML_SAMPLE_BASIC_MULTI_TURN]
tokenized = tokenizer(conversations, add_special_tokens=False, padding=True, truncation=True, max_length=512)

batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected outputs based on the number of assistant turns in the specific samples used
expected_outputs = [
EXPECTED_DECODED_ASSISTANT_CHUNK * 8, # CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN
EXPECTED_DECODED_ASSISTANT_CHUNK * 4, # CHATML_SAMPLE_BASIC_MULTI_TURN
]

self.assertEqual(len(collated_batch["labels"]), len(expected_outputs), "Batch size mismatch")

for i in range(len(collated_batch["labels"])):
valid_indices = collated_batch["labels"][i] != -100
valid_labels = collated_batch["labels"][i][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
self.assertEqual(
decoded_text.strip(),
expected_outputs[i].strip(),
f"Mismatch in decoded labels for left padding, sample {i}",
)

def test_masking_tokenized_templates(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

# Pre-tokenize the templates
instruction_templates_str = ["<|im_start|>tool\n", "<|im_start|>user\n"]
response_template_str = "<|im_start|>assistant\n"

instruction_token_ids = [
tokenizer.encode(tmpl, add_special_tokens=False) for tmpl in instruction_templates_str
]
response_token_ids = tokenizer.encode(response_template_str, add_special_tokens=False)

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_token_ids, # Pass List[List[int]]
response_template=response_token_ids, # Pass List[int]
tokenizer=tokenizer,
mlm=False,
)

conversations = [CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN, CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE]
tokenized = tokenizer(conversations, add_special_tokens=False, padding=True, truncation=True, max_length=512)

batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected outputs based on the number of assistant turns
expected_outputs = [
EXPECTED_DECODED_ASSISTANT_CHUNK * 8,
EXPECTED_DECODED_ASSISTANT_CHUNK * 8,
]

self.assertEqual(len(collated_batch["labels"]), len(expected_outputs), "Batch size mismatch")

for i in range(len(collated_batch["labels"])):
valid_indices = collated_batch["labels"][i] != -100
valid_labels = collated_batch["labels"][i][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
self.assertEqual(
decoded_text.strip(),
expected_outputs[i].strip(),
f"Mismatch in decoded labels for tokenized templates, sample {i}",
)
Loading