Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions slime/rollout/sft_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False):

token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools)

max_len = getattr(args, "rollout_max_context_len", None)
if max_len is not None and len(token_ids) > max_len:
logger.warning(f"sft_rollout: truncating sequence from {len(token_ids)} to {max_len} tokens")
token_ids = token_ids[:max_len]
loss_mask = loss_mask[:max_len]

response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0]

sample.tokens = token_ids
Expand Down
55 changes: 50 additions & 5 deletions slime/utils/mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ def get_system_message_length(self) -> tuple[int, int]:
chat_template_token_ids = self.tokenizer(chat_template_token, add_special_tokens=False)["input_ids"]
idx_1, idx_2 = self.find_all_sublist_indices(chat_template_token_ids, raw_token_ids)
end_interval = len(chat_template_token_ids) - len(raw_token_ids) - idx_2
gen_token_length = len(
self.tokenizer.apply_chat_template(
test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True
)
) - len(chat_template_token_ids)

gen_prompt_token_ids = self.tokenizer.apply_chat_template(
test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True
)
# Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True
if not isinstance(gen_prompt_token_ids, list):
gen_prompt_token_ids = gen_prompt_token_ids["input_ids"]

gen_token_length = len(gen_prompt_token_ids) - len(chat_template_token_ids)

system_message_length = idx_1 - ((idx_2 - idx_1) - end_interval - len(raw_token_ids))
return system_message_length, gen_token_length
Expand All @@ -57,6 +61,10 @@ def gen_multi_turn_loss_mask_qwen(
else:
message_ids = self.tokenizer.apply_chat_template([message], tokenize=True)

# Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True
if not isinstance(message_ids, list):
message_ids = message_ids["input_ids"]

if message["role"] != "system" and i > 0:
message_ids = message_ids[self.system_message_length :]

Expand All @@ -82,14 +90,24 @@ def gen_multi_turn_loss_mask_qwen3(
prefix_message = {"role": "user", "content": "FOR CALCULATING LOSS MASK ONLY"}
prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True)

# Handle transformers 5.2.0+ API change: apply_chat_template now returns dict when tokenize=True
if not isinstance(prefix_token_ids, list):
prefix_token_ids = prefix_token_ids["input_ids"]

for i, message in enumerate(messages):
if i == 0:
tailed_message_ids = self.tokenizer.apply_chat_template(
[message, prefix_message], tokenize=True, tools=tools
)
# Handle transformers 5.2.0+ API change
if not isinstance(tailed_message_ids, list):
tailed_message_ids = tailed_message_ids["input_ids"]
message_ids = tailed_message_ids[: -len(prefix_token_ids)]
else:
prefixed_message_ids = self.tokenizer.apply_chat_template([prefix_message, message], tokenize=True)
# Handle transformers 5.2.0+ API change
if not isinstance(prefixed_message_ids, list):
prefixed_message_ids = prefixed_message_ids["input_ids"]
message_ids = prefixed_message_ids[len(prefix_token_ids) :]

if message["role"] != "system" and i > 0:
Expand Down Expand Up @@ -181,3 +199,30 @@ def get_text_from_loss_mask(self, token_ids: list[int], loss_masks: list[int]) -
selected_texts.append(self.tokenizer.decode(current_tokens))

return selected_texts


if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("/root/Qwen3.5-35B-A3B")
mask_utils = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="qwen3")
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "reasoning_content": "hihi", "content": "hello"},
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string", "description": "The city to get the weather for"}},
},
},
}
]
token_ids, loss_mask = mask_utils.get_loss_mask(messages, tools=tools)
for i in range(len(token_ids)):
print(f"Token: {repr(tokenizer.decode(token_ids[i]))}, Loss Mask: {loss_mask[i]}")

print(tokenizer.apply_chat_template(messages, tokenize=False, tools=tools))
12 changes: 10 additions & 2 deletions slime/utils/profile_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import time
import traceback
from pathlib import Path
Expand Down Expand Up @@ -58,6 +59,11 @@ def _profile_simple_loop(iterator, args, name):


def _create_torch_profiler(args, name):
tensorboard_dir = args.tensorboard_dir
if tensorboard_dir is not None:
tensorboard_dir = str(Path(tensorboard_dir).resolve())
os.makedirs(tensorboard_dir, exist_ok=True)

return torch.profiler.profile(
schedule=torch.profiler.schedule(
# TODO the train_actor and train_log_probs ones may need to have different args to control step
Expand All @@ -67,7 +73,7 @@ def _create_torch_profiler(args, name):
repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
args.tensorboard_dir,
tensorboard_dir,
worker_name=f"{name}_rank_{torch.distributed.get_rank()}",
use_gzip=True,
),
Expand All @@ -88,8 +94,10 @@ def create(args):
return c(args)

def __init__(self, args):
snapshot_dir = Path(args.memory_snapshot_dir).resolve()
snapshot_dir.mkdir(parents=True, exist_ok=True)
self._path_dump = (
Path(args.memory_snapshot_dir)
snapshot_dir
/ f"memory_snapshot_time{time.time()}_rank{torch.distributed.get_rank()}_{args.memory_snapshot_path}"
)

Expand Down
Loading