Skip to content

Commit d116da6

Browse files
authored
fix incorrect sft loss mask for qwen3 thinking series models. (THUDM#330)
* fix incorrect sft loss mask for qwen3 thinking series models. * Merge Qwen3MultiTurnLossMaskGenerator into MultiTurnLossMaskGenerator
1 parent e35f933 commit d116da6

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

slime/utils/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ def add_rollout_buffer_arguments(parser):
831831
"--loss-mask-type",
832832
type=str,
833833
default="qwen",
834-
choices=["qwen", "distill_qwen"],
834+
choices=["qwen", "qwen3", "distill_qwen"],
835835
help="Loss mask type",
836836
)
837837
return parser

slime/utils/mask_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,30 @@ def gen_multi_turn_loss_mask_qwen(self, messages: List[Dict]) -> Tuple[List[int]
6262

6363
return all_token_ids, all_loss_masks
6464

65+
def gen_multi_turn_loss_mask_qwen3(self, messages: List[Dict]) -> Tuple[List[int], List[int]]:
66+
all_loss_masks = []
67+
all_token_ids = []
68+
69+
prefix_message = {"role": "user", "content": "FOR CALCULATING LOSS MASK ONLY"}
70+
prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True)
71+
72+
for i, message in enumerate(messages):
73+
prefixed_message_ids = self.tokenizer.apply_chat_template([prefix_message, message], tokenize=True)
74+
message_ids = prefixed_message_ids[len(prefix_token_ids) :]
75+
76+
if message["role"] != "system" and i > 0:
77+
message_ids = message_ids[self.system_message_length :]
78+
79+
if message["role"] == "assistant":
80+
loss_mask = [0] * self.gen_token_length + [1] * (len(message_ids) - self.gen_token_length)
81+
else:
82+
loss_mask = [0] * len(message_ids)
83+
84+
all_loss_masks.extend(loss_mask)
85+
all_token_ids.extend(message_ids)
86+
87+
return all_token_ids, all_loss_masks
88+
6589
def gen_multi_turn_loss_mask_distill_qwen(self, messages: List[Dict]) -> Tuple[List[int], List[int]]:
6690
prompt = self.tokenizer.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True)
6791
response = messages[-1]["content"]
@@ -79,6 +103,8 @@ def get_loss_mask(self, messages: List[Dict]) -> List[int]:
79103
return self.gen_multi_turn_loss_mask_distill_qwen(messages)
80104

81105
return self.gen_multi_turn_loss_mask_qwen(messages)
106+
elif self.tokenizer_type == "qwen3":
107+
return self.gen_multi_turn_loss_mask_qwen3(messages)
82108
elif self.tokenizer_type == "distill_qwen":
83109
return self.gen_multi_turn_loss_mask_distill_qwen(messages)
84110
else:

0 commit comments

Comments
 (0)