diff --git a/torchtune/models/qwen2/_tokenizer.py b/torchtune/models/qwen2/_tokenizer.py index 2b60a381a1..0b7146f49f 100644 --- a/torchtune/models/qwen2/_tokenizer.py +++ b/torchtune/models/qwen2/_tokenizer.py @@ -325,26 +325,85 @@ def decode( text = "".join(sub_texts) return text + def _tokenize_header(self, message: Message) -> List[int]: + return ( + [self.im_start_id] + + self.encode(f"{message.role}\n", add_bos=False, add_eos=False) + ) + + def _tokenize_body(self, message: Message) -> List[int]: + tokenized_body = [] + for item in message.content: + if item["type"] == "text": + tokenized_body += self.encode( + item["content"], add_bos=False, add_eos=False, + ) + else: + raise RuntimeError(f"Unsupported message content type: {item['type']}") + return tokenized_body + + def _tokenize_end(self, message: Message) -> List[int]: + return ( + [self.im_end_id] + +self.encode("\n", add_bos=False, add_eos=False) + ) + + + def tokenize_message( + self, + message: Message, + *, + add_start_tokens: bool = True, + add_end_tokens: bool = True + ) -> List[int]: + """ + Tokenize a message into a list of token ids. + + Args: + message (Message): The message to tokenize. + add_start_tokens (bool): Whether to prepend a tokenized header to the message. Default is True. + add_end_tokens (bool): Whether to append eot or eom id at the end of the message. Default is True. + + Returns: + List[int]: The list of token ids. + """ + tokenized_header = self._tokenize_header(message) if add_start_tokens else [] + tokenized_body = self._tokenize_body(message) + tokenized_end = self._tokenize_end(message) if add_end_tokens else [] + + tokenized_message = tokenized_header + tokenized_body + tokenized_end + return tokenized_message + def tokenize_messages( self, messages: List[Message], *, - add_eos: bool = True, + add_end_tokens: bool = True, ) -> Tuple[List[int], List[bool]]: """ - Given a list of messages, return a list of tokens for the concatenated - and formatted messages. + Tokenize a list of messages into a list of token ids and masks. Args: - messages (List[Message]): The message list to tokenize. - add_eos (bool): Wether to add the tokenizer's eos_id at the end of the - sequence of messages. Default is True. + messages (List[Message]): The list of messages to tokenize. + add_end_tokens (bool): Whether to append end tokens ids (end-of-seq, end-of-turn, end-of-message) at the end of the + last assistant message. This value should be set to False for generation. Default is True. + + Examples: + >>> # Tokenize a list of messages with default settings + >>> messages = [ + ... Message(role="user", content="Hello world!", masked=True), + ... Message(role="assistant", content="How are you?", masked=False), + ... ] + >>> tokenizer = Qwen2Tokenizer("/path/to/tt_model") + >>> tokenizer.tokenize_messages(messages) + ([1, 31587, 29644, 102, 1, 31587, 29644, 102, 2], [True, True, True, True, True, False, False, False, True]) + + >>> # Tokenize a list of messages with add_end_tokens set to False + >>> tokenizer.tokenize_messages(messages, add_end_tokens=False) + ([1, 31587, 29644, 102, 1, 31587, 29644], [True, True, True, True, True, False, False]) Returns: Tuple[List[int], List[bool]]: The list of token ids and the list of masks. - - Raises: - RuntimeError: If a message contains non-text content """ assert not isinstance(self.prompt_template, ChatMLTemplate), ( "Using ChatMLTemplate with tokenize_messages will result in multiple <|im_*|> tokens wrapping each message." @@ -355,69 +414,48 @@ def tokenize_messages( if self.prompt_template is not None else messages ) + tokens = [self.bos_id] + # bos and eos are always masked + mask = [True] + + num_messages = len(templated_messages) + for i, message in enumerate(templated_messages): + # Add end tokens to the last assistant message if add_end_tokens is True + # Otherwise, end tokens should always be added + add_end_tokens_to_message = ( + add_end_tokens if i == num_messages - 1 else True + ) + tokenized_message = self.tokenize_message( + message, add_end_tokens=add_end_tokens_to_message + ) - tokenized_messages = [] - mask = [] - for index, message in enumerate(templated_messages): - tokens = [] - - # message header - if message.role != "ipython": - tokens.append(self.im_start_id) - tokens.extend( - self.encode(f"{message.role}\n", add_bos=False, add_eos=False) - ) - - # message content - for item in message.content: - if item["type"] == "text": - tokens.extend( - self.encode( - item["content"], - add_bos=False, - add_eos=False, - ) - ) - else: - raise RuntimeError( - f"Unsupported message content type: {item['type']}" - ) - - # message footer - if message.role != "ipython" and ( - message.role != "assistant" or index != len(messages) - 1 - ): - tokens.append(self.im_end_id) - tokens.extend(self.encode("\n", add_bos=False, add_eos=False)) - - tokenized_messages.extend(tokens) - mask.extend([message.masked] * len(tokens)) + tokens = tokens + tokenized_message + mask = mask + ([message.masked] * len(tokenized_message)) # Break out early if we reach max_seq_len - if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len: + if self.max_seq_len and len(tokens) >= self.max_seq_len: break - # Add the End-Of-Sequence token - if add_eos: - tokenized_messages.append(self.eos_id) - mask.append(mask[-1]) + if add_end_tokens: + tokens = tokens + [self.eos_id] + mask = mask + [True] # Finally, truncate if necessary if self.max_seq_len: - tokenized_messages = truncate( - tokens=tokenized_messages, + tokens = truncate( + tokens=tokens, max_seq_len=self.max_seq_len, - eos_id=self.eos_id if add_eos else None, + eos_id=self.eos_id if add_end_tokens else None, truncation_type=self.truncation_type, ) mask = truncate( tokens=mask, max_seq_len=self.max_seq_len, - eos_id=True if add_eos else None, + eos_id=True if add_end_tokens else None, truncation_type=self.truncation_type, ) - return tokenized_messages, mask + return tokens, mask def __call__( self, sample: Mapping[str, Any], inference: bool = False @@ -436,7 +474,7 @@ def __call__( inference (bool): Whether the template is being used for inference or not. """ messages = sample.pop("messages") - tokens, mask = self.tokenize_messages(messages) + tokens, mask = self.tokenize_messages(messages, add_end_tokens=not inference) sample["tokens"] = tokens sample["mask"] = mask return sample