diff --git a/tests/test_data_collator_completion_only.py b/tests/test_data_collator_completion_only.py index 661a1dcf8f..a29df7e23d 100644 --- a/tests/test_data_collator_completion_only.py +++ b/tests/test_data_collator_completion_only.py @@ -20,6 +20,124 @@ from trl import DataCollatorForCompletionOnlyLM +# Define samples globally for reuse +CHATML_SAMPLE_BASIC_MULTI_TURN = """<|im_start|>system +system prompt system ptompt system prompt + +<|im_end|> +<|im_start|>user +U U U U U<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>user +U U U U U<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>user +T T T T T<|im_end|> +<|im_start|>user +T T T T T<|im_end|> +<|im_start|>user +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>user +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|>""" # 4 assistant turns + +CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN = """<|im_start|>system +system prompt system ptompt system prompt + +<|im_end|> +<|im_start|>user +U U U U U<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>user +U U U U U<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>user +U U U U U<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|>""" # 8 assistant turns + +CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE = """<|im_start|>system +system prompt system ptompt system prompt + +<|im_end|> +<|im_start|>user +U U U U U<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>user +U U U U U<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>user +U U U U U<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>tool +T T T T T<|im_end|> +<|im_start|>assistant +A A A A A<|im_end|>""" # 8 assistant turns + +CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_SIMPLE = """<|im_start|>system +Prompt. +<|im_end|> +<|im_start|>user +User query.<|im_end|> +<|im_start|>assistant +Assistant response 1.<|im_end|> +<|im_start|>assistant +Assistant response 2.<|im_end|> +<|im_start|>user +Another user query.<|im_end|> +<|im_start|>assistant +Assistant response 3.<|im_end|>""" # 3 assistant turns total, 2 consecutive + +# Expected decoded output for a single assistant turn based on the samples above +EXPECTED_DECODED_ASSISTANT_CHUNK = "A A A A A<|im_end|>\n" + + class DataCollatorForCompletionOnlyLMTester(unittest.TestCase): def test_data_collator_finds_response_template_llama2_tokenizer(self): # this should ideally be tested with meta-llama/Llama-2-7b-hf @@ -167,3 +285,216 @@ def test_data_collator_for_completion_only_lm(self): self.assertEqual(batch["cu_seq_lens_k"].tolist(), [[0, 6, 13]]) # idem self.assertEqual(batch["max_length_k"], torch.tensor([7])) # max length in batch, here 7 (second sequence) self.assertEqual(batch["max_length_q"], torch.tensor([7])) # idem + + def test_masking_basic_multi_turn(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + instruction_template = "<|im_start|>user\n" + response_template = "<|im_start|>assistant\n" + + data_collator = DataCollatorForCompletionOnlyLM( + instruction_template=instruction_template, + response_template=response_template, + tokenizer=tokenizer, + mlm=False, + ) + + conversations = [ + CHATML_SAMPLE_BASIC_MULTI_TURN, + CHATML_SAMPLE_BASIC_MULTI_TURN, + ] # Batch of 2 identical samples + tokenized = tokenizer(conversations, add_special_tokens=False) + + # Prepare input for collator in the typical dictionary format + batch_input = [ + {"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]} + for i in range(len(tokenized.input_ids)) + ] + collated_batch = data_collator(batch_input) + + # Expected output: 4 assistant turns per sample + expected_decoded_output = EXPECTED_DECODED_ASSISTANT_CHUNK * 4 + + # Check labels for each sample in the batch + for i in range(len(collated_batch["labels"])): + valid_indices = collated_batch["labels"][i] != -100 + valid_labels = collated_batch["labels"][i][valid_indices] + decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False) + # strip potential leading/trailing whitespace artefacts from decode + self.assertEqual( + decoded_text.strip(), expected_decoded_output.strip(), f"Mismatch in decoded labels for sample {i}" + ) + + def test_masking_multi_role_multi_template(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + # Use a list for multiple instruction templates + instruction_template = ["<|im_start|>tool\n", "<|im_start|>user\n"] + response_template = "<|im_start|>assistant\n" + + data_collator = DataCollatorForCompletionOnlyLM( + instruction_template=instruction_template, + response_template=response_template, + tokenizer=tokenizer, + mlm=False, + ) + + conversations = [CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN, CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE] + tokenized = tokenizer(conversations, add_special_tokens=False) + + batch_input = [ + {"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]} + for i in range(len(tokenized.input_ids)) + ] + collated_batch = data_collator(batch_input) + + # Expected outputs based on the number of assistant turns + expected_outputs = [ + EXPECTED_DECODED_ASSISTANT_CHUNK * 8, # CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN has 8 assistant turns + EXPECTED_DECODED_ASSISTANT_CHUNK + * 8, # CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE has 8 assistant turns + ] + + # Check labels for each sample in the batch + self.assertEqual(len(collated_batch["labels"]), len(expected_outputs), "Batch size mismatch") + + for i in range(len(collated_batch["labels"])): + valid_indices = collated_batch["labels"][i] != -100 + valid_labels = collated_batch["labels"][i][valid_indices] + decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False) + self.assertEqual( + decoded_text.strip(), expected_outputs[i].strip(), f"Mismatch in decoded labels for sample {i}" + ) + + def test_masking_consecutive_assistant(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + instruction_template = ["<|im_start|>tool\n", "<|im_start|>user\n"] + response_template = "<|im_start|>assistant\n" + + data_collator = DataCollatorForCompletionOnlyLM( + instruction_template=instruction_template, + response_template=response_template, + tokenizer=tokenizer, + mlm=False, + ) + + tokenized = tokenizer([CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_SIMPLE], add_special_tokens=False) + batch_input = [ + {"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]} + for i in range(len(tokenized.input_ids)) + ] + collated_batch = data_collator(batch_input) + + # Expected: Only the content *after* the response_template should be unmasked for all assistant turns. + # The logic correctly handles consecutive turns by masking up to the *next* instruction or the end. + expected_decoded_output = ( + "Assistant response 1.<|im_end|>\nAssistant response 2.<|im_end|>\nAssistant response 3.<|im_end|>\n" + ) + + valid_indices = collated_batch["labels"][0] != -100 + valid_labels = collated_batch["labels"][0][valid_indices] + decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False) + self.assertEqual( + decoded_text.strip(), + expected_decoded_output.strip(), + "Mismatch in decoded labels for consecutive assistant test", + ) + + def test_masking_left_padding(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + # Explicitly set left padding + tokenizer.padding_side = "left" + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + instruction_template = ["<|im_start|>tool\n", "<|im_start|>user\n"] + response_template = "<|im_start|>assistant\n" + + data_collator = DataCollatorForCompletionOnlyLM( + instruction_template=instruction_template, + response_template=response_template, + tokenizer=tokenizer, + mlm=False, + ) + + conversations = [CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN, CHATML_SAMPLE_BASIC_MULTI_TURN] + tokenized = tokenizer(conversations, add_special_tokens=False, padding=True, truncation=True, max_length=512) + + batch_input = [ + {"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]} + for i in range(len(tokenized.input_ids)) + ] + collated_batch = data_collator(batch_input) + + # Expected outputs based on the number of assistant turns in the specific samples used + expected_outputs = [ + EXPECTED_DECODED_ASSISTANT_CHUNK * 8, # CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN + EXPECTED_DECODED_ASSISTANT_CHUNK * 4, # CHATML_SAMPLE_BASIC_MULTI_TURN + ] + + self.assertEqual(len(collated_batch["labels"]), len(expected_outputs), "Batch size mismatch") + + for i in range(len(collated_batch["labels"])): + valid_indices = collated_batch["labels"][i] != -100 + valid_labels = collated_batch["labels"][i][valid_indices] + decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False) + self.assertEqual( + decoded_text.strip(), + expected_outputs[i].strip(), + f"Mismatch in decoded labels for left padding, sample {i}", + ) + + def test_masking_tokenized_templates(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + # Pre-tokenize the templates + instruction_templates_str = ["<|im_start|>tool\n", "<|im_start|>user\n"] + response_template_str = "<|im_start|>assistant\n" + + instruction_token_ids = [ + tokenizer.encode(tmpl, add_special_tokens=False) for tmpl in instruction_templates_str + ] + response_token_ids = tokenizer.encode(response_template_str, add_special_tokens=False) + + data_collator = DataCollatorForCompletionOnlyLM( + instruction_template=instruction_token_ids, # Pass List[List[int]] + response_template=response_token_ids, # Pass List[int] + tokenizer=tokenizer, + mlm=False, + ) + + conversations = [CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN, CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE] + tokenized = tokenizer(conversations, add_special_tokens=False, padding=True, truncation=True, max_length=512) + + batch_input = [ + {"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]} + for i in range(len(tokenized.input_ids)) + ] + collated_batch = data_collator(batch_input) + + # Expected outputs based on the number of assistant turns + expected_outputs = [ + EXPECTED_DECODED_ASSISTANT_CHUNK * 8, + EXPECTED_DECODED_ASSISTANT_CHUNK * 8, + ] + + self.assertEqual(len(collated_batch["labels"]), len(expected_outputs), "Batch size mismatch") + + for i in range(len(collated_batch["labels"])): + valid_indices = collated_batch["labels"][i] != -100 + valid_labels = collated_batch["labels"][i][valid_indices] + decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False) + self.assertEqual( + decoded_text.strip(), + expected_outputs[i].strip(), + f"Mismatch in decoded labels for tokenized templates, sample {i}", + ) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 7b20f7c290..26134ad314 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -77,19 +77,25 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): response_template (`Union[str, list[int]]`): the template form that indicates the start of the response, typically something like '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response differently if it does not have proper context. - instruction_template (`Union[str, list[int]]`): the template form that indicates the start of the human instruction, typically something like - '### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids. - mlm (`bool`, *optional*, defaults to `False`): Whether to use masked language modeling in the underlying + instruction_template (`Union[str, list[int], list[str]]`, *optional*, defaults to `None`): + The template form that indicates the start of the human instruction, typically something like + '### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids + or as a list of strings when multiple instruction templates need to be detected (useful for multi-turn conversations e.g. ["", "", ""]). + mlm (`bool`, *optional*, defaults to `False`): + Whether to use masked language modeling in the underlying `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present for flexibility and backwards-compatibility. ignore_index (`int`, *optional*, defaults to `-100`): The index to use to ignore the initial tokens with + padding_free (`bool`, *optional*, defaults to `False`): + Whether to use padding-free training. When set to True, padding tokens are removed and positional ids are + added to the inputs to enable proper attention. """ def __init__( self, response_template: Union[str, list[int]], - instruction_template: Optional[Union[str, list[int]]] = None, + instruction_template: Optional[Union[str, list[int], list[str]]] = None, *args, mlm: bool = False, ignore_index: int = -100, @@ -99,12 +105,27 @@ def __init__( super().__init__(*args, mlm=mlm, **kwargs) self.instruction_template = instruction_template + self.has_multiple_instruction_templates = False + if isinstance(instruction_template, str): # The user provides a string, must tokenize self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) + elif isinstance(instruction_template, list) and isinstance(instruction_template[0], str): + # The user provides a list of strings, must tokenize each template + self.instruction_token_ids = [] + for template in self.instruction_template: + self.instruction_token_ids.append(self.tokenizer.encode(template, add_special_tokens=False)) + self.has_multiple_instruction_templates = True else: # The user already provides the token ids self.instruction_token_ids = instruction_template + # Check if it's a list of lists (multiple templates) + if ( + isinstance(instruction_template, list) + and instruction_template + and isinstance(instruction_template[0], list) + ): + self.has_multiple_instruction_templates = True self.response_template = response_template if isinstance(response_template, str): @@ -129,6 +150,12 @@ def __init__( def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: batch = super().torch_call(examples) + sequence_lengths = (batch["input_ids"] != self.tokenizer.pad_token_id).sum(dim=1) + content_starts = ( + batch["input_ids"].shape[1] - sequence_lengths + if self.tokenizer.padding_side == "left" + else torch.zeros_like(sequence_lengths) + ) if self.instruction_template is None: for i in range(len(examples)): response_token_ids_start_idx = None @@ -157,8 +184,8 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d else: for i in range(len(examples)): - response_token_ids_idxs = [] - human_token_ids_idxs = [] + response_start_positions = [] + instruction_start_positions = [] for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: # find the indexes of the start of a response. @@ -166,9 +193,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d self.response_token_ids == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist() ): - response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids)) + response_start_positions.append(assistant_idx + len(self.response_token_ids)) - if len(response_token_ids_idxs) == 0: + if len(response_start_positions) == 0: warnings.warn( f"Could not find response key `{self.response_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " @@ -176,14 +203,34 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d UserWarning, ) batch["labels"][i, :] = self.ignore_index - - human_token_ids = self.instruction_token_ids - for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]: - # find the indexes of the start of a human answer. - if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist(): - human_token_ids_idxs.append(human_idx) - - if len(human_token_ids_idxs) == 0: + continue + + # Find all instruction token positions + if self.has_multiple_instruction_templates: + # Handle multiple instruction templates + for instruction_token_ids in self.instruction_token_ids: + for instruction_idx in np.where(batch["labels"][i] == instruction_token_ids[0])[0]: + if ( + instruction_token_ids + == batch["labels"][i][ + instruction_idx : instruction_idx + len(instruction_token_ids) + ].tolist() + ): + instruction_start_positions.append(instruction_idx) + instruction_start_positions = sorted(instruction_start_positions) + else: + instruction_token_ids = self.instruction_token_ids + for instruction_idx in np.where(batch["labels"][i] == instruction_token_ids[0])[0]: + # find the indexes of the start of an instruction. + if ( + instruction_token_ids + == batch["labels"][i][ + instruction_idx : instruction_idx + len(instruction_token_ids) + ].tolist() + ): + instruction_start_positions.append(instruction_idx) + + if len(instruction_start_positions) == 0: warnings.warn( f"Could not find instruction key `{self.instruction_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " @@ -191,23 +238,40 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d UserWarning, ) batch["labels"][i, :] = self.ignore_index - - if ( - len(human_token_ids_idxs) > 0 - and len(response_token_ids_idxs) > 0 - and human_token_ids_idxs[0] > response_token_ids_idxs[0] - ): - human_token_ids_idxs = [0] + human_token_ids_idxs - - for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): - # Make pytorch loss function ignore all non response tokens - if idx != 0: - batch["labels"][i, start:end] = self.ignore_index + continue + + # Mask everything first and we will unmask step by step + batch["labels"][i, :] = self.ignore_index + + # Unmask regions between each response and next instruction (or till end) + sequence_length = sequence_lengths[i].item() + content_start = content_starts[i].item() + last_processed_instruction_pos = -1 + for response_pos in response_start_positions: + # Find the first instruction position that comes after this response + next_instruction_pos = None + for instruction_pos in instruction_start_positions: + if instruction_pos > response_pos: + next_instruction_pos = instruction_pos + break + + # If no instruction position found after response, use sequence length from input_ids + if next_instruction_pos is None: + # Calculate actual sequence length using pad token positions + next_instruction_pos = content_start + sequence_length + + # Handle consecutive responses + if response_pos > last_processed_instruction_pos: + # Unmask from response start to instruction start (or end); base case + batch["labels"][i, response_pos:next_instruction_pos] = batch["input_ids"][ + i, response_pos:next_instruction_pos + ] + last_processed_instruction_pos = next_instruction_pos else: - batch["labels"][i, :end] = self.ignore_index - - if len(response_token_ids_idxs) < len(human_token_ids_idxs): - batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index + # 2 reponses in a row so we unmask the special tokens for response in the middle + batch["labels"][i, response_pos - len(self.response_token_ids) : response_pos] = ( + self.ignore_index + ) if self.padding_free: # remove padding, `attention_mask` and add `position_ids`