diff --git a/slime/rollout/sft_rollout.py b/slime/rollout/sft_rollout.py index 6b914a964..67d33e383 100644 --- a/slime/rollout/sft_rollout.py +++ b/slime/rollout/sft_rollout.py @@ -48,6 +48,12 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False): token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools) + max_len = getattr(args, "rollout_max_context_len", None) + if max_len is not None and len(token_ids) > max_len: + logger.warning(f"sft_rollout: truncating sequence from {len(token_ids)} to {max_len} tokens") + token_ids = token_ids[:max_len] + loss_mask = loss_mask[:max_len] + response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0] sample.tokens = token_ids diff --git a/slime/utils/mask_utils.py b/slime/utils/mask_utils.py index 0ddb3a141..75a025c37 100644 --- a/slime/utils/mask_utils.py +++ b/slime/utils/mask_utils.py @@ -36,11 +36,15 @@ def get_system_message_length(self) -> tuple[int, int]: chat_template_token_ids = self.tokenizer(chat_template_token, add_special_tokens=False)["input_ids"] idx_1, idx_2 = self.find_all_sublist_indices(chat_template_token_ids, raw_token_ids) end_interval = len(chat_template_token_ids) - len(raw_token_ids) - idx_2 - gen_token_length = len( - self.tokenizer.apply_chat_template( - test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True - ) - ) - len(chat_template_token_ids) + + gen_prompt_token_ids = self.tokenizer.apply_chat_template( + test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True + ) + # Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True + if not isinstance(gen_prompt_token_ids, list): + gen_prompt_token_ids = gen_prompt_token_ids["input_ids"] + + gen_token_length = len(gen_prompt_token_ids) - len(chat_template_token_ids) system_message_length = idx_1 - ((idx_2 - idx_1) - end_interval - len(raw_token_ids)) return system_message_length, gen_token_length @@ -57,6 +61,10 @@ def gen_multi_turn_loss_mask_qwen( else: message_ids = self.tokenizer.apply_chat_template([message], tokenize=True) + # Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True + if not isinstance(message_ids, list): + message_ids = message_ids["input_ids"] + if message["role"] != "system" and i > 0: message_ids = message_ids[self.system_message_length :] @@ -82,14 +90,24 @@ def gen_multi_turn_loss_mask_qwen3( prefix_message = {"role": "user", "content": "FOR CALCULATING LOSS MASK ONLY"} prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True) + # Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True + if not isinstance(prefix_token_ids, list): + prefix_token_ids = prefix_token_ids["input_ids"] + for i, message in enumerate(messages): if i == 0: tailed_message_ids = self.tokenizer.apply_chat_template( [message, prefix_message], tokenize=True, tools=tools ) + # Handle transformers 5.2.0+ API change + if not isinstance(tailed_message_ids, list): + tailed_message_ids = tailed_message_ids["input_ids"] message_ids = tailed_message_ids[: -len(prefix_token_ids)] else: prefixed_message_ids = self.tokenizer.apply_chat_template([prefix_message, message], tokenize=True) + # Handle transformers 5.2.0+ API change + if not isinstance(prefixed_message_ids, list): + prefixed_message_ids = prefixed_message_ids["input_ids"] message_ids = prefixed_message_ids[len(prefix_token_ids) :] if message["role"] != "system" and i > 0: @@ -181,3 +199,30 @@ def get_text_from_loss_mask(self, token_ids: list[int], loss_masks: list[int]) - selected_texts.append(self.tokenizer.decode(current_tokens)) return selected_texts + + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("/root/Qwen3.5-35B-A3B") + mask_utils = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="qwen3") + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "reasoning_content": "hihi", "content": "hello"}, + ] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string", "description": "The city to get the weather for"}}, + }, + }, + } + ] + token_ids, loss_mask = mask_utils.get_loss_mask(messages, tools=tools) + for i in range(len(token_ids)): + print(f"Token: {repr(tokenizer.decode(token_ids[i]))}, Loss Mask: {loss_mask[i]}") + + print(tokenizer.apply_chat_template(messages, tokenize=False, tools=tools)) diff --git a/slime/utils/profile_utils.py b/slime/utils/profile_utils.py index 504d1ce86..4be4136cb 100644 --- a/slime/utils/profile_utils.py +++ b/slime/utils/profile_utils.py @@ -1,4 +1,5 @@ import logging +import os import time import traceback from pathlib import Path @@ -58,6 +59,11 @@ def _profile_simple_loop(iterator, args, name): def _create_torch_profiler(args, name): + tensorboard_dir = args.tensorboard_dir + if tensorboard_dir is not None: + tensorboard_dir = str(Path(tensorboard_dir).resolve()) + os.makedirs(tensorboard_dir, exist_ok=True) + return torch.profiler.profile( schedule=torch.profiler.schedule( # TODO the train_actor and train_log_probs ones may need to have different args to control step @@ -67,7 +73,7 @@ def _create_torch_profiler(args, name): repeat=1, ), on_trace_ready=torch.profiler.tensorboard_trace_handler( - args.tensorboard_dir, + tensorboard_dir, worker_name=f"{name}_rank_{torch.distributed.get_rank()}", use_gzip=True, ), @@ -88,8 +94,10 @@ def create(args): return c(args) def __init__(self, args): + snapshot_dir = Path(args.memory_snapshot_dir).resolve() + snapshot_dir.mkdir(parents=True, exist_ok=True) self._path_dump = ( - Path(args.memory_snapshot_dir) + snapshot_dir / f"memory_snapshot_time{time.time()}_rank{torch.distributed.get_rank()}_{args.memory_snapshot_path}" )