Skip to content

Commit 242b073

Browse files
committed
fix transforrmers api change at 5.2.0
1 parent a2b16da commit 242b073

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

slime/utils/mask_utils.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,15 @@ def get_system_message_length(self) -> tuple[int, int]:
3636
chat_template_token_ids = self.tokenizer(chat_template_token, add_special_tokens=False)["input_ids"]
3737
idx_1, idx_2 = self.find_all_sublist_indices(chat_template_token_ids, raw_token_ids)
3838
end_interval = len(chat_template_token_ids) - len(raw_token_ids) - idx_2
39-
gen_token_length = len(
40-
self.tokenizer.apply_chat_template(
41-
test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True
42-
)
43-
) - len(chat_template_token_ids)
39+
40+
gen_prompt_token_ids = self.tokenizer.apply_chat_template(
41+
test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True
42+
)
43+
# Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True
44+
if not isinstance(gen_prompt_token_ids, list):
45+
gen_prompt_token_ids = gen_prompt_token_ids["input_ids"]
46+
47+
gen_token_length = len(gen_prompt_token_ids) - len(chat_template_token_ids)
4448

4549
system_message_length = idx_1 - ((idx_2 - idx_1) - end_interval - len(raw_token_ids))
4650
return system_message_length, gen_token_length
@@ -57,6 +61,10 @@ def gen_multi_turn_loss_mask_qwen(
5761
else:
5862
message_ids = self.tokenizer.apply_chat_template([message], tokenize=True)
5963

64+
# Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True
65+
if not isinstance(message_ids, list):
66+
message_ids = message_ids["input_ids"]
67+
6068
if message["role"] != "system" and i > 0:
6169
message_ids = message_ids[self.system_message_length :]
6270

@@ -81,15 +89,25 @@ def gen_multi_turn_loss_mask_qwen3(
8189

8290
prefix_message = {"role": "user", "content": "FOR CALCULATING LOSS MASK ONLY"}
8391
prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True)
92+
93+
# Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True
94+
if not isinstance(prefix_token_ids, list):
95+
prefix_token_ids = prefix_token_ids["input_ids"]
8496

8597
for i, message in enumerate(messages):
8698
if i == 0:
8799
tailed_message_ids = self.tokenizer.apply_chat_template(
88100
[message, prefix_message], tokenize=True, tools=tools
89101
)
102+
# Handle transformers 5.2.0+ API change
103+
if not isinstance(tailed_message_ids, list):
104+
tailed_message_ids = tailed_message_ids["input_ids"]
90105
message_ids = tailed_message_ids[: -len(prefix_token_ids)]
91106
else:
92107
prefixed_message_ids = self.tokenizer.apply_chat_template([prefix_message, message], tokenize=True)
108+
# Handle transformers 5.2.0+ API change
109+
if not isinstance(prefixed_message_ids, list):
110+
prefixed_message_ids = prefixed_message_ids["input_ids"]
93111
message_ids = prefixed_message_ids[len(prefix_token_ids) :]
94112

95113
if message["role"] != "system" and i > 0:
@@ -181,3 +199,12 @@ def get_text_from_loss_mask(self, token_ids: list[int], loss_masks: list[int]) -
181199
selected_texts.append(self.tokenizer.decode(current_tokens))
182200

183201
return selected_texts
202+
203+
if __name__ == "__main__":
204+
tokenizer = AutoTokenizer.from_pretrained("/workspace/Qwen3.5-35B-A3B")
205+
mask_utils = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="qwen")
206+
messages = [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}]
207+
tools = [{"type": "function", "function": {"name": "get_weather", "description": "Get the weather", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city to get the weather for"}}}}}]
208+
token_ids, loss_mask = mask_utils.get_loss_mask(messages, tools=tools)
209+
for i in range(len(token_ids)):
210+
print(f'Token: {repr(tokenizer.decode(token_ids[i]))}, Loss Mask: {loss_mask[i]}')

0 commit comments

Comments
 (0)