@@ -36,11 +36,15 @@ def get_system_message_length(self) -> tuple[int, int]:
3636 chat_template_token_ids = self .tokenizer (chat_template_token , add_special_tokens = False )["input_ids" ]
3737 idx_1 , idx_2 = self .find_all_sublist_indices (chat_template_token_ids , raw_token_ids )
3838 end_interval = len (chat_template_token_ids ) - len (raw_token_ids ) - idx_2
39- gen_token_length = len (
40- self .tokenizer .apply_chat_template (
41- test_messages , add_special_tokens = False , tokenize = True , add_generation_prompt = True
42- )
43- ) - len (chat_template_token_ids )
39+
40+ gen_prompt_token_ids = self .tokenizer .apply_chat_template (
41+ test_messages , add_special_tokens = False , tokenize = True , add_generation_prompt = True
42+ )
43+ # Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True
44+ if not isinstance (gen_prompt_token_ids , list ):
45+ gen_prompt_token_ids = gen_prompt_token_ids ["input_ids" ]
46+
47+ gen_token_length = len (gen_prompt_token_ids ) - len (chat_template_token_ids )
4448
4549 system_message_length = idx_1 - ((idx_2 - idx_1 ) - end_interval - len (raw_token_ids ))
4650 return system_message_length , gen_token_length
@@ -57,6 +61,10 @@ def gen_multi_turn_loss_mask_qwen(
5761 else :
5862 message_ids = self .tokenizer .apply_chat_template ([message ], tokenize = True )
5963
64+ # Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True
65+ if not isinstance (message_ids , list ):
66+ message_ids = message_ids ["input_ids" ]
67+
6068 if message ["role" ] != "system" and i > 0 :
6169 message_ids = message_ids [self .system_message_length :]
6270
@@ -81,15 +89,25 @@ def gen_multi_turn_loss_mask_qwen3(
8189
8290 prefix_message = {"role" : "user" , "content" : "FOR CALCULATING LOSS MASK ONLY" }
8391 prefix_token_ids = self .tokenizer .apply_chat_template ([prefix_message ], tokenize = True )
92+
93+ # Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True
94+ if not isinstance (prefix_token_ids , list ):
95+ prefix_token_ids = prefix_token_ids ["input_ids" ]
8496
8597 for i , message in enumerate (messages ):
8698 if i == 0 :
8799 tailed_message_ids = self .tokenizer .apply_chat_template (
88100 [message , prefix_message ], tokenize = True , tools = tools
89101 )
102+ # Handle transformers 5.2.0+ API change
103+ if not isinstance (tailed_message_ids , list ):
104+ tailed_message_ids = tailed_message_ids ["input_ids" ]
90105 message_ids = tailed_message_ids [: - len (prefix_token_ids )]
91106 else :
92107 prefixed_message_ids = self .tokenizer .apply_chat_template ([prefix_message , message ], tokenize = True )
108+ # Handle transformers 5.2.0+ API change
109+ if not isinstance (prefixed_message_ids , list ):
110+ prefixed_message_ids = prefixed_message_ids ["input_ids" ]
93111 message_ids = prefixed_message_ids [len (prefix_token_ids ) :]
94112
95113 if message ["role" ] != "system" and i > 0 :
@@ -181,3 +199,12 @@ def get_text_from_loss_mask(self, token_ids: list[int], loss_masks: list[int]) -
181199 selected_texts .append (self .tokenizer .decode (current_tokens ))
182200
183201 return selected_texts
202+
203+ if __name__ == "__main__" :
204+ tokenizer = AutoTokenizer .from_pretrained ("/workspace/Qwen3.5-35B-A3B" )
205+ mask_utils = MultiTurnLossMaskGenerator (tokenizer , tokenizer_type = "qwen" )
206+ messages = [{"role" : "user" , "content" : "hi" }, {"role" : "assistant" , "content" : "hello" }]
207+ tools = [{"type" : "function" , "function" : {"name" : "get_weather" , "description" : "Get the weather" , "parameters" : {"type" : "object" , "properties" : {"city" : {"type" : "string" , "description" : "The city to get the weather for" }}}}}]
208+ token_ids , loss_mask = mask_utils .get_loss_mask (messages , tools = tools )
209+ for i in range (len (token_ids )):
210+ print (f'Token: { repr (tokenizer .decode (token_ids [i ]))} , Loss Mask: { loss_mask [i ]} ' )
0 commit comments