Skip to content

Commit 36184b7

Browse files
authored
Simplify RL Colocate Trainer initialization by using cfg.build (#1520)
* Build XtunerMeta and TrainController by cfg.build * Build RolloutController by cfg.build * simplify rl colocate trainer init * fix some lint errors * fix some bugs
1 parent ddf3ec3 commit 36184b7

18 files changed

Lines changed: 213 additions & 244 deletions

File tree

examples/v1/config/rl_grpo_gsm8k_judge.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from xtuner.v1.model import get_model_config_from_hf
1515
from xtuner.v1.ray.base import AcceleratorResourcesConfig
1616
from xtuner.v1.ray.config.worker import RolloutConfig
17-
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
17+
from xtuner.v1.ray.judger.gsm8k import GSM8KRouterJudgerConfig
18+
from xtuner.v1.rl.base.replay_buffer import SyncReplayBufferConfig
1819
from xtuner.v1.rl.base import WorkerConfig
1920
from xtuner.v1.rl.base.agent_loop import SingleTurnAgentLoopConfig
2021
from xtuner.v1.rl.base.agent_loop_manager import AgentLoopManagerConfig
@@ -67,7 +68,7 @@
6768
)
6869

6970
# 3. judger
70-
judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
71+
judger_config = GSM8KRouterJudgerConfig(judger_name="openai/gsm8k")
7172

7273
# 4. train worker
7374
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6)
@@ -181,7 +182,7 @@
181182
rollout_config=rollout_config,
182183
judger_config=judger_config,
183184
tokenizer_path=model_path,
184-
replay_buffer_config=dict(),
185+
replay_buffer_config=SyncReplayBufferConfig(),
185186
agent_loop_manager_cfg=agent_loop_manager_cfg,
186187
eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg,
187188
evaluator_config=evaluator_config,

xtuner/v1/data_proto/rl_data.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010
# ====================================
1111
# ====== DataFlow 数据流 ==============
1212
# ====================================
13+
from xtuner.v1.utils.cache import CacheObj
1314
from xtuner.v1.utils.logger import get_logger
1415

1516

1617
if TYPE_CHECKING:
17-
import ray
18-
19-
RayObjectRef = ray.ObjectRef
18+
from ray import ObjectRef as RayObjectRef
2019
else:
2120
RayObjectRef: TypeAlias = Any
2221

@@ -60,12 +59,12 @@ class Status(Enum):
6059

6160
class MultimodalInfo(TypedDict):
6261
# 使用TypedDict给出pixel_values的类型提示
63-
pixel_values: NotRequired[torch.Tensor | RayObjectRef | None] # type: ignore[valid-type]
62+
pixel_values: NotRequired[torch.Tensor | RayObjectRef | None]
6463
image_grid_thw: NotRequired[torch.Tensor]
6564
position_ids: NotRequired[torch.Tensor]
6665

6766

68-
class RolloutState(BaseModel):
67+
class RolloutState(CacheObj, BaseModel):
6968
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
7069

7170
# --- 数据 ---
@@ -88,22 +87,22 @@ class RolloutState(BaseModel):
8887
response: str | None = None
8988
response_ids: list[int] | None = None
9089
logprobs: list[float] | None = None
91-
routed_experts: list[int] | RayObjectRef | None = None # type: ignore[valid-type]
90+
routed_experts: list[int] | RayObjectRef | None = None
9291
finish_reason: str | None = None
9392

94-
@field_serializer('routed_experts')
93+
@field_serializer("routed_experts")
9594
def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | None:
9695
"""Dump 时跳过 ray.ObjectRef,序列化为 None,避免 PydanticSerializationError。"""
9796
if value is None:
9897
return None
9998
try:
10099
import ray
100+
101101
if isinstance(value, ray.ObjectRef):
102102
return None
103103
except ImportError:
104104
pass
105-
if type(value).__name__ == 'ObjectRef' and 'ray' in getattr(
106-
type(value), '__module__', ''):
105+
if type(value).__name__ == "ObjectRef" and "ray" in getattr(type(value), "__module__", ""):
107106
return None
108107
return value # list[int]
109108

xtuner/v1/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .rl_tokenize_fn import RLTextTokenizeFnConfig
2323
from .sampler import LengthGroupedSampler, ParallelSampler
2424
from .sft_tokenize_fn import OpenaiTokenizeFunction, OpenaiTokenizeFunctionConfig
25-
from .utils import CachableTokenizeFunction, CacheObj, calculate_file_sha256, calculate_xxhash, tokenizer_hash
25+
from .utils import CachableTokenizeFunction, CacheDict, calculate_file_sha256, calculate_xxhash, tokenizer_hash
2626
from .vlm_jsonl import VLMJsonlDataset
2727

2828

