Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding EOS Tokens to Qwen Models #2512

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 94 additions & 56 deletions torchtune/models/qwen2/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,26 +325,85 @@ def decode(
text = "".join(sub_texts)
return text

def _tokenize_header(self, message: Message) -> List[int]:
return (
[self.im_start_id]
+ self.encode(f"{message.role}\n", add_bos=False, add_eos=False)
)

def _tokenize_body(self, message: Message) -> List[int]:
tokenized_body = []
for item in message.content:
if item["type"] == "text":
tokenized_body += self.encode(
item["content"], add_bos=False, add_eos=False,
)
else:
raise RuntimeError(f"Unsupported message content type: {item['type']}")
return tokenized_body

def _tokenize_end(self, message: Message) -> List[int]:
return (
[self.im_end_id]
+self.encode("\n", add_bos=False, add_eos=False)
)


def tokenize_message(
self,
message: Message,
*,
add_start_tokens: bool = True,
add_end_tokens: bool = True
) -> List[int]:
"""
Tokenize a message into a list of token ids.

Args:
message (Message): The message to tokenize.
add_start_tokens (bool): Whether to prepend a tokenized header to the message. Default is True.
add_end_tokens (bool): Whether to append eot or eom id at the end of the message. Default is True.

Returns:
List[int]: The list of token ids.
"""
tokenized_header = self._tokenize_header(message) if add_start_tokens else []
tokenized_body = self._tokenize_body(message)
tokenized_end = self._tokenize_end(message) if add_end_tokens else []

tokenized_message = tokenized_header + tokenized_body + tokenized_end
return tokenized_message

def tokenize_messages(
self,
messages: List[Message],
*,
add_eos: bool = True,
add_end_tokens: bool = True,
) -> Tuple[List[int], List[bool]]:
"""
Given a list of messages, return a list of tokens for the concatenated
and formatted messages.
Tokenize a list of messages into a list of token ids and masks.

Args:
messages (List[Message]): The message list to tokenize.
add_eos (bool): Wether to add the tokenizer's eos_id at the end of the
sequence of messages. Default is True.
messages (List[Message]): The list of messages to tokenize.
add_end_tokens (bool): Whether to append end tokens ids (end-of-seq, end-of-turn, end-of-message) at the end of the
last assistant message. This value should be set to False for generation. Default is True.

Examples:
>>> # Tokenize a list of messages with default settings
>>> messages = [
... Message(role="user", content="Hello world!", masked=True),
... Message(role="assistant", content="How are you?", masked=False),
... ]
>>> tokenizer = Qwen2Tokenizer("/path/to/tt_model")
>>> tokenizer.tokenize_messages(messages)
([1, 31587, 29644, 102, 1, 31587, 29644, 102, 2], [True, True, True, True, True, False, False, False, True])

>>> # Tokenize a list of messages with add_end_tokens set to False
>>> tokenizer.tokenize_messages(messages, add_end_tokens=False)
([1, 31587, 29644, 102, 1, 31587, 29644], [True, True, True, True, True, False, False])

Returns:
Tuple[List[int], List[bool]]: The list of token ids and the list of masks.

Raises:
RuntimeError: If a message contains non-text content
"""
assert not isinstance(self.prompt_template, ChatMLTemplate), (
"Using ChatMLTemplate with tokenize_messages will result in multiple <|im_*|> tokens wrapping each message."
Expand All @@ -355,69 +414,48 @@ def tokenize_messages(
if self.prompt_template is not None
else messages
)
tokens = [self.bos_id]
# bos and eos are always masked
mask = [True]

num_messages = len(templated_messages)
for i, message in enumerate(templated_messages):
# Add end tokens to the last assistant message if add_end_tokens is True
# Otherwise, end tokens should always be added
add_end_tokens_to_message = (
add_end_tokens if i == num_messages - 1 else True
)
tokenized_message = self.tokenize_message(
message, add_end_tokens=add_end_tokens_to_message
)

tokenized_messages = []
mask = []
for index, message in enumerate(templated_messages):
tokens = []

# message header
if message.role != "ipython":
tokens.append(self.im_start_id)
tokens.extend(
self.encode(f"{message.role}\n", add_bos=False, add_eos=False)
)

# message content
for item in message.content:
if item["type"] == "text":
tokens.extend(
self.encode(
item["content"],
add_bos=False,
add_eos=False,
)
)
else:
raise RuntimeError(
f"Unsupported message content type: {item['type']}"
)

# message footer
if message.role != "ipython" and (
message.role != "assistant" or index != len(messages) - 1
):
tokens.append(self.im_end_id)
tokens.extend(self.encode("\n", add_bos=False, add_eos=False))

tokenized_messages.extend(tokens)
mask.extend([message.masked] * len(tokens))
tokens = tokens + tokenized_message
mask = mask + ([message.masked] * len(tokenized_message))

# Break out early if we reach max_seq_len
if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len:
if self.max_seq_len and len(tokens) >= self.max_seq_len:
break

# Add the End-Of-Sequence token
if add_eos:
tokenized_messages.append(self.eos_id)
mask.append(mask[-1])
if add_end_tokens:
tokens = tokens + [self.eos_id]
mask = mask + [True]

# Finally, truncate if necessary
if self.max_seq_len:
tokenized_messages = truncate(
tokens=tokenized_messages,
tokens = truncate(
tokens=tokens,
max_seq_len=self.max_seq_len,
eos_id=self.eos_id if add_eos else None,
eos_id=self.eos_id if add_end_tokens else None,
truncation_type=self.truncation_type,
)
mask = truncate(
tokens=mask,
max_seq_len=self.max_seq_len,
eos_id=True if add_eos else None,
eos_id=True if add_end_tokens else None,
truncation_type=self.truncation_type,
)

return tokenized_messages, mask
return tokens, mask

def __call__(
self, sample: Mapping[str, Any], inference: bool = False
Expand All @@ -436,7 +474,7 @@ def __call__(
inference (bool): Whether the template is being used for inference or not.
"""
messages = sample.pop("messages")
tokens, mask = self.tokenize_messages(messages)
tokens, mask = self.tokenize_messages(messages, add_end_tokens=not inference)
sample["tokens"] = tokens
sample["mask"] = mask
return sample