@@ -62,6 +62,30 @@ def gen_multi_turn_loss_mask_qwen(self, messages: List[Dict]) -> Tuple[List[int]
6262
6363 return all_token_ids , all_loss_masks
6464
65+ def gen_multi_turn_loss_mask_qwen3 (self , messages : List [Dict ]) -> Tuple [List [int ], List [int ]]:
66+ all_loss_masks = []
67+ all_token_ids = []
68+
69+ prefix_message = {"role" : "user" , "content" : "FOR CALCULATING LOSS MASK ONLY" }
70+ prefix_token_ids = self .tokenizer .apply_chat_template ([prefix_message ], tokenize = True )
71+
72+ for i , message in enumerate (messages ):
73+ prefixed_message_ids = self .tokenizer .apply_chat_template ([prefix_message , message ], tokenize = True )
74+ message_ids = prefixed_message_ids [len (prefix_token_ids ) :]
75+
76+ if message ["role" ] != "system" and i > 0 :
77+ message_ids = message_ids [self .system_message_length :]
78+
79+ if message ["role" ] == "assistant" :
80+ loss_mask = [0 ] * self .gen_token_length + [1 ] * (len (message_ids ) - self .gen_token_length )
81+ else :
82+ loss_mask = [0 ] * len (message_ids )
83+
84+ all_loss_masks .extend (loss_mask )
85+ all_token_ids .extend (message_ids )
86+
87+ return all_token_ids , all_loss_masks
88+
6589 def gen_multi_turn_loss_mask_distill_qwen (self , messages : List [Dict ]) -> Tuple [List [int ], List [int ]]:
6690 prompt = self .tokenizer .apply_chat_template (messages [:1 ], tokenize = False , add_generation_prompt = True )
6791 response = messages [- 1 ]["content" ]
@@ -79,6 +103,8 @@ def get_loss_mask(self, messages: List[Dict]) -> List[int]:
79103 return self .gen_multi_turn_loss_mask_distill_qwen (messages )
80104
81105 return self .gen_multi_turn_loss_mask_qwen (messages )
106+ elif self .tokenizer_type == "qwen3" :
107+ return self .gen_multi_turn_loss_mask_qwen3 (messages )
82108 elif self .tokenizer_type == "distill_qwen" :
83109 return self .gen_multi_turn_loss_mask_distill_qwen (messages )
84110 else :
0 commit comments