@@ -32,7 +32,7 @@
3232
__all__ = [
3333
"JsonlDataset",
3434
"CachableTokenizeFunction",
35-
"CacheObj",
35+
"CacheDict",
3636
"calculate_file_sha256",
3737
"calculate_xxhash",
3838
"tokenizer_hash",

xtuner/v1/datasets/_hardcode_patch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
from xtuner.v1.utils import get_logger
2727

2828
from .ftdp import FtdpTokenizeFunction
29-
from .mllm_tokenize_fn import Qwen3VLTokenizeFunction
30-
from .pt_tokenize_fn import PretrainTokenizeFunction
29+
3130
# from .rl_tokenize_fn.rl_tokenize_fn import InternS1VLTokenizeFunction
32-
from .mllm_tokenize_fn import InternS1VLTokenizeFunction
31+
from .mllm_tokenize_fn import InternS1VLTokenizeFunction, Qwen3VLTokenizeFunction
32+
from .pt_tokenize_fn import PretrainTokenizeFunction
3333
from .sft_tokenize_fn import OpenaiTokenizeFunction
3434

3535

xtuner/v1/datasets/jsonl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from xtuner.v1.datasets.data_item import CacheItem
2626
from xtuner.v1.utils import SharedMemory, get_logger
2727

28-
from .utils import CachableTokenizeFunction, CacheObj, calculate_xxhash
28+
from .utils import CachableTokenizeFunction, CacheDict, CacheObj, calculate_xxhash
2929

3030

3131
T = TypeVar("T")
@@ -439,11 +439,11 @@ def count_offsets(self, cache_dir=None):
439439
@staticmethod
440440
def _tokenize_by_offset(
441441
data: bytes,
442-
tokenize_fn: Callable[[dict], CacheObj],
442+
tokenize_fn: Callable[[dict], CacheDict | CacheObj],
443443
) -> dict:
444444
line = data.decode()
445445
tokenized = tokenize_fn(json.loads(line))
446-
if hasattr(tokenized, "num_tokens"):
446+
if isinstance(tokenized, CacheObj):
447447
num_tokens = tokenized.num_tokens
448448
else:
449449
num_tokens = tokenized["num_tokens"]

xtuner/v1/datasets/utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import xxhash
1111
from PIL import Image
12-
from typing_extensions import TypedDict
12+
from xtuner.v1.utils.cache import CacheDict, CacheObj
1313

1414
from .data_item import CacheItem
1515

@@ -20,10 +20,6 @@
2020
from transformers import PreTrainedTokenizer
2121

2222

23-
class CacheObj(TypedDict, total=False):
24-
num_tokens: int
25-
26-
2723
class CachableTokenizeFunction(ABC, Generic[T]):
2824
def __init__(self, tokenizer, *args, **kwargs):
2925
self.tokenizer = tokenizer

xtuner/v1/ray/config/worker.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import os
33
import socket
44
from pathlib import Path
5-
from typing import Any, List, Literal, Optional, Union
5+
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
6+
7+
8+
if TYPE_CHECKING:
9+
from ray.util.placement_group import PlacementGroup
610

711
from cyclopts import Group, Parameter
812
from pydantic import BaseModel, ConfigDict, PrivateAttr
@@ -313,6 +317,25 @@ def model_post_init(self, __context: Any) -> None:
313317

314318
self.worker_log_dir.mkdir(parents=True, exist_ok=True)
315319

320+
def build(self, placement_group: "PlacementGroup"):
321+
"""Build and return a Ray remote RolloutController from this config.
322+
323+
Args:
324+
placement_group: The placement group for scheduling RolloutWorker actors.
325+
326+
Returns:
327+
A Ray actor handle (proxy) of RolloutController.
328+
"""
329+
import ray
330+
331+
from xtuner.v1.ray.rollout.controller import RolloutController
332+
333+
return (
334+
ray.remote(RolloutController)
335+
.options(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)))
336+
.remote(self, placement_group)
337+
)
338+
316339

317340
if __name__ == "__main__":
318341
from cyclopts import App, Group, Parameter

xtuner/v1/ray/judger/native.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
class Judger(ABC):
1717
@abstractmethod
18+
@ray_method
1819
async def judge(self, rollout_state: RolloutState) -> RolloutState: ...
1920

2021

xtuner/v1/ray/rollout/controller.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from uuid import uuid4
1111

1212
import ray
13-
from ray.actor import ActorProxy
1413
import uvicorn
1514
from fastapi import FastAPI
15+
from ray.actor import ActorProxy
1616
from ray.util.placement_group import PlacementGroup
1717

1818
from transformers import AutoTokenizer
@@ -468,5 +468,6 @@ def _init_workers(self):
468468
self.logger.info(f"Rollout worker server URLs: {list(self.workers_info.keys())}")
469469
return engine_rank_mesh_array, worker_server_urls_map
470470

471+
471472
RayRolloutController = ray.remote(RolloutController)
472-
RolloutControllerProxy = ActorProxy[RayRolloutController]
473+
RolloutControllerProxy = ActorProxy[RayRolloutController]

xtuner/v1/ray/rollout/lmdeploy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,14 @@ def _get_request_payload(self, rollout_state: RolloutState) -> dict:
107107
prompt_token_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"]
108108
payload["input_ids"] = prompt_token_ids
109109
sample_params.return_routed_experts = True if self.enable_return_routed_experts else False
110-
lmdeploy_sample_params = self._transform_sample_params(sample_params)
110+
lmdeploy_sample_params = self._transform_sample_params(sample_params)
111111
payload.update(sample_params)
112112
else:
113113
payload = {
114114
"model": self.model_name,
115115
"messages": rollout_state.message,
116116
}
117-
lmdeploy_sample_params = self._transform_sample_params(sample_params)
117+
lmdeploy_sample_params = self._transform_sample_params(sample_params)
118118
lmdeploy_sample_params.pop("no_stop_trim", None)
119119
lmdeploy_sample_params.pop("return_logprob", None)
120120
lmdeploy_sample_params.pop("stop_token_ids", None)

0 commit comments

Comments
 (0)