Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions xtuner/v1/datasets/rl_tokenize_fn/qwen3_vl_tokenize_fn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import cast
import os
from typing import Any, cast

from xtuner.v1.data_proto.rl_data import RolloutState

from ...data_proto.rl_data import MultimodalInfo
from ..mllm_tokenize_fn.qwen3_vl_tokenize_fn import Qwen3VLTokenizeFnConfig, Qwen3VLTokenizeFunction, QwenVL3DataItem
from ..utils import replace_image_context_and_collect_media_data


def remove_consecutive_img_context_tokens(tokens: list[int], img_context_id: int) -> list[int]:
Expand All @@ -20,6 +20,39 @@ def remove_consecutive_img_context_tokens(tokens: list[int], img_context_id: int
return new_tokens


def replace_image_context_and_collect_media_data(
prompt: str | list[dict[str, Any]], media_root: str, replace_image_ctx: bool
) -> tuple:
"""Collect image data from the prompt and extra_info.

Args:
prompt (str): The input prompt containing image placeholders.
media_root (str): The root directory of the media files.
replace_image_ctx (bool): Whether to replace the image context in the prompt.

Returns:
List[dict]: A list of image data dictionaries.
"""
if not isinstance(prompt, list):
return [], []

image_paths = []
video_paths = []
for msg in prompt:
if msg["role"] == "user":
content = msg["content"]
if isinstance(content, list):
for c in content:
if c["type"] in ("image_url", "image"):
key = "image_url" if "image_url" in c else "image"
image_paths.append(os.path.join(media_root, c[key]["url"]))
elif c["type"] in ("video_url", "video"):
key = "video_url" if "video_url" in c else "video"
video_paths.append(os.path.join(media_root, c[key]["url"]))

return image_paths, video_paths


class RLQwen3VLTokenizeFunction(Qwen3VLTokenizeFunction):
def __init__(self, *args, ignore_multimodal_info: bool = False, data_judger_mapping: dict | None = None, **kwargs):
self.ignore_multimodal_info = ignore_multimodal_info
Expand Down Expand Up @@ -101,6 +134,7 @@ def hash(self) -> str:

class RLQwen3VLTokenizeFnConfig(Qwen3VLTokenizeFnConfig):
ignore_multimodal_info: bool = False # eval is True
data_judger_mapping: dict | None = None # {origin_data_source: mapped_judger_name_and_weight}

def build(
self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs
Expand All @@ -127,4 +161,5 @@ def build(
ignore_multimodal_info=self.ignore_multimodal_info,
add_generation_prompt=self.add_generation_prompt,
enable_thinking=self.enable_thinking,
data_judger_mapping=self.data_judger_mapping,
)
2 changes: 1 addition & 1 deletion xtuner/v1/loss/mtp_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _kl_loss_fn(
Called per-chunk in chunk mode, so tensors here may be a slice of the full sequence.
"""
from xtuner.v1.rl.loss_fn import kl_penalty
from xtuner.v1.rl.loss import kl_penalty
from xtuner.v1.rl.utils import gather_logprobs

logits = F.linear(hidden_states, head_weight, head_bias).float()
Expand Down
24 changes: 23 additions & 1 deletion xtuner/v1/rl/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import asyncio

from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
from xtuner.v1.rl.judger import Judger
from xtuner.v1.rl.rollout import RolloutController
from xtuner.v1.rl.utils import create_task

from .agent_loop import AgentLoop, AgentLoopConfig
from .utils import PartialRolloutHandler


class SingleTurnAgentLoopConfig(AgentLoopConfig):
enable_batch_judge: bool = False

def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "SingleTurnAgentLoop":
return SingleTurnAgentLoop(
rollout_ctl=rollout_controller,
sample_params=self.sample_params,
hf_checkpoint=self.hf_checkpoint,
judger=judger,
logger=logger,
enable_batch_judge=self.enable_batch_judge,
)


Expand All @@ -25,10 +31,12 @@ def __init__(
hf_checkpoint: str,
judger: Judger | None = None,
logger=None,
enable_batch_judge: bool = False,
):
super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger)
self.max_tokens = self.sample_params.max_tokens
self.partial_rollout_handler = PartialRolloutHandler(max_tokens=self.max_tokens)
self.enable_batch_judge = enable_batch_judge

async def generate_sample(
self,
Expand All @@ -49,6 +57,20 @@ async def generate_sample(
# 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分
if rollout_state.status != Status.COMPLETED:
return rollout_state
if self.judger is not None:
if self.judger is not None and not self.enable_batch_judge:
# 如果开启了批量打分,则在 generate_group 里统一打分,不在这里逐条打分
rollout_state = await self.judger.judge(rollout_state)
return rollout_state

async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
pending_tasks = []
for state in rollout_state:
state.sample_params = self.sample_params
task = create_task(self.generate_sample(state, **kwargs))
pending_tasks.append(task)
generated_samples = asyncio.gather(*pending_tasks)
group_samples = await generated_samples
if self.judger is not None and self.enable_batch_judge:
# 批量打分
group_samples = await self.judger.judge(group_samples)
return group_samples
41 changes: 8 additions & 33 deletions xtuner/v1/rl/agent_loop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import ray

from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.utils import clear_rollout_response_for_rerun, free_object_refs
from xtuner.v1.rl.utils import clear_rollout_response_for_rerun
from xtuner.v1.utils import get_logger


logger = get_logger()


def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef) -> list[int]:
if isinstance(routed_experts, ray.ObjectRef):
routed_experts = ray.get(routed_experts)
Expand All @@ -24,6 +21,7 @@ class PartialRolloutHandler:
continuation."""

def __init__(self, max_tokens: int) -> None:
self.logger = get_logger(self.__class__.__name__)
self.max_tokens = max_tokens

def preprocess(self, rollout_state: RolloutState, enable_partial_rollout: bool = False) -> RolloutState:
Expand All @@ -50,7 +48,7 @@ def preprocess(self, rollout_state: RolloutState, enable_partial_rollout: bool =
remaining_tokens = self.max_tokens - response_len # compute remaining max_tokens budget
rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": remaining_tokens})

logger.debug(
self.logger.debug(
f"[PartialRolloutHandler] Sample {rollout_state.uid} continue rollout | Remaining tokens allowed: {remaining_tokens} | Status: {rollout_state.status} | Prompt len: {prompt_len} | Response len: {response_len} | Staleness: {rollout_state.seq_staleness} | Total tokens: {len(rollout_state.tokens)}"
)
# TODO: handle routed_experts
Expand Down Expand Up @@ -88,39 +86,16 @@ def postprocess(self, rollout_state: RolloutState) -> RolloutState:
)
cur_routed_experts = cur_routed_experts[history_routed_experts_len:]
concat_routed_experts = history_routed_experts + cur_routed_experts

prompt_ids = rollout_state.prompt_ids or []
response_ids = rollout_state.response_ids or []
expect_tokens_num = len(prompt_ids) + len(response_ids) - 1
assert len(concat_routed_experts) == expect_tokens_num, (
f"After concatenation, routed_experts len: {len(concat_routed_experts)}, expected tokens num: {expect_tokens_num}, prompt len: {len(prompt_ids)}, response len: {len(response_ids)}, history routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}"
)
logger.info(
f"[PartialRolloutHandler] Postprocess rollout {rollout_state.uid}: "
f"concat routed_experts len={len(concat_routed_experts)} "
f"(history={history_routed_experts_len}, new={cur_routed_experts_len}), "
f"prompt={len(prompt_ids)}, response={len(response_ids)}"
)
rollout_state.routed_experts = ray.put(concat_routed_experts)
free_object_refs(
[ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)]
)
# free_object_refs(
# [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)]
# )
end_time = time.time()
logger.info(
self.logger.info(
f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds"
)
elif history_routed_experts_ref is None and cur_routed_experts_ref is not None:
prompt_ids = rollout_state.prompt_ids or []
response_ids = rollout_state.response_ids or []
expect_tokens_num = len(prompt_ids) + len(response_ids) - 1
cur_routed_experts_data = ray.get(cur_routed_experts_ref)
if cur_routed_experts_data.shape[0] != expect_tokens_num:
logger.warning(
f"Routed experts shape {cur_routed_experts_data.shape} does not match total tokens {expect_tokens_num}, maybe due to some error in the model side. We will try to truncate the routed experts to match the tokens, but please check if there is any error in the model side."
)
cur_routed_experts_data = cur_routed_experts_data[:expect_tokens_num, :, :]
routed_experts = ray.put(cur_routed_experts_data)
rollout_state.routed_experts = routed_experts
rollout_state.routed_experts = cur_routed_experts_ref
elif history_routed_experts_ref is not None and cur_routed_experts_ref is None:
rollout_state.routed_experts = history_routed_experts_ref

Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/rl/agent_loop_manager/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def _timed_generate_group(


def default_is_valid_sample_fn(samples: list[RolloutState]) -> bool:
return all(sample.status == Status.COMPLETED for sample in samples)
return True


def default_should_continue_fn(completed_count: int, batch_size: int, **kwargs) -> bool:
Expand Down
79 changes: 61 additions & 18 deletions xtuner/v1/rl/judger/composed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from copy import deepcopy
from typing import Callable, TypeAlias

from pydantic import BaseModel, ConfigDict, Field
Expand All @@ -12,7 +13,10 @@

SelectedJudgerKeys: TypeAlias = str | list[str] | None
JudgerSelectFn: TypeAlias = Callable[[RolloutState, dict[str, Judger]], SelectedJudgerKeys]
JudgerMergeFn: TypeAlias = Callable[[RolloutState, dict[str, RolloutState]], RolloutState]
JudgerMergeFn: TypeAlias = Callable[
[RolloutState | list[RolloutState], dict[str, RolloutState | list[RolloutState]]],
RolloutState | list[RolloutState],
]


def default_select_fn(rollout_state: RolloutState, branches: dict[str, Judger]) -> SelectedJudgerKeys:
Expand All @@ -32,26 +36,61 @@ def default_select_fn(rollout_state: RolloutState, branches: dict[str, Judger])
return None


def default_merge_fn(original: RolloutState, judged: dict[str, RolloutState]) -> RolloutState:
def default_merge_fn(
original: RolloutState | list[RolloutState],
judged: dict[str, RolloutState | list[RolloutState]],
) -> RolloutState | list[RolloutState]:
"""Default merger for ``ComposedJudgerConfig``.

This merger intentionally does not combine multiple judger scores into a single aggregated value.
It writes the merged reward as ``{branch_name: score}``, where ``branch_name`` is the selected
key from ``ComposedJudgerConfig.branches`` and ``score`` is taken from each child judger's
``reward["score"]``.

Supports both single ``RolloutState`` and batched ``list[RolloutState]`` inputs. In the batch
case, each element in the list represents a different response to the same prompt, and each
branch's judged result must be a list of the same length.

Users who need weighted sums, richer reward payloads, or custom post-processing should provide
their own ``merge_fn``.
"""
merged = original.model_copy(deep=True)
merged.reward = {}

for name, state in judged.items():
if state.reward is None or "score" not in state.reward:
raise KeyError(f"Default merge_fn requires reward['score'] for branch {name!r}.")
merged.reward[name] = state.reward["score"]

return merged
if isinstance(original, list):
for name, state in judged.items():
if not isinstance(state, list):
raise TypeError(
f"default_merge_fn: branch {name!r} returned a single RolloutState "
"but original is a list. All branches must return lists when input is a list."
)
if len(state) != len(original):
raise ValueError(
f"default_merge_fn: branch {name!r} returned {len(state)} states "
f"but original has {len(original)} states."
)
results: list[RolloutState] = []
for i, orig in enumerate(original):
merged = orig.model_copy(deep=True)
merged.reward = {}
for name, states in judged.items():
assert isinstance(states, list)
state_i: RolloutState = states[i]
reward = state_i.reward
if reward is None or "score" not in reward:
raise KeyError(f"Default merge_fn requires reward['score'] for branch {name!r}.")
merged.reward[name] = reward["score"]
results.append(merged)
return results
else:
merged = original.model_copy(deep=True)
merged.reward = {}
for name, state in judged.items():
if isinstance(state, list):
raise TypeError(
f"default_merge_fn: branch {name!r} returned a list but original is a single RolloutState."
)
if state.reward is None or "score" not in state.reward:
raise KeyError(f"Default merge_fn requires reward['score'] for branch {name!r}.")
merged.reward[name] = state.reward["score"]
return merged


class ComposedJudger(Judger):
Expand All @@ -69,8 +108,11 @@ def __init__(
self.merge_fn = merge_fn
self.default_key = default_key

def _resolve_selected_keys(self, rollout_state: RolloutState) -> list[str]:
selected = self.select_fn(rollout_state, self.branches)
def _resolve_selected_keys(self, rollout_state: RolloutState | list[RolloutState]) -> list[str]:
if isinstance(rollout_state, list):
selected = self.select_fn(rollout_state[0], self.branches)
else:
selected = self.select_fn(rollout_state, self.branches)

if selected is None:
selected_keys: list[str] = []
Expand All @@ -84,20 +126,21 @@ def _resolve_selected_keys(self, rollout_state: RolloutState) -> list[str]:
return [self.default_key]
if len(self.branches) == 1:
return [next(iter(self.branches))]
state = rollout_state[0] if isinstance(rollout_state, list) else rollout_state
raise KeyError(
f"ComposedJudger could not select a branch for task_name={rollout_state.task_name!r}, "
f"data_source={rollout_state.data_source!r}, available={sorted(self.branches)}"
f"ComposedJudger could not select a branch for task_name={state.task_name!r}, "
f"data_source={state.data_source!r}, available={sorted(self.branches)}"
)
return selected_keys

async def judge(self, rollout_state: RolloutState) -> RolloutState:
async def judge(self, rollout_state: RolloutState | list[RolloutState]) -> RolloutState | list[RolloutState]: # type: ignore[override]
selected_keys = self._resolve_selected_keys(rollout_state)

judged: dict[str, RolloutState] = {}
judged: dict[str, RolloutState | list[RolloutState]] = {}
for key in selected_keys:
if key not in self.branches:
raise KeyError(f"Unknown judger branch: {key}, available={sorted(self.branches)}")
judged[key] = await self.branches[key].judge(rollout_state.model_copy(deep=True))
judged[key] = await self.branches[key].judge(deepcopy(rollout_state))
return self.merge_fn(rollout_state, judged)


Expand Down
Loading
Loading