diff --git a/tests/ray/test_auto.py b/.codex similarity index 100% rename from tests/ray/test_auto.py rename to .codex diff --git a/.dev_scripts/debug_gateway.py b/.dev_scripts/debug_gateway.py new file mode 100644 index 0000000000..db7a087cfe --- /dev/null +++ b/.dev_scripts/debug_gateway.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python +"""Start a RolloutController-backed Gateway for manual protocol debugging. + +This script is intended for end-to-end debugging with real clients such as +Claude Code, Codex, curl, or the OpenAI SDK. It starts the RolloutController, +waits for rollout workers to become ready, then serves the Gateway in the +current process. +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path +from typing import Any + + +DEFAULT_WORK_DIR = Path("/tmp/xtuner_debug_gateway") +DEFAULT_MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Start a local XTuner Gateway backed by a RolloutController for manual protocol debugging.\n\n" + "Example:\n" + " python .dev_scripts/debug_gateway.py --model-path /path/to/model --model-name local-test" + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--model-path", + default=DEFAULT_MODEL_PATH, + required=DEFAULT_MODEL_PATH is None, + help="Model path for rollout workers. Defaults to the ROLLOUT_MODEL_PATH environment variable.", + ) + parser.add_argument("--model-name", default=None, help="Model name exposed by the Gateway.") + parser.add_argument("--tokenizer-path", default=None, help="Tokenizer path. Defaults to --model-path.") + parser.add_argument("--rollout-env", default="debug_gateway", help="Rollout environment name.") + parser.add_argument("--ray-address", default="local", help="Ray cluster address. Use 'local' to start one.") + parser.add_argument("--ray-namespace", default="xtuner-debug-gateway", help="Ray namespace for this debug run.") + parser.add_argument("--controller-name", default=None, help="Optional Ray actor name for the RolloutController.") + parser.add_argument( + "--ray-max-concurrency", + type=int, + default=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)), + help="max_concurrency for the RolloutController actor.", + ) + + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--expert-parallel-size", type=int, default=1) + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--num-cpus-per-worker", type=int, default=16) + parser.add_argument("--cpu-memory-per-worker-gb", type=int, default=8) + parser.add_argument("--context-length", type=int, default=32768) + parser.add_argument("--dist-port-base", type=int, default=42000) + parser.add_argument("--api-host", default="127.0.0.1") + parser.add_argument("--api-port", type=int, default=30080) + parser.add_argument("--worker-log-dir", default=str(DEFAULT_WORK_DIR / "worker_logs")) + parser.add_argument("--placement-group-name", default="xtuner_debug_gateway_pg") + parser.add_argument( + "--ready-poll-seconds", + type=float, + default=5.0, + help="Polling interval while waiting for rollout workers to become ready.", + ) + parser.add_argument("--tool-call-parser", default="qwen3", help="Tool call parser used by the rollout backend.") + parser.add_argument("--reasoning-parser", default="qwen3", help="Reasoning parser used by the rollout backend.") + + parser.add_argument("--host", default="127.0.0.1", help="Gateway bind host.") + parser.add_argument("--port", type=int, default=8091, help="Gateway bind port.") + parser.add_argument("--log-level", default="info", help="Uvicorn log level.") + parser.add_argument( + "--capture-folder", + default=None, + help="Optional request capture folder. If omitted, defaults to /gateway_captures.", + ) + + return parser.parse_args() + + +def resolve_capture_output_file(capture_folder: str | Path | None) -> Path | None: + if capture_folder is None: + return None + from xtuner.v1.rl.gateway.adapters.capture import resolve_capture_output_path + + return resolve_capture_output_path(capture_folder) + + +def describe_capture_output(capture_folder: str | Path | None) -> str: + capture_output_file = resolve_capture_output_file(capture_folder) + if capture_output_file is None: + return "disabled" + return f"{capture_output_file} (requests with API keys are split into api_key_.jsonl)" + + +def init_ray(address: str, namespace: str) -> dict[str, Any]: + import ray + + ctx = ray.init(address=address, namespace=namespace, ignore_reinit_error=True) + address_info = getattr(ctx, "address_info", {}) or {} + return { + "requested_ray_address": address, + "ray_address": address_info.get("address") or address_info.get("gcs_address") or address, + "namespace": namespace, + "ray_context": address_info, + } + + +def build_rollout_config(args: argparse.Namespace): + from xtuner.v1.rl.rollout.worker import RolloutConfig + + model_path = str(args.model_path) + tokenizer_path = str(args.tokenizer_path or args.model_path) + model_name = args.model_name or Path(model_path).name.lower() + return RolloutConfig( + env=args.rollout_env, + device="GPU", + model_path=model_path, + model_name=model_name, + tokenizer_path=tokenizer_path, + tensor_parallel_size=args.tensor_parallel_size, + expert_parallel_size=args.expert_parallel_size, + context_length=args.context_length, + worker_log_dir=args.worker_log_dir, + dist_port_base=args.dist_port_base, + api_host=args.api_host, + api_port=args.api_port, + tool_call_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + ) + + +def build_controller(args: argparse.Namespace): + import ray + + from xtuner.v1.rl.rollout.controller import RolloutController + from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + + resource_config = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=args.num_workers, + num_cpus_per_worker=args.num_cpus_per_worker, + cpu_memory_per_worker=args.cpu_memory_per_worker_gb * 1024**3, + ) + placement_group = AutoAcceleratorWorkers.build_placement_group( + resource_config, + name=args.placement_group_name, + ) + rollout_config = build_rollout_config(args) + actor_options: dict[str, Any] = { + "max_concurrency": args.ray_max_concurrency, + } + if args.controller_name: + actor_options["name"] = args.controller_name + controller = ray.remote(RolloutController).options(**actor_options).remote(rollout_config, placement_group) + print("Created rollout controller.") + return controller, placement_group + + +def wait_for_controller_ready(controller, poll_seconds: float) -> dict[str, Any]: + import ray + + while True: + ready, status = ray.get(controller.get_ready_status.remote()) + if ready: + print(f"Rollout controller ready: {status}") + return status + print(f"Waiting for rollout workers... {status}") + time.sleep(poll_seconds) + + +def start_gateway(args: argparse.Namespace, controller) -> None: + from xtuner.v1.rl.gateway.config import GatewayConfig + from xtuner.v1.rl.gateway.server import build_local_gateway_app, serve_gateway + + capture_folder = args.capture_folder + if capture_folder is None: + capture_folder = str(Path(args.worker_log_dir) / GatewayConfig._CAPTURE_PATH_FOLDER) + + cfg = GatewayConfig( + host=args.host, + port=args.port, + auto_start=False, + capture_folder=capture_folder, + log_level=args.log_level, + ) + + app = build_local_gateway_app(controller, config=cfg) + print(f"Starting gateway at http://{cfg.host}:{cfg.port}") + print(f"Gateway capture output: {describe_capture_output(cfg.capture_folder)}") + serve_gateway(app, cfg) + + +def cleanup_controller(controller, placement_group) -> None: + import ray + + try: + ray.get(controller.shutdown.remote(), timeout=300) + except Exception as exc: + print(f"Failed to shutdown rollout controller cleanly: {exc}", file=sys.stderr) + try: + ray.kill(controller, no_restart=True) + except Exception as exc: + print(f"Failed to kill rollout controller: {exc}", file=sys.stderr) + if placement_group is not None: + try: + ray.util.remove_placement_group(placement_group) + except Exception as exc: + print(f"Failed to remove placement group: {exc}", file=sys.stderr) + + +def main() -> None: + args = parse_args() + controller = None + placement_group = None + try: + init_info = init_ray(args.ray_address, args.ray_namespace) + print( + "Initialized Ray: " + f"requested_address={init_info['requested_ray_address']}, " + f"address={init_info['ray_address']}, namespace={init_info['namespace']}" + ) + controller, placement_group = build_controller(args) + wait_for_controller_ready(controller, args.ready_poll_seconds) + start_gateway(args, controller) + finally: + ray_module = sys.modules.get("ray") + if ray_module is not None and ray_module.is_initialized(): + if controller is not None: + cleanup_controller(controller, placement_group) + ray_module.shutdown() + + +if __name__ == "__main__": + main() diff --git a/.dev_scripts/rl_config_factory.py b/.dev_scripts/rl_config_factory.py deleted file mode 100644 index d8a9e42f1c..0000000000 --- a/.dev_scripts/rl_config_factory.py +++ /dev/null @@ -1,129 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Optional - -from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.utils.rl_test_utils import get_eos_token - - -def _filter_pydantic_kwargs(target_class: Any, kwargs: Dict) -> Dict: - accepted_keys = set(target_class.model_fields.keys()) - return {k: v for k, v in kwargs.items() if k in accepted_keys} - - -def _build_config(config_class, **kwargs): - filtered_params = _filter_pydantic_kwargs(config_class, kwargs) - return config_class(**filtered_params) - - -def get_resources_config(**kwargs) -> AcceleratorResourcesConfig: - return _build_config(AcceleratorResourcesConfig, **kwargs) - - -def get_rollout_config(**kwargs) -> RolloutConfig: - return _build_config(RolloutConfig, **kwargs) - - -def get_dataflow_config(**kwargs) -> DataFlowConfig: - return _build_config(DataFlowConfig, **kwargs) - - -def get_replay_buffer_config(tokenizer: Any, **kwargs) -> ReplayBufferConfig: - tokenizer_config = RLTokenizeFnConfig(max_length=kwargs["max_prompt_length"]) - train_dataset = DatasetConfig(anno_path=kwargs["data_path"]) - train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] - dataloader_config = DataloaderConfig(collator="fake_collator", pack_level="none") - return ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=tokenizer, - postprocessor_func=kwargs.get("filter_func"), - ) - - -def get_dapo_judger_config(tokenizer: Any, **kwargs): - dapo_defaults_args = { - "enable_overlong_buffer": True, - "overlong_buffer_len": 4096, - "overlong_penalty_factor": 1.0, - } - dapo_config_params = {**dapo_defaults_args, **kwargs} - eos_token_id = get_eos_token(kwargs["model_path"]) - eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) - from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig - - filtered_params = _filter_pydantic_kwargs(DapoMathJudgerConfig, dapo_config_params) - dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - max_response_len=kwargs["max_response_length"], - tokenizer=tokenizer, - **filtered_params, - ) - return JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - - -def get_gsm8k_judger_config(): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config]) - return judger_cfg - - -def get_evaluator_config(tokenizer: Any, **kwargs) -> Optional[EvaluatorConfig]: - if not kwargs["enable_evaluate"]: - return None - - eval_dataset = DatasetConfig(anno_path=kwargs["eval_data_path"]) - tokenizer_config = RLTokenizeFnConfig(max_length=kwargs["max_prompt_length"]) - eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] - - filtered_params = _filter_pydantic_kwargs(EvaluatorConfig, kwargs) - - return EvaluatorConfig( - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - **filtered_params, - ) - - -def get_train_worker_config(**kwargs) -> WorkerConfig: - from xtuner.v1.model import get_model_config_from_hf - - model_cfg = get_model_config_from_hf(Path(kwargs["model_path"])) - defaults = { - "optim_cfg": AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False), - "loss_cfg": GRPOLossConfig( - policy_loss_cfg={ - "cliprange_high": 0.28, - "cliprange_low": 0.2, - "loss_type": "vanilla", - "clip_ratio_c": 10.0, - "log_prob_diff_min": -20.0, - "log_prob_diff_max": 20.0, - }, - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - ), - "lr_cfg": LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6), - "fsdp_cfg": FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1), - "sp_size": 1, - "optimizer_steps": 16, - "pack_max_length": 4096, - } - config_params = {**defaults, **kwargs} - filtered_params = _filter_pydantic_kwargs(WorkerConfig, config_params) - return WorkerConfig(load_from=config_params["model_path"], model_cfg=model_cfg, **filtered_params) diff --git a/.github/workflows/e2e_test.yaml b/.github/workflows/e2e_test.yaml index ad3e376ddc..5be1d1a0a2 100644 --- a/.github/workflows/e2e_test.yaml +++ b/.github/workflows/e2e_test.yaml @@ -5,6 +5,9 @@ permissions: pages: write on: + pull_request: + branches: + - "rl_design" workflow_dispatch: inputs: repo_org: @@ -29,16 +32,16 @@ jobs: run: sudo git clean -ffdx - name: Clone repository uses: actions/checkout@v2 - with: - repository: ${{ github.event.inputs.repo_org || 'InternLM/xtuner' }} - ref: ${{github.event.inputs.repo_ref || 'main'}} + #with: + #repository: ${{ github.event.inputs.repo_org || 'InternLM/xtuner' }} + #ref: ${{github.event.inputs.repo_ref || 'main'}} - name: run-test run: | source /mnt/shared-storage-user/opencompass-shared/qa-llm-cicd/miniconda3/bin/activate conda activate clusterx conda env list unset HTTP_PROXY;unset HTTPS_PROXY;unset http_proxy;unset https_proxy; - pytest autotest/test_all.py -m all -n 1 -vv --run_id ${{ github.run_id }} + pytest autotest/test_all.py::test_all[qwen3-rl-lmdeploy] -m all -n 1 -vv --run_id ${{ github.run_id }} - name: Upload Artifacts if: ${{ !cancelled() }} @@ -49,9 +52,12 @@ jobs: retention-days: 7 name: xtuner-e2e-${{ github.run_id }} + - name: Copy deploy action + run: cp -r ../JamesIves ./ + - name: Deploy to GitHub Pages if: ${{ !cancelled() }} - uses: JamesIves/github-pages-deploy-action@v4 + uses: ./JamesIves/github-pages-deploy-action/v4 with: token: ${{ github.token }} branch: gh-pages diff --git a/.github/workflows/unit_test.yaml b/.github/workflows/unit_test.yaml index 488cd141c2..3964b6cb08 100644 --- a/.github/workflows/unit_test.yaml +++ b/.github/workflows/unit_test.yaml @@ -4,6 +4,7 @@ on: branches: - "main" - "refactor" + - "rl_design" paths-ignore: - "docs/**" - "**.md" @@ -49,4 +50,4 @@ jobs: - name: unit-test run: | export PYTHONPYCACHEPREFIX=/tmp - python ci/scripts/xtuner_unittest.py "$IMAGE" "source ${{env.WORKSPACE_PREFIX}}/BASE_ENV.sh;source ci/scripts/CI_ENV.sh" "pytest tests" + python ci/scripts/xtuner_unittest.py "$IMAGE" "source ${{env.WORKSPACE_PREFIX}}/BASE_ENV.sh;source ci/scripts/CI_ENV.sh" "pytest -s tests/rl --ignore=tests/rl/test_evaluator.py --ignore=tests/rl/test_rl_trainer.py --ignore=tests/rl/test_vl_rollout.py --ignore=tests/rl/test_rl_train_with_sft.py" diff --git a/.gitignore b/.gitignore index c1873d13a1..d029bea53c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +zdev/ +old/ +bak/ +exp*/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/autotest/config.yaml b/autotest/config.yaml index b33bfa82f9..730d551713 100644 --- a/autotest/config.yaml +++ b/autotest/config.yaml @@ -188,3 +188,46 @@ case: lr: 0 runtime_info/tgs: 0.05 timeout: 10800 + + qwen3-rl-lmdeploy: + - + type: rl + parameters: + config: autotest/config/rl_qwen3_gsk8k_grpo.py + infer_backend: lmdeploy + output_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/test_output + resource: + envs: + - MODEL_PATH=/mnt/shared-storage-user/llmrazor-share/model/Qwen3-8B + - DATA_PATH=/mnt/shared-storage-user/llmrazor-share/data/gsm8k/train-mini.jsonl + - EVAL_DATA_PATH=/mnt/shared-storage-user/llmrazor-share/data/gsm8k/test.jsonl + - XTUNER_DETERMINISTIC=true + assert_info: + base_metric: qwen3-rl-lmdeploy/20260203/tracker.jsonl + check_metrics: + - + metric: eval/accuracy + threshold: 0.05 + method: absolute + operator: < + - + metric: response/rewards/mean + threshold: 0.1 + method: absolute + operator: < + - + metric: mismatch/mismatch_k3_kl + threshold: 0.0001 + method: absolute + operator: <= + - + metric: response/response_len/mean + threshold: 0.12 + method: relative + operator: < + - + metric: time/step + threshold: 10 + method: absolute + operator: < + timeout: 2460 diff --git a/autotest/config/rl_qwen3_gsk8k_grpo.py b/autotest/config/rl_qwen3_gsk8k_grpo.py deleted file mode 100644 index 18f4eb92a7..0000000000 --- a/autotest/config/rl_qwen3_gsk8k_grpo.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False - -# basic settings -experimental_name = "grpo_gsm8k_tiny" -total_epochs = 3 -global_batch_size = 64 -prompt_repeat_k = 5 -rollout_tp_size = 1 -rollout_ep_size = 1 -max_prompt_length = 512 -max_response_length = 1024 -pack_max_length = 32768 -train_optimizer_steps = 1 -evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length=max_prompt_length + max_response_length, - rollout_max_batch_size_per_instance=1024, -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - max_concurrent=512, -) - -evaluator_cfg = ( - EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=True, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, - ) - if enable_evaluate - else None -) - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - exp_tracker="jsonl", -) diff --git a/autotest/config/rl_qwen3_gsm8k_grpo.py b/autotest/config/rl_qwen3_gsm8k_grpo.py new file mode 100644 index 0000000000..cce29c9ada --- /dev/null +++ b/autotest/config/rl_qwen3_gsm8k_grpo.py @@ -0,0 +1,198 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH 可选: +WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" + +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.agent_loop import ( + AgentLoopManagerConfig, + SamplerConfig, + SingleTurnAgentLoopConfig, + SyncProduceStrategyConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig + + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k" +rollout_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +global_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 512 +max_response_length = 1024 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * WORLD_SIZE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + sampler_config=eval_sampler_config, +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + judger_config=judger_config, + tokenizer_path=model_path, + replay_buffer_config=dict(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + rollout_steps=rollout_steps, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/ci/scripts/test_ray_sft.py b/ci/scripts/test_ray_sft.py index 10943072a8..64fac48762 100644 --- a/ci/scripts/test_ray_sft.py +++ b/ci/scripts/test_ray_sft.py @@ -12,8 +12,8 @@ from xtuner.v1.model.moe.moe import BalancingLossConfig, ZLossConfig from xtuner.v1.datasets import FTDPTokenizeFnConfig import ray -from xtuner.v1.rl.base.worker import TrainingWorker -from xtuner.v1.ray.base import AutoAcceleratorWorkers, AcceleratorResourcesConfig +from xtuner.v1.rl.trainer import TrainingWorker +from xtuner.v1.rl.utils import AutoAcceleratorWorkers, AcceleratorResourcesConfig from xtuner.v1.train import TrainerConfig from xtuner.v1.train.trainer import Trainer from xtuner.v1.loss.ce_loss import CELossConfig diff --git a/design/component_rl.py b/design/component_rl.py new file mode 100644 index 0000000000..af49c5c7ed --- /dev/null +++ b/design/component_rl.py @@ -0,0 +1,363 @@ +################################### imports ###################################### +from typing import Any, Callable +import asyncio +from enum import Enum +from torch.utils.data import DataLoader +import threading +from typing import List +from collections import deque + +from xtuner.v1.ray.rollout.controller import SampleParams +from xtuner.v1.data_proto.rl_data import SampleParams # TODO: 删掉一个? +from xtuner.v1.data_proto.sequence_context import SequenceContext +from xtuner.v1.loss.base_loss_ctx import BaseLossContext + +def load_tokenizer(hf_checkpoint, trust_remote_code=True): ... +def load_processor(hf_checkpoint, trust_remote_code=True): ... + +class PlacementGroup: ... + +def log_metrics(metrics: dict): ... + +class TrainItem: + seq_ctx: SequenceContext + loss_ctxs: BaseLossContext # 考虑更通用的多 loss 场景,时间原因暂时不改 + + +################################### Main components ###################################### +class Status(Enum): + INIT = "init" + COMPLETED = "completed" + ABORTED = "aborted" + FAILED = "failed" + + +class RolloutState: # RolloutState: + # message: list + tokens: list[int] # 每一次实际输入 + + uid: int + session_id: int | None = None + prompt_ids: list[int] + response: str + response_ids: list[int] # 每一次实际输出,覆盖写 + logprobs: list[float] + routed_experts: list[int] | None = None + reward: float | list[float] | list[dict] | None = None + loss_mask: list[int] | None = None # tokens + response_ids的长度 + state: Status = Status.INIT + sample_parms: SampleParams | None = None + tools: list | None = None + tool_choice: str | None = None + mm_infer_info: dict[str, Any] + mm_train_info: dict[str, Any] + finish_reason: str | None = None + staleness: int = 0 + extra_fields: dict[str, Any] = {} + + +class RolloutController: + async def generate(self, rollout_state: RolloutState) -> RolloutState: ... + + +class Judge: + def judge(self, rollout_state: RolloutState) -> RolloutState: ... + + +# 负责一条和一组轨迹生成,非常简单 +class AgentLoop: + def __init__(self, rollout_ctl: RolloutController, hf_checkpoint, sample_params=SampleParams(), judge_cfg: dict = None) -> None: + self.rollout_ctl = rollout_ctl + self.hf_checkpoint = hf_checkpoint + self.tokenizer = load_tokenizer(hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(hf_checkpoint, trust_remote_code=True) + self.sample_params = sample_params + self.judge = Judge() if judge_cfg is not None else None + self.task_name = 'aa' + + async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: ... + + async def generate_group(self, rollout_states: list[RolloutState]) -> list[RolloutState]: + pending_tasks = [] + + for rollout_state in rollout_states: + task = asyncio.create_task(self.generate_sample(rollout_state)) + pending_tasks.append(task) + + generated_samples = asyncio.gather(*pending_tasks) + + group_samples = await generated_samples + return group_samples + + +class SingleTurnAgentLoop(AgentLoop): + async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: + rollout_state = await self.rollout_ctl.generate(rollout_state) + if self.judge is not None: + rollout_state = self.judge.judge(rollout_state) + return rollout_state + + +class MultiTurnAgentLoop(AgentLoop): + async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: + ... + + +# 中心化管理所有 rollout 过程中的数据,暂时训练中途数据不会放到其中,后续可能会统一为全局数据管理器 +# 只管理数据,不控制数据 +# 后续可能会抽象一层 backend interface,支持不同存储后端 +# 是否需要是 ray 对象? + +class Storage: + async def put(self, items: list[RolloutState], storage_indices: StorageIndices): ... + async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ... + def __len__(self): ... + +class FIFOBackend(Storage): # 同步RL用 + limit: int = 0 + _storage: deque[RolloutState] = deque(maxlen=limit) + + +class StalenessBackend(Storage): # 异步RL用 + max_staleness : int + min_staleness : int + +class BaseReplayBuffer: + def __init__(self, storage_backend: Storage): + self._storage = storage_backend + + async def put(self, items: list[RolloutState], task_name: str, **kwargs): ... + + async def get(self, batch_size: int, task_name: str, group_state: Status, **kwargs) -> list[RolloutState]: ... + +class Sampler: + prompt_k: int + + async def sample(self) -> list[RolloutState]: ... + +class SamplerWithReplayBuffer(Sampler): + replay_buffer: BaseReplayBuffer + async def sample(self) -> list[RolloutState]: ... + +class ProduceStrategy: # Scheduler负责调度多个样本的生成,里面可以有超发、异步、重排长短样本等优化 + replay_buffer: BaseReplayBuffer + + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int): ... + + +class SyncProduceStrategy(ProduceStrategy): + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int): + # TODO: 将 batch_size 封装成 while stop_condition ? + data_concurrency = batch_size + + rollout_states = await sampler.sample() + + pending_tasks = [] + for _ in range(data_concurrency): + task = asyncio.create_task(agent_loop.generate_group(rollout_states)) + pending_tasks.append(task) + + completed_sample_count = 0 + while completed_sample_count < data_concurrency: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED) + + # 如果要过滤,在这个地方处理,然后加入到 replay buffer + # 如果被过滤的数据就放到 put_to_filtered pool 中 + for task in done_tasks: + try: + await self.replay_buffer.put(task.result(), task_name=agent_loop.task_name) + completed_sample_count += 1 + except Exception as e: + print(f"Error in generating trajectory: {e}") + + +class AsyncProduceStrategy(ProduceStrategy): + staleness_threshold: float = 0.0 + enable_partial_rollout: bool = False + tail_batch_trigger_size: int = 0 + tail_batch_candidate_step: int = 0 + + async def produce_batch(self, agent_loop: AgentLoop, sampler: SamplerWithReplayBuffer, batch_size: int): + # hack sampler with replay buffer + rollout_state = await sampler.sample() + ... + + +# 支持单 task +class AgentLoopManager: + def __init__(self, agent_loop: AgentLoop, produce_strategy: ProduceStrategy, sampler: Sampler, replay_buffer: BaseReplayBuffer): + # 一一绑定 + self._agent_loop: AgentLoop = agent_loop # 负责一条或者一组样本生成 + self._produce_strategy: ProduceStrategy = produce_strategy # 负责一批样本生成+调度 + self._sampler: Sampler = sampler # 负责采样 + self._replay_buffer: BaseReplayBuffer = replay_buffer + + # 共卡 + async def produce_batch(self, batch_size: int): + await self._produce_strategy.produce_batch(self._agent_loop, self._sampler, batch_size) + return await self._replay_buffer.get(batch_size, task_name=self._agent_loop.task_name, group_state=Status.COMPLETED) + + # 非共卡 + async def disaggregate_produce_batch(self, batch_size: int): + # 起一个单独线程不断生成 + self._produce_strategy.produce_batch(self._agent_loop, self._sampler, batch_size) + + async def disaggregate_get_batch(self, batch_size: int): + # 从不同的 replay_buffer 中采样,然后训练 + return await self._replay_buffer.get(batch_size, task_name=self._agent_loop.task_name, group_state=Status.COMPLETED) + + +# 多 task 自己写 +class MulitiAgentLoopManager(AgentLoopManager): + def __init__(self, + agent_loop_managers: list[AgentLoopManager]): + self._agent_loop_managers = agent_loop_managers + + async def produce_batch(self, batch_size: int): + pass + + async def disaggregate_produce_batch(self, batch_size: int): + self._produce_strategy[0].produce_batch(self._agent_loop, self._sampler, batch_size) + self._produce_strategy[1].produce_batch(self._agent_loop, self._sampler, batch_size) + + async def disaggregate_get_batch(self, batch_size: int): + pass + + +# 1. grpo 算法是在 TrainController 中调用AdvantageEstimator +# 2. ppo 算法是在 Trainworker 中调用AdvantageEstimator,并新增额外方法实现(如 compute_ref_logprobs, compute_values),无需重写 TrainController 和 Trainworker +class AdvantageEstimator: + def compute_advantages(self, batch: list[TrainItem]) -> list[TrainItem]: ... + + +# ppo 算法是通过在 Trainworker 中新增额外方法实现,无需重写 TrainController 和 Trainworker +class TrainController: + advantage_estimator: AdvantageEstimator + # high level API + def fit(self, batch: list[TrainItem]) -> dict: ... + def train(self, batch: list[TrainItem]) -> dict: ... + + +class CheckpointEngine: + rollout_ctl: RolloutController + train_ctl: TrainController + def update_weights(self): ... + +class Evaluator: # 根据rollout输出的batch,计算评估指标。本身并不负责rollout。 + def evaluate(self, batch: list[RolloutState]) -> dict: ... + + +################################### Usage example with components ######################################### +# 弱化Trainer:Trainer中代码尽量少,尽量用componet来组织代码。下面是几种典型Trainer的组织方式。 + +class Packer: + def pack_pad_dispatch(self, samples: list[RolloutState]) -> list[TrainItem]: ... + + +def main_colocate_with_train_highlevel(): + # rollout_ctl, train_ctl, data_mgr, env, evaluator等对象都是主进程中本地对象,并不是ray actor。这样: + # 1. 保证一大部分的数据传递无需跨机传输,方便统一管理 + # 2. 减少ray引入的debug和维护难度 + pg: PlacementGroup + rollout_ctl: RolloutController(pg) + train_ctl: TrainController(pg) + checkpoint_engine: CheckpointEngine(rollout_ctl, train_ctl) + + global_batch_size: int + dataloader: DataLoader + hf_checkpoint: str + agent_loop: AgentLoop = AgentLoop(rollout_ctl, hf_checkpoint) + produce_strategy: AsyncProduceStrategy = AsyncProduceStrategy(replay_buffer) + sampler: SamplerWithReplayBuffer + replay_buffer: BaseReplayBuffer + env: AgentLoopManager = AgentLoopManager(agent_loop, produce_strategy, sampler, replay_buffer) + + eval_batch_size: int + eval_produce_strategy: ProduceStrategy = SyncProduceStrategy(eval_replay_buffer) + eval_sampler: Sampler = Sampler(replay_buffer) + eval_replay_buffer: BaseReplayBuffer + eval_env: AgentLoopManager = AgentLoopManager(agent_loop, eval_produce_strategy, eval_sampler, eval_replay_buffer) + evaluator: Evaluator + total_rollouts: int + + for i in range(total_rollouts): + train_batch: list[RolloutState] = asyncio.run(env.produce_batch(global_batch_size)) + + train_batch = Packer.pack_pad_dispatch(train_batch) + + metrics = train_ctl.fit(train_batch) + log_metrics(metrics) + + checkpoint_engine.update_weights() + + eval_batch: list[RolloutState] = asyncio.run(eval_env.produce_batch(eval_batch_size)) + eval_metrics = evaluator.evaluate(eval_batch) + log_metrics(eval_metrics) + + +def main_colocate_with_train_lowlevel(): + data_mgr: DataManager + pg: PlacementGroup + rollout_ctl: RolloutController(pg) + env: AgentLoopManager(rollout_ctl) + train_ctl: TrainController(pg) + checkpoint_engine: CheckpointEngine(rollout_ctl, train_ctl) + + eval_data_mgr: DataManager + evaluator: Evaluator + total_rollouts: int + + for i in range(total_rollouts): + asyncio.run(env.produce_batch(data_mgr)) + + batch: list[TrainItem] = data_mgr.get_batch() + + # below is equivalent to train_ctl.fit(batch) + batch = Packer.pack_pad_dispatch(batch) + batch = train_ctl.compute_old_logprobs(batch) + batch = train_ctl.compute_ref_logprobs(batch) + batch = train_ctl.compute_values(batch) + batch = train_ctl.compute_advantages(batch) # TODO: AdvEstimator + metrics = train_ctl.train(batch) + + log_metrics(metrics) + + checkpoint_engine.update_weights() + + env.produce_batch(eval_data_mgr) + eval_metrics = evaluator.evaluate(eval_data_mgr.get_batch()) + log_metrics(eval_metrics) + + +def main_separate(): + data_mgr: DataManager + pg1: PlacementGroup + rollout_ctl: RolloutController(pg1) + # pg1_2: PlacementGroup + # rollout_ctl_2: RolloutController(pg1_2) + env: AgentLoopManager(rollout_ctl) # Environment(rollout_ctl, rollout_ctl_2) + + pg2: PlacementGroup + train_ctl: TrainController(pg2) + checkpoint_engine: CheckpointEngine(rollout_ctl, train_ctl) + + eval_data_mgr: DataManager + evaluator: Evaluator + + producer_thread = threading.Thread(target=env.produce_loop, args=(data_mgr,)) + producer_thread.start() + + total_rollouts: int + for i in range(total_rollouts): + batch: list[TrainItem] = data_mgr.get_batch() + metrics = train_ctl.fit(batch) + log_metrics(metrics) + + checkpoint_engine.update_weights() + + env.produce_batch(eval_data_mgr) # 优先级高于env.produce_loop + eval_metrics = evaluator.evaluate(eval_data_mgr.get_batch()) + log_metrics(eval_metrics) diff --git a/docs/en/rl/advanced_tutorial/loss.md b/docs/en/rl/advanced_tutorial/loss.md index 130517d2e3..c8f54df849 100644 --- a/docs/en/rl/advanced_tutorial/loss.md +++ b/docs/en/rl/advanced_tutorial/loss.md @@ -9,8 +9,7 @@ All loss calculations in XTuner involve two core components: `LossConfig` and `L ```python import torch import torch.nn as nn -from xtuner.v1.rl.grpo import GRPOLossConfig, GRPOLossContext -from xtuner.v1.rl.base import RLLossContextInputItem +from xtuner.v1.rl.loss import GRPOLossConfig, GRPOLossContext, RLLossContextInputItem from xtuner.v1.data_proto import SequenceContext def gather_logprobs(logits, shifted_labels): diff --git a/docs/en/rl/tutorial/rl_grpo_trainer.md b/docs/en/rl/tutorial/rl_grpo_trainer.md index 95f0efd5a5..fa37a424e7 100644 --- a/docs/en/rl/tutorial/rl_grpo_trainer.md +++ b/docs/en/rl/tutorial/rl_grpo_trainer.md @@ -91,7 +91,7 @@ If you need more fine-grained control (such as distributed inference, inference ```{code-block} python :caption: Configure Inference Environment -from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig model_path = "/path/to/qwen3-8B" # Replace with your model path @@ -143,8 +143,8 @@ For more configuration parameters, please refer to the API documentation: {class :caption: Configure Training Strategy from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.loss import GRPOLossConfig model_path = "/path/to/qwen3-8B" # Fill in your model path train_optimizer_steps = 4 # Training optimization steps diff --git a/docs/zh_cn/rl/advanced_tutorial/gateway_api_debug.md b/docs/zh_cn/rl/advanced_tutorial/gateway_api_debug.md new file mode 100644 index 0000000000..06b23c6ab2 --- /dev/null +++ b/docs/zh_cn/rl/advanced_tutorial/gateway_api_debug.md @@ -0,0 +1,207 @@ +# Gateway 兼容接口联调 + +本文记录如何使用真实的 Agent 客户端和 OpenAI SDK 联调 XTuner Gateway,验证 Gateway 对 Anthropic Messages、OpenAI Responses 和 OpenAI Chat Completions 接口的兼容情况。 + +## 适用场景 + +当你修改 Gateway、Rollout Controller、Agent Loop 或协议适配层后,可以按本文流程做一次端到端验证,确认: + +- Gateway 能够接收 `/v1/messages`、`/v1/responses` 和 `/v1/chat/completions` 请求。 +- Claude Code、Codex 等真实 Agent 客户端能够连接到本地 Gateway。 +- 普通对话和工具调用链路都能正常返回。 +- Gateway 的请求捕获日志能够记录调试所需的协议转换信息。 + +## 前置条件 + +1. 已安装 XTuner 运行环境,并能启动 Rollout Controller 和 Gateway。 +2. Gateway 服务默认监听 `http://127.0.0.1:8091`。 +3. Gateway 模型名配置为 `local-test`。 +4. 鉴权 token 使用本地调试值 `dummy`。 +5. 启动 Gateway 时建议打开 `capture_folder`,便于回看请求、协议适配结果和模型输出。 + +```{note} +真实 Agent 客户端会携带较长的系统提示词和工具定义。联调 Claude Code 时建议将上下文长度设置到 32K;联调 Codex 时建议至少设置到 16K。 +``` + +## 启动 Gateway + +先启动 Rollout Controller 和 Gateway。以下命令是本地调试脚本示例: + +```bash +python .dev_scripts/debug_gateway.py \ + --model-path \ + --model-name local-test \ + --context-length 32768 +``` + +启动时需要确认: + +- Gateway 端口为 `8091`。 +- 模型名为 `local-test`。 +- 上下文长度满足当前客户端需求。 +- 已配置 `capture_folder`。 + +## 验证 Anthropic Messages 接口 + +Claude Code 通过 Anthropic Messages API 访问 Gateway,可用于验证 `/v1/messages` 的协议适配和工具调用链路。 + +### 安装 Claude Code + +```bash +curl -fsSL https://claude.ai/install.sh | bash +``` + +### 配置环境变量 + +```bash +export ANTHROPIC_BASE_URL=http://127.0.0.1:8091 +export ANTHROPIC_AUTH_TOKEN=dummy +export ANTHROPIC_MODEL=local-test +export API_TIMEOUT_MS=600000 +``` + +### 验证普通对话 + +启动 Claude Code 后发送: + +```text +Reply with exactly: OK +``` + +如果客户端能够收到模型回复,说明 `/v1/messages` 的基础请求链路可用。 + +### 验证工具调用 + +继续发送以下 prompt: + +```text +Use your tools to find the gateway route definitions, then add a single log line for every incoming request to /v1/messages. Show me the exact file you changed and the patch you would apply. +``` + +如果 Claude Code 能够正常调用工具、读取仓库文件,并返回拟修改的文件和 patch,说明工具调用链路可用。 + +## 验证 OpenAI Responses 接口 + +Codex 通过 OpenAI Responses API 访问 Gateway,可用于验证 `/v1/responses` 的协议适配和工具调用链路。 + +### 安装 Codex + +按 Codex 官方安装方式完成安装后,配置本地模型提供方。 + +### 配置 Codex + +在 Codex 的 `config.toml` 中添加本地 Gateway provider: + +```toml +model = "local-test" +model_provider = "xtuner" + +[model_providers.xtuner] +name = "xtuner gateway" +base_url = "http://127.0.0.1:8091/v1" +env_key = "XTUNER_GATEWAY_KEY" +``` + +配置访问 token: + +```bash +export XTUNER_GATEWAY_KEY=dummy +``` + +### 先用 curl 验证接口 + +启动 Codex 前,先确认 `/v1/responses` 能直接返回: + +```bash +curl http://127.0.0.1:8091/v1/responses \ + -H 'content-type: application/json' \ + -H 'authorization: Bearer dummy' \ + -d '{ + "model": "local-test", + "input": "Reply with exactly OK" + }' +``` + +如果返回状态为 `completed`,且 `output` 中包含模型回复,说明 Responses 接口基础链路可用。 + +### 验证普通对话 + +启动 Codex 后发送: + +```text +你好 +``` + +如果 Codex 能收到中文回复,说明客户端能够通过本地 Gateway 完成基础对话。 + +### 验证工具调用 + +继续发送以下 prompt: + +```text +Use your tools to list the top-level files and directories in the current repository. +Do not explain your plan. +Do not answer from memory. +If you cannot access tools, reply exactly: NO_TOOLS +``` + +如果 Codex 返回了仓库顶层文件和目录,而不是 `NO_TOOLS`,说明 Responses 接口下的工具调用链路可用。 + +## 验证 OpenAI Chat Completions 接口 + +除了真实 Agent 客户端,也可以使用 OpenAI Python SDK 验证 `/v1/chat/completions`。 + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://127.0.0.1:8091/v1", + api_key="dummy", +) + +resp = client.chat.completions.create( + model="local-test", + messages=[ + {"role": "system", "content": "You are a helpful coding assistant."}, + {"role": "user", "content": "Reply with exactly: OK"}, + ], + max_tokens=32, + temperature=0, +) + +print(resp.choices[0].message.content) +``` + +如果输出包含 `OK`,说明 Chat Completions 接口基础链路可用。 + +## 检查 capture 日志 + +联调过程中建议同步检查 Gateway 的 `capture_folder` 输出。重点确认每条记录中是否包含: + +- `source_protocol`:请求来源协议,例如 `anthropic_messages` 或 `openai_responses`。 +- `internal_messages`:Gateway 转换后发送给 Rollout 的内部消息。 +- `output_messages` 或 `output_text`:模型输出转换回客户端协议后的结果。 +- `rollout_tools` 和 `rollout_tool_choice`:工具定义和工具选择策略。 +- `request_id`:用于串联客户端请求、Gateway 记录和 Rollout 结果。 + +这些字段能帮助定位问题出在客户端请求、协议适配、Rollout 生成还是响应转换阶段。 + +## 常见问题 + +### 客户端请求超时 + +先检查 Gateway 是否仍在运行,并适当增大客户端超时时间。Claude Code 可设置: + +```bash +export API_TIMEOUT_MS=600000 +``` + +同时检查 Rollout Controller 是否收到请求,以及推理服务是否有可用并发。 + +### 客户端上下文过长 + +真实 Agent 客户端会注入系统提示词、工具 schema 和历史消息。如果请求被截断或报 context length 相关错误,需要增大 Gateway 和推理后端的上下文长度。 + +### 工具调用没有触发 + +先使用本文中的工具调用 prompt 做最小复现,再检查 `capture_folder` 中是否记录了工具定义。如果 `rollout_tools` 为空,问题通常出在客户端请求到 Gateway 的协议适配阶段;如果工具定义存在但没有工具调用结果,需要继续检查模型输出和 Agent 客户端的工具执行日志。 diff --git a/docs/zh_cn/rl/advanced_tutorial/loss.md b/docs/zh_cn/rl/advanced_tutorial/loss.md index 278129d6bd..5578b30a44 100644 --- a/docs/zh_cn/rl/advanced_tutorial/loss.md +++ b/docs/zh_cn/rl/advanced_tutorial/loss.md @@ -9,8 +9,7 @@ XTuner 中所有的 loss 计算均涉及两个核心组件 `LossConfig` 和 `Los ```python import torch import torch.nn as nn -from xtuner.v1.rl.grpo import GRPOLossConfig, GRPOLossContext -from xtuner.v1.rl.base import RLLossContextInputItem +from xtuner.v1.rl.loss import GRPOLossConfig, GRPOLossContext, RLLossContextInputItem from xtuner.v1.data_proto import SequenceContext def gather_logprobs(logits, shifted_labels): diff --git a/docs/zh_cn/rl/tutorial/rl_grpo_trainer.md b/docs/zh_cn/rl/tutorial/rl_grpo_trainer.md index 267bd14d49..e8f1486741 100644 --- a/docs/zh_cn/rl/tutorial/rl_grpo_trainer.md +++ b/docs/zh_cn/rl/tutorial/rl_grpo_trainer.md @@ -92,7 +92,7 @@ replay_buffer_cfg = ReplayBufferConfig( ```{code-block} python :caption: 配置推理环境 -from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig model_path = "/path/to/qwen3-8B" # 替换为您的模型路径 @@ -144,8 +144,8 @@ judger_cfg = JudgerConfig( :caption: 配置训练策略 from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig +from xtuner.v1.rl.rollout.worker import WorkerConfig +from xtuner.v1.rl.loss import GRPOLossConfig model_path = "/path/to/qwen3-8B" # 填入您的模型路径 train_optimizer_steps = 4 # 训练优化步数 @@ -201,7 +201,7 @@ evaluator_cfg = EvaluatorConfig( 除以上的生成和训练配置外,我们需要配置系统所需资源(如GPU、CPU、内存)等,此处我们使用默认的资源配置,示例如下。 ```{code-block} python -from xtuner.v1.ray.base import AcceleratorResourcesConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig resources = AcceleratorResourcesConfig( accelerator="GPU", num_accelerators_per_worker=1, diff --git a/examples/v1/config/rl_dapo_math.py b/examples/v1/config/rl_dapo_math.py new file mode 100644 index 0000000000..238f18ae0a --- /dev/null +++ b/examples/v1/config/rl_dapo_math.py @@ -0,0 +1,211 @@ +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, TaskSpecConfig, SingleTurnAgentLoopConfig, SyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig + +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "dapo_math" +total_epochs = 1 +global_batch_size = 512 +prompt_repeat_k = 16 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 2048 +max_response_length = 8192 +pack_max_length = 32768 +train_optimizer_steps = 16 +hf_interval = 50 +enable_initial_evaluate = True +evaluate_step = 5 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), + rollout_max_batch_size_per_instance=2048 +) + +# 3. judger +from xtuner.v1.rl.utils import get_eos_token +from transformers import AutoTokenizer +eos_token_id = get_eos_token(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) +judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", + num_ray_actors=1, + eos_token=eos_token_str, + enable_overlong_buffer = True, + max_response_len =max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer +) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=0.7, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +def dapo_compute_metric(samples): + return {"accuracy": sum(s.reward["acc"] > 0 for s in samples) / len(samples)} + +evaluator_config = EvaluatorConfig(compute_metric_func=dapo_compute_metric) + +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + rollout_steps=500, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_dapo_math_async.py b/examples/v1/config/rl_dapo_math_async.py new file mode 100644 index 0000000000..5641397fec --- /dev/null +++ b/examples/v1/config/rl_dapo_math_async.py @@ -0,0 +1,214 @@ +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, TaskSpecConfig, SingleTurnAgentLoopConfig, AsyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig + +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "dapo_math" +total_epochs = 1 +global_batch_size = 512 +prompt_repeat_k = 16 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 2048 +max_response_length = 8192 +pack_max_length = 32768 +train_optimizer_steps = 16 +hf_interval = 50 +enable_initial_evaluate = True +evaluate_step = 5 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), + rollout_max_batch_size_per_instance=2048 +) + +# 3. judger +from xtuner.v1.rl.utils import get_eos_token +from transformers import AutoTokenizer +eos_token_id = get_eos_token(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) +judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", + num_ray_actors=1, + eos_token=eos_token_str, + enable_overlong_buffer = True, + max_response_len=max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer) +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold=0.2, + enable_partial_rollout=True, + tail_batch_stale_threshold=1, + tail_batch_trigger_size=256, +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=0.7, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +def dapo_compute_metric(samples): + return {"accuracy": sum(s.reward["acc"] > 0 for s in samples) / len(samples)} + +evaluator_config = EvaluatorConfig(compute_metric_func=dapo_compute_metric) + +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=AsyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + rollout_steps=500, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_dapo_math_async_filter.py b/examples/v1/config/rl_dapo_math_async_filter.py new file mode 100644 index 0000000000..5b26fd9c68 --- /dev/null +++ b/examples/v1/config/rl_dapo_math_async_filter.py @@ -0,0 +1,227 @@ +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, TaskSpecConfig, SingleTurnAgentLoopConfig, AsyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig + +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "dapo_math" +total_epochs = 1 +global_batch_size = 512 +prompt_repeat_k = 16 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 2048 +max_response_length = 8192 +pack_max_length = 32768 +train_optimizer_steps = 16 +hf_interval = 50 +enable_initial_evaluate = True +evaluate_step = 5 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), + rollout_max_batch_size_per_instance=1024 +) + +# 3. judger +from xtuner.v1.rl.utils import get_eos_token +from transformers import AutoTokenizer +eos_token_id = get_eos_token(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) +dapomath_judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", + eos_token=eos_token_str, + enable_overlong_buffer = True, + max_response_len =max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer) +judger_config = DapoMathJudgerConfig(judger_name="dapo_math", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +def group_samples_filter_func(rollout_states): + valid_responses = [] + for state in rollout_states: + if state.response_ids is not None: + valid_responses.append(state) + + rewards = [res.reward["score"] for res in valid_responses] + if len(set(rewards)) == 1: + return False + else: + return True + +produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold=0.2, + enable_partial_rollout=True, + tail_batch_stale_threshold=1, + tail_batch_trigger_size=256, + is_valid_sample_fn=group_samples_filter_func +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=0.7, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +def dapo_compute_metric(samples): + return {"accuracy": sum(s.reward["acc"] > 0 for s in samples) / len(samples)} + +evaluator_config = EvaluatorConfig(compute_metric_func=dapo_compute_metric) + +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=AsyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_grpo_geo3k_judge.py b/examples/v1/config/rl_grpo_geo3k_judge.py new file mode 100644 index 0000000000..ba30dda1ef --- /dev/null +++ b/examples/v1/config/rl_grpo_geo3k_judge.py @@ -0,0 +1,223 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLQwen3VLTokenizeFnConfig +from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense8BConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GEO3KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, TaskSpecConfig, SingleTurnAgentLoopConfig, SyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) +media_root = os.environ["MEDIA_ROOT"] + +# basic settings +experimental_name = "grpo_geo3k" +rollout_steps = 45 # TODO: total_epoch +evaluate_step = 45 +train_optimizer_steps = 4 +global_batch_size = 1024 +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 1024 +max_response_length = 2048 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GEO3KJudgerConfig(num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) + +# TODO: support get_model_config_from_hf +model_cfg = Qwen3VLDense8BConfig(freeze_vision=True, freeze_projector=True) + +if hasattr(model_cfg.text_config, "balancing_loss_cfg"): + model_cfg.text_config.balancing_loss_cfg = None +if hasattr(model_cfg.text_config, "z_loss_cfg"): + model_cfg.text_config.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset_cfg = [ + { + "dataset": DatasetConfig(name="geo3k", + anno_path=data_path, + class_name='VLMJsonlDataset', + media_root=media_root, + sample_ratio=1.0), + "tokenize_fn": RLQwen3VLTokenizeFnConfig(processor_path=model_path, + max_length=max_prompt_length), + } +] + +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + num_workers=8, +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset_cfg = [ + { + "dataset": DatasetConfig(name="geo3k", + anno_path=eval_data_path, + class_name='VLMJsonlDataset', + media_root=media_root, + sample_ratio=1.0), + "tokenize_fn": RLQwen3VLTokenizeFnConfig(processor_path=model_path, + max_length=max_prompt_length, + ignore_multimodal_info=True), + } +] + +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + num_workers=8, +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + rollout_steps=rollout_steps, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_grpo_gsm8k_async.py b/examples/v1/config/rl_grpo_gsm8k_async.py new file mode 100644 index 0000000000..3316851338 --- /dev/null +++ b/examples/v1/config/rl_grpo_gsm8k_async.py @@ -0,0 +1,205 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, TaskSpecConfig, SingleTurnAgentLoopConfig, AsyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k" +rollout_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +global_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 2048 +max_response_length = 8192 +pack_max_length = 10 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold = 0.8, + enable_partial_rollout = True, + tail_batch_stale_threshold=1, + tail_batch_trigger_size=64 +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=AsyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + rollout_steps=rollout_steps, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_grpo_gsm8k_judge.py b/examples/v1/config/rl_grpo_gsm8k_judge.py new file mode 100644 index 0000000000..25acc4faaa --- /dev/null +++ b/examples/v1/config/rl_grpo_gsm8k_judge.py @@ -0,0 +1,200 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, TaskSpecConfig, SingleTurnAgentLoopConfig, SyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k" +rollout_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +global_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 512 +max_response_length = 1024 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + rollout_steps=rollout_steps, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_grpo_gsm8k_with_tool.py b/examples/v1/config/rl_grpo_gsm8k_with_tool.py new file mode 100644 index 0000000000..239adbe9e1 --- /dev/null +++ b/examples/v1/config/rl_grpo_gsm8k_with_tool.py @@ -0,0 +1,222 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, TaskSpecConfig, SyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig +from xtuner.v1.rl.agent_loop.gsm8k_with_tool import GSM8KToolAgentLoopConfig + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k_with_tool" +rollout_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +global_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 1024 +max_response_length = 2048 +pack_max_length = 8 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +gsm8k_tools = [ + { + "type": "function", + "function": { + "name": "calc_gsm8k_reward", + "description": "A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The model's answer to the GSM8K math problem, must be a digits", + }, + "required": ["answer"], + }, + }, + }, + } +] +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length, tools_schema=gsm8k_tools) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = GSM8KToolAgentLoopConfig( + max_turns=2, + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = GSM8KToolAgentLoopConfig( + max_turns=2, + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + rollout_steps=rollout_steps, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_interns1_mini_grpo.py b/examples/v1/config/rl_interns1_mini_grpo.py deleted file mode 100644 index 8872395a34..0000000000 --- a/examples/v1/config/rl_interns1_mini_grpo.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -from copy import deepcopy - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.model.compose.intern_s1 import InternS1MiniConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, InternS1VLTokenizeFnConfig, DataloaderConfig -from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False -media_root = os.environ["MEDIA_ROOT"] - -# basic settings -experimental_name = "grpo_geo3k" -total_epochs = 15 -global_batch_size = 1024 -prompt_repeat_k = 5 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 4096 # Note: 不设置大一点,大部分数据都会被过滤掉 -max_response_length = 2048 -pack_max_length = 32768 -train_optimizer_steps = 4 -hf_interval = 15 -enable_initial_evaluate = True -evaluate_step = 10 - -# grpo quick test: -# total_epochs = 3 -# global_batch_size = 64 -# prompt_repeat_k = 5 -# rollout_tp_size = 1 -# rollout_ep_size = 1 -# max_prompt_length = 512 -# max_response_length = 1024 -# pack_max_length = 32768 -# train_optimizer_steps = 1 -# hf_interval = 100 -# enable_initial_evaluate = True -# evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length=max_prompt_length+max_response_length, - extra_rollout_config={ - "sglang_grammar_backend": 'none', - } - # rollout_max_batch_size_per_instance=16, # optional -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - -tokenize_fn_cfg = InternS1VLTokenizeFnConfig(model_cfg=InternS1MiniConfig(), max_length=max_prompt_length) -train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg), - } -] - -eval_dataset_cfg = [] -if enable_evaluate: - eval_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=eval_data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg, - ignore_multimodal_info=True), - } - ] - -dataloader_config = DataloaderConfig(num_workers=8, - collator="fake_collator", - pack_level="none") - -# 3. judger -geo3k_judger_config = GEO3KJudgerConfig() -judger_cfg = JudgerConfig(reward_judger_configs=[geo3k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - # max_concurrent=64, # optional -) - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=model_path, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=model_path -) - -# 5. Train worker -# NOTE: modify model_cfg -model_cfg = InternS1MiniConfig(freeze_vision=True, freeze_projector=True) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, -) diff --git a/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py b/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py new file mode 100644 index 0000000000..5a6f992d42 --- /dev/null +++ b/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py @@ -0,0 +1,311 @@ +"""RL Colocate Trainer 示例配置(Multi-Task: GSM8K + DAPO Math)。 + +需设置环境变量: + WORK_DIR + MODEL_PATH + GSM8K_DATA_PATH + GSM8K_EVAL_DATA_PATH + DAPO_DATA_PATH + DAPO_EVAL_DATA_PATH + +可选环境变量: + WORLD_SIZE + ENABLE_RETURN_ROUTED_EXPERTS + LOSS_TYPE + LOSS_MODE + SP_SIZE + GSM8K_TASK_WEIGHT + DAPO_TASK_WEIGHT +""" + +import os +from pathlib import Path + +from transformers import AutoTokenizer + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.agent_loop import ( + AgentLoopManagerConfig, + SamplerConfig, + SingleTurnAgentLoopConfig, + SyncProduceStrategyConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, get_eos_token +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig + +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +gsm8k_data_path = os.environ["GSM8K_DATA_PATH"] +gsm8k_eval_data_path = os.environ["GSM8K_EVAL_DATA_PATH"] +dapo_data_path = os.environ["DAPO_DATA_PATH"] +dapo_eval_data_path = os.environ["DAPO_EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +experimental_name = "multi_task_gsm8k_dapo_math" +rollout_steps = 50 +evaluate_step = 5 +train_optimizer_steps = 8 +global_batch_size = 128 +gsm8k_task_weight = float(os.environ.get("GSM8K_TASK_WEIGHT", "1.0")) +dapo_task_weight = float(os.environ.get("DAPO_TASK_WEIGHT", "1.0")) +rollout_tp_size = 1 +rollout_ep_size = 1 +gsm8k_prompt_repeat_k = 5 +dapo_prompt_repeat_k = 8 +gsm8k_max_prompt_length = 512 +dapo_max_prompt_length = 2048 +gsm8k_max_response_length = 1024 +dapo_max_response_length = 8192 +max_prompt_length = dapo_max_prompt_length +max_response_length = dapo_max_response_length +pack_max_length = 32768 + +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, +) + +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), + rollout_max_batch_size_per_instance=2048, +) + +eos_token_id = get_eos_token(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) +judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", + num_ray_actors=1, + eos_token=eos_token_str, + enable_overlong_buffer=True, + max_response_len=max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer, +) + +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +gsm8k_train_tokenizer_config = RLTextTokenizeFnConfig(max_length=gsm8k_max_prompt_length) +dapo_train_tokenizer_config = RLTextTokenizeFnConfig(max_length=dapo_max_prompt_length) + +gsm8k_train_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="gsm8k", anno_path=gsm8k_data_path), + "tokenize_fn": gsm8k_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=gsm8k_prompt_repeat_k, +) +dapo_train_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="dapo_math", anno_path=dapo_data_path), + "tokenize_fn": dapo_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=dapo_prompt_repeat_k, +) + +gsm8k_train_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=gsm8k_max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ), +) +dapo_train_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=dapo_max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ), +) + +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="train_task:dapo_math", + weight=dapo_task_weight, + agent_loop_config=dapo_train_agent_loop_config, + judger_config=judger_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=dapo_train_sampler_config, + ), + TaskSpecConfig( + task_name="train_task:gsm8k", + weight=gsm8k_task_weight, + agent_loop_config=gsm8k_train_agent_loop_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=gsm8k_train_sampler_config, + ), + ], +) + +gsm8k_eval_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="gsm8k_eval", anno_path=gsm8k_eval_data_path, sample_ratio=1.0), + "tokenize_fn": gsm8k_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=1, +) +dapo_eval_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="dapo_math_eval", anno_path=dapo_eval_data_path, sample_ratio=1.0), + "tokenize_fn": dapo_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=1, +) + +gsm8k_eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=gsm8k_max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, + ), +) +dapo_eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=dapo_max_response_length, + top_k=1, + top_p=0.7, + temperature=0.0, + min_tokens=0, + ), +) + +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="eval_task:dapo_math", + weight=dapo_task_weight, + agent_loop_config=dapo_eval_agent_loop_config, + judger_config=judger_config, + sampler_config=dapo_eval_sampler_config, + ), + TaskSpecConfig( + task_name="eval_task:gsm8k", + weight=gsm8k_task_weight, + agent_loop_config=gsm8k_eval_agent_loop_config, + sampler_config=gsm8k_eval_sampler_config, + ), + ], +) + + +def compute_metric(samples): + return {"accuracy": sum(sample.reward["acc"] > 0 for sample in samples) / len(samples)} + + +evaluator_config = EvaluatorConfig(compute_metric_func=compute_metric) + +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + rollout_steps=rollout_steps, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_qwen25_7B_dapo.py b/examples/v1/config/rl_qwen25_7B_dapo.py deleted file mode 100644 index c5451cec28..0000000000 --- a/examples/v1/config/rl_qwen25_7B_dapo.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False -enbale_partial_rollout = int(os.environ.get("ENBALE_PARTIAL_ROLLOUT", "0")) - -# basic settings -experimental_name = "dapo_math" -total_epochs = 1 -global_batch_size = 512 -prompt_repeat_k = 16 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 2048 -max_response_length = 8192 -pack_max_length = 32768 -train_optimizer_steps = 16 -hf_interval = 50 -enable_initial_evaluate = True -evaluate_step = 5 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.8, - context_length = max_response_length + max_prompt_length, - # rollout_max_batch_size_per_instance=64, # optional, will be determined automatically if not set -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, - top_k=0, - top_p=1.0, - temperature=1.0, - min_tokens=0, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 0.7 - -# dataset -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -from xtuner.v1.utils.rl_test_utils import get_eos_token -eos_token_id = get_eos_token(model_path) -eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) -dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer = True, - max_response_len =max_response_length, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer) -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - enable_partial_rollout=enbale_partial_rollout, - # max_concurrent=64, # optional, will be determined automatically if not set -) - - -def dapo_compute_metric(samples): - return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)} - - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=dapo_compute_metric, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, -) diff --git a/examples/v1/config/rl_qwen25_7B_dapo_async.py b/examples/v1/config/rl_qwen25_7B_dapo_async.py deleted file mode 100644 index 95ec88a81b..0000000000 --- a/examples/v1/config/rl_qwen25_7B_dapo_async.py +++ /dev/null @@ -1,212 +0,0 @@ -import os -from copy import deepcopy - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling -from xtuner.v1.model import get_model_config_from_hf - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ.get("EVAL_DATA_PATH") -enable_evaluate = True if eval_data_path != "" else False -global_batch_size = int(os.environ.get("GLOBAL_BATCH_SIZE", "16")) -enable_return_routed_experts = 0 -enbale_partial_rollout = 1 -staleness_threshold = 0.2 -tail_batch_candidate_steps = 2 -tail_batch_trigger_size = global_batch_size -max_response_length= 8192 -enable_float8_rollout = 0 - -# basic settings -experimental_name = "dapo_math" -total_epochs = 1 -prompt_repeat_k = 16 -rollout_tp_size = 1 -rollout_ep_size = 1 -max_prompt_length = 2048 -pack_max_length = 32768 -train_optimizer_steps = 16 -hf_interval = 50 -enable_initial_evaluate = True -evaluate_step = 5 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.8, - enable_float8=enable_float8_rollout, - context_length = max_response_length + max_prompt_length, - rollout_max_batch_size_per_instance=512, - allow_over_concurrency_ratio=4, - rollout_timeout=7200.0, - enable_return_routed_experts=enable_return_routed_experts, - extra_rollout_config=dict(lmdeploy_log_level="ERROR", lmdeploy_uvicorn_log_level="ERROR"), -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, - top_k=0, - top_p=1.0, - temperature=1.0, - min_tokens=0, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 0.7 - -# dataset -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -from xtuner.v1.utils.rl_test_utils import get_eos_token -eos_token_id = get_eos_token(model_path) -eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) -dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer = True, - max_response_len =max_response_length, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer) -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - enable_partial_rollout=enbale_partial_rollout, - staleness_threshold=staleness_threshold, - tail_batch_candidate_steps=tail_batch_candidate_steps, - tail_batch_trigger_size=tail_batch_trigger_size -) - -def dapo_compute_metric(samples): - return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)} - - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=dapo_compute_metric, - sample_params=evaluation_sample_params, - max_concurrent=1024, -) if enable_evaluate else None - -def group_sample_filter_func(group_samples): - rewards = [d.env.judger.reward["score"] for d in group_samples] - if len(set(rewards)) == 1: - print(f"filter all same reward sample: {rewards}") - return [] - else: - return group_samples - -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=tokenizer, - # postprocessor_func=group_sample_filter_func -) - -# 5. Train worker -model_cfg = Qwen2Dense7BConfig() -optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - rollout_is=RolloutImportanceSampling( - rollout_is_level="token", - rollout_is_mode="both", - rollout_is_threshold=(5, 0.5), - rollout_is_mask_threshold=(5, 0.5), - rollout_is_veto_threshold=(20, 0), - ), -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, -) diff --git a/examples/v1/config/rl_qwen3_30B_dapo.py b/examples/v1/config/rl_qwen3_30B_dapo.py deleted file mode 100644 index cf95b6a1ca..0000000000 --- a/examples/v1/config/rl_qwen3_30B_dapo.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.model import get_model_config_from_hf - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", '0') -enable_evaluate = True if eval_data_path != "" else False - -# basic settings -experimental_name = "dapo_math" -total_epochs = 1 -global_batch_size = 512 -prompt_repeat_k = 16 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 2048 -max_response_length = 8192 -pack_max_length = 32768 -train_optimizer_steps = 16 -hf_interval = 50 -enable_initial_evaluate = True -evaluate_step = 5 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.8, - context_length = max_response_length + max_prompt_length, - # rollout_max_batch_size_per_instance=512, - enable_return_routed_experts=True if enable_return_routed_experts == "1" else False, -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, - top_k=0, - top_p=1.0, - temperature=1.0, - min_tokens=0, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 0.7 - -# dataset -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -from xtuner.v1.utils.rl_test_utils import get_eos_token -eos_token_id = get_eos_token(model_path) -eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) -dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer = True, - max_response_len =max_response_length, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer) -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, -) - - -def dapo_compute_metric(samples): - return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)} - - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=dapo_compute_metric, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, -) diff --git a/examples/v1/config/rl_qwen3_30B_dapo_async.py b/examples/v1/config/rl_qwen3_30B_dapo_async.py deleted file mode 100644 index fdfa6826cb..0000000000 --- a/examples/v1/config/rl_qwen3_30B_dapo_async.py +++ /dev/null @@ -1,213 +0,0 @@ -import os -from copy import deepcopy - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling -from xtuner.v1.model import get_model_config_from_hf - - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ.get("EVAL_DATA_PATH") -enable_evaluate = True if eval_data_path != "" else False -global_batch_size = int(os.environ.get("GLOBAL_BATCH_SIZE", "16")) -enable_return_routed_experts = 1 -enbale_partial_rollout = 1 -staleness_threshold = 0.2 -tail_batch_candidate_steps = 2 -tail_batch_trigger_size = global_batch_size -max_response_length= 8192 -enable_float8_rollout = 0 - -# basic settings -experimental_name = "dapo_math" -total_epochs = 1 -prompt_repeat_k = 16 -rollout_tp_size = 1 -rollout_ep_size = 1 -max_prompt_length = 2048 -pack_max_length = 32768 -train_optimizer_steps = 16 -hf_interval = 50 -enable_initial_evaluate = True -evaluate_step = 5 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.8, - enable_float8=enable_float8_rollout, - context_length = max_response_length + max_prompt_length, - rollout_max_batch_size_per_instance=512, - allow_over_concurrency_ratio=4, - rollout_timeout=7200.0, - enable_return_routed_experts=enable_return_routed_experts, - extra_rollout_config=dict(lmdeploy_log_level="ERROR", lmdeploy_uvicorn_log_level="ERROR"), -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, - top_k=0, - top_p=1.0, - temperature=1.0, - min_tokens=0, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 0.7 - -# dataset -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -from xtuner.v1.utils.rl_test_utils import get_eos_token -eos_token_id = get_eos_token(model_path) -eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) -dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer = True, - max_response_len =max_response_length, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer) -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - enable_partial_rollout=enbale_partial_rollout, - staleness_threshold=staleness_threshold, - tail_batch_candidate_steps=tail_batch_candidate_steps, - tail_batch_trigger_size=tail_batch_trigger_size -) - -def dapo_compute_metric(samples): - return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)} - - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=dapo_compute_metric, - sample_params=evaluation_sample_params, - max_concurrent=1024, -) if enable_evaluate else None - -def group_sample_filter_func(group_samples): - rewards = [d.env.judger.reward["score"] for d in group_samples] - if len(set(rewards)) == 1: - print(f"filter all same reward sample: {rewards}") - return [] - else: - return group_samples - -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=tokenizer, - # postprocessor_func=group_sample_filter_func -) - -# 5. Train worker -model_cfg = Qwen3MoE30BA3Config(freeze_routers=True) -optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.95), max_grad_norm=1.0, weight_decay=0.1, foreach=False, skip_grad_norm_threshold=0.9, eps=1e-15) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - rollout_is=RolloutImportanceSampling( - rollout_is_level="token", - rollout_is_mode="both", - rollout_is_threshold=(5, 0.5), - rollout_is_mask_threshold=(5, 0.5), - rollout_is_veto_threshold=(20, 0), - ), -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, -) diff --git a/examples/v1/config/rl_qwen3_30B_grpo.py b/examples/v1/config/rl_qwen3_30B_grpo.py deleted file mode 100644 index 02c42d14bb..0000000000 --- a/examples/v1/config/rl_qwen3_30B_grpo.py +++ /dev/null @@ -1,174 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.model import get_model_config_from_hf - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", '0') -enable_evaluate = True if eval_data_path != "" else False - -# basic settings -experimental_name = "grpo_gsm8k" -total_epochs = 15 -global_batch_size = 1024 -prompt_repeat_k = 5 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 512 -max_response_length = 1024 -pack_max_length = 32768 -train_optimizer_steps = 4 -hf_interval = 15 -enable_initial_evaluate = True -evaluate_step = 10 - -# grpo quick test settings -# total_epochs = 3 -# global_batch_size = 64 -# prompt_repeat_k = 5 -# rollout_tp_size = 1 -# rollout_ep_size = 1 -# max_prompt_length = 512 -# max_response_length = 1024 -# pack_max_length = 32768 -# train_optimizer_steps = 1 -# hf_interval = 100 -# enable_initial_evaluate = True -# evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length = max_response_length + max_prompt_length, - enable_return_routed_experts=True if enable_return_routed_experts == "1" else False, -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -dapomath_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, -) - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -# NOTE: modify model_cfg -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, -) diff --git a/examples/v1/config/rl_qwen3_8B_grpo.py b/examples/v1/config/rl_qwen3_8B_grpo.py deleted file mode 100644 index a7246c9f7c..0000000000 --- a/examples/v1/config/rl_qwen3_8B_grpo.py +++ /dev/null @@ -1,182 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.model import get_model_config_from_hf - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False -enbale_partial_rollout = int(os.environ.get("ENBALE_PARTIAL_ROLLOUT", "0")) - -# basic settings -experimental_name = "grpo_gsm8k" -total_epochs = 15 -global_batch_size = 1024 -prompt_repeat_k = 5 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 512 -max_response_length = 1024 -pack_max_length = 32768 -train_optimizer_steps = 4 -hf_interval = 15 -enable_initial_evaluate = True -evaluate_step = 10 -# TODO: 提供不同模型/不同输入输出长度下最优的rollout_max_batch_size_per_instance配置建议 -# NOTE: 目前Xtuner的数据流并发度由rollout_max_batch_size_per_instance控制,并且提供allow_over_concurrency_ratio来控制数据流并发度略大于推理引擎并发度, -# 具体逻辑可见 xtuner/v1/ray/dataflow/flow.py 中 max_concurrent 的计算方式 -# 当然你也可以手动调整 dataflow_config 中的 max_concurrent 参数来控制数据流并发度 -rollout_max_batch_size_per_instance = 128 - -# grpo quick test settings for rapid accuracy validation within ~30 minutes: -# - Initial eval accuracy: ~25% -# - After training: ~88% eval accuracy -# total_epochs = 3 -# global_batch_size = 64 -# prompt_repeat_k = 5 -# rollout_tp_size = 1 -# rollout_ep_size = 1 -# max_prompt_length = 512 -# max_response_length = 1024 -# pack_max_length = 32768 -# train_optimizer_steps = 1 -# hf_interval = 100 -# enable_initial_evaluate = True -# evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length = max_response_length + max_prompt_length, - # rollout_max_batch_size_per_instance=rollout_max_batch_size_per_instance, # optional, will be determined automatically if not set -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - enable_partial_rollout=enbale_partial_rollout, - # max_concurrent=64, # optional, will be determined automatically if not set -) - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, -) diff --git a/examples/v1/config/rl_qwen3_8B_grpo_tiny.py b/examples/v1/config/rl_qwen3_8B_grpo_tiny.py deleted file mode 100644 index 7097b3b9a0..0000000000 --- a/examples/v1/config/rl_qwen3_8B_grpo_tiny.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.model import get_model_config_from_hf - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] - -# basic settings -experimental_name = "grpo_gsm8k_tiny" -total_epochs = 1 -global_batch_size = 128 -prompt_repeat_k = 8 -rollout_tp_size = 1 -rollout_ep_size = 1 -max_prompt_length = 512 -max_response_length = 1024 -pack_max_length = 32768 -train_optimizer_steps = 1 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=1, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length=max_prompt_length+max_response_length, - # rollout_max_batch_size_per_instance=1024, # optional -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) - -# dataset: 不需要修改 -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, -) - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -# NOTE: modify model_cfg -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, -) diff --git a/examples/v1/config/rl_qwen3_vl_8B_grpo.py b/examples/v1/config/rl_qwen3_vl_8B_grpo.py deleted file mode 100644 index 92fd44761b..0000000000 --- a/examples/v1/config/rl_qwen3_vl_8B_grpo.py +++ /dev/null @@ -1,201 +0,0 @@ -import os -from copy import deepcopy - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense8BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, Qwen3VLTokenizeFnConfig, DataloaderConfig -from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False -media_root = os.environ["MEDIA_ROOT"] - -# basic settings -experimental_name = "grpo_geo3k" -total_epochs = 15 -global_batch_size = 1024 -prompt_repeat_k = 5 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 1024 -max_response_length = 2048 -pack_max_length = 32768 -train_optimizer_steps = 4 -hf_interval = 15 -enable_initial_evaluate = True -evaluate_step = 10 - -# grpo quick test settings: -# total_epochs = 3 -# global_batch_size = 64 -# prompt_repeat_k = 5 -# rollout_tp_size = 1 -# rollout_ep_size = 1 -# max_prompt_length = 512 -# max_response_length = 1024 -# pack_max_length = 32768 -# train_optimizer_steps = 1 -# hf_interval = 100 -# enable_initial_evaluate = True -# evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length = max_response_length + max_prompt_length, - # rollout_max_batch_size_per_instance=64, # optional, will be determined automatically if not set -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - -tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(processor_path=model_path, max_length=max_prompt_length) -train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg), - } -] - -eval_dataset_cfg = [] -if enable_evaluate: - eval_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=eval_data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg, - ignore_multimodal_info=True), - } - ] - -dataloader_config = DataloaderConfig(num_workers=8, - collator="fake_collator", - pack_level="none") - -# 3. judger -geo3k_judger_config = GEO3KJudgerConfig() -judger_cfg = JudgerConfig(reward_judger_configs=[geo3k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - # max_concurrent=64, # optional, will be determined automatically if not set -) - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -# NOTE: modify model_cfg -model_cfg = Qwen3VLDense8BConfig(freeze_vision=True, freeze_projector=True) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, -) diff --git a/examples/v1/scripts/run_rl.sh b/examples/v1/scripts/run_rl.sh index 3fe52ec44e..48eec6db54 100644 --- a/examples/v1/scripts/run_rl.sh +++ b/examples/v1/scripts/run_rl.sh @@ -23,6 +23,8 @@ else ACCELERATOR_PER_NODE=${7:-8} fi +ulimit -n 65536 # OSError: [Errno 24] Too many open files + export PYTHONPATH=$(pwd):$PYTHONPATH # ray 环境变量 diff --git a/examples/v1/scripts/run_rl_submit.sh b/examples/v1/scripts/run_rl_submit.sh index 4d268527d0..1fb09c644e 100644 --- a/examples/v1/scripts/run_rl_submit.sh +++ b/examples/v1/scripts/run_rl_submit.sh @@ -21,6 +21,8 @@ else ACCELERATOR_PER_NODE=${7:-8} fi +ulimit -n 65536 # OSError: [Errno 24] Too many open files + export PYTHONPATH=$(pwd):$PYTHONPATH # NOTE: if you add new env vars, please also add them to RUNTIME_ENV_JSON in step 4. # master 节点的 IP 地址 diff --git a/recipe/__init__.py b/recipe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recipe/rl_simulator/plan.md b/recipe/rl_simulator/plan.md new file mode 100644 index 0000000000..8b22289bfa --- /dev/null +++ b/recipe/rl_simulator/plan.md @@ -0,0 +1,891 @@ +# RL Simulator Recipe 计划 + +## 目标 + +在 `recipe/rl_simulator/` 下建设一套 RL 模拟验证 recipe,用于在不运行真实推理引擎和真实训练引擎的情况下,验证 XTuner RL 复杂控制流的正确性。 + +从 trainer 视角看,simulator 应该像正常 RL 训练任务一样运行: + +- fake rollout engine +- fake training engine +- fake weight sync +- fake judge +- recipe 专用 `SimulationSampler` +- 继续走现有 `RLColocateTrainer`、`AgentLoopManager`、`SingleTurnAgentLoop`、replay buffer、sync/async producer 路径 + +正确性不在 simulator 运行时通过断言判断,而是在运行结束后由每个 case 自己的离线 log analyzer 判断。 + +第一版只做单机 CPU-only 模拟。Ray 可以继续作为现有 `.remote()` actor 接口的本地调度机制,但不模拟真实多机、多 rank、GPU/NPU 训练。 + +## 为什么放在 `recipe/` + +这个 simulator 是实验性验证配方,不是稳定公共 API。 + +放在 `recipe/rl_simulator/` 下有几个好处: + +- 可以随着代码版本快速迭代 +- 允许每个 case 有自己的 `verify.py` 校验逻辑 +- 对 `xtuner/v1/rl` 主路径改动最少 +- 避免 fake/runtime 测试代码进入 RL 产品主接口 + +## 目录结构 + +```text +recipe/rl_simulator/ + README.md + plan.md + __init__.py + + core/ + __init__.py + deterministic.py + runner.py + + runtime/ + __init__.py + sampler.py + rollout.py + trainer.py + judger.py + + analyzer/ + __init__.py + cli.py + context.py + helpers.py + loaders.py + result.py + + cases/ + __init__.py + + smoke_sync/ + __init__.py + config.py + verify.py + + async_partial/ + __init__.py + config.py + verify.py + + multitask_weighted/ + __init__.py + config.py + verify.py + + fault_filtering/ + __init__.py + config.py + verify.py + + tail_expired/ + __init__.py + config.py + verify.py + + async_update_every_4/ + __init__.py + config.py + verify.py + + determinism_replay/ + __init__.py + config.py + verify.py +``` + +### 文件职责说明 + +`core/` 放运行 simulator 所需的通用基础能力,不直接实现 fake RL 组件: + +- `core/__init__.py`: 标记 Python package,并导出少量常用入口。 +- `core/deterministic.py`: 提供确定性随机工具。所有 response length、token、delay、fault、reward、fake loss 都从 `seed + component + task_name + uid + attempt + repeat_index` 派生,禁止依赖全局随机状态。 +- `core/runner.py`: simulator 运行入口。负责加载 case config、初始化 Ray、构建 trainer、调用 `trainer.fit()`,并可选在结束后调用 analyzer。 + +`runtime/` 放真正参与训练流程的 fake 组件: + +- `runtime/__init__.py`: 导出 fake runtime 配置和组件,方便 case config 统一 import。 +- `runtime/sampler.py`: 实现 `SimulationSamplerConfig` 和 `SimulationSampler`。它是第一版必选组件,用于稳定生成 uid、task_name、attempt 等 analyzer 需要的信息。 +- `runtime/rollout.py`: 实现 `FakeRolloutConfig`、`FakeRolloutController`、`DelaySpec`、`FaultSpec`。`FakeRolloutConfig.build()` 直接返回本地 Ray actor 形式的 fake rollout controller,不实现真实 rollout worker。 +- `runtime/trainer.py`: 实现 `FakeWorkerConfig`、`FakeTrainingController`、`FakeWeightSyncConfig`。`FakeWorkerConfig.build()` 直接返回本地 Ray actor 形式的 fake training controller,不再实现真实 `TrainingWorker` 或 fake model。 +- `runtime/judger.py`: 实现 `FakeJudgerConfig` 和 `FakeJudger`。负责给 completed rollout 生成确定性随机 `0/1` reward。 + +`analyzer/` 放离线分析的公共框架,不写具体 case 的业务规则: + +- `analyzer/__init__.py`: 导出 analyzer 公共 API。 +- `analyzer/cli.py`: 离线分析命令入口。构造 `AnalyzerContext`,加载 case 目录下的 `verify.py`,调用 `verify(ctx)` 并写出结果。 +- `analyzer/context.py`: 定义 `AnalyzerContext`,统一承载 case_dir、work_dir、从正常日志中提取出的 case 相关信息、tracker rows、train rollout 数据和原始 log 路径。 +- `analyzer/helpers.py`: 提供可复用的日志分组、failure 构造和基础断言 helper。helper 不包含具体 case 策略。 +- `analyzer/loaders.py`: 负责读取 `logs/rank_*.log`、`tracker.jsonl` 和 `train_rollout/*.jsonl`。其中关键 fake 日志从普通日志行里的 JSON 片段提取,使用 `json.loads()`;正则只用于非关键辅助信息。 +- `analyzer/result.py`: 定义 `CheckResult`、failure 记录格式,以及把 summary/failures 写到 `logs/simulation_analysis_summary.json` 和 `logs/simulation_analysis_failures.jsonl` 的工具。 + +## 主代码最小改动 + +主代码尽量少改,主要做 duck typing 兼容。 + +### `xtuner/v1/train/rl_colocate_trainer.py` + +放宽配置类型,让 recipe 里的 fake config 能传进来: + +- `RLColocateTrainerConfig.train_worker_cfg`: 当前是 Pydantic 字段 `WorkerConfig`,需要改成 `Any` 或 `WorkerConfig | FakeWorkerConfig`。为了保持 recipe 独立,优先改成 `Any`,并在注释中说明要求实现 `build(placement_group)`。 +- `RLColocateTrainerConfig.rollout_config`: 当前是 Pydantic 字段 `RolloutConfig`,需要改成 `Any` 或 `RolloutConfig | FakeRolloutConfig`。优先改成 `Any`,要求实现 `build(placement_group)`。 +- `RLColocateTrainer.__init__` 的参数类型同步放宽,但运行逻辑不加 fake 分支。 + +注意:这里不能只改类型注解而不改 Pydantic 字段类型。`RLColocateTrainerConfig` 继承 `BaseModel`,字段类型如果仍是 `WorkerConfig` / `RolloutConfig`,Pydantic 会拒绝 recipe fake config,duck typing 不会自动生效。 + +当前 trainer 本身已经调用 `.build(...)`: + +```python +self.train_controller = train_worker_cfg.build(self._pg) +self.rollout_controller = rollout_config.build(self._pg) +``` + +所以字段类型放宽后,运行逻辑可以基本不变。 + +### `xtuner/v1/rl/agent_loop/agent_loop_manager.py` + +放宽 `TaskSpecConfig.sampler_config` 类型,让 recipe 里的 `SimulationSamplerConfig` 能通过 Pydantic 校验: + +```python +sampler_config: Any +``` + +并要求它实现: + +```python +build(tokenizer, replay_buffer) +``` + +不再改现有 `xtuner/v1/rl/agent_loop/sampler.py`。第一版直接使用 recipe 下的 `SimulationSampler`,避免为了 simulator 在主路径 sampler 中加入测试专用 uid 逻辑。 + +### Agent Loop 接入 + +第一版不新增 recipe 专用 agent loop,也不修改现有 `SingleTurnAgentLoop`。 + +继续直接使用现有 `SingleTurnAgentLoop`。`rollout_step` 仍由现有 producer/agent loop 链路内部使用,例如 partial rollout 的 `response_rollout_steps` 和 `seq_staleness` 后处理;fake rollout 不从 `RolloutState` 读取 `rollout_step`,也不把 `rollout_step` 放进确定性 key。 + +step 相关 correctness 通过正常日志和产物离线判断: + +- aborted/completed 发生在哪个 rollout step。 +- step 4 训练结束后 `weight_version` 是否变成 1。 +- step 5 的 completed 样本是否使用 version 1。 +- expired 样本的 `seq_staleness` 是否符合当前 step。 + +如果现有日志无法判断这些信息,再在 producer、trainer 或 agent loop 的事实发生位置补一条普通 logger 日志。不要为了提前传递 step 而新增 agent loop wrapper。 + +### Fake Judger + +优先让 `FakeJudgerConfig` 继承现有 `JudgerConfig`,并 override `build_local()`,这样不需要改 `judger/factory.py`。 + +## Runtime 组件 + +### 模型与 Tokenizer 路径 + +不实现 `fake_hf.py`。case config 需要直接提供一个现成的 HF 模型/tokenizer 路径,例如集群上已有的 Qwen3。 + +这些路径只用于满足现有 trainer、tokenizer 和 dataset tokenize 初始化需求;fake rollout 不加载真实推理引擎,fake trainer 不加载真实训练模型。 + +### Dataset 与 Sampler + +第一版所有 case 都使用 `SimulationSampler`。可以不依赖真实 dataset 内容;如果 case config 仍需要 dataset/tokenizer 占位,也可以共用同一套很小的 gsm8k/jsonl dataset config。 + +`SimulationSampler` 必须保持现有“被动采样”语义:producer 调用 `await sampler.sample(task_name=..., group_status=...)` 时,sampler 返回一个 `list[RolloutState]` group,而不是主动 push 到 replay buffer。 + +约定: + +- `prompt_repeat_k` 由 `SimulationSampler` 自己实现。 +- 对 fresh sample,`SimulationSampler` 直接构造最小合法 `RolloutState`,不需要真实 dataset 内容。 +- 每个 fresh group 的 uid、message_uid、attempt 必须可复现;同一 group 内 `prompt_repeat_k` 个样本共享同一个 uid/message_uid,但可以用 `extra_fields["repeat_index"]` 区分。 +- `SimulationSampler.sample(task_name=...)` 必须写入 `rollout_state.task_name = task_name` 和 `rollout_state.extra_fields["task_name"] = task_name`。 +- 对 `group_status=Status.ABORTED` 或 `Status.EXPIRED`,sampler 继续优先从 replay buffer 对应状态池取已有 group;如果没有可复用 group,再退回 fresh sample。 +- sampler 不直接判断 completed/failed 是否有效,也不决定是否进入训练;这些仍由 producer、replay buffer 和 trainer 的现有逻辑决定。 + +### Fake Rollout + +实现 agent loop 和 trainer 需要的 rollout controller 外部方法: + +- `generate` +- `pause_generation` +- `continue_generation` +- `offload` +- `onload` +- `onload_weights` +- `onload_kvcache` +- `get_rollout_metadata` +- `recover_failed_workers` +- `shutdown` +- `get_ready_status` + +行为要求: + +- response length 由 seed 和稳定 key 确定 +- 如果用户设置 `max_tokens=512`,生成长度在 `[1, 512]` +- delay 由配置确定,并且可复现 +- fault 行为由配置确定,并且可复现 +- pause 可以产生带 partial response ids 的 `Status.ABORTED` +- v1 暂不支持 routed experts + +`generate()` 不额外增加 `rollout_step` 参数,必须保持真实 rollout controller 的签名。 + +fake rollout 的确定性 key 不包含 `rollout_step`,而是使用: + +```text +seed + component + task_name + uid + attempt + repeat_index +``` + +其中 `task_name`、`attempt`、`repeat_index` 由 `SimulationSampler` 写入 `RolloutState.task_name` 和 `extra_fields`。partial/aborted/expired 复用同一个 uid 时,uid 保持不变,attempt 递增。 + +Fake rollout controller 的方法签名和返回值必须与现有调用点对齐: + +```python +class FakeRolloutController: + async def generate(self, rollout_state: RolloutState) -> RolloutState: ... + def pause_generation(self) -> None: ... + def continue_generation(self) -> None: ... + def offload(self) -> None: ... + def onload(self) -> None: ... + def onload_weights(self) -> None: ... + def onload_kvcache(self) -> None: ... + def recover_failed_workers(self) -> None: ... + def shutdown(self) -> None: ... + def get_ready_status(self) -> tuple[bool, dict[str, Any]]: ... + def get_rollout_metadata(self) -> dict[str, Any]: ... +``` + +`generate()` 返回的 `RolloutState` 至少要填充这些字段: + +- `status`: `Status.COMPLETED`、`Status.ABORTED` 或 `Status.FAILED` +- `response`: 非空字符串;aborted 时为 partial response +- `response_ids`: 与本次生成或 partial 结果一致的 token id list +- `logprobs`: 长度与 `response_ids` 一致 +- `response_mask`: 长度与 `response_ids` 一致 +- `finish_reason`: completed 时为 `stop` 或 `length`,aborted 时为 `abort`,failed 时为 `error` +- `error_msg`: failed 时必须非空 + +`get_rollout_metadata()` 返回的字段需要满足 `bind_train_rollout()` 和 fake trainer 消费: + +```python +{ + "engine_rank_mesh_array": [[0]], + "server_url_dict": {0: "fake://rollout/0"}, + "rollout_config": self.config, + "worker_server_urls_status": { + "fake://rollout/0": True, + }, + "api_server_url": None, +} +``` + +`get_ready_status()` 返回: + +```python +(True, {"active_workers": 1, "total_workers": 1}) +``` + +这里保持字段名和真实 `RolloutController.get_rollout_metadata()` 一致,避免 trainer/weight sync 逻辑因为 metadata shape 不一致而失败。 + +`FaultSpec` 至少支持按样本精确注入故障: + +```python +FaultSpec( + task="train_task:main", + uid=1005, + attempt=1, + status="failed", + error="timeout", +) +``` + +匹配规则:非 `None` 字段全部相等才触发。`status` 第一版支持 `failed` 和 `aborted`,分别模拟推理失败和 partial/被暂停样本。第一版 fault injection 不按 rollout step 匹配;如果后续确实需要按 step 注入,再先确认日志和调用链能稳定提供 step。 + +### Fake Judge + +返回确定性随机 reward: + +```python +{"score": 0} +``` + +或: + +```python +{"score": 1} +``` + +reward 必须能通过下面的 key 复现: + +```text +seed + task_name + uid + attempt + repeat_index +``` + +### Fake Trainer + +fake trainer 的边界是 training controller,不是 fake model。 + +保留 `FakeWorkerConfig` 的原因是 `RLColocateTrainer` 现有入口会调用: + +```python +self.train_controller = train_worker_cfg.build(self._pg) +``` + +但 `FakeWorkerConfig.build()` 不创建真实 `TrainingWorker`,而是直接返回本地 Ray actor 形式的 `FakeTrainingController`。这样可以跳过真实 packing、actor logprobs、loss ctx、engine train step、optimizer step 和真实 weight update,只保留 trainer loop 需要的 controller 接口。 + +`FakeTrainingController` 实现 `RLColocateTrainer` 会调用的方法: + +- `fit` +- `offload` +- `onload` +- `update_rollout_info` +- `update_weights` +- `save` +- `resume` +- `save_hf` +- `ready` + +`fit()` 根据配置 sleep 指定时间,然后返回合法的 `list[WorkerLogItem]`,让现有 `_log_step()` 和 `_log_mini_batch_metrics()` 可以继续正常工作。 + +最小 `WorkerLogItem` 结构来自 `xtuner/v1/rl/trainer/worker.py`: + +```python +{ + "train_entropy": 0.0, + "train_metrics": [ + { + "loss_log": {"loss": fake_loss}, + "rl_other_log": { + "step_consumed_tokens": num_tokens, + "efficient_attn_ratio": 1.0, + }, + } + ], + "sft_train_metrics": {}, +} +``` + +其中 `loss`、`num_tokens` 可以由 deterministic key 生成,或使用 case config 里的固定值。 + +### Fake Weight Sync + +维护一个 fake `weight_version`。 + +对于延迟参数更新场景,支持类似这样的调度: + +```python +weight_update_interval = 4 +``` + +这样 8 step 的 case 可以验证: + +- step 4 后更新一次 +- step 1、2、3 不更新 +- step 8 后再更新一次 +- step 5、6、7 不更新 +- rollout step 5 使用 version 1 + +### Fake Weight Sync 协调机制 + +`weight_version` 由 `FakeRolloutController` 持有,fake trainer 通过现有同步调用链更新它: + +1. `RLColocateTrainer._sync_weights_and_save()` 调用 `bind_train_rollout(train_controller, rollout_controller)`。 +2. `bind_train_rollout()` 通过 `rollout_controller.get_rollout_metadata()` 把 fake rollout metadata 传给 fake trainer。 +3. fake trainer 的 `update_rollout_info(info_dict)` 保存 fake rollout actor handle 或 metadata 中的 fake endpoint 信息。 +4. `train_controller.update_weights()` 被调用时,fake trainer 根据 `weight_update_interval` 判断当前 step 是否应该更新。 +5. 如果应该更新,fake trainer 调用 fake rollout controller 的 `set_weight_version(version)` 或等价方法,把 rollout 侧版本推进。 + +对于 `async_update_every_4`: + +- step 1-3: `train_update_weights` 不推进版本,rollout 仍是 version 0。 +- step 4: step 4 的 rollout 先使用 version 0;step 4 train 结束后 sync,把 rollout version 推进到 1。 +- step 5: rollout 读取到 version 1。 +- step 8: step 8 train 结束后把 version 推进到 2。 + +如果为了少改 `RLColocateTrainer._sync_weights_and_save()`,也可以让 fake trainer 的 `update_weights()` 在非更新 step 直接 no-op 并打印普通日志说明 no-op;`verify.py` 负责确认只有 expected update step 出现 version 增长。 + +### Partial Rollout 与 Staleness 定义 + +v1 只按 rollout step 定义 staleness,不按 wall-clock time 定义。 + +语义沿用现有 `RolloutState.seq_staleness`: + +- 每段 response token 都有对应的 `response_rollout_steps`。 +- 当前 step 为 `cur_step` 时,`seq_staleness = cur_step - min(response_rollout_steps)`。 +- 如果样本在 step 1 生成了一段 partial response 并被 abort,step 2 继续生成时,这段历史 token 的 staleness 为 `2 - 1 = 1`。 +- `tail_batch_stale_threshold=N` 表示当 group 中存在 aborted/leftover sample 的 `seq_staleness >= N` 时,该 group 可以被标记为 `Status.EXPIRED`。 +- expired sample 再次被采样时,partial 历史必须清空,重新从 prompt 开始生成。 + +`tail_expired` 的 `verify.py` 只检查 step-based staleness: + +- `response_rollout_steps` 是否和生成 step 匹配。 +- `seq_staleness` 是否等于 `cur_step - min(response_rollout_steps)`。 +- 达到 threshold 的样本是否进入 expired/reset 路径。 + +## 日志与提取策略 + +不建设全局 event sink,也不要求 fake 组件打印特殊前缀日志。 + +simulator 按正常任务运行,优先复用现有日志和产物: + +```text +logs/rank_*.log +logs/exp_tracking/tracker.jsonl +train_rollout/train_rollout_*.jsonl +``` + +每个 case 的 `verify.py` 可以按自己的需求解析正常日志。允许不同 case 使用不同提取逻辑,但关键 fake 事件第一版就使用 JSON 片段,避免字段顺序、空格或措辞变化导致校验逻辑脆弱。 + +如果某个 `verify.py` 无法从现有日志判断正确性,再在最接近事实发生的位置补充一条普通 logger 日志。新增日志应满足: + +- 使用现有 logger 正常打印,不引入特殊前缀协议。 +- 从第一版开始,关键 fake 日志使用“普通说明 + JSON 片段”的形式,verify 使用 `json.loads()` 解析 JSON 片段,不依赖字段顺序、空格或自然语言。 +- 尽量包含 case 需要的 key 信息,例如 step、task、uid、attempt、status、reward、generated_tokens、weight_version。 +- 只在信息不够时新增,不预先大面积埋点。 + +例如 fake rollout 的正常日志: + +```text +Fake rollout done: {"step":1,"task":"train_task:main","uid":1001,"attempt":1,"status":"completed","generated_tokens":137,"max_tokens":512,"weight_version":0,"fault":null,"response_hash":"sha256:..."} +``` + +但是这仍然是普通日志,不定义 `[XTUNER_SIM]` 之类的全局特殊前缀。 + +第一版普通日志中的 JSON 片段建议统一包含 `event` 字段,便于公共 loader 提取成 `log_records`: + +```text +Fake rollout done: {"event":"rollout_done","step":1,"task":"train_task:main","uid":1001,"attempt":1,"status":"completed","generated_tokens":137,"max_tokens":512,"weight_version":0,"fault":null,"response_hash":"sha256:..."} +``` + +这不是特殊前缀协议;它只是正常 logger 文本里附带的机器可读片段。 + +第一版预期 `verify.py` 可能需要从日志中提取的信息包括: + +- sampler 实际采样的 step、task、uid/group、group_status、prompt_repeat_k +- rollout 完成时的 step、task、uid、attempt、status、generated_tokens、max_tokens、fault、weight_version +- fake judge 输出的 step、task、uid、attempt、reward +- fake trainer 完成的 step、num_groups、num_samples、loss +- fake weight sync 完成的 step、version_before、version_after + +### 日志合并策略 + +第一版 simulator 是单机 CPU-only,但 Ray actor logger 仍可能产生多个 `logs/rank_*.log`。Analyzer 读取所有 `logs/rank_*.log`,并提取其中包含 JSON 片段的 fake runtime 日志。 + +约定: + +- fake rollout、fake judge 可以在 actor 日志中打印,verify 必须按 `(event, step, task, uid, attempt)` 或 case 自己定义的 key 聚合。 +- fake trainer 的 `train_fit_done` 和 `weight_update_done` 默认只由 `FakeTrainingController` 打印一次。 +- 如果同一事件被重复打印,verify 必须显式去重,不能依赖文件读取顺序。 +- 单机 fake 运行时可能只有 `rank_0.log`,也可能因为 Ray actor logger 行为产生多个 log;loader 不假设固定文件数量。 +- loader 不按 wall-clock 合并日志;verify 用 step/task/uid/attempt/version 等业务字段判断。 + +### 确定性策略 + +即使第一版只做单机 CPU-only,Ray actor 调度顺序也不保证稳定,因此 determinism 不能依赖事件到达顺序。 + +约定: + +- 每个随机结果只由稳定 key 决定,不能由“第几个 actor 执行”或全局随机序列决定。 +- fake runtime 的随机 key 不包含 `rollout_step`,避免为了 simulator 修改 agent loop 或 producer。 +- analyzer 比较 determinism 时使用 normalized digest,忽略日志顺序、wall-clock timestamp、pid、hostname 和 log 文件名。 +- 如果某个 case 需要判断“哪个 uid 在第几步 completed / aborted”,verify 从 producer/agent loop 日志和 train rollout 产物中提取 step;如果信息不足,再补普通 logger 日志。 +- 默认 verify 应比较集合或 multiset,而不是比较原始日志行顺序。 + +`determinism_replay` 只要求这些字段一致: + +- `(task, uid, attempt, step)` +- `status` +- `generated_tokens` +- `response_hash` +- `reward` +- `fault` +- `weight_version` + +不要求日志输出顺序完全一致。 + +## Analyzer 设计 + +每个 case 都有自己的 `verify.py`。期望值和校验逻辑放在同一个 Python 文件里,不再单独写 `expected.yaml`。 + +公共 analyzer 负责加载: + +- `logs/rank_*.log` +- `logs/exp_tracking/tracker.jsonl` +- `train_rollout/train_rollout_*.jsonl` +- `logs/simulation_run_manifest.json` + +公共 verify API: + +```python +def verify(ctx: AnalyzerContext) -> CheckResult: + ... +``` + +`AnalyzerContext`: + +```python +@dataclass +class AnalyzerContext: + case_dir: Path + work_dir: Path + log_records: list[dict] + tracker_rows: list[dict] + train_rollouts: dict[int, list[dict]] + run_manifest: dict + raw_log_paths: list[Path] +``` + +`CheckResult`: + +```python +@dataclass +class CheckResult: + case_name: str + passed: bool + summary: dict + failures: list[dict] +``` + +Analyzer 输出: + +```text +/logs/simulation_analysis_summary.json +/logs/simulation_analysis_failures.jsonl +``` + +## Case 扩展方式 + +这个框架的目标不是一次性写完所有 case,而是让后续新增 case 足够简单、可控、可 review。 + +新增一个 case 时,只需要新增一个目录: + +```text +recipe/rl_simulator/cases// + config.py + verify.py +``` + +两份文件职责固定: + +- `config.py`: 定义这个 case 怎么跑。包括 seed、rollout_steps、global_batch_size、fake rollout/fake trainer/fake judge/fake sync 配置、task 配置、producer 配置。 +- `verify.py`: 定义这个 case 怎么验。文件开头直接写期望常量,下面写提取和校验逻辑。 + +公共 analyzer 不理解 case 的业务含义,只做三件事: + +1. 加载日志和产物,构造 `AnalyzerContext`。 +2. import case 目录下的 `verify.py`。 +3. 调用 `verify(ctx)`,写出 summary/failures。 + +这样后续新增 case 时,不需要改公共 analyzer;只有当公共 loader 缺少某种通用产物读取能力时,才扩展 analyzer。 + +## Verify 编写约定 + +每个 `verify.py` 都应该保持结构清晰,建议按四段写: + +```python +EXPECTED = { + "case_name": "async_update_every_4", + "update_steps": [4, 8], + "forbidden_update_steps": [1, 2, 3, 5, 6, 7], + "rollout_weight_version_by_step": { + 1: 0, + 2: 0, + 3: 0, + 4: 0, + 5: 1, + 6: 1, + 7: 1, + 8: 1, + }, +} + +def verify(ctx: AnalyzerContext) -> CheckResult: + extracted = extract(ctx) + failures = [] + + failures.extend(check_weight_updates(extracted, EXPECTED)) + failures.extend(check_rollout_versions(extracted, EXPECTED)) + + summary = build_summary(extracted, failures) + return CheckResult( + case_name=EXPECTED["case_name"], + passed=not failures, + summary=summary, + failures=failures, + ) +``` + +推荐每个 `verify.py` 至少拆出这些函数: + +- `extract(ctx)`: 只负责从 `ctx.log_records`、`ctx.tracker_rows`、`ctx.train_rollouts` 中提取当前 case 关心的信息,不做判断。 +- `check_*(extracted, expected)`: 每个函数只检查一类规则,例如 weight sync、partial rollout、fault filtering、task allocation。这里的 `expected` 通常就是本文件顶部的 `EXPECTED`。 +- `build_summary(extracted, failures)`: 输出人能读懂的结构化摘要,方便失败后 debug。 + +failure 记录建议统一成 dict: + +```python +{ + "check": "weight_update_steps", + "message": "unexpected weight update steps", + "expected": [4, 8], + "actual": [3, 8], + "step": 3, + "task": "train_task:main", + "uid": 1007, +} +``` + +字段不要求完全固定,但建议包含: + +- `check`: 失败的规则名 +- `message`: 人可读说明 +- `expected`: 期望值 +- `actual`: 实际值 +- `step/task/uid/attempt`: 如果能定位到具体样本,就尽量给出 + +## 公共 Verify Helper + +为了避免每个 case 重复写基础逻辑,可以在 analyzer 下提供 helper,但不强制使用: + +```text +recipe/rl_simulator/analyzer/helpers.py +``` + +初始 helper 可以包括: + +- `group_by_step(records)` +- `group_by_task(records)` +- `find_log_records(records, pattern_or_predicate)` +- `load_train_rollout_items(ctx, step)` +- `assert_exact_steps(actual, expected, check_name)` +- `assert_no_forbidden_steps(actual, forbidden, check_name)` +- `build_failure(check, message, expected=None, actual=None, **loc)` + +helper 只做通用数据处理和 failure 构造,不包含具体 case 策略。 + +## 推荐新增 Case 流程 + +1. 复制一个最接近的已有 case 目录。 +2. 修改 `config.py`,先让 case 能跑完。 +3. 修改 `verify.py` 文件顶部的 `EXPECTED`,写清楚这个 case 想证明什么。 +4. 写 `verify.py` 的 `extract()`,先生成 summary,不急着加很多 check。 +5. 看 summary 是否已经包含判断所需信息。 +6. 如果信息不够,再在 fake 组件或少量核心路径补普通 logger 日志。 +7. 补 `check_*()`,让 analyzer 能 pass/fail。 +8. 故意改错一个 `EXPECTED`,确认 verify 能 fail,并且 failure 信息能定位问题。 + +## Verify 不应该做的事 + +- 不应该依赖 wall-clock 精确时间,除非 case 明确测试耗时,并设置容忍区间。 +- 不应该解析高度易变的自然语言描述。 +- 不应该把多个无关规则塞进一个巨大判断里。 +- 不应该在 `verify.py` 里重新运行训练。 +- 不应该修改 work_dir 下的原始日志和训练产物。 + +## 运行与 CI 集成 + +runner 必须优先支持单 case 运行和 debug,然后再支持批量 CI。 + +常用命令: + +```text +# 跑单个 case,并在结束后分析 +python -m recipe.rl_simulator.core.runner --case recipe/rl_simulator/cases/smoke_sync --analyze + +# 跑单个 case,但不分析,方便直接看原始日志 +python -m recipe.rl_simulator.core.runner --case recipe/rl_simulator/cases/async_partial + +# 指定 work_dir,方便复现同一个失败 case +python -m recipe.rl_simulator.core.runner --case recipe/rl_simulator/cases/async_partial --work-dir /tmp/xtuner_sim_debug/async_partial --analyze + +# 只分析已有 work_dir,不重新运行 +python -m recipe.rl_simulator.core.runner --case recipe/rl_simulator/cases/async_partial --work-dir /tmp/xtuner_sim_debug/async_partial --analyze-only + +# 跑所有 case,用于本地回归或 CI +python -m recipe.rl_simulator.core.runner --case-dir recipe/rl_simulator/cases --all --analyze +``` + +`core/runner.py` 负责: + +1. 读取 case 目录下的 `config.py`。 +2. 解析命令行覆盖项,例如 `--work-dir`、`--seed`、`--keep-work-dir`。 +3. 写出 `logs/simulation_run_manifest.json`,记录 case、seed、work_dir、命令行参数和最终生效配置摘要。 +4. 初始化 Ray。 +5. 构建并运行 `trainer.fit()`。 +6. 如果指定 `--analyze` 或 `--analyze-only`,调用 analyzer。 +7. 根据 analyzer 的 `CheckResult.passed` 返回 exit code。 + +runner 参数约定: + +- `--case `: 运行单个 case。 +- `--case-dir --all`: 顺序运行所有 case。 +- `--work-dir `: 指定输出目录;debug 时推荐固定这个路径。 +- `--analyze`: 运行结束后立即分析。 +- `--analyze-only`: 不运行训练,只分析已有 `--work-dir`。 +- `--seed `: 临时覆盖 case config 里的 seed,用于复现或缩小问题。 +- `--keep-work-dir`: 如果 `work_dir` 已存在,不删除旧内容;用于保留中间产物和手动对比日志。 + +debug 约定: + +- 每个 case 的 `config.py` 必须写默认 seed。 +- runner 默认使用 case config 中的 seed,不做随机覆盖。 +- 如果命令行传了 `--seed`,最终 seed 必须写入 `simulation_run_manifest.json`,analyzer summary 也要显示这个 seed。 +- 单 case debug 时,推荐固定 `--work-dir`;失败后可以用同一条命令复现,也可以用 `--analyze-only` 反复调 verify。 +- `--keep-work-dir` 只保证不删除旧目录;如果同名日志可能追加或混淆,runner 需要在 manifest 中记录本次 run id。 + +CI 可以先接入一个轻量脚本: + +```text +recipe/rl_simulator/run_all.sh +``` + +脚本只顺序运行初始 cases,任何 case analyzer 失败都返回非零。 + +`determinism_replay` 需要特殊处理: + +- runner 支持同一个 case 跑两次到不同 work_dir。 +- runner 固定生成 `run_1/` 和 `run_2/` 两个 work_dir,并写出 manifest。 +- verify 从 manifest 读取两次运行目录,不写死 `compare_work_dir`。 +- verify 比较两次 normalized digest,忽略 timestamp、pid、hostname、log 文件名和日志行顺序。 + +manifest 路径和格式: + +```text +/determinism_manifest.json +``` + +```json +{ + "case_name": "determinism_replay", + "runs": [ + {"name": "run_1", "work_dir": ".../run_1"}, + {"name": "run_2", "work_dir": ".../run_2"} + ] +} +``` + +第一阶段 CI 建议只跑 CPU-only fake cases,不依赖真实 GPU/NPU,不访问外网。 + +## 初始 Cases + +### `smoke_sync` + +目的: + +- 单 task +- sync producer +- 无故障 +- 每 step 都训练并同步权重 + +预期: + +- 所有 rollout group 都完成 +- 所有 group 都进入训练 +- 每 step 都发生 weight update +- 每个 group 恰好包含 `prompt_repeat_k` 个样本 +- reward 只可能是 `0/1` +- train rollout 中不包含 failed、aborted、expired、filtered 样本 + +### `async_partial` + +目的: + +- async producer +- oversampling +- partial rollout enabled + +预期: + +- 至少一个 `(task, uid)` 经历 `aborted -> completed` +- response length 从不超过 `max_tokens` +- 最终 train rollout 只包含 completed 且合法的样本 + +### `multitask_weighted` + +目的: + +- 两个 task 使用同一个 gsm8k dataset/sampler config +- task weight 控制 batch allocation + +预期: + +- 当 `global_batch_size=8` 且权重为 `1:3` 时,分配为 `2/6` +- 每个 group 有 `prompt_repeat_k` 个样本 +- reward 只可能是 `0/1` + +### `fault_filtering` + +目的: + +- 确定性 fault injection + +预期: + +- 配置指定的 `(task, uid, attempt)` 失败 +- failed group 不进入 train rollout +- async 路径仍能补齐需要的 train batch + +### `tail_expired` + +目的: + +- partial rollout staleness 和 expired 行为 + +预期: + +- stale partial samples 会变成 expired +- expired samples 在重新生成前会清空历史 response +- 最终 response length 合法 + +### `async_update_every_4` + +目的: + +- 延迟参数更新 + +预期: + +- 至少运行 8 step +- 只在 step 4 和 step 8 更新 +- rollout step 1-4 使用 version 0 +- rollout step 5-8 使用 version 1 + +### `determinism_replay` + +目的: + +- 可复现性 + +预期: + +- 同一个 config 运行两次,产生相同的 deterministic digest +- 忽略 wall-clock timestamps +- 比较 uid、status、generated length、response hash、reward、fault、weight version + +## 确定性 + +每个 case 必须提供 seed。 + +所有随机行为都由稳定 key 派生: + +```text +seed + component + task_name + uid + attempt + repeat_index +``` + +适用于: + +- response length +- response token ids/hash +- delay +- fault decision +- reward +- 如果 fake loss 使用随机,也必须适用 + +禁止依赖全局 `random`、`numpy` 或 `torch` 状态。 + +## V1 非目标 + +- 真实分布式训练、真实多机多 rank 行为 +- MoE routed experts replay +- 真实 judge +- 真实 inference backend 分析 +- 解析任意自由文本日志 +- 在 `xtuner/v1/rl` 下提供稳定公共 API diff --git a/recipe/verl_agent/__init__.py b/recipe/verl_agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recipe/verl_agent/common/__init__.py b/recipe/verl_agent/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recipe/verl_agent/common/agent_loop_verl_tool.py b/recipe/verl_agent/common/agent_loop_verl_tool.py new file mode 100644 index 0000000000..b607ff7f04 --- /dev/null +++ b/recipe/verl_agent/common/agent_loop_verl_tool.py @@ -0,0 +1,154 @@ +from typing import Any, Optional + +from omegaconf import DictConfig +from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, DictConfigWrap +from verl.experimental.agent_loop.tool_agent_loop import ToolAgentLoop +from verl.utils.dataset.rl_dataset import get_dataset_class +from verl.workers.rollout.replica import TokenOutput + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status +from xtuner.v1.rl.judger import Judger +from xtuner.v1.rl.rollout.controller import RolloutControllerProxy +from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig + + +class VerlToolAgentLoopConfig(AgentLoopConfig): + config: DictConfig + + def build( + self, + rollout_controller: RolloutControllerProxy, + judger: Judger | None = None, + logger=None, + ) -> "VerlToolAgentLoop": + verl_tool_agent_loop = VerlToolAgentLoop( + rollout_controller=rollout_controller, + sample_params=self.sample_params, + hf_checkpoint=self.hf_checkpoint, + config=self.config, + judger=judger, + ) + return verl_tool_agent_loop + + +class XtunerAsyncLLMServerManager: + def __init__(self, rollout_controller: RolloutControllerProxy): + self.rollout_controller = rollout_controller + + async def generate( + self, + request_id: str, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + image_data: Optional[list[Any]] = None, + video_data: Optional[list[Any]] = None, + ) -> TokenOutput: + sample_params = SampleParams( + return_token_ids=True, + temperature=sampling_params.get("temperature", 1.0), + top_p=sampling_params.get("top_p", 1.0), + top_k=sampling_params.get("top_k", 0), + repetition_penalty=sampling_params.get("repetition_penalty", 1.0), + return_logprob=bool(sampling_params.get("logprobs", True)), + ) + + # session_id is set in the VerlToolAgentLoop.generate_sample + # and ignore request_id generated by verl.ToolAgentLoop.run + session_uid = sampling_params.get("session_uid", -1) + + rollout_state = RolloutState( + message=[], + tokens=prompt_ids, + session_uid=session_uid, + sample_params=sample_params, + ) + + response: RolloutState = await self.rollout_controller.generate.remote( + rollout_state=rollout_state, + ) + + finish_reason = response.finish_reason + + return TokenOutput( + token_ids=response.response_ids or [], + log_probs=response.logprobs, + routed_experts=response.routed_experts, + stop_reason=finish_reason, + ) + + +class VerlToolAgentLoop(AgentLoop): + def __init__( + self, + rollout_controller: RolloutControllerProxy, + sample_params: SampleParams, + hf_checkpoint: str, + config: DictConfig, + judger: Judger | None = None, + logger=None, + ): + super().__init__(rollout_controller, sample_params, hf_checkpoint, judger, logger) + + server_manager = XtunerAsyncLLMServerManager(rollout_controller) + + dataset_cls = get_dataset_class(config.data) + + self.verl_tool_agent_loop = ToolAgentLoop( + trainer_config=DictConfigWrap(config=config), + server_manager=server_manager, + tokenizer=self.tokenizer, + processor=self.processor, + dataset_cls=dataset_cls, + data_config=DictConfigWrap(config.data), + ) + + async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: + assert rollout_state.sample_params is not None, "sample_params must be set in rollout_state" + + # convert rollout_state to verl_tool_agent_loop input + sp = rollout_state.sample_params + sampling_params = dict( + temperature=sp.temperature, + top_p=sp.top_p, + top_k=sp.top_k, + repetition_penalty=sp.repetition_penalty, + logprobs=sp.return_logprob, + # session_id is used to identify the session in the server manager + session_uid=rollout_state.session_uid, + ) + + input_kwargs = { + "raw_prompt": rollout_state.message, + "tools_kwargs": rollout_state.extra_fields.get("tools_kwargs", {}), + } + + # run verl_tool_agent_loop + try: + output: AgentLoopOutput = await self.verl_tool_agent_loop.run(sampling_params, **input_kwargs) + except Exception as e: + rollout_state.status = Status.FAILED + rollout_state.error_msg = str(e) + self.logger.error(f"[VerlToolAgentLoop][{rollout_state.session_uid}] generate_sample failed: {e}") + return rollout_state + # TODO: handle samples with corrupted tool tokens ? + + # convert verl_tool_agent_loop output to rollout_state + rollout_state.prompt_ids = output.prompt_ids + rollout_state.response_ids = output.response_ids + rollout_state.logprobs = output.response_logprobs + rollout_state.routed_experts = output.routed_experts + rollout_state.response_mask = output.response_mask + rollout_state.status = Status.COMPLETED + rollout_state.extra_fields.update(output.extra_fields) + # judger needs response in text format + rollout_state.response = self.tokenizer.decode(rollout_state.response_ids) + # for trajectory dump, we need to add raw_prompt to extra_fields + # raw_prompt is updated in tool_agent_loop: apply_chat_template of tools + rollout_state.extra_fields["raw_prompt"] = self.tokenizer.decode(rollout_state.prompt_ids) + + # judge rollout_state + if self.judger is not None: + rollout_state = await self.judger.judge(rollout_state) + + return rollout_state diff --git a/recipe/verl_agent/gsm8k_tool_example/gsm8k_tool_grpo_config.py b/recipe/verl_agent/gsm8k_tool_example/gsm8k_tool_grpo_config.py new file mode 100644 index 0000000000..b4bf7f2409 --- /dev/null +++ b/recipe/verl_agent/gsm8k_tool_example/gsm8k_tool_grpo_config.py @@ -0,0 +1,219 @@ +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, create_task +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, SyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig +from recipe.verl_agent.common.agent_loop_verl_tool import VerlToolAgentLoopConfig +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "verl_gsm8k_tool" +rollout_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +global_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 1024 +max_response_length = 1024 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * WORLD_SIZE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5.0 verl config +# gsm8k tool config +tool_config_path = "recipe/verl_agent/gsm8k_tool_example/tool_config/gsm8k_tool_config.yaml" +tool_call_parser_name = "hermes" + +from hydra import compose, initialize_config_dir +import verl + +verl_config_dir = os.path.join(os.path.dirname(verl.__file__), "trainer/config") +with initialize_config_dir(config_dir=verl_config_dir): + verl_config = compose( + config_name="ppo_trainer", + overrides=[ + "data.max_prompt_length=" + str(max_prompt_length), # also set rollout.prompt_length by OmegaConf's oc.select + "data.max_response_length=" + str(max_response_length), # also set rollout.response_length + "+data.apply_chat_template_kwargs.enable_thinking=False", + "actor_rollout_ref.rollout.multi_turn.format=" + tool_call_parser_name, + "actor_rollout_ref.rollout.multi_turn.tool_config_path=" + tool_config_path, + "actor_rollout_ref.rollout.multi_turn.max_tool_response_length=" + str(max_response_length), + "actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5", + "actor_rollout_ref.rollout.multi_turn.enable=True", + ], + ) + +# 5.1 train agent loop +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +verl_tool_agent_loop_config = VerlToolAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, + config=verl_config, +) + +# 5.2 train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="train_task", + agent_loop_config=verl_tool_agent_loop_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, +) + +# 6.1 eval agent loop +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_verl_tool_agent_loop_config = VerlToolAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, + config=verl_config, +) + +# 6.2 eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="eval_task", + agent_loop_config=eval_verl_tool_agent_loop_config, + sampler_config=eval_sampler_config, +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + judger_config=judger_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + rollout_steps=rollout_steps, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/recipe/verl_agent/gsm8k_tool_example/tool_config/gsm8k_tool_config.yaml b/recipe/verl_agent/gsm8k_tool_example/tool_config/gsm8k_tool_config.yaml new file mode 100644 index 0000000000..a4197baabf --- /dev/null +++ b/recipe/verl_agent/gsm8k_tool_example/tool_config/gsm8k_tool_config.yaml @@ -0,0 +1,16 @@ +tools: + - class_name: "verl.tools.gsm8k_tool.Gsm8kTool" + config: + type: native + tool_schema: + type: "function" + function: + name: "calc_gsm8k_reward" + description: "A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)" + parameters: + type: "object" + properties: + answer: + type: "string" + description: "The model's answer to the GSM8K math problem, must be a digits" + required: ["answer"] diff --git a/recipe/verl_agent/sandbox_example/__init__.py b/recipe/verl_agent/sandbox_example/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recipe/verl_agent/sandbox_example/sandbox.py b/recipe/verl_agent/sandbox_example/sandbox.py new file mode 100644 index 0000000000..81a9ac2543 --- /dev/null +++ b/recipe/verl_agent/sandbox_example/sandbox.py @@ -0,0 +1,53 @@ +import re + +import aiohttp +from transformers.utils import get_json_schema + +from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema, ToolResponse + + +class SandboxTool(BaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + self.code_pattern = re.compile(r"```py(.*?)```", re.DOTALL) + + async def code_interpreter(self, code: str) -> str: + """Execute the code in the sandbox. + + Args: + code: The code to be executed. + + Returns: + str: The output of the code execution. + """ + async with aiohttp.ClientSession() as session: + async with session.post( + self.config.get("sandbox_fusion_url"), + json={"code": code}, + ) as resp: + resp.raise_for_status() + result = await resp.json() + stdout, stderr = result["run_result"]["stdout"], result["run_result"]["stderr"] + return stdout + stderr + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.code_interpreter) + return OpenAIFunctionToolSchema(**schema) + + async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[str, float, dict]: + code = parameters["code"] + matches = self.code_pattern.findall(code) + if matches: + code = matches[0].strip() + + lines = code.split("\n") + for i, line in reversed(list(enumerate(lines))): + if line == "": + continue + if not lines[i].startswith("print"): + lines[i] = f"print({line})" + break + code = "\n".join(lines) + + result = await self.code_interpreter(code) + return ToolResponse(text=result), 0.0, {} diff --git a/recipe/verl_agent/sandbox_example/sandbox_grpo_config.py b/recipe/verl_agent/sandbox_example/sandbox_grpo_config.py new file mode 100644 index 0000000000..cc65c7e18a --- /dev/null +++ b/recipe/verl_agent/sandbox_example/sandbox_grpo_config.py @@ -0,0 +1,313 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.utils import create_task +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, SyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig +from recipe.verl_agent.common.agent_loop_verl_tool import VerlToolAgentLoopConfig +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k_verl_tool" +rollout_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +global_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 512 +max_response_length = 2048 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * WORLD_SIZE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# code sand box just for toy example +import ray +import asyncio +import socket +import tempfile +import sys +import fastapi +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +import uvicorn +import json + +@ray.remote(num_cpus=1) +class Sandbox: + """Sandbox to execute python code.""" + + def __init__(self): + self.address = ray._private.services.get_node_ip_address() + self.port = self._get_free_port() + create_task(self._start_fastapi_server()) + + async def code_execution(self, request: Request): + request_json = await request.json() + code = request_json["code"] + # print(f"execute code:\n{code}") + + _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True) + with open(temp_file, "w") as f: + f.write(code) + + try: + process = await asyncio.create_subprocess_exec( + sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + + response = { + "status": "Success" if process.returncode == 0 else "Failed", + "run_result": { + "status": "Finished", + "stdout": stdout.decode(), + "stderr": stderr.decode(), + "return_code": process.returncode, + }, + } + return JSONResponse(content=response) + finally: + try: + os.unlink(temp_file) + except Exception: + pass + + def _get_free_port(self): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + async def _start_fastapi_server(self): + app = fastapi.FastAPI() + app.router.add_api_route("/run_code", self.code_execution, methods=["POST"]) + + config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") + server = uvicorn.Server(config) + await server.serve() + + async def get_server_address(self) -> str: + """Get FastAPI server address.""" + return f"{self.address}:{self.port}" + +sandbox = Sandbox.remote() +sandbox_address = ray.get(sandbox.get_server_address.remote()) +print(f"Sandbox server address: {sandbox_address}") +tool_config = { + "tools": [ + { + "class_name": "recipe.verl_agent.sandbox_example.sandbox.SandboxTool", + "config": { + "type": "native", + "sandbox_fusion_url": f"http://{sandbox_address}/run_code", + }, + }, + ], +} + +tool_config_path = "tool_config.json" +with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + +# 5.0 verl config +tool_call_parser_name = "hermes" + +from hydra import compose, initialize_config_dir +import verl + +verl_config_dir = os.path.join(os.path.dirname(verl.__file__), "trainer/config") +with initialize_config_dir(config_dir=verl_config_dir): + verl_config = compose( + config_name="ppo_trainer", + overrides=[ + "data.max_prompt_length=" + str(max_prompt_length), # also set rollout.prompt_length by OmegaConf's oc.select + "data.max_response_length=" + str(max_response_length), # also set rollout.response_length + "+data.apply_chat_template_kwargs.enable_thinking=False", + "actor_rollout_ref.rollout.multi_turn.format=" + tool_call_parser_name, + "actor_rollout_ref.rollout.multi_turn.tool_config_path=" + tool_config_path, + "actor_rollout_ref.rollout.multi_turn.max_tool_response_length=" + str(max_response_length), + "actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5", + "actor_rollout_ref.rollout.multi_turn.enable=True", + ], + ) + +# 5.1 train agent loop +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +verl_tool_agent_loop_config = VerlToolAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, + config=verl_config, +) + +# 5.2 train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="train_task", + agent_loop_config=verl_tool_agent_loop_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, +) + +# 6.1 eval agent loop +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_verl_tool_agent_loop_config = VerlToolAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, + config=verl_config, +) + +# 6.2 eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="eval_task", + agent_loop_config=eval_verl_tool_agent_loop_config, + sampler_config=eval_sampler_config, +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + judger_config=judger_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + rollout_steps=rollout_steps, + global_batch_size=global_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/recipe/verl_agent/sandbox_example/test_verl_tool_agent_loop.py b/recipe/verl_agent/sandbox_example/test_verl_tool_agent_loop.py new file mode 100644 index 0000000000..fa8ff15548 --- /dev/null +++ b/recipe/verl_agent/sandbox_example/test_verl_tool_agent_loop.py @@ -0,0 +1,406 @@ +import os +import sys +import json +import socket +import asyncio +import tempfile +import unittest + +import ray +import torch +import fastapi +import uvicorn +from fastapi import Request +from fastapi.responses import JSONResponse +from transformers import AutoTokenizer + +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from recipe.verl_agent.common.agent_loop_verl_tool import VerlToolAgentLoopConfig +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, SyncProduceStrategyConfig, SamplerConfig +from xtuner.v1.data_proto import RolloutState, Status, SampleParams +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.utils import create_task +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +VERL_TRAIN_DATA_PATH = "/fake/path/to/train.parquet" +VERL_TEST_DATA_PATH = "/fake/path/to/test.parquet" + +FAKE_INPUT_ITEM = RolloutState( + message=[{ + 'role': 'user', + 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' + }], + reward_model={'ground_truth': '72', 'style': 'rule'}, +) + +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} + + +@ray.remote(num_cpus=1) +class Sandbox: + """Sandbox to execute python code for tool-calling agent tests.""" + + def __init__(self): + self.address = ray._private.services.get_node_ip_address() + self.port = self._get_free_port() + create_task(self._start_fastapi_server()) + + async def code_execution(self, request: Request): + request_json = await request.json() + code = request_json["code"] + + _, temp_file = tempfile.mkstemp( + suffix=".py", prefix="temp_code", dir=None, text=True + ) + with open(temp_file, "w") as f: + f.write(code) + + try: + process = await asyncio.create_subprocess_exec( + sys.executable, + temp_file, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + response = { + "status": "Success" if process.returncode == 0 else "Failed", + "run_result": { + "status": "Finished", + "stdout": stdout.decode(), + "stderr": stderr.decode(), + "return_code": process.returncode, + }, + } + return JSONResponse(content=response) + finally: + try: + os.unlink(temp_file) + except Exception: + pass + + def _get_free_port(self): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + async def _start_fastapi_server(self): + app = fastapi.FastAPI() + app.router.add_api_route( + "/run_code", self.code_execution, methods=["POST"] + ) + config = uvicorn.Config( + app, host=["::", "0.0.0.0"], port=self.port, log_level="warning" + ) + server = uvicorn.Server(config) + await server.serve() + + async def get_server_address(self) -> str: + return f"{self.address}:{self.port}" + + +def _build_verl_config( + model_path: str, + train_file: str, + test_file: str, + tool_config_path: str, + max_prompt_length: int, + max_response_length: int, + rollout_name: str = "sglang", + tool_call_parser_name: str = "hermes", +): + from hydra import compose, initialize_config_dir + import verl + + verl_config_dir = os.path.join( + os.path.dirname(verl.__file__), "trainer/config" + ) + with initialize_config_dir(config_dir=verl_config_dir): + verl_config = compose( + config_name="ppo_trainer", + overrides=[ + "algorithm.adv_estimator=grpo", + "data.train_files=" + train_file, + "data.val_files=" + test_file, + "data.return_raw_chat=True", + "data.train_batch_size=32", + "data.max_prompt_length=" + str(max_prompt_length), + "data.max_response_length=" + str(max_response_length), + "+data.apply_chat_template_kwargs.enable_thinking=False", + "actor_rollout_ref.model.path=" + model_path, + "actor_rollout_ref.actor.ppo_mini_batch_size=8", + "actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8", + "actor_rollout_ref.actor.fsdp_config.param_offload=True", + "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True", + "actor_rollout_ref.rollout.name=" + rollout_name, + "actor_rollout_ref.rollout.mode=async", + "actor_rollout_ref.rollout.tensor_model_parallel_size=1", + "actor_rollout_ref.rollout.n=8", + "actor_rollout_ref.rollout.response_length=" + str(max_response_length), + "actor_rollout_ref.rollout.skip_tokenizer_init=False", + "+actor_rollout_ref.rollout.engine_kwargs.vllm.enable_auto_tool_choice=True", + "+actor_rollout_ref.rollout.engine_kwargs.vllm.tool_call_parser=hermes", + "+actor_rollout_ref.rollout.engine_kwargs.sglang.tool_call_parser=qwen25", + "actor_rollout_ref.rollout.multi_turn.format=" + tool_call_parser_name, + "actor_rollout_ref.rollout.multi_turn.tool_config_path=" + tool_config_path, + "+actor_rollout_ref.rollout.multi_turn.multi_turn.max_tool_response_length=" + str(max_response_length), + "actor_rollout_ref.rollout.agent.default_agent_loop=tool_agent", + "actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8", + "trainer.val_before_train=True", + "trainer.log_val_generations=10", + "trainer.n_gpus_per_node=8", + "trainer.test_freq=-1", + "trainer.total_training_steps=5", + "trainer.logger=['console','tensorboard']", + "trainer.project_name=verl", + "trainer.experiment_name=test_verl_tool_agent_loop", + ], + ) + return verl_config + + +class TestVerlToolAgentLoop(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=1, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, + ) + self.max_prompt_length = 512 + self.max_response_length = 4096 + self.context_length = self.max_prompt_length + self.max_response_length + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True + ) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.model_path = MODEL_PATH + self.data_path = TRAIN_DATA_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + def _setup_sandbox_and_verl_config(self): + """Create sandbox actor and verl config, return (verl_config, tool_config_path).""" + sandbox = Sandbox.remote() + self._sandbox = sandbox + # TODO: replace with a real sandbox server address + sandbox_address = ray.get(sandbox.get_server_address.remote()) + + tool_config = { + "tools": [ + { + "class_name": "recipe.verl_agent.sandbox_example.sandbox.SandboxTool", + "config": { + "type": "native", + "sandbox_fusion_url": f"http://{sandbox_address}/run_code", + }, + }, + ], + } + tool_config_path = os.path.join(self.temp_dir.name, "tool_config.json") + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + + verl_config = _build_verl_config( + model_path=self.model_path, + train_file=VERL_TRAIN_DATA_PATH, + test_file=VERL_TEST_DATA_PATH, + tool_config_path=tool_config_path, + max_prompt_length=self.max_prompt_length, + max_response_length=self.max_response_length, + ) + return verl_config + + async def test_verl_tool_agent_loop(self): + # 1. 初始化 config + self.init_config() + verl_config = self._setup_sandbox_and_verl_config() + + rollout_config = RolloutConfig( + env="test_verl_tool_agent_loop", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + + training_sample_params = SampleParams( + max_tokens=self.max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + return_token_ids=True, + return_logprob=True, + ) + agent_loop_cfg = VerlToolAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=training_sample_params, + config=verl_config, + ) + + # 2. 创建 rollout_controller, judger + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote( + rollout_config, pg + ) + gsm8k_judger = judger_config.build() + + # 3. 创建 VerlToolAgentLoop + agent_loop = agent_loop_cfg.build( + rollout_controller=rollout_controller, judger=gsm8k_judger + ) + + # 4. 构造输入数据 + prompt_repeat_k = 4 + rollout_state = FAKE_INPUT_ITEM.model_copy(deep=True) + group_in_rollout_state = [ + FAKE_INPUT_ITEM.model_copy(deep=True) for _ in range(prompt_repeat_k) + ] + + # 5. 执行 generate_group && generate_sample + group_rollout_state = await agent_loop.generate_group(group_in_rollout_state) + single_rollout_state = await agent_loop.generate_sample(rollout_state) + + print(f"prompt: {single_rollout_state.extra_fields['raw_prompt']}") + print(f"response: {single_rollout_state.response}") + + # 6. 验证结果 + self.assertEqual(len(group_rollout_state), prompt_repeat_k) + for state in group_rollout_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertIsNotNone(state.response_ids) + self.assertGreater(len(state.response_ids), 0) + self.assertIsNotNone(state.prompt_ids) + self.assertIsNotNone(state.logprobs) + self.assertIsNotNone(state.loss_mask) + + self.assertEqual(single_rollout_state.status, Status.COMPLETED) + self.assertIsNotNone(single_rollout_state.response_ids) + self.assertGreater(len(single_rollout_state.response_ids), 0) + self.assertIsNotNone(single_rollout_state.prompt_ids) + self.assertIsNotNone(single_rollout_state.logprobs) + self.assertIsNotNone(single_rollout_state.loss_mask) + + async def test_verl_tool_agent_loop_manager(self): + # 1. 初始化 config + self.init_config() + verl_config = self._setup_sandbox_and_verl_config() + + rollout_config = RolloutConfig( + env="test_verl_tool_agent_loop_manager", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + + training_sample_params = SampleParams( + max_tokens=self.max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ) + agent_loop_cfg = VerlToolAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=training_sample_params, + config=verl_config, + ) + + prompt_repeat_k = 2 + sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig( + name="gsm8k", + anno_path=TRAIN_DATA_PATH, + sample_ratio=1.0, + ), + "tokenize_fn": RLTextTokenizeFnConfig( + max_length=self.max_prompt_length + ), + }, + ], + collator="fake_collator", + pack_level="none", + group_by_length=False, + ), + prompt_repeat_k=prompt_repeat_k, + ) + agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="test_verl_tool", + agent_loop_config=agent_loop_cfg, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=sampler_config, + ) + + # 2. 创建 rollout_controller, judger + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote( + rollout_config, pg + ) + gsm8k_judger = judger_config.build() + + # 3. 创建 AgentLoopManager + replay_buffer_cfg = SyncReplayBufferConfig() + replay_buffer = replay_buffer_cfg.build() + agent_loop_manager = agent_loop_manager_cfg.build( + rollout_controller=rollout_controller, + judger=gsm8k_judger, + tokenizer=self.tokenizer, + replay_buffer=replay_buffer, + ) + + # 4. 执行 produce_batch + results = await agent_loop_manager.produce_batch(batch_size=4) + batch_rollout_states = results.rollout_states + + # 5. 验证结果 + self.assertEqual(len(batch_rollout_states), 4) + for group_state in batch_rollout_states: + self.assertEqual(len(group_state), prompt_repeat_k) + group_message = group_state[0].message + for state in group_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertIsNotNone(state.response_ids) + self.assertGreater(len(state.response_ids), 0) + self.assertEqual(state.message, group_message) + self.assertIsNotNone(state.prompt_ids) + self.assertIsNotNone(state.logprobs) + self.assertIsNotNone(state.loss_mask) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/loss/test_grpo_loss.py b/tests/loss/test_grpo_loss.py index 0ee2378941..384d045a89 100644 --- a/tests/loss/test_grpo_loss.py +++ b/tests/loss/test_grpo_loss.py @@ -7,10 +7,9 @@ import torch import torch.distributed as dist import torch.nn as nn -from xtuner.v1.rl.grpo import GRPOLossConfig, GRPOLossContext +from xtuner.v1.rl.loss import GRPOLossConfig, GRPOLossContext, kl_penalty from xtuner.v1.data_proto import SequenceContext from xtuner.v1.rl.utils import gather_logprobs -from xtuner.v1.rl.loss_fn import kl_penalty from xtuner.v1.utils.test_utils import init_data_mesh diff --git a/tests/loss/test_oreal_loss.py b/tests/loss/test_oreal_loss.py index 1e50b2e488..761f9e73c6 100644 --- a/tests/loss/test_oreal_loss.py +++ b/tests/loss/test_oreal_loss.py @@ -11,10 +11,9 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh import torch.nn as nn import torch.nn.functional as F -from xtuner.v1.rl.oreal.loss import OrealLossConfig, OrealLossContext +from xtuner.v1.rl.loss import OrealLossConfig, OrealLossContext, kl_penalty from xtuner.v1.data_proto import SequenceContext from xtuner.v1.rl.utils import gather_logprobs -from xtuner.v1.rl.loss_fn import kl_penalty from xtuner.v1.data_proto.utils import unpack_sequence from xtuner.v1.utils.test_utils import init_data_mesh diff --git a/tests/ray/test_evaluator.py b/tests/ray/test_evaluator.py deleted file mode 100644 index 321070f878..0000000000 --- a/tests/ray/test_evaluator.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import unittest -import ray -import tempfile -from transformers import AutoTokenizer - -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.ray.evaluator import Evaluator, EvaluatorConfig -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, OpenaiTokenizeFunctionConfig - - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] - - -class TestEvaluator(unittest.TestCase): - - @classmethod - def setUpClass(cls) -> None: - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls) -> None: - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_config(self): - self.resources_cfg = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - self.max_prompt_length = 512 - self.max_response_length = 1024 - self.rollout_cfg = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - tensor_parallel_size=8, - context_length=self.max_prompt_length + self.max_response_length, - worker_log_dir=self.worker_log_dir - ) - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - self.judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir - ) - self.eval_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", - anno_path=TEST_DATA_PATH, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length) - }, - ] - self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - self.rollout_cfg, - None, - self.judger_cfg - ) - self.sample_params = SampleParams( - top_p=1.0, - temperature=0.0, - max_tokens=self.max_response_length, - top_k=1 - ) - - def setUp(self): - ray.init(num_cpus=80) - self.model_path = MODEL_PATH - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.init_config() - - def tearDown(self): - ray.shutdown() - self.temp_dir.cleanup() - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_evaluator(self): - def custom_compute_metric(samples): - return {"custom_accuracy": sum(s.env.judger.reward["score"] > 0 for s in samples) / len(samples)} - - evaluator_cfg = EvaluatorConfig( - dataset_cfg=self.eval_dataset_cfg, - tokenizer=self.tokenizer, - max_concurrent=16, - eval_sample_ratio=0.004, # generate 5 samples - compute_metric_func=custom_compute_metric, - sample_params=self.sample_params, - worker_log_dir=self.worker_log_dir - ) - evaluator = Evaluator.remote(evaluator_cfg, self.test_env) - try: - ray.get(evaluator.run.remote()) - except Exception as e: - self.fail(f"evaluator.run.remote() raised an exception: {e}") - -if __name__ == '__main__': - unittest.main() diff --git a/tests/ray/test_judger.py b/tests/ray/test_judger.py deleted file mode 100644 index d8c0c36080..0000000000 --- a/tests/ray/test_judger.py +++ /dev/null @@ -1,231 +0,0 @@ -import os -import copy -import json -import ray -import unittest -import tempfile -import numpy as np -from uuid import uuid4 -from xtuner.v1.ray.judger.controller import JudgerController, JudgerConfig -from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLDatasetItem, RLEnvDataItem, RLRolloutResponseItem, RLUIDItem -from xtuner.v1.ray.base import AutoCPUWorkers, CPUResourcesConfig -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -GEO_ROLLOUT_DATA_PATH = os.environ["GEO_ROLLOUT_DATA_PATH"] -VERL_ROLLOUT_DATA_PATH = os.environ["VERL_ROLLOUT_DATA_PATH"] -DAPO_DATA_PATH = os.environ.get("ROLLOUT_DAPO_DATA_PATH") - -FAKE_JUDGER_INPUT_ITEM = RLDataFlowItem( - uid=RLUIDItem(action_id=uuid4().int, - observation_id=uuid4().int), - data=RLDatasetItem( - messages=[{ - 'role': 'user', - 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' - }], - num_tokens=62, - reward_model={'ground_truth': '72', 'style': 'rule'}, - ability='math', - data_source={'openai/gsm8k': 1.0} - ), - env=RLEnvDataItem( - rollout=RLRolloutResponseItem( - response="\nOkay, let's see. Natalia sold clips to 48 friends in April. Then in May, she sold half as many. So first, I need to figure out how many she sold in May. Half of 48 is 24, right? Because 48 divided by 2 is 24. So in May, she sold 24 clips.\n\nNow, to find the total number of clips sold in both months, I need to add the number from April and May together. That would be 48 (April) plus 24 (May). Let me do the addition: 48 + 24. Hmm, 40 + 20 is 60, and 8 + 4 is 12. So 60 + 12 is 72. So altogether, she sold 72 clips.\n\nWait, let me check that again. 48 plus 24. Yes, 48 + 20 is 68, then plus 4 more is 72. Yep, that seems right. So the total is 72.\n\n\nNatalia sold 48 clips in April. In May, she sold half as many, which is 48 ÷ 2 = 24 clips. Adding both months together: 48 + 24 = 72. \n\n#### 72<|im_end|>", - ) - ) -) -FAKE_JUDGER_INPUT_ITEM_1 = copy.deepcopy(FAKE_JUDGER_INPUT_ITEM) -FAKE_JUDGER_INPUT_ITEM_1.uid.observation_id = uuid4().int -FAKE_JUDGER_INPUT_ITEM_MULTI_DATA = [FAKE_JUDGER_INPUT_ITEM, FAKE_JUDGER_INPUT_ITEM_1] # 用action_id来标识是不同的输入数据 -FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE = copy.deepcopy(FAKE_JUDGER_INPUT_ITEM) -FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE.data.data_source = {'openai/gsm8k-1': 0.5, 'openai/gsm8k-2': 0.5} - - -def construct_judger_data(data_path): - dataitem = [] - with open(data_path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - data = json.loads(line.strip()) - data_item = RLDataFlowItem( - uid=RLUIDItem( - action_id=uuid4().int, - observation_id=uuid4().int - ), - data=RLDatasetItem( - messages=[{ - 'role': 'user', - 'content': data["input"][5:-11] - }], - reward_model={"ground_truth": data["gts"]}, - data_source={"openai/gsm8k": 1.0} - ), - env=RLEnvDataItem( - rollout=RLRolloutResponseItem(response=data['output']) - ) - ) - dataitem.append(data_item) - return dataitem - - -def construct_new_judger_data(data_path, judger_name='dapo_math'): - data_item_list = [] - save_reward = [] - with open(data_path, 'r', encoding='utf-8') as f: - lines = f.readlines() - for i in range(0, len(lines), 7): - group = ''.join(lines[i:i + 7]).strip() - if group: - try: - item = json.loads(group) - data_item = RLDataFlowItem( - uid=RLUIDItem( - action_id=uuid4().int, - observation_id=uuid4().int - ), - data=RLDatasetItem( - messages=[{ - 'role': 'user', - 'content': "" - }], - reward_model={"ground_truth": item["label"]}, - data_source={judger_name: 1.0} - ), - env=RLEnvDataItem( - rollout=RLRolloutResponseItem(response=item['response']) - ) - ) - data_item_list.append(data_item) - save_reward.append(item["reward"]) - except Exception as e: - print(f"Error parsing group starting at line {i + 12}: {e}") - return data_item_list, save_reward - - -class TestJudgerController(unittest.TestCase): - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - - def tearDown(self): - ray.shutdown() - self.temp_dir.cleanup() - - def test_gsm8k_judger(self): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - # 返回的形式为:RLJudgerResponseItem(uid=112750990920317762694895938380669501546, reward={'openai/gsm8k': 1}, extra_info={}) - res1 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM)) - self.assertEqual(res1.reward["score"], 1.0) - res2 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM_MULTI_DATA)) - self.assertEqual(res2[0].reward["score"], 1.0) - self.assertEqual(res2[1].reward["score"], 1.0) - - def test_dapo_judger(self): - from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig - from xtuner.v1.utils.rl_test_utils import get_eos_token - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - eos_token = get_eos_token(MODEL_PATH) - eos_token_str = tokenizer.convert_ids_to_tokens(eos_token) - - dapo_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer=True, - max_response_len=32768, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer - - ) - judger_cfg = JudgerConfig( - reward_judger_configs=[dapo_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - judger_data, save_reward = construct_new_judger_data(DAPO_DATA_PATH) - group_data = ray.get(judger_controller.run.remote(judger_data)) - reward = [data.reward["score"] for data in group_data] - self.assertEqual(np.mean(reward), np.mean(save_reward)) - - def test_geo_judger(self): - from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig - geo_judger_config = GEO3KJudgerConfig() - judger_cfg = JudgerConfig( - reward_judger_configs=[geo_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - judger_data, save_reward = construct_new_judger_data(GEO_ROLLOUT_DATA_PATH, judger_name="hiyouga/geometry3k") - group_data = ray.get(judger_controller.run.remote(judger_data)) - reward = [data.reward["score"] for data in group_data] - self.assertEqual(np.mean(reward), np.mean(save_reward)) - - def test_gsm8k_multi_judger(self): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - # 支持一个GSM8KJudgerConfig创建多个实例 - gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1") - gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2") - judger_cfg = JudgerConfig( - reward_judger_configs=[ - gsm8k_judger_config_1, - gsm8k_judger_config_2 - ], - enable_weighted_judgers=True, - worker_log_dir=self.worker_log_dir, - ) - cpu_resources_config = CPUResourcesConfig.from_total( - total_cpus=2, - total_memory=2 * 1024**3, - num_workers=2 - ) - pg = AutoCPUWorkers.build_placement_group(cpu_resources_config) - judger_controller = JudgerController.remote(judger_cfg, pg) - res3 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE)) - self.assertEqual(res3.reward["weighted_score"], 1.0) # weighted_score为固定字段,表示加权后的reward - - def test_gsm8k_judger_score(self): - """Test the judger functionality with single and multiple data sources.""" - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - judger_data = construct_judger_data(VERL_ROLLOUT_DATA_PATH) - group_data = ray.get(judger_controller.run.remote(judger_data)) - reward = [data.reward["score"] for data in group_data] - verl_score = 0.2418 - self.assertEqual(round(np.mean(reward), 4), verl_score) - - def test_gsm8k_remote_judger(self): - from xtuner.v1.utils.rl_test_utils import JudgerServer, GSM8KRemoteJudgerConfig - - server = JudgerServer(port=8018) - server.start() - try: - remote_judger_config = GSM8KRemoteJudgerConfig(judger_name="openai/gsm8k", remote_url=server.url) - judger_cfg = JudgerConfig( - reward_judger_configs=[remote_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - judger_data = construct_judger_data(VERL_ROLLOUT_DATA_PATH) - group_data = ray.get(judger_controller.run.remote(judger_data)) - reward = [data.reward["score"] for data in group_data] - verl_score = 0.2418 - self.assertEqual(round(np.mean(reward), 4), verl_score) - finally: - server.stop() - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py deleted file mode 100644 index c57ece9076..0000000000 --- a/tests/ray/test_mock_rollout.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -import asyncio -import unittest -import ray -from transformers import AutoTokenizer -import torch -import tempfile -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.ray.rollout.controller import RolloutController -from xtuner.v1.utils.rl_test_utils import MockTimeoutRolloutWorker, MockRequestErrorRolloutWorker, MockClientErrorRolloutWorker, MockServerErrorRolloutWorker - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -resource_map = {"npu": "NPU", "cuda": "GPU"} -@ray.remote -class MockTimeoutRolloutController(RolloutController): - def _get_worker_cls(self): - return ray.remote(MockTimeoutRolloutWorker) - def deactivate_worker_by_url(self, url): - pass -@ray.remote -class MockRequestErrorRolloutController(RolloutController): - def _get_worker_cls(self): - return ray.remote(MockRequestErrorRolloutWorker) - def deactivate_worker_by_url(self, url): - pass -@ray.remote -class MockClientErrorRolloutController(RolloutController): - def _get_worker_cls(self): - return ray.remote(MockClientErrorRolloutWorker) - def deactivate_worker_by_url(self, url): - pass -@ray.remote -class MockServerErrorRolloutController(RolloutController): - def _get_worker_cls(self): - return ray.remote(MockServerErrorRolloutWorker) - - def deactivate_worker_by_url(self, url): - pass - -class TestMockRollout(unittest.TestCase): - @classmethod - def setUpClass(cls): - os.environ["XTUNER_USE_FA3"] = "1" - - @classmethod - def tearDownClass(cls): - del os.environ["XTUNER_USE_FA3"] - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.global_batch_size = 3 - self.max_prompt_length = 4096 - self.max_response_length = 128 - self.max_concurrent = 3 - self.max_retry_times = 3 - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.rollout_cfg = RolloutConfig( - env="test_mock_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - tensor_parallel_size=1, - context_length=self.max_prompt_length + self.max_response_length, - max_retry_per_worker=2, - worker_log_dir=self.worker_log_dir - ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - - self.dataflow_cfg = DataFlowConfig( - max_concurrent=self.max_concurrent, - global_batch_size=self.global_batch_size, - max_retry_times=self.max_retry_times, - worker_log_dir=self.worker_log_dir - ) - train_dataset_cfg = [{ - "dataset": DatasetConfig(name="mock_data", anno_path=TRAIN_DATA_PATH), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length), - }] - dataloader_cfg = DataloaderConfig( - collator='fake_collator', - pack_level='none', - group_by_length=False, - ) - self.replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_cfg, - tokenizer=tokenizer, - worker_log_dir=self.worker_log_dir - ) - - def tearDown(self): - ray.shutdown() - self.temp_dir.cleanup() - - async def _run_mock_test(self, mock_controller_cls, error_name, pg): - rollout_controller = mock_controller_cls.remote(self.rollout_cfg, pg) - self.test_env = SingleTurnEnvironment.remote("env", pg, self.rollout_cfg, rollout_controller=rollout_controller) - self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env) - - result = await self.test_dataflow.run.remote(num=3) - completed_rollouts = result["data_groups"] - status = await self.test_dataflow.get_replaybuffer_status.remote() - self.assertEqual(len(completed_rollouts), 0, f"[{error_name}] Expected no rollouts to complete successfully.") - self.assertEqual(status["remain_completed_samples_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") - self.assertEqual(status["remain_aborted_samples_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.") - await self.test_env.shutdown.remote() - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_parallel_mock_rollout(self): - async def run_parallel(): - res_cfg_small = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=2, - num_cpus_per_worker=2, - ) - - pgs = [AutoAcceleratorWorkers.build_placement_group(res_cfg_small, name=f"pg_{i}") for i in range(4)] - await asyncio.gather(*[pg.ready() for pg in pgs]) - - tasks = [ - self._run_mock_test(MockTimeoutRolloutController, "timeout", pgs[0]), - self._run_mock_test(MockRequestErrorRolloutController, "request_error", pgs[1]), - self._run_mock_test(MockClientErrorRolloutController, "client_error", pgs[2]), - self._run_mock_test(MockServerErrorRolloutController, "server_error", pgs[3]), - ] - await asyncio.gather(*tasks) - - asyncio.run(run_parallel()) - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/tests/ray/test_rl_train_with_sft.py b/tests/ray/test_rl_train_with_sft.py deleted file mode 100644 index be7dc93816..0000000000 --- a/tests/ray/test_rl_train_with_sft.py +++ /dev/null @@ -1,180 +0,0 @@ -import os -import unittest -from transformers import AutoTokenizer -import shutil -import tempfile -import json -import torch -from xtuner.v1.data_proto.sequence_context import SequenceContext -import ray -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.rl.base import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.grpo.loss import GRPOLossConfig as LossConfig -from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig -from xtuner.v1.loss import CELossConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.train.trainer import LoadCheckpointConfig - -QWEN3_PATH = os.environ["QWEN3_PATH"] -ALPACA_PATH = os.environ["ALPACA_PATH"] - - -class TestRLTrainWithSFT(unittest.TestCase): - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - - resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_accelerators_per_worker=1, - num_cpus_per_worker=8, - num_workers=8, - cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB - ) - - pg = AutoAcceleratorWorkers.build_placement_group(resources) - self.pg = pg - - self.temp_dir = tempfile.mkdtemp() - tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True) - self.tokenizer = tokenizer - self.prompt_repeat_k = 8 - file = './tests/ray/rollout_output.jsonl' - with open(file, 'r') as f: - data = [json.loads(line) for line in f] - data_groups = [data[i:i + self.prompt_repeat_k] for i in range(0, len(data), self.prompt_repeat_k)] - data_groups = data_groups[:8] - data_batches = [] - for group in data_groups: - prompt_ids = tokenizer(group[0]['prompt'], return_tensors='pt')['input_ids'].flatten().tolist() - rewards = [item['reward'] for item in group] - rewards = torch.tensor(rewards, dtype=torch.float32) - advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8) - - for i in range(self.prompt_repeat_k): - item = group[i] - response_ids = tokenizer(item['response'], return_tensors='pt')['input_ids'].flatten().tolist() - input_ids = prompt_ids + response_ids - shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100] - input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) - shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) - data_batches.append( - dict( - seq_ctx=SequenceContext.from_input_ids((input_ids,), device="cpu"), - shifted_labels=shifted_labels, - advantage=advantages[i].item(), - ) - ) - self.data_batches = data_batches - - def tearDown(self): - shutil.rmtree(self.temp_dir) - ray.shutdown() - - def build_train_controller(self): - model_cfg = Qwen3Dense8BConfig() - optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) - fsdp_cfg: FSDPConfig = FSDPConfig( - torch_compile=True, - cpu_offload=False, - ep_size=1, - ) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) - - dataset_config = [] - _data_cfg = {"dataset": DatasetConfig(name='apach', - anno_path=ALPACA_PATH), - "tokenize_fn": OpenaiTokenizeFunctionConfig( - chat_template='qwen3', - max_length=32768 - ) - } - dataset_config.append(_data_cfg) - - sft_dataloader_cfg = DataloaderConfig( - dataset_config_list=dataset_config, - pack_max_length=32768, - pack_to_max_length=True, - num_workers=0, - ) - sft_global_batch_size = 8 - loss_reduction = "square" - sft_loss_cfg = CELossConfig(mode="chunk", chunk_size=1024, loss_reduction=loss_reduction) - - worker_cfg: WorkerConfig = WorkerConfig( - sft_dataloader_cfg=sft_dataloader_cfg, - sft_global_batch_size=sft_global_batch_size, - sft_loss_cfg=sft_loss_cfg, - seed=42, - model_cfg=model_cfg, - optim_cfg=optim_cfg, - loss_cfg=LossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="eager"), - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - load_from=QWEN3_PATH, - sp_size=1, - pack_max_length=8192, - ) - - TrainingWorker = ray.remote( - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", - } - }, - )(BaseTrainingWorker) - train_workers, _ = AutoAcceleratorWorkers.from_placement_group( - TrainingWorker, worker_cfg, self.pg - ) - futures = [worker.test_all_reduce.remote() for worker in train_workers] - print(ray.get(futures)) - train_controller = TrainingController.remote( - workers=train_workers, - ) - ray.get(train_controller.__ray_ready__.remote()) - return train_controller - - def test_rl_train_with_sft(self): - train_controller = self.build_train_controller() - - ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0)) - ray.get(train_controller.save.remote(os.path.join(self.temp_dir, "save_test"), no_save_optimizer=True)) - - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) - efficient_attn_ratio_list = [] - for log_info in log_infos: - efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) - assert all([efficient_attn_ratio > 0 for efficient_attn_ratio in efficient_attn_ratio_list]) - - ray.kill(train_controller) - train_controller = self.build_train_controller() - load_checkpoint_cfg = LoadCheckpointConfig(checkpoint_path=os.path.join(self.temp_dir, "save_test"), - load_optimizer_states=False, - load_optimizer_args=False - ) - ray.get(train_controller.resume.remote(load_checkpoint_cfg)) - - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) - new_efficient_attn_ratio_list = [] - for log_info in log_infos: - new_efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) - - efficient_attn_ratio_list.sort() - new_efficient_attn_ratio_list.sort() - self.assertEqual(efficient_attn_ratio_list, new_efficient_attn_ratio_list) diff --git a/tests/ray/test_rl_trainer.py b/tests/ray/test_rl_trainer.py deleted file mode 100644 index 113c94fd8a..0000000000 --- a/tests/ray/test_rl_trainer.py +++ /dev/null @@ -1,252 +0,0 @@ -import os -import tempfile -import unittest -from pathlib import Path - -import ray -import torch - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.ray.base import AcceleratorResourcesConfig, CPUResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainer, RLTrainerConfig - - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] -resource_map = { - "npu": "NPU", - "cuda": "GPU", -} - - -class TestRLTrainer(unittest.TestCase): - @classmethod - def setUpClass(cls): - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls): - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_traine_worker_config(self, train_optimizer_steps, pack_max_length): - model_cfg = get_model_config_from_hf(Path(MODEL_PATH)) - optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) - loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - ) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) - fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) - train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=MODEL_PATH, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, - ) - return train_worker_cfg - - def init_replay_buffer_config(self, max_prompt_length): - train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", anno_path=TRAIN_DATA_PATH, sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length), - }, - ] - dataloader_cfg = DataloaderConfig( - collator="fake_collator", - pack_level="none", - group_by_length=False, - ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_cfg, - tokenizer=tokenizer, - worker_log_dir=self.worker_log_dir, - ) - return replay_buffer_cfg - - def init_resources_config(self, num_workers, num_cpus_per_worker, cpu_memory_per_worker): - resources = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=num_workers, - num_cpus_per_worker=num_cpus_per_worker, - cpu_memory_per_worker=cpu_memory_per_worker, - ) - return resources - - def init_cpu_resources_config(self, num_cpus_per_worker, cpu_memory_per_worker): - cpu_resources = CPUResourcesConfig( - num_cpus_per_worker=num_cpus_per_worker, - cpu_memory_per_worker=cpu_memory_per_worker, - ) - return cpu_resources - - def init_rollout_config(self, max_prompt_length, max_response_length): - rollout_config = RolloutConfig( - env="test_rl_trainer", - model_path=MODEL_PATH, - worker_log_dir=self.worker_log_dir, - rollout_max_batch_size_per_instance=1024, - context_length=max_response_length + max_prompt_length, - ) - return rollout_config - - def init_dataflow_config(self, max_response_length, global_batch_size, prompt_repeat_k, enable_partial_rollout): - sample_params = SampleParams( - max_tokens=max_response_length, - ) - dataflow_config = DataFlowConfig( - env="test_rl_trainer", - global_batch_size=global_batch_size, - prompt_repeat_k=prompt_repeat_k, - worker_log_dir=self.worker_log_dir, - sample_params=sample_params, - enable_partial_rollout=enable_partial_rollout, - max_concurrent=1024, - ) - return dataflow_config - - def init_judger_config(self): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config], worker_log_dir=self.worker_log_dir) - return judger_cfg - - def init_multi_judger_config(self): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - - # 支持一个GSM8KJudgerConfig创建多个实例 - gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1") - gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2") - judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config_1, gsm8k_judger_config_2], - worker_log_dir=self.worker_log_dir, - ) - return judger_cfg - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - - train_optimizer_steps = 2 - pack_max_length = 32768 - max_prompt_length = 2048 - max_response_length = 1024 - global_batch_size = 4 - prompt_repeat_k = 4 - enable_partial_rollout = False - - self.train_worker_cfg = self.init_traine_worker_config(train_optimizer_steps, pack_max_length) - self.replay_buffer_cfg = self.init_replay_buffer_config(max_prompt_length) - self.resources_cfg = self.init_resources_config( - num_workers=8, num_cpus_per_worker=8, cpu_memory_per_worker=8 * 1024**3 - ) - self.cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) - self.rollout_config = self.init_rollout_config( - max_response_length=max_response_length, max_prompt_length=max_prompt_length - ) - self.dataflow_config = self.init_dataflow_config( - max_response_length=max_response_length, - global_batch_size=global_batch_size, - prompt_repeat_k=prompt_repeat_k, - enable_partial_rollout=enable_partial_rollout, - ) - self.judger_config = self.init_judger_config() - - def tearDown(self): - self.temp_dir.cleanup() - ray.shutdown() - - def test_rl_trainer(self): - multi_judger_config = self.init_multi_judger_config() - cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=2, cpu_memory_per_worker=2 * 1024**3) - trainer_config = RLTrainerConfig( - load_from=MODEL_PATH, - resources=self.resources_cfg, - cpu_resources=cpu_resources, - rollout_config=self.rollout_config, - dataflow_config=self.dataflow_config, - judger_config=multi_judger_config, - replay_buffer_config=self.replay_buffer_cfg, - train_worker_config=self.train_worker_cfg, - work_dir=self.worker_log_dir, - tokenizer_path=MODEL_PATH, - total_epochs=1, - rollout_steps=1, - ) - trainer = RLTrainer.from_config(trainer_config) - self.assertIsNotNone(trainer, "Trainer should be created successfully") - try: - trainer.fit() - except Exception as e: - self.fail(f"trainer.fit() raised unexpected exception: {e}") - # assure all writers are closed before checking log files - del trainer - log_files = list(Path(self.worker_log_dir).rglob("*.log")) - self.assertGreater(len(log_files), 0, "Should generate log files") - trajectory_files = list(Path(self.worker_log_dir).rglob("*_trajectory.jsonl")) - self.assertGreater(len(trajectory_files), 0, "Should generate trajectory files") - - def test_judger_cpu_pg_creation_with_error(self): - """Test RLTrainer judger_cpu_pg creation.""" - multi_judger_config = self.init_multi_judger_config() - # error resource with multi-judger - cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) - trainer_config = RLTrainerConfig( - load_from=MODEL_PATH, - resources=self.resources_cfg, - cpu_resources=cpu_resources, - rollout_config=self.rollout_config, - dataflow_config=self.dataflow_config, - judger_config=multi_judger_config, - replay_buffer_config=self.replay_buffer_cfg, - train_worker_config=self.train_worker_cfg, - work_dir=self.worker_log_dir, - tokenizer_path=MODEL_PATH, - total_epochs=1, - rollout_steps=1, - ) - with self.assertRaises(AssertionError) as cm: - trainer = RLTrainer.from_config(trainer_config) - - print(f"Expected AssertionError caught: {cm.exception}") - -if __name__ == "__main__": - unittest.main() diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py deleted file mode 100644 index 31f8542d3c..0000000000 --- a/tests/ray/test_rollout.py +++ /dev/null @@ -1,401 +0,0 @@ -import os -import subprocess -from functools import wraps -import unittest -import tempfile -import ray -import torch -from pathlib import Path -from transformers import AutoTokenizer -import tempfile -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.ray.rollout import RolloutController -from xtuner.v1.ray.judger import JudgerController -from xtuner.v1.datasets import RLTokenizeFnConfig, build_datasets, build_dataloader -from xtuner.v1.datasets.config import ( - DataloaderConfig, - DatasetConfig, -) -import asyncio - -TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -MOE_MODEL_PATH = os.environ["QWEN3_MOE_PATH"] -TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] -resource_map = { - "npu": "NPU", - "cuda": "GPU", -} -class TestRollout(unittest.TestCase): - - @classmethod - def setUpClass(cls) -> None: - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls) -> None: - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_config(self): - self.resources_cfg = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=8, - num_cpus_per_worker=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - self.max_prompt_length = 512 - self.max_response_length = 1024 - self.context_length = self.max_prompt_length + self.max_response_length - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - self.judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir, - ) - self.dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=1, - global_batch_size=1, - enable_partial_rollout=0, - max_retry_times=1, - worker_log_dir=self.worker_log_dir, - ) - self.train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", - anno_path=TRAIN_DATA_PATH, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length), - }, - ] - self.dataloader_cfg = DataloaderConfig( - collator='fake_collator', - pack_level='none', - group_by_length=False, - ) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) - self.replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=self.train_dataset_cfg, - dataloader_cfg=self.dataloader_cfg, - tokenizer=self.tokenizer, - worker_log_dir=self.worker_log_dir, - ) - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.data_path = TRAIN_DATA_PATH - self.model_path = MODEL_PATH - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.init_config() - - def tearDown(self): - ray.shutdown() - # When lmdeploy enable ep>1, it uses deep_ep. Buffer implicit destroy would cause some ray actor stucked. - # Use pkill cleen up ray::WorkerWrapper process after close ray cluster connection as workaround. - # TODO(chenchiyu): add excplicit deep_ep destroy in lmdeploy. - self._cleanup_lmdeploy_ray_worker_wrapper() - self.temp_dir.cleanup() - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_parallel_rollout(self): - resource_config = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=4, - num_cpus_per_worker=4, - cpu_memory_per_worker=8 * 1024**3, # 8 GB - ) - pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="tp_pg") - pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="ep_pg") - dense_model_path = MODEL_PATH - moe_model_path = MOE_MODEL_PATH - dist_port_base = 38000 - async def run_both(): - return await asyncio.gather( - self._run_rollout(model_path=dense_model_path, tp_size=4, ep_size=1, pg=pg1, dist_port_base=dist_port_base), - self._run_rollout(model_path=moe_model_path, tp_size=1, ep_size=4, pg=pg2, dist_port_base=dist_port_base + 1024 * 4), - return_exceptions=False - ) - - asyncio.run(run_both()) - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_parallel_model_save_and_resume(self): - resource_config = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=4, - num_cpus_per_worker=4, - cpu_memory_per_worker=8 * 1024**3, # 8 GB - ) - pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="dense_pg") - pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="moe_pg") - - async def run_both(): - return await asyncio.wait_for( - asyncio.gather( - self._run_dense_save_resume_sync_async(pg1), - self._run_moe_save_resume_with_r3(pg2), - return_exceptions=False - ), - timeout=300 - ) - try: - asyncio.run(run_both()) - except asyncio.TimeoutError: - self.fail("test_parallel_model_save_and_resume timed out after 300s") - - def _cleanup_lmdeploy_ray_worker_wrapper(self): - try: - result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) - if result.returncode != 0: - print(f"pkill command failed with return code {result.returncode}: {result.stderr}." - " Maybe no lmdeploy ray::RayWorkerWrapper processes found.") - except Exception as e: - print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") - - async def _run_rollout(self, model_path, tp_size, ep_size, pg, dist_port_base): - rollout_config = RolloutConfig( - env="test_rollout", - model_path=model_path, - model_name=os.path.basename(model_path).lower(), - tokenizer_path=model_path, - tensor_parallel_size=tp_size, - expert_parallel_size=ep_size, - context_length=self.context_length, - worker_log_dir=self.worker_log_dir, - dist_port_base=dist_port_base, - - ) - rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) - try: - result = await asyncio.wait_for(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES), timeout=300) - self.assertEqual(result.finish_reason, "stop") - except asyncio.TimeoutError: - self.fail("TP Rollout timed out!") - finally: - await asyncio.wait_for(rollout_controller.shutdown.remote(), timeout=300) - - async def _run_dataflow_save_resume_test(self, test_env, dataflow_cfg: DataFlowConfig, replay_buffer_cfg: ReplayBufferConfig): - """ - Generic driver for dataflow save/resume tests. - """ - # 1. Initialize Environment and DataFlow - is_partial_rollout = dataflow_cfg.enable_partial_rollout == 1 - test_flow = DataFlow.remote("test_env", dataflow_cfg, replay_buffer_cfg, test_env) - - # 2. Initial Run - await test_flow.run.remote() - - # Capture status before saving (critical for partial rollout consistency check) - rl_status_before_save = await test_flow.get_replaybuffer_status.remote() - - # 3. Save - save_dir = Path(self.temp_dir.name) / 'checkpoints' / f'ckpt-step-2' - save_dir.mkdir(parents=True, exist_ok=True) - await test_flow.save.remote(save_dir) - - # Define run logic based on mode - async def run_continuation(status_ref): - if is_partial_rollout: - remain = status_ref["remain_aborted_samples_count"] + status_ref["remain_completed_samples_count"] - # Finish the remaining paused samples - result = await test_flow.run.remote(num=remain, enable_partial_rollout=0) - return result["data_groups"] - else: - # Normal run - result = await test_flow.run.remote() - return result["data_groups"] - - # continue running after save - responses_old = await run_continuation(rl_status_before_save) - rb_status_old = await test_flow.get_replaybuffer_status.remote() - - - # resume from saved checkpoint - await test_flow.resume.remote(save_dir) - rl_status_resume = await test_flow.get_replaybuffer_status.remote() - responses_new = await run_continuation(rl_status_resume) - rb_status_new = await test_flow.get_replaybuffer_status.remote() - - # Compare Data - ids_old = self._get_sorted_input_ids(responses_old) - ids_new = self._get_sorted_input_ids(responses_new) - self.assertEqual(ids_old, ids_new) - - # Compare ReplayBuffer Status (Old run vs New run) - for key in rb_status_old: - self.assertEqual(rb_status_old[key], rb_status_new[key]) - - # For partial rollout, verify the resumed state matches the saved state - if is_partial_rollout: - for key in rl_status_before_save: - self.assertEqual(rl_status_before_save[key], rl_status_resume[key]) - - async def _run_dense_save_resume_sync_async(self, pg): - model_path = MODEL_PATH - worker_log_dir = os.path.join(self.worker_log_dir, "test_dense") - rollout_config = RolloutConfig( - env="test_rollout", - model_path=model_path, - model_name=os.path.basename(model_path).lower(), - tokenizer_path=model_path, - context_length=self.context_length, - worker_log_dir=worker_log_dir, - dist_port_base=37000, - ) - test_env = SingleTurnEnvironment.remote( - "test_env", - pg, - rollout_cfg=rollout_config, - ) - sync_dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=2, - global_batch_size=2, - enable_partial_rollout=0, - max_concurrent=2, - max_retry_times=1, - worker_log_dir=worker_log_dir, - ) - async_dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=2, - global_batch_size=2, - enable_partial_rollout=1, - staleness_threshold=1, - max_retry_times=1, - worker_log_dir=self.worker_log_dir, - ) - replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=self.train_dataset_cfg, - dataloader_cfg=self.dataloader_cfg, - tokenizer=self.tokenizer, - worker_log_dir=worker_log_dir, - ) - self._run_dataflow_save_resume_test(test_env, sync_dataflow_cfg, replay_buffer_cfg) - self._run_dataflow_save_resume_test(test_env, async_dataflow_cfg, replay_buffer_cfg) - - async def _run_moe_save_resume_with_r3(self, pg): - model_path = MOE_MODEL_PATH - worker_log_dir = os.path.join(self.worker_log_dir, "test_moe_r3") - rollout_config = RolloutConfig( - env="test_rollout", - model_path=model_path, - model_name=os.path.basename(model_path).lower(), - tokenizer_path=model_path, - expert_parallel_size=2, - context_length=self.context_length, - worker_log_dir=worker_log_dir, - dist_port_base=36000, - ) - test_env = SingleTurnEnvironment.remote( - "test_env", - pg, - rollout_cfg=rollout_config, - ) - async_dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=2, - global_batch_size=2, - enable_partial_rollout=1, - max_concurrent=4, - max_retry_times=1, - worker_log_dir=worker_log_dir, - ) - replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=self.train_dataset_cfg, - dataloader_cfg=self.dataloader_cfg, - tokenizer=self.tokenizer, - worker_log_dir=worker_log_dir, - ) - self._run_dataflow_save_resume_test(test_env, async_dataflow_cfg, replay_buffer_cfg) - - def _get_sorted_input_ids(self, responses): - """Helper to extract and sort input_ids from responses.""" - all_ids = [] - for data_items in responses[0]: - for data_item in data_items: - all_ids.extend(data_item.data.input_ids) - all_ids.sort() - return all_ids - - @unittest.skip("skip lmdeploy turbomind generate test due to ci environment issue") - def test_lmdeploy_turbomind_generate(self): - from xtuner.v1.ray.rollout import LMDeployWorker - rollout_config = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - context_length=self.context_length, - worker_log_dir=self.worker_log_dir, - extra_rollout_config={"lmdeploy_backend": "turbomind"}, - ) - sample_params = SampleParams(temperature=0.0) - pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) # type: ignore[attr-defined] - res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - res2 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - self.assertEqual(res1, res2, f"res1 != res2, res1={res1}, res2={res2}") - ray.get(rollout_controller.shutdown.remote(), timeout=300) - - @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "lmdeploy backend is not enabled") - def test_sglang_generate(self): - from xtuner.v1.ray.rollout import SGLangWorker - self.rollout_cfg.launch_server_method="multiprocessing" - rollout_config = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - context_length=self.context_length, - worker_log_dir=self.worker_log_dir, - launch_server_method="multiprocessing" - ) - sample_params = SampleParams(temperature=0.0) - pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) # type: ignore[attr-defined] - res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - self.assertEqual(res1.finish_reason, "stop") - print("Response from SGLang infer:", res1) - ray.get(rollout_controller.shutdown.remote(), timeout=300) - - @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "lmdeploy backend is not enabled") - def test_sglang_dataflow(self): - self.dataflow_cfg.enable_partial_rollout = 0 - rollout_config = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - context_length=self.context_length, - worker_log_dir=self.worker_log_dir, - launch_server_method="multiprocessing" - ) - pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - test_env = SingleTurnEnvironment.remote( - "test_env", - pg, - rollout_cfg=rollout_config, - ) - test_flow = DataFlow.remote("test_env", - self.dataflow_cfg, - self.replay_buffer_cfg, - test_env - ) - responses = ray.get(test_flow.run.remote(), timeout=300)["data_groups"] - finished_samples_count = sum(1 for data in responses for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") - self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) - ray.get(test_env.shutdown.remote(), timeout=300) - print("responses: ", responses) - -if __name__ == "__main__": - unittest.main() diff --git a/tests/ray/test_vl_rollout.py b/tests/ray/test_vl_rollout.py deleted file mode 100644 index 81621e9d21..0000000000 --- a/tests/ray/test_vl_rollout.py +++ /dev/null @@ -1,207 +0,0 @@ -import os -import subprocess -from functools import wraps -import unittest -import tempfile -import ray -import torch -from pathlib import Path -from transformers import AutoTokenizer -import tempfile -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.ray.rollout import RolloutController -from xtuner.v1.ray.judger import JudgerController -from xtuner.v1.datasets import RLTokenizeFnConfig, build_datasets, build_dataloader -from xtuner.v1.datasets.config import ( - DataloaderConfig, - DatasetConfig, -) - -MODEL_PATH=os.getenv("QWEN3_VL_DENSE_PATH") -TRAIN_DATA_PATH=os.getenv("GEO3K_TRAIN_DATA_PATH") -MEDIA_ROOT=os.getenv("GEO3K_MEDIA_ROOT") - -resource_map = { - "npu": "NPU", - "cuda": "GPU", -} -class TestRollout(unittest.TestCase): - - @classmethod - def setUpClass(cls) -> None: - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls) -> None: - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_config(self): - self.resources_cfg = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=8, - num_cpus_per_worker=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - self.max_prompt_length = 2048 - self.max_response_length = 2048 - self.rollout_cfg = RolloutConfig( - env="test_rollout", - model_path=self.model_path, - model_name=os.path.basename(self.model_path).lower(), - tokenizer_path=self.model_path, - rollout_cross_node_comm=False, - tensor_parallel_size=2, - expert_parallel_size=1, - gpus_per_node=8, # gpu: 8, npu: 16 - dtype="bfloat16", - launch_server_method="ray", - context_length=self.max_prompt_length + self.max_response_length, - worker_log_dir=self.worker_log_dir, - ) - from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig - geo3k_judger_config = GEO3KJudgerConfig() - self.judger_cfg = JudgerConfig(reward_judger_configs=[geo3k_judger_config]) - - self.dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=2, - global_batch_size=2, - enable_partial_rollout=0, - max_retry_times=1, - worker_log_dir=self.worker_log_dir, - ) - self.training_sample_params = SampleParams( - max_tokens=self.max_response_length, - ) - self.evaluation_sample_params = SampleParams( - max_tokens=self.max_response_length, - top_p=1.0, - temperature=0.0, - top_k=1, - ) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) - from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig - tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(processor_path=self.model_path) - train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=self.data_path, - class_name='VLMJsonlDataset', - media_root=self.media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg), - } - ] - dataloader_config = DataloaderConfig(num_workers=8, - collator="fake_collator", - pack_level="none") - - self.replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=self.tokenizer, - ) - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.data_path = TRAIN_DATA_PATH - self.model_path = MODEL_PATH - self.media_root = MEDIA_ROOT - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.init_config() - self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - - def tearDown(self): - ray.shutdown() - # When lmdeploy enable ep>1, it uses deep_ep. Buffer implicit destroy would cause some ray actor stucked. - # Use pkill cleen up ray::WorkerWrapper process after close ray cluster connection as workaround. - # TODO(chenchiyu): add excplicit deep_ep destroy in lmdeploy. - self._cleanup_lmdeploy_ray_worker_wrapper() - self.temp_dir.cleanup() - - def _cleanup_lmdeploy_ray_worker_wrapper(self): - try: - result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) - if result.returncode != 0: - print(f"pkill command failed with return code {result.returncode}: {result.stderr}." - " Maybe no lmdeploy ray::RayWorkerWrapper processes found.") - except Exception as e: - print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_vl_resume_with_partial_rollout(self): - rollout_cfg = self.rollout_cfg - # rollout_cfg.enable_return_routed_experts = True - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - rollout_cfg=rollout_cfg, - ) - dataflow_cfg = self.dataflow_cfg - dataflow_cfg.global_batch_size = 2 - dataflow_cfg.staleness_threshold = 1 - dataflow_cfg.enable_partial_rollout = 1 - self.test_flow = DataFlow.remote("test_env", - dataflow_cfg, - self.replay_buffer_cfg, - self.test_env - ) - ray.get(self.test_flow.run.remote(), timeout=300) - rl_status_save = ray.get(self.test_flow.get_replaybuffer_status.remote()) - save_dir = Path(self.temp_dir.name) / 'checkpoints' / 'ckpt-step-2' - save_dir.mkdir(parents=True, exist_ok=True) - - ray.get(self.test_flow.save.remote(save_dir)) - remain_paused_samples_old = rl_status_save["remain_aborted_samples_count"] + rl_status_save["remain_completed_samples_count"] - responses_old = ray.get(self.test_flow.run.remote(num=remain_paused_samples_old, staleness_threshold=0), timeout=300) - rb_status_old = ray.get(self.test_flow.get_replaybuffer_status.remote()) - - mm_info_old = [] - for multimodal_train_infos in responses_old["mm_train_infos"]: - image_grid_thw = multimodal_train_infos["image_grid_thw"].numpy().flatten() - mm_info_old.extend(image_grid_thw) - - ray.get(self.test_flow.resume.remote(save_dir)) - rl_status_resume = ray.get(self.test_flow.get_replaybuffer_status.remote()) - remain_paused_samples_new = rl_status_resume["remain_aborted_samples_count"] + rl_status_resume["remain_completed_samples_count"] - responses_new = ray.get(self.test_flow.run.remote(num=remain_paused_samples_new, staleness_threshold=0), timeout=300) - rb_status_new = ray.get(self.test_flow.get_replaybuffer_status.remote()) - - mm_info_new = [] - for multimodal_train_infos in responses_new["mm_train_infos"]: - image_grid_thw = multimodal_train_infos["image_grid_thw"].numpy().flatten() - mm_info_new.extend(image_grid_thw) - - all_train_prompt_ids_old = [] - for data_items in responses_old["data_groups"]: - for data_item in data_items: - all_train_prompt_ids_old.extend(data_item.data.input_ids) - - all_train_prompt_ids_new = [] - for data_items in responses_new["data_groups"]: - for data_item in data_items: - all_train_prompt_ids_new.extend(data_item.data.input_ids) - - all_train_prompt_ids_old.sort() - all_train_prompt_ids_new.sort() - mm_info_old.sort() - mm_info_new.sort() - self.assertEqual(all_train_prompt_ids_old, all_train_prompt_ids_new) - self.assertEqual(mm_info_old, mm_info_new) - for key in rb_status_old: - self.assertEqual(rb_status_old[key], rb_status_new[key]) - for key in rl_status_save: - self.assertEqual(rl_status_save[key], rl_status_resume[key]) - ray.get(self.test_env.shutdown.remote(), timeout=300) - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/tests/ray/rollout_output.jsonl b/tests/rl/rollout_output.jsonl similarity index 100% rename from tests/ray/rollout_output.jsonl rename to tests/rl/rollout_output.jsonl diff --git a/tests/rl/test_agent_loop.py b/tests/rl/test_agent_loop.py new file mode 100644 index 0000000000..44f4f63239 --- /dev/null +++ b/tests/rl/test_agent_loop.py @@ -0,0 +1,228 @@ +import os +import unittest +import copy +import ray +import tempfile +import torch +from transformers import AutoTokenizer +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.rl.agent_loop import ( + SingleTurnAgentLoopConfig, + AgentLoopManagerConfig, + TaskSpecConfig, + SyncProduceStrategyConfig, + SamplerConfig, +) +from xtuner.v1.data_proto import RolloutState, Status, SampleParams +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +MOE_MODEL_PATH = os.environ["QWEN3_MOE_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +FAKE_INPUT_ITEM = RolloutState( + message=[{ + 'role': 'user', + 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' + }], + reward_model={'ground_truth': '72', 'style': 'rule'}, +) +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} + +class TestAgentLoop(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=1, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + self.max_prompt_length = 512 + self.max_response_length = 1024 + self.context_length = self.max_prompt_length + self.max_response_length + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.data_path = TRAIN_DATA_PATH + self.model_path = MODEL_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + async def test_gsm8k_agent_loop(self): + # 1. 初始化 config + self.init_config() + rollout_config = RolloutConfig( + env="test_agent_loop", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + agent_loop_cfg = SingleTurnAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0), + ) + # 2. 创建 rollout_controller + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + # 3. 创建 AgentLoop + agent_loop = agent_loop_cfg.build( + rollout_controller=rollout_controller, + judger=judger_config.build(), + ) + # 4. 构造输入数据 + prompt_repeat_k = 4 + rollout_state = FAKE_INPUT_ITEM + group_in_rollout_state = [FAKE_INPUT_ITEM] * prompt_repeat_k + # 5. 执行 generate_group && generate_sample + group_rollout_state = await agent_loop.generate_group(group_in_rollout_state) + single_rollout_state = await agent_loop.generate_sample(rollout_state) + # 6. 验证结果 + self.assertEqual(len(group_rollout_state), 4) + for state in group_rollout_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertGreater(len(state.response_ids), 0) + self.assertEqual(single_rollout_state.reward["score"], 1) + self.assertEqual(single_rollout_state.status, Status.COMPLETED) + self.assertGreater(len(single_rollout_state.response_ids), 0) + self.assertEqual(single_rollout_state.reward["score"], 1) + + async def test_gsm8k_agent_loop_with_ray_actor_judger(self): + self.init_config() + rollout_config = RolloutConfig( + env="test_agent_loop_ray_actor", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig( + judger_name="openai/gsm8k", + num_ray_actors=1, + num_cpus_per_actor=1, + ) + agent_loop_cfg = SingleTurnAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0), + num_ray_actors=1, + num_cpus=1, + ) + + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + agent_loop = agent_loop_cfg.build( + rollout_controller=rollout_controller, + judger=judger_config.build(), + ) + + prompt_repeat_k = 2 + rollout_state = copy.deepcopy(FAKE_INPUT_ITEM) + group_in_rollout_state = [copy.deepcopy(FAKE_INPUT_ITEM) for _ in range(prompt_repeat_k)] + + group_rollout_state = await agent_loop.generate_group.remote(group_in_rollout_state) + single_rollout_state = await agent_loop.generate_sample.remote(rollout_state) + + self.assertEqual(len(group_rollout_state), prompt_repeat_k) + for state in group_rollout_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertGreater(len(state.response_ids), 0) + self.assertEqual(state.reward["score"], 1) + self.assertEqual(single_rollout_state.status, Status.COMPLETED) + self.assertGreater(len(single_rollout_state.response_ids), 0) + self.assertEqual(single_rollout_state.reward["score"], 1) + + async def test_gsm8k_agent_loop_manager(self): + # 1. 初始化 config + self.init_config() + rollout_config = RolloutConfig( + env="test_agent_loop", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + agent_loop_cfg = SingleTurnAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0), + ) + sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="gsm8k", + anno_path=TRAIN_DATA_PATH, + sample_ratio=1.0), + "tokenize_fn": RLTextTokenizeFnConfig(max_length=self.max_prompt_length), + }, + ], + collator='fake_collator', + pack_level='none', + group_by_length=False, + ), + prompt_repeat_k=2, + ) + agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="test_gsm8k", + agent_loop_config=agent_loop_cfg, + judger_config=judger_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=sampler_config, + ) + ], + ) + # 2. 创建 rollout_controller + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + # 3. 创建 AgentLoopManager + replay_buffer_cfg = SyncReplayBufferConfig() + replay_buffer = replay_buffer_cfg.build() + agent_loop_manager = agent_loop_manager_cfg.build( + rollout_controller=rollout_controller, + tokenizer=self.tokenizer, + replay_buffer=replay_buffer, + ) + # 4. 执行 produce_batch + result = await agent_loop_manager.produce_batch(batch_size=4) + batch_rollout_states = result.rollout_states + # 5. 验证结果 + self.assertEqual(len(batch_rollout_states), 4) + for group_state in batch_rollout_states: + self.assertEqual(len(group_state), 2) + group_message = group_state[0].message + for state in group_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertGreater(len(state.response_ids), 0) + self.assertEqual(state.message, group_message) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_async_rollout.py b/tests/rl/test_async_rollout.py new file mode 100644 index 0000000000..32cd372aea --- /dev/null +++ b/tests/rl/test_async_rollout.py @@ -0,0 +1,720 @@ +from __future__ import annotations + +import os +import unittest + +import ray +import torch + +from transformers import AutoTokenizer + +from xtuner.v1.data_proto import SampleParams, Status +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.rl.agent_loop import ( + AgentLoopManagerConfig, + AsyncProduceStrategyConfig, + SamplerConfig, + SingleTurnAgentLoopConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + +MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH", "") +DATA_PATH = os.environ.get("ROLLOUT_DATA_PATH", "") +MAX_PROMPT_LENGTH = 512 +MAX_RESPONSE_LENGTH = 512 +PACK_MAX_LENGTH = MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH +EXPERIMENTAL_NAME = "async_rl_integration_test" + +_RESOURCE_MAP = {"npu": "NPU", "cuda": "GPU"} + + +def _accelerator_type() -> str: + return _RESOURCE_MAP[torch.accelerator.current_accelerator().type] + + +def _build_rollout_controller(): + """Build a RolloutController backed by a real inference engine.""" + resources_cfg = AcceleratorResourcesConfig( + accelerator=_accelerator_type(), + num_workers=1, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, + ) + rollout_config = RolloutConfig( + env=EXPERIMENTAL_NAME, + device=resources_cfg.accelerator, + model_path=MODEL_PATH, + gpu_memory_utilization=0.8, + context_length=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH, + rollout_max_batch_size_per_instance=16, + max_retry_per_sample=0, + ) + pg = AutoAcceleratorWorkers.build_placement_group(resources_cfg) + rollout_ctl = ray.remote(RolloutController).remote(rollout_config, pg) + return rollout_ctl + + +def _build_agent_loop_manager( + rollout_ctl, + task_name: str, + over_sample_threshold: float = 0.0, + enable_partial_rollout: bool = False, + tail_batch_stale_threshold: int = 0, + tail_batch_trigger_size: int = 0, + prompt_repeat_k: int = 1, + max_tokens: int = MAX_RESPONSE_LENGTH, +): + """Build an AgentLoopManager backed by a fresh AsyncReplayBuffer.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + dataset_cfg = DatasetConfig(name=EXPERIMENTAL_NAME, anno_path=DATA_PATH) + tokenizer_fn_cfg = RLTextTokenizeFnConfig(max_length=MAX_PROMPT_LENGTH) + dataloader_cfg = DataloaderConfig( + dataset_config_list=[{"dataset": dataset_cfg, "tokenize_fn": tokenizer_fn_cfg}], + pack_max_length=PACK_MAX_LENGTH, + collator="fake_collator", + pack_level="none", + ) + sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, + ) + + sample_params = SampleParams( + max_tokens=max_tokens, + temperature=1.0, + top_k=0, + top_p=1.0, + return_token_ids=True, + ) + agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=MODEL_PATH, + sample_params=sample_params, + ) + + produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold=over_sample_threshold, + enable_partial_rollout=enable_partial_rollout, + tail_batch_stale_threshold=tail_batch_stale_threshold, + tail_batch_trigger_size=tail_batch_trigger_size, + ) + + replay_buffer = AsyncReplayBufferConfig().build() + + manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name=task_name, + agent_loop_config=agent_loop_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ) + ], + ) + manager = manager_cfg.build( + rollout_controller=rollout_ctl, + tokenizer=tokenizer, + replay_buffer=replay_buffer, + logger=None, + ) + return manager + +class TestOversampling(unittest.IsolatedAsyncioTestCase): + """Oversampling tests (mirrors debug_rollout=True: rollout only, no training). + + Why ABORTED samples are guaranteed: + - over_sample_threshold=2.0 => data_concurrency = 3 * batch_size = 6 tasks + - max_tokens=512 => long responses; most tasks still in-flight + when the first batch_size completions arrive + - _cleanup_pending_tasks() => remaining tasks get abort-signalled and + stored as ABORTED in the replay buffer + """ + + OVER_SAMPLE_THRESHOLD = 2.0 # data_concurrency = 3 * batch_size + BATCH_SIZE = 2 + INITIAL_DATA_CONCURRENCY = int((1 + OVER_SAMPLE_THRESHOLD) * BATCH_SIZE) # = 6 + + @classmethod + def setUpClass(cls) -> None: + os.environ.setdefault("XTUNER_USE_FA3", "1") + os.environ.setdefault("LMD_SKIP_WARMUP", "1") + + def setUp(self): + ray.init(num_cpus=32, ignore_reinit_error=True) + self.rollout_ctl = _build_rollout_controller() + + def tearDown(self): + ray.shutdown() + + async def test_1_1_total_count_after_first_rollout(self): + """1.1: After produce_batch round 1: + + remain_completed + remain_aborted == INITIAL_DATA_CONCURRENCY + + Flow: + 1. strategy starts INITIAL_DATA_CONCURRENCY tasks concurrently. + 2. As soon as BATCH_SIZE completions are collected, the while-loop + exits; remaining pending tasks go through _cleanup_pending_tasks + and are stored as ABORTED. + 3. produce_batch() then calls replay_buffer.get(BATCH_SIZE, COMPLETED) + which consumes exactly BATCH_SIZE items. + 4. Any extras that completed during the abort window remain as + COMPLETED in the buffer. + + Therefore: + remain_completed + remain_aborted == INITIAL_DATA_CONCURRENCY - BATCH_SIZE + + Because every task either ends up COMPLETED or ABORTED in the buffer, + and exactly BATCH_SIZE items are consumed by replay_buffer.get(). + """ + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name="test_1_1", + over_sample_threshold=self.OVER_SAMPLE_THRESHOLD, + ) + replay_buffer = manager.replay_buffer + + await manager.produce_batch(batch_size=self.BATCH_SIZE, rollout_step=1) + + remain_completed = await replay_buffer.count( + task_name="test_1_1", group_status=Status.COMPLETED + ) + remain_aborted = await replay_buffer.count( + task_name="test_1_1", group_status=Status.ABORTED + ) + + # Primary assertion: items remaining in buffer after produce_batch consumes + # BATCH_SIZE completed samples == INITIAL_DATA_CONCURRENCY - BATCH_SIZE + expected_remaining = self.INITIAL_DATA_CONCURRENCY - self.BATCH_SIZE + self.assertEqual( + remain_completed + remain_aborted, + expected_remaining, + msg=( + f"remain_completed={remain_completed}, remain_aborted={remain_aborted}, " + f"expected total={expected_remaining} " + f"(= INITIAL_DATA_CONCURRENCY {self.INITIAL_DATA_CONCURRENCY} " + f"- BATCH_SIZE {self.BATCH_SIZE})" + ), + ) + + async def test_1_2_second_rollout_samples_from_aborted_queue(self): + """1.2: Round 2's produce_batch re-samples exactly the oversampled items + left over from round 1. + + Key mechanism (enable_partial_rollout=False): + - After round 1, replay_buffer holds INITIAL_DATA_CONCURRENCY - BATCH_SIZE + leftover items (all COMPLETED, since all tasks finished). + - At the start of round 2, _process_leftover_samples() converts those + leftover COMPLETED items to ABORTED (because enable_partial_rollout=False). + - _async_sample() then draws from the ABORTED pool via + Sampler.sample(group_status=Status.ABORTED) before issuing new samples. + + We instrument Sampler.sample() to count how many times it returned items + that were already in ABORTED state (= drawn from replay buffer, not + freshly sampled from the dataloader). + + The expected count equals the number of leftover items from round 1, + i.e. INITIAL_DATA_CONCURRENCY - BATCH_SIZE. + """ + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name="test_1_2", + over_sample_threshold=self.OVER_SAMPLE_THRESHOLD, + # enable_partial_rollout=False (default) so leftover COMPLETED are + # converted to ABORTED by _process_leftover_samples() in round 2. + ) + replay_buffer = manager.replay_buffer + original_sample = manager.data_sampler.sample + + sampled_from_aborted = 0 + + async def instrumented_sample(task_name, group_status=None, **kwargs): + nonlocal sampled_from_aborted + result = await original_sample( + task_name=task_name, group_status=group_status, **kwargs + ) + # Items fetched from the ABORTED pool still carry status==ABORTED. + if result and result[0].status == Status.ABORTED: + sampled_from_aborted += 1 + return result + + manager.data_sampler.sample = instrumented_sample + + # --- Round 1 --- + await manager.produce_batch(batch_size=self.BATCH_SIZE, rollout_step=1) + + # After round 1: produce_batch consumed BATCH_SIZE completed items. + # The leftover items (completed but not consumed) stay in the buffer. + round1_remain_completed = await replay_buffer.count( + task_name="test_1_2", group_status=Status.COMPLETED + ) + round1_remain_aborted = await replay_buffer.count( + task_name="test_1_2", group_status=Status.ABORTED + ) + # Total leftover == INITIAL_DATA_CONCURRENCY - BATCH_SIZE + expected_leftover = self.INITIAL_DATA_CONCURRENCY - self.BATCH_SIZE + self.assertEqual( + round1_remain_completed + round1_remain_aborted, + expected_leftover, + msg=( + f"Round 1 leftover: completed={round1_remain_completed}, " + f"aborted={round1_remain_aborted}, expected total={expected_leftover}" + ), + ) + # These leftover items are the ones round 2 must re-sample from the + # ABORTED pool (after _process_leftover_samples converts them). + expected_resampled_from_aborted = round1_remain_completed + round1_remain_aborted + + # --- Round 2: reset counter then run --- + sampled_from_aborted = 0 + await manager.produce_batch(batch_size=self.BATCH_SIZE, rollout_step=2) + + self.assertEqual( + sampled_from_aborted, + expected_resampled_from_aborted, + msg=( + f"Round 2 should have re-sampled {expected_resampled_from_aborted} " + f"item(s) from the ABORTED queue (converted from round-1 leftovers), " + f"got sampled_from_aborted={sampled_from_aborted}" + ), + ) + + +class TestPartialRollout(unittest.IsolatedAsyncioTestCase): + """Partial-rollout tests. + + All tests inject pre-constructed ABORTED samples directly into the + replay buffer so that Sampler.sample(group_status=ABORTED) picks them + up without any mocking. The real AgentLoopManager.produce_batch() is + used throughout. + + Key configuration: + - over_sample_threshold=2.0 → data_concurrency = 3; guarantees concurrent + tasks so the genuine oversampling + partial- + rollout path is exercised (not just injected + into a single-task environment). + - enable_partial_rollout=True → ABORTED samples resume from existing + response_ids instead of starting over. + """ + + BATCH_SIZE = 1 + OVER_SAMPLE = 2.0 # data_concurrency = int((1+2.0)*1) = 3; genuine oversampling + # Short max_tokens for the max-exhausted short-circuit test; medium for multi-round. + MAX_TOKENS_SHORT = 8 + MAX_TOKENS_MULTI = 32 + + @classmethod + def setUpClass(cls) -> None: + os.environ.setdefault("XTUNER_USE_FA3", "1") + os.environ.setdefault("LMD_SKIP_WARMUP", "1") + + def setUp(self): + ray.init(num_cpus=32, ignore_reinit_error=True) + self.rollout_ctl = _build_rollout_controller() + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + def tearDown(self): + ray.shutdown() + + def _make_aborted_state(self, uid: int, prompt: str, response_ids: list[int], + response_rollout_steps: list[int] | None = None, + max_tokens: int = MAX_RESPONSE_LENGTH) -> "RolloutState": + """Helper: build an ABORTED RolloutState with given response_ids.""" + from xtuner.v1.data_proto import RolloutState, SampleParams, Status + prompt_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + state = RolloutState( + uid=uid, + message=[{"role": "user", "content": prompt}], + prompt_ids=prompt_ids, + sample_params=SampleParams( + max_tokens=max_tokens, + temperature=1.0, + top_k=0, + top_p=1.0, + return_token_ids=True, + ), + status=Status.ABORTED, + response_ids=response_ids, + response="placeholder", + logprobs=[0.0] * len(response_ids), + response_mask=[1] * len(response_ids), + response_rollout_steps=response_rollout_steps if response_rollout_steps is not None else [0] * len(response_ids), + seq_staleness=0, + extra_fields={}, + ) + return state + + async def test_2_1_partial_rollout_response_ids_are_concatenated(self): + """2.1: Partial rollout 的 response_ids 前缀必须保持不变。 + + Setup: + - over_sample_threshold=2.0 → 3 个并发任务;注入的 ABORTED 样本与另外 + 2 个 dataloader 新样本同时运行,真实触发 oversampling + partial-rollout 路径。 + - 注入 uid=9001 的 ABORTED 样本,response_ids=[1000,1001,1002,1003]。 + - 由于多任务竞争,注入样本可能多次被 abort 并在后续轮次继续;每次 + preprocess 以 existing response_ids 为前缀,postprocess 拼接新内容。 + + 断言: 最终完成的 uid=9001 样本的 response_ids 以初始 4 token 为前缀, + 且长度 > 4(确实生成了新内容)。 + """ + from xtuner.v1.data_proto import Status + task_name = "test_2_1" + initial_response_ids = [1000, 1001, 1002, 1003] + injected_uid = 9001 + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_tokens=MAX_RESPONSE_LENGTH, + ) + replay_buffer = manager.replay_buffer + + state = self._make_aborted_state( + uid=injected_uid, + prompt="Count from one.", + response_ids=initial_response_ids, + max_tokens=MAX_RESPONSE_LENGTH, + ) + await replay_buffer.put([state], task_name) + + # Loop: with oversampling the injected sample may be aborted multiple times + # before completing. Search by uid across rounds. + target_sample = None + for rollout_step in range(1, 15): + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, rollout_step=rollout_step + ) + for group in completed_groups.rollout_states: + for sample in group: + if sample.uid == injected_uid: + target_sample = sample + if target_sample is not None: + break + + self.assertIsNotNone( + target_sample, + msg=f"Injected sample (uid={injected_uid}) never completed within 14 rounds", + ) + final_response_ids = target_sample.response_ids + self.assertGreater( + len(final_response_ids), len(initial_response_ids), + msg="Partial rollout should have appended new tokens", + ) + self.assertEqual( + final_response_ids[: len(initial_response_ids)], + initial_response_ids, + msg="response_ids must start with the original injected prefix", + ) + + async def test_2_2_eos_in_response_skips_inference_engine(self): + """2.2: ABORTED 样本末尾为 EOS token → worker 短路,response_ids 不变。 + + EOS 短路不调用推理引擎,注入样本几乎瞬间完成,在 3 个并发任务中 + 必然最先完成,因此 completed_groups[0][0] 就是注入样本。 + + 断言: 返回样本的 response_ids 与注入时完全相同。 + """ + from xtuner.v1.data_proto import Status + from xtuner.v1.rl.rollout.worker import get_eos_token + + task_name = "test_2_2" + eos = get_eos_token(MODEL_PATH) + eos_id = eos[0] if isinstance(eos, list) else eos + initial_response_ids = [1000, 1001, eos_id] + injected_uid = 9002 + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_tokens=MAX_RESPONSE_LENGTH, + ) + replay_buffer = manager.replay_buffer + + state = self._make_aborted_state( + uid=injected_uid, + prompt="Say hello.", + response_ids=initial_response_ids, + max_tokens=MAX_RESPONSE_LENGTH, + ) + await replay_buffer.put([state], task_name) + + # EOS short-circuit completes with no LLM call → always wins the race. + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, rollout_step=1 + ) + completed_groups = completed_groups.rollout_states + + self.assertEqual(len(completed_groups), self.BATCH_SIZE) + final = completed_groups[0][0] + self.assertEqual(final.uid, injected_uid, + msg="EOS short-circuit sample should be the first to complete") + self.assertEqual(final.status, Status.COMPLETED) + self.assertEqual( + final.response_ids, + initial_response_ids, + msg="EOS short-circuit: response_ids must be identical to the injected ones", + ) + + async def test_2_3_max_tokens_exhausted_skips_inference_engine(self): + """2.3: len(response_ids)==max_tokens → remaining_tokens==0 → worker 短路,response_ids 不变。 + + 与 test_2_2 同理,短路不调用推理引擎,注入样本在 3 个并发任务中必然最先完成。 + + 断言: 返回样本的 response_ids 与注入时完全相同。 + """ + from xtuner.v1.data_proto import Status + task_name = "test_2_3" + max_tokens = self.MAX_TOKENS_SHORT + initial_response_ids = list(range(1010, 1010 + max_tokens)) # len == max_tokens + injected_uid = 9003 + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_tokens=max_tokens, + ) + replay_buffer = manager.replay_buffer + + state = self._make_aborted_state( + uid=injected_uid, + prompt="Say hello.", + response_ids=initial_response_ids, + max_tokens=max_tokens, + ) + await replay_buffer.put([state], task_name) + + # max_tokens short-circuit completes with no LLM call → always wins the race. + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, rollout_step=1 + ) + completed_groups = completed_groups.rollout_states + + self.assertEqual(len(completed_groups), self.BATCH_SIZE) + final = completed_groups[0][0] + self.assertEqual(final.uid, injected_uid, + msg="max_tokens short-circuit sample should be the first to complete") + self.assertEqual(final.status, Status.COMPLETED) + self.assertEqual( + final.response_ids, + initial_response_ids, + msg="max_tokens exhausted: response_ids must be identical to the injected ones", + ) + + async def test_2_4_multi_round_response_ids_never_exceed_max_tokens(self): + """2.4: 多轮 partial rollout 后 len(response_ids) <= max_tokens。 + + over_sample_threshold=2.0 → 每轮 3 个并发任务;注入样本可能经历多次 + abort + continue 才能完成。无论经历几轮,最终 response_ids 长度不超过 max_tokens。 + + 按 uid 搜索目标样本,最多跑 14 轮。 + """ + from xtuner.v1.data_proto import Status + task_name = "test_2_4" + max_tokens = self.MAX_TOKENS_MULTI + injected_uid = 9004 + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_tokens=max_tokens, + ) + replay_buffer = manager.replay_buffer + + state = self._make_aborted_state( + uid=injected_uid, + prompt="Count from one.", + response_ids=[1020, 1021], # 2 tokens initially; max_tokens=32 + max_tokens=max_tokens, + ) + await replay_buffer.put([state], task_name) + + target_sample = None + for rollout_step in range(1, 15): + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, rollout_step=rollout_step + ) + for group in completed_groups.rollout_states : + for sample in group: + if sample.uid == injected_uid: + self.assertLessEqual( + len(sample.response_ids), + max_tokens, + msg=( + f"Step {rollout_step}: accumulated response_ids length " + f"{len(sample.response_ids)} exceeds max_tokens {max_tokens}" + ), + ) + target_sample = sample + if target_sample is not None: + break + + self.assertIsNotNone( + target_sample, + msg=f"Injected sample (uid={injected_uid}) never completed within 14 rounds", + ) + self.assertLessEqual( + len(target_sample.response_ids), + max_tokens, + msg=f"Final response_ids length {len(target_sample.response_ids)} > max_tokens {max_tokens}", + ) + + +class TestTailBatch(unittest.IsolatedAsyncioTestCase): + BATCH_SIZE = 2 + OVER_SAMPLE = 5.0 # data_concurrency = (1 + 5.0) * BATCH_SIZE = 12 + + @classmethod + def setUpClass(cls) -> None: + os.environ.setdefault("XTUNER_USE_FA3", "1") + os.environ.setdefault("LMD_SKIP_WARMUP", "1") + + def setUp(self): + ray.init(num_cpus=32, ignore_reinit_error=True) + self.rollout_ctl = _build_rollout_controller() + + def tearDown(self): + ray.shutdown() + + async def test_3_1_staleness_threshold_1_marks_expired(self): + """3.1a: tail_batch_stale_threshold=1 — 需要 3 轮才能在 buffer 中观察到 EXPIRED。 + + staleness 积累路径(enable_partial_rollout=True): + Round 1 (step=1): 6 个并发任务,2 个完成后其余被 abort。 + 被 abort 的样本携带 step=1 生成的分段 response,response_rollout_steps=[1,...]. + Round 2 (step=2): round1 的 ABORTED 样本被续写,多数在 round2 内完成(COMPLETED)。 + postprocess 更新 seq_staleness = 2 - min([1,...]) = 1。 + 但 update_expired_status 只对 status==ABORTED 的样本触发,COMPLETED 不受影响。 + 这些 COMPLETED 样本(seq_staleness=1)留在 buffer 中。 + Round 3 (step=3): _process_leftover_samples 在 round3 开始时读取 buffer 中的 + COMPLETED 样本,检查 seq_staleness=1 >= threshold=1 → 标为 EXPIRED, + 放回 buffer。由于 trigger_size=0,EXPIRED 样本不在本轮被消费。 + + 断言: round3 结束后 buffer 中 expired > 0。 + """ + from xtuner.v1.data_proto import Status + + STALE_THRESHOLD = 1 + task_name = "test_3_1a" + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + tail_batch_stale_threshold=STALE_THRESHOLD, + tail_batch_trigger_size=0, # 只测 EXPIRED 标记,不触发 tail-batch 模式 + max_tokens=8192, # 让 response 足够长,确保 staleness 能积累到 1(不被 max_tokens 短路) + ) + replay_buffer = manager.replay_buffer + + # 3 轮是让 staleness 自然积累并被 _process_leftover_samples 标记的最少轮数: + # round1 产生 ABORTED(step=1 tokens)→ round2 续写完成(COMPLETED, staleness=1) + # → round3 开头 _process_leftover_samples 标 EXPIRED(staleness=1 >= 1) + for rollout_step in range(1, 5): + await manager.produce_batch(batch_size=self.BATCH_SIZE, rollout_step=rollout_step) + + expired_count = await replay_buffer.count( + task_name=task_name, group_status=Status.EXPIRED + ) + aborted_count = await replay_buffer.count( + task_name=task_name, group_status=Status.ABORTED + ) + + self.assertGreater( + expired_count, 0, + msg=( + f"threshold=1: after 3 rounds (steps 1→3), leftover COMPLETED samples " + f"with seq_staleness=1 should be marked EXPIRED by _process_leftover_samples. " + f"expired={expired_count}, aborted={aborted_count}" + ), + ) + + async def test_3_2_tail_batch_mode_resets_staleness_to_zero(self): + """3.2: 真实多轮循环自然触发 tail-batch 模式,验证 seq_staleness 重置为 0。 + + 配置: + over_sample_threshold=2.0 → 每轮产生大量遗留样本 + tail_batch_stale_threshold=1 → staleness >= 1 即标 EXPIRED(一步即触发) + tail_batch_trigger_size = BATCH_SIZE // 2 = 1 → expired >= 1 即进入 tail-batch + + 流程 (最多 10 轮): + - 在调用 produce_batch 之前读取 expired_before。 + - 若 expired_before >= trigger_size,本轮由 strategy 进入 tail-batch 模式: + 从 EXPIRED 池取样 → preprocess 重置 response_ids=[], response_rollout_steps=[] + → 全新生成 → postprocess: response_rollout_steps=[rollout_step, ...] + → staleness = rollout_step - rollout_step = 0。 + - 取到第一个 tail-batch 完成样本后退出循环。 + + 断言: + 1. tail-batch 模式在 10 轮内被触发。 + 2. 该轮返回的 COMPLETED 样本 seq_staleness == 0。 + """ + from xtuner.v1.data_proto import Status + + STALE_THRESHOLD = 1 + TRIGGER_SIZE = self.BATCH_SIZE // 2 # = 1 + task_name = "test_3_2" + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + tail_batch_stale_threshold=STALE_THRESHOLD, + tail_batch_trigger_size=TRIGGER_SIZE, + max_tokens=8192, + ) + replay_buffer = manager.replay_buffer + + tail_batch_triggered = False + completed_from_tail_batch = None + + for rollout_step in range(1, 11): + expired_before = await replay_buffer.count( + task_name=task_name, group_status=Status.EXPIRED + ) + + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, rollout_step=rollout_step + ) + completed_groups = completed_groups.rollout_states + + # 进入本轮前 expired >= trigger_size → 本轮就是 tail-batch 轮 + if expired_before >= TRIGGER_SIZE: + tail_batch_triggered = True + if completed_groups: + completed_from_tail_batch = completed_groups[0][0] + break + + self.assertTrue( + tail_batch_triggered, + msg="Tail-batch mode was never triggered within 10 rollout rounds.", + ) + self.assertIsNotNone( + completed_from_tail_batch, + msg="Tail-batch round produced no completed samples.", + ) + self.assertEqual( + completed_from_tail_batch.seq_staleness, 0, + msg=( + f"Tail-batch sample must have seq_staleness=0 (fresh generation), " + f"got seq_staleness={completed_from_tail_batch.seq_staleness}" + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_auto.py b/tests/rl/test_auto.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/ray/test_cpu_pg.py b/tests/rl/test_cpu_pg.py similarity index 98% rename from tests/ray/test_cpu_pg.py rename to tests/rl/test_cpu_pg.py index c66a5e4954..773e4f3d5a 100644 --- a/tests/ray/test_cpu_pg.py +++ b/tests/rl/test_cpu_pg.py @@ -6,7 +6,7 @@ import httpx import ray -from xtuner.v1.ray.base import AutoCPUWorkers, BaseCPUWorker, CPUResourcesConfig +from xtuner.v1.rl.utils import AutoCPUWorkers, BaseCPUWorker, CPUResourcesConfig @ray.remote(num_cpus=1) diff --git a/tests/rl/test_gateway.py b/tests/rl/test_gateway.py new file mode 100644 index 0000000000..93bceab474 --- /dev/null +++ b/tests/rl/test_gateway.py @@ -0,0 +1,691 @@ +import json +import os +import socket +import subprocess +import tempfile +import threading +import time +import unittest +from pathlib import Path +from typing import Any +from uuid import uuid4 + +import httpx +import ray +import torch + +from xtuner.v1.rl.gateway.adapters import build_api_key_trace_key +from xtuner.v1.rl.gateway.config import GatewayConfig +from xtuner.v1.rl.gateway.server import build_local_gateway_app, serve_gateway +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +RESOURCE_MAP = { + "npu": "NPU", + "cuda": "GPU", +} + + +@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") +class TestGatewayProtocolChain(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def setUp(self): + ray.init(address="local", ignore_reinit_error=True) + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.capture_output_path = Path(self.temp_dir.name) / "gateway_capture_output" + self.openai_body_output_path = Path(self.temp_dir.name) / "openai_body.json" + self.anthropic_body_output_path = Path(self.temp_dir.name) / "anthropic_body.json" + self.responses_body_output_path = Path(self.temp_dir.name) / "responses_body.json" + self.controller = None + self.placement_group = None + self.test_run_id = uuid4().hex[:8] + + def tearDown(self): + if self.controller is not None: + try: + ray.get(self.controller.shutdown.remote(), timeout=300) + except Exception: + pass + try: + ray.kill(self.controller, no_restart=True) + except Exception: + pass + if self.placement_group is not None: + ray.util.remove_placement_group(self.placement_group) + ray.shutdown() + self._cleanup_lmdeploy_ray_worker_wrapper() + self.temp_dir.cleanup() + + def _cleanup_lmdeploy_ray_worker_wrapper(self): + try: + subprocess.run( + ["pkill", "-f", "ray::RayWorkerWrapper*"], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except Exception: + return + + def _build_controller(self): + resource_config = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=16, + cpu_memory_per_worker=8 * 1024**3, + ) + self.placement_group = AutoAcceleratorWorkers.build_placement_group( + resource_config, + name=f"gateway_protocol_pg_{self.test_run_id}", + ) + rollout_config = RolloutConfig( + env=f"test_gateway_protocol_{self.test_run_id}", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + tool_call_parser="qwen3", + reasoning_parser="qwen3", + tensor_parallel_size=4, + expert_parallel_size=1, + context_length=1536, + worker_log_dir=os.path.join(self.worker_log_dir, "gateway"), + dist_port_base=42000, + api_host="127.0.0.1", + api_port=30080, + ) + return ray.remote(RolloutController).remote(rollout_config, self.placement_group) + + def _get_rollout_config(self) -> RolloutConfig: + rollout_metadata = ray.get(self.controller.get_rollout_metadata.remote()) + return rollout_metadata["rollout_config"] + + def _read_capture_records(self) -> list[dict]: + if not self.capture_output_path.exists(): + return [] + if self.capture_output_path.is_file(): + with self.capture_output_path.open("r", encoding="utf-8") as f: + return [json.loads(line) for line in f] + records = [] + for capture_file in sorted(self.capture_output_path.glob("*.jsonl")): + with capture_file.open("r", encoding="utf-8") as f: + records.extend(json.loads(line) for line in f) + return records + + def _write_json_output(self, path: Path, payload: dict) -> None: + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + + def _capture_records_by_protocol(self, capture_records: list[dict]) -> dict[str, dict]: + return {record["protocol"]: record for record in capture_records} + + def _assert_trace_record_matches_capture( + self, + trace_record, + capture_record: dict, + *, + expected_request_field: str, + expected_request_role: str | None, + expected_response_field: str, + ) -> None: + self.assertIsNotNone(trace_record) + self.assertEqual(trace_record.request_id, capture_record["request_id"]) + self.assertEqual(trace_record.finish_reason, capture_record["rollout_finish_reason"] or capture_record["finish_reason"]) + self.assertEqual(trace_record.status.value, capture_record["status"]) + self.assertGreater(len(trace_record.prompt_ids), 0) + self.assertGreater(len(trace_record.response_ids), 0) + self.assertTrue(trace_record.input_text) + self.assertTrue(trace_record.output_text) + self.assertGreater(capture_record["prompt_tokens"], 0) + self.assertGreater(capture_record["completion_tokens"], 0) + self.assertTrue(capture_record["input_text"]) + self.assertIn(expected_request_field, trace_record.request_snapshot) + if expected_request_role is not None: + self.assertEqual(trace_record.request_snapshot[expected_request_field][0]["role"], expected_request_role) + self.assertIn(expected_response_field, trace_record.response_snapshot) + + def _find_free_port(self) -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + def _wait_for_gateway_ready(self, base_url: str, *, timeout_seconds: float = 120.0) -> None: + deadline = time.time() + timeout_seconds + last_error = None + while time.time() < deadline: + try: + response = httpx.get(f"{base_url}/livez", timeout=5.0) + if response.status_code == 200: + return + except Exception as exc: + last_error = exc + time.sleep(1.0) + if last_error is not None: + raise AssertionError(f"Gateway did not become ready at {base_url}: {last_error}") from last_error + raise AssertionError(f"Gateway did not become ready at {base_url}") + + def _serve_gateway_blocking_in_background(self, app, config: GatewayConfig) -> tuple[str, threading.Thread]: + thread = threading.Thread( + target=serve_gateway, + args=(app, config), + daemon=True, + name=f"gateway-blocking-{config.port}", + ) + thread.start() + base_url = self._wait_for_gateway_ready_from_config(config) + return base_url, thread + + def _wait_for_gateway_ready_from_config(self, config: GatewayConfig, *, timeout_seconds: float = 120.0) -> str: + deadline = time.time() + timeout_seconds + last_error = None + while time.time() < deadline: + base_url = f"http://127.0.0.1:{config.port}" + try: + response = httpx.get(f"{base_url}/livez", timeout=5.0) + if response.status_code == 200: + return base_url + except Exception as exc: + last_error = exc + time.sleep(1.0) + if last_error is not None: + raise AssertionError(f"Gateway did not become ready for config port {config.port}: {last_error}") from last_error + raise AssertionError(f"Gateway did not become ready for config port {config.port}") + + def _post_json( + self, + base_url: str, + path: str, + payload: dict, + *, + api_key: str | None = None, + ) -> httpx.Response: + headers = {"Authorization": f"Bearer {api_key}"} if api_key else None + return httpx.post(f"{base_url}{path}", json=payload, headers=headers, timeout=120.0) + + def _get_json(self, base_url: str, path: str) -> httpx.Response: + return httpx.get(f"{base_url}{path}", timeout=30.0) + + def start_rollout_controller_and_gateway(self) -> tuple[RolloutConfig, GatewayConfig, str, Any]: + self.controller = self._build_controller() + rollout_config = self._get_rollout_config() + gateway_config = GatewayConfig(port=self._find_free_port(), capture_folder=str(self.capture_output_path)) + app = build_local_gateway_app(self.controller, config=gateway_config) + base_url, _ = self._serve_gateway_blocking_in_background(app, gateway_config) + return rollout_config, gateway_config, base_url, app + + def test_gateway_runtime_endpoints(self): + rollout_config, _, base_url, _ = self.start_rollout_controller_and_gateway() + + livez_response = self._get_json(base_url, "/livez") + self.assertEqual(livez_response.status_code, 200, livez_response.text) + self.assertEqual(livez_response.json(), {"status": "ok"}) + + readyz_response = self._get_json(base_url, "/readyz") + self.assertEqual(readyz_response.status_code, 200, readyz_response.text) + readyz_body = readyz_response.json() + self.assertTrue(readyz_body["ready"]) + self.assertEqual(readyz_body["status"], "ready") + self.assertIsInstance(readyz_body["details"], dict) + + capabilities_response = self._get_json(base_url, "/capabilities") + self.assertEqual(capabilities_response.status_code, 200, capabilities_response.text) + capabilities_body = capabilities_response.json() + self.assertEqual(capabilities_body["model"], rollout_config.model_name) + self.assertEqual(capabilities_body["backend"], rollout_config.rollout_backend) + self.assertEqual(capabilities_body["context_length"], rollout_config.context_length) + self.assertTrue(capabilities_body["supports_stream"]) + self.assertTrue(capabilities_body["supports_tools"]) + self.assertFalse(capabilities_body["supports_cancel"]) + self.assertTrue(capabilities_body["supports_parallel_tool_calls"]) + self.assertTrue(capabilities_body["supports_reasoning"]) + + def test_gateway_messages(self): + rollout_config, _, base_url, app = self.start_rollout_controller_and_gateway() + + openai_payload = { + "model": rollout_config.model_name, + "messages": [ + {"role": "user", "content": "你好,请用一句话介绍自己。"}, + ], + "max_tokens": 256, + } + anthropic_payload = { + "model": rollout_config.model_name, + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "今天北京天气怎么样?"}]}, + ], + "tools": [ + { + "name": "get_weather", + "description": "查询指定城市的实时天气", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string", "description": "城市名称"}}, + "required": ["city"], + }, + } + ], + "tool_choice": {"type": "auto"}, + "max_tokens": 512, + } + responses_payload = { + "model": rollout_config.model_name, + "instructions": "你是一个数学助手,回答要简洁。", + "input": [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "1 + 1 等于几?"}], + }, + ], + "max_output_tokens": 1024, + } + + openai_adapter = app.state.gateway_openai_adapter + anthropic_adapter = app.state.gateway_anthropic_adapter + responses_adapter = app.state.gateway_responses_adapter + openai_api_key = "trace-openai" + anthropic_api_key = "trace-anthropic" + responses_api_key = "trace-responses" + openai_trace_key = build_api_key_trace_key(openai_api_key) + anthropic_trace_key = build_api_key_trace_key(anthropic_api_key) + responses_trace_key = build_api_key_trace_key(responses_api_key) + + openai_response = self._post_json( + base_url, + "/v1/chat/completions", + openai_payload, + api_key=openai_api_key, + ) + self.assertEqual(openai_response.status_code, 200, openai_response.text) + openai_body = openai_response.json() + self._write_json_output(self.openai_body_output_path, openai_body) + self.assertEqual(openai_body["model"], rollout_config.model_name) + self.assertEqual(openai_body["choices"][0]["message"]["role"], "assistant") + self.assertIn(openai_body["choices"][0]["finish_reason"], {"stop", "length"}) + self.assertGreater(openai_body["usage"]["prompt_tokens"], 0) + self.assertTrue(openai_body["choices"][0]["message"].get("content")) + + anthropic_response = self._post_json( + base_url, + "/v1/messages", + anthropic_payload, + api_key=anthropic_api_key, + ) + self.assertEqual(anthropic_response.status_code, 200, anthropic_response.text) + anthropic_body = anthropic_response.json() + self._write_json_output(self.anthropic_body_output_path, anthropic_body) + self.assertEqual(anthropic_body["type"], "message") + self.assertEqual(anthropic_body["role"], "assistant") + self.assertEqual(anthropic_body["model"], rollout_config.model_name) + self.assertGreater(anthropic_body["usage"]["input_tokens"], 0) + self.assertTrue(anthropic_body["content"]) + + responses_response = self._post_json( + base_url, + "/v1/responses", + responses_payload, + api_key=responses_api_key, + ) + self.assertEqual(responses_response.status_code, 200, responses_response.text) + responses_body = responses_response.json() + self._write_json_output(self.responses_body_output_path, responses_body) + self.assertEqual(responses_body["object"], "response") + self.assertEqual(responses_body["model"], rollout_config.model_name) + self.assertGreater(responses_body["usage"]["input_tokens"], 0) + self.assertTrue(responses_body["output"]) + + openai_traces = openai_adapter.get_trace_records(openai_trace_key) + anthropic_traces = anthropic_adapter.get_trace_records(anthropic_trace_key) + responses_traces = responses_adapter.get_trace_records(responses_trace_key) + self.assertEqual(len(openai_traces), 1) + self.assertEqual(len(anthropic_traces), 1) + self.assertEqual(len(responses_traces), 1) + openai_trace = openai_traces[0] + anthropic_trace = anthropic_traces[0] + responses_trace = responses_traces[0] + self.assertEqual(openai_trace.trace_key, openai_trace_key) + self.assertEqual(anthropic_trace.trace_key, anthropic_trace_key) + self.assertEqual(responses_trace.trace_key, responses_trace_key) + self.assertNotEqual(openai_trace.trace_key, openai_api_key) + self.assertNotEqual(anthropic_trace.trace_key, anthropic_api_key) + self.assertNotEqual(responses_trace.trace_key, responses_api_key) + self.assertEqual(openai_trace.sequence, 0) + self.assertEqual(anthropic_trace.sequence, 0) + self.assertEqual(responses_trace.sequence, 0) + self.assertGreater(openai_trace.created_at, 0.0) + self.assertGreater(anthropic_trace.created_at, 0.0) + self.assertGreater(responses_trace.created_at, 0.0) + + capture_records = self._read_capture_records() + self.assertGreaterEqual(len(capture_records), 3) + protocol_records = {record["protocol"]: record for record in capture_records[-3:]} + self.assertIn("OpenAIChatAdapter", protocol_records) + self.assertIn("AnthropicChatAdapter", protocol_records) + self.assertIn("OpenAIResponsesAdapter", protocol_records) + + openai_record = protocol_records["OpenAIChatAdapter"] + self.assertTrue(openai_record["internal_messages"]) + self.assertEqual(openai_record["request_id"], openai_trace.request_id) + self.assertEqual(openai_record["output_messages"][0]["role"], "assistant") + self.assertTrue(openai_record["input_text"]) + self._assert_trace_record_matches_capture( + openai_trace, + openai_record, + expected_request_field="messages", + expected_request_role="user", + expected_response_field="choices", + ) + self.assertEqual(openai_trace.response_snapshot["choices"][0]["message"]["role"], "assistant") + + anthropic_record = protocol_records["AnthropicChatAdapter"] + self.assertEqual(anthropic_record["request_id"], anthropic_trace.request_id) + self.assertTrue(anthropic_record["output_messages"][0]["content"]) + self._assert_trace_record_matches_capture( + anthropic_trace, + anthropic_record, + expected_request_field="messages", + expected_request_role="user", + expected_response_field="content", + ) + self.assertEqual(anthropic_trace.request_snapshot["messages"][0]["role"], "user") + self.assertEqual(anthropic_trace.response_snapshot["role"], "assistant") + + responses_record = protocol_records["OpenAIResponsesAdapter"] + self.assertTrue(responses_record["output_messages"]) + self.assertEqual(responses_record["request_id"], responses_trace.request_id) + self.assertTrue(responses_record["input_text"]) + self._assert_trace_record_matches_capture( + responses_trace, + responses_record, + expected_request_field="input", + expected_request_role=None, + expected_response_field="output", + ) + self.assertEqual(responses_trace.response_snapshot["status"], "completed") + + openai_trace_get_response = httpx.get( + f"{base_url}/trace_store", + headers={"Authorization": f"Bearer {openai_api_key}"}, + timeout=30.0, + ) + self.assertEqual(openai_trace_get_response.status_code, 200, openai_trace_get_response.text) + openai_trace_get_body = openai_trace_get_response.json() + self.assertEqual(openai_trace_get_body["trace_key"], openai_trace_key) + self.assertEqual(openai_trace_get_body["count"], 1) + self.assertEqual(openai_trace_get_body["records"][0]["request_id"], openai_trace.request_id) + self.assertEqual(openai_trace_get_body["records"][0]["status"], openai_trace.status.value) + self.assertEqual(openai_trace_get_body["records"][0]["sequence"], openai_trace.sequence) + + openai_trace_pop_response = httpx.post( + f"{base_url}/trace_store/pop", + headers={"Authorization": f"Bearer {openai_api_key}"}, + timeout=30.0, + ) + self.assertEqual(openai_trace_pop_response.status_code, 200, openai_trace_pop_response.text) + openai_trace_pop_body = openai_trace_pop_response.json() + self.assertEqual(openai_trace_pop_body["trace_key"], openai_trace_key) + self.assertEqual(openai_trace_pop_body["count"], 1) + self.assertEqual(openai_trace_pop_body["records"][0]["request_id"], openai_trace.request_id) + + anthropic_trace_pop_response = httpx.post( + f"{base_url}/trace_store/pop", + params={"trace_key": anthropic_trace_key}, + timeout=30.0, + ) + self.assertEqual(anthropic_trace_pop_response.status_code, 200, anthropic_trace_pop_response.text) + anthropic_trace_pop_body = anthropic_trace_pop_response.json() + self.assertEqual(anthropic_trace_pop_body["trace_key"], anthropic_trace_key) + self.assertEqual(anthropic_trace_pop_body["count"], 1) + self.assertEqual(anthropic_trace_pop_body["records"][0]["request_id"], anthropic_trace.request_id) + + responses_trace_clear_response = httpx.post( + f"{base_url}/trace_store/clear", + params={"trace_key": responses_trace_key}, + timeout=30.0, + ) + self.assertEqual(responses_trace_clear_response.status_code, 200, responses_trace_clear_response.text) + responses_trace_clear_body = responses_trace_clear_response.json() + self.assertEqual(responses_trace_clear_body["trace_key"], responses_trace_key) + self.assertTrue(responses_trace_clear_body["cleared"]) + + self.assertEqual(openai_adapter.get_trace_records(openai_trace_key), []) + self.assertEqual(anthropic_adapter.get_trace_records(anthropic_trace_key), []) + self.assertEqual(responses_adapter.get_trace_records(responses_trace_key), []) + + def test_gateway_ir_fallback_behavior(self): + self.controller = self._build_controller() + rollout_config = self._get_rollout_config() + gateway_config = GatewayConfig(port=self._find_free_port(), capture_folder=str(self.capture_output_path)) + app = build_local_gateway_app(self.controller, config=gateway_config) + base_url, _ = self._serve_gateway_blocking_in_background(app, gateway_config) + + openai_payload = { + "model": rollout_config.model_name, + "messages": [ + {"role": "user", "content": "Call the search tool if you need it."}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_bad_openai", + "type": "function", + "function": { + "name": "search", + "arguments": "not-json", + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_bad_openai", + "content": "Sunny, 26C", + }, + {"role": "user", "content": "Finish the answer in one sentence. DONE"}, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search the latest weather.", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } + ], + "tool_choice": { + "type": "function", + "function": {"name": "search"}, + }, + "temperature": 0.2, + "top_p": 0.9, + "presence_penalty": 0.6, + "frequency_penalty": 0.4, + "stop": ["DONE"], + "max_tokens": 32, + } + openai_invalid_n_payload = { + **openai_payload, + "n": 2, + } + anthropic_payload = { + "model": rollout_config.model_name, + "system": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "abc", + }, + } + ], + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "hello"}]}, + ], + "max_tokens": 32, + } + responses_payload = { + "model": rollout_config.model_name, + "instructions": "Follow the system rule.", + "input": [ + { + "type": "message", + "role": "developer", + "content": [{"type": "input_text", "text": "Use concise answers."}], + }, + { + "type": "reasoning", + "summary": [{"type": "summary_text", "text": "Need private reasoning first."}], + }, + { + "type": "function_call", + "call_id": "call_bad_responses", + "name": "search", + "arguments": "not-json", + }, + { + "type": "function_call_output", + "call_id": "call_bad_responses", + "output": [{"type": "text", "text": "Sunny, 26C"}], + }, + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Answer now."}], + }, + ], + "tools": [ + { + "type": "function", + "name": "search", + "description": "Search the latest weather.", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + { + "type": "web_search_preview", + "name": "web_search_preview", + }, + ], + "tool_choice": {"type": "function", "name": "search"}, + "parallel_tool_calls": True, + "store": True, + "include": ["reasoning.encrypted_content"], + "reasoning": {"effort": "high"}, + "temperature": 0.1, + "top_p": 0.8, + "max_output_tokens": 32, + } + responses_invalid_content_payload = { + **responses_payload, + "input": [ + { + "type": "message", + "role": "developer", + "content": [ + {"type": "input_text", "text": "Use concise answers."}, + {"type": "image", "image_url": "https://example.com/ignored.png"}, + ], + } + ], + } + responses_stream_payload = { + **responses_payload, + "stream": True, + } + openai_response = self._post_json(base_url, "/v1/chat/completions", openai_payload) + self.assertEqual(openai_response.status_code, 200, openai_response.text) + + openai_invalid_n_response = self._post_json(base_url, "/v1/chat/completions", openai_invalid_n_payload) + self.assertEqual(openai_invalid_n_response.status_code, 400, openai_invalid_n_response.text) + openai_invalid_n_body = openai_invalid_n_response.json() + self.assertEqual(openai_invalid_n_body["error"]["type"], "invalid_request_error") + self.assertEqual(openai_invalid_n_body["error"]["code"], "n_not_supported") + + anthropic_response = self._post_json(base_url, "/v1/messages", anthropic_payload) + self.assertEqual(anthropic_response.status_code, 400, anthropic_response.text) + anthropic_error_body = anthropic_response.json() + self.assertEqual(anthropic_error_body["type"], "error") + self.assertEqual(anthropic_error_body["error"]["type"], "invalid_request_error") + self.assertIn("Unsupported Anthropic content block type(s) in system: image", anthropic_error_body["error"]["message"]) + + responses_response = self._post_json(base_url, "/v1/responses", responses_payload) + self.assertEqual(responses_response.status_code, 200, responses_response.text) + + responses_invalid_content_response = self._post_json(base_url, "/v1/responses", responses_invalid_content_payload) + self.assertEqual(responses_invalid_content_response.status_code, 400, responses_invalid_content_response.text) + responses_invalid_content_body = responses_invalid_content_response.json() + self.assertEqual(responses_invalid_content_body["error"]["type"], "invalid_request_error") + self.assertEqual(responses_invalid_content_body["error"]["code"], "unsupported_content_block") + + responses_stream_response = self._post_json(base_url, "/v1/responses", responses_stream_payload) + self.assertEqual(responses_stream_response.status_code, 200, responses_stream_response.text) + self.assertEqual( + responses_stream_response.headers.get("content-type"), + "text/event-stream; charset=utf-8", + ) + self.assertIn("event: response.created", responses_stream_response.text) + self.assertIn("event: response.completed", responses_stream_response.text) + + capture_records = self._read_capture_records() + protocol_records = self._capture_records_by_protocol(capture_records) + self.assertIn("OpenAIChatAdapter", protocol_records) + self.assertIn("OpenAIResponsesAdapter", protocol_records) + self.assertNotIn("AnthropicChatAdapter", protocol_records) + + openai_record = protocol_records["OpenAIChatAdapter"] + self.assertEqual(openai_record["rollout_tool_choice"], {"type": "function", "function": {"name": "search"}}) + self.assertEqual(len(openai_record["rollout_tools"]), 1) + self.assertEqual(openai_record["rollout_tools"][0]["function"]["name"], "search") + self.assertEqual(openai_record["rollout_sample_params"]["presence_penalty"], 0.6) + self.assertEqual(openai_record["rollout_sample_params"]["frequency_penalty"], 0.4) + self.assertEqual(openai_record["rollout_sample_params"]["temperature"], 0.2) + self.assertEqual(openai_record["rollout_sample_params"]["top_p"], 0.9) + self.assertEqual(openai_record["rollout_sample_params"]["stops"], ["DONE"]) + self.assertEqual( + openai_record["internal_messages"][1]["tool_calls"][0]["function"]["arguments"], + {"raw": "not-json"}, + ) + + responses_record = protocol_records["OpenAIResponsesAdapter"] + self.assertEqual(responses_record["rollout_tool_choice"], {"type": "function", "function": {"name": "search"}}) + self.assertEqual(len(responses_record["rollout_tools"]), 1) + self.assertEqual(responses_record["rollout_tools"][0]["function"]["name"], "search") + self.assertTrue(responses_record["rollout_sample_params"]["max_tokens"] <= 32) + self.assertEqual(responses_record["rollout_sample_params"]["temperature"], 0.1) + self.assertEqual(responses_record["rollout_sample_params"]["top_p"], 0.8) + self.assertNotIn("store", responses_record["rollout_sample_params"]) + self.assertNotIn("include", responses_record["rollout_sample_params"]) + self.assertEqual(responses_record["internal_messages"][0]["role"], "system") + self.assertEqual(responses_record["internal_messages"][0]["content"], "Follow the system rule.") + self.assertEqual(responses_record["internal_messages"][1]["role"], "system") + self.assertEqual(responses_record["internal_messages"][1]["content"], "Use concise answers.") + self.assertEqual( + responses_record["internal_messages"][3]["tool_calls"][0]["function"]["arguments"], + {"raw": "not-json"}, + ) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_judger.py b/tests/rl/test_judger.py new file mode 100644 index 0000000000..3261651b3f --- /dev/null +++ b/tests/rl/test_judger.py @@ -0,0 +1,216 @@ +import os +import json +import ray +import unittest +import tempfile +import numpy as np +import asyncio +from xtuner.v1.rl.utils import AutoCPUWorkers, CPUResourcesConfig +from xtuner.v1.data_proto.rl_data import RolloutState + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +GEO_ROLLOUT_DATA_PATH = os.environ["GEO_ROLLOUT_DATA_PATH"] +VERL_ROLLOUT_DATA_PATH = os.environ["VERL_ROLLOUT_DATA_PATH"] +DAPO_DATA_PATH = os.environ.get("ROLLOUT_DAPO_DATA_PATH") +FAKE_JUDGER_INPUT_ITEM = RolloutState( + message=[{ + 'role': 'user', + 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' + }], + reward_model={'ground_truth': '72', 'style': 'rule'}, + response="\nOkay, let's see. Natalia sold clips to 48 friends in April. Then in May, she sold half as many. So first, I need to figure out how many she sold in May. Half of 48 is 24, right? Because 48 divided by 2 is 24. So in May, she sold 24 clips.\n\nNow, to find the total number of clips sold in both months, I need to add the number from April and May together. That would be 48 (April) plus 24 (May). Let me do the addition: 48 + 24. Hmm, 40 + 20 is 60, and 8 + 4 is 12. So 60 + 12 is 72. So altogether, she sold 72 clips.\n\nWait, let me check that again. 48 plus 24. Yes, 48 + 20 is 68, then plus 4 more is 72. Yep, that seems right. So the total is 72.\n\n\nNatalia sold 48 clips in April. In May, she sold half as many, which is 48 ÷ 2 = 24 clips. Adding both months together: 48 + 24 = 72. \n\n#### 72<|im_end|>" +) + +def construct_gsm8k_judger_data(data_path) -> tuple[list[RolloutState], list[float]]: + states = [] + history_reward = [] + if not data_path or not os.path.exists(data_path): + return states + with open(data_path, 'r', encoding='utf-8') as f: + for line in f: + item = json.loads(line.strip()) + prompt = item["input"][5:-11] + response = item["output"] + gt = item["gts"] + states.append( + RolloutState( + message=[{"role": "user", "content": prompt}], + response=response, + reward_model={"ground_truth": str(gt)} + ) + ) + history_reward.append(item["reward"]) + return states, history_reward + +def construct_geo3k_dapo_judger_data(data_path) -> tuple[list[RolloutState], list[float]]: + states = [] + history_reward = [] + if not data_path or not os.path.exists(data_path): + return states + with open(data_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + for i in range(0, len(lines), 7): + group = ''.join(lines[i:i + 7]).strip() + if not group: continue + item = json.loads(group) + states.append( + RolloutState( + message=[{"role": "user", "content": ""}], + response=item['response'], + reward_model={"ground_truth": str(item["label"])} + ) + ) + history_reward.append(item["reward"]) + return states, history_reward + +class TestJudgerController(unittest.TestCase): + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + async def _judger_batch(self, judger_router, states): + return await asyncio.gather(*(judger_router.judge(s) for s in states)) + + def test_gsm8k_judger(self): + from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig + + gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + # Test Case 1: NativeJudger + native_judger = GSM8KJudgerConfig(judger_name="openai/gsm8k").build() + res1 = asyncio.run(native_judger.judge(FAKE_JUDGER_INPUT_ITEM)) + self.assertEqual(res1.reward["score"], 1.0) + + # Test Case 2: remote judger with given pg + cpu_cfg = CPUResourcesConfig(num_workers=1, num_cpus_per_worker=1) + pg = AutoCPUWorkers.build_placement_group(cpu_cfg) + ray.get(pg.ready()) + native_judger_actors = gsm8k_judger_config.build(pg, 0) + res2 = asyncio.run(native_judger_actors.judge(FAKE_JUDGER_INPUT_ITEM)) + self.assertEqual(res2.reward["score"], 1.0) + del native_judger_actors + + # Test Case 3: JudgerPool + 一批数据的分数是否正确 + judger_router = gsm8k_judger_config.build(pg) + states, history_reward = construct_gsm8k_judger_data(VERL_ROLLOUT_DATA_PATH) + rollout_states = asyncio.run(self._judger_batch(judger_router, states)) + rewards = [s.reward["score"] for s in rollout_states] + expected_avg_score = np.mean(history_reward) + self.assertEqual(round(np.mean(rewards), 4), round(expected_avg_score, 4)) + + def test_dapo_batch_judge_score(self): + # 测试 dapo judger + 1 个实例池 的评判分数是否正确 + from xtuner.v1.rl.judger.dapo_math import DapoMathJudgerConfig + from xtuner.v1.utils.rl_test_utils import get_eos_token + from transformers import AutoTokenizer + # 构建数据 + states, history_reward = construct_geo3k_dapo_judger_data(DAPO_DATA_PATH) + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + eos_token = get_eos_token(MODEL_PATH) + eos_token_str = tokenizer.convert_ids_to_tokens(eos_token) + # 定义 Judger Config + config = DapoMathJudgerConfig( + judger_name="dapo_math", + num_ray_actors=1, + eos_token=eos_token_str, + enable_overlong_buffer=True, + max_response_len=32768, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer + ) + router = config.build() + rollout_states = asyncio.run(self._judger_batch(router, states)) + rewards = [s.reward["score"] for s in rollout_states] + expected_avg_score = np.mean(history_reward) + self.assertEqual(round(np.mean(rewards), 4), round(expected_avg_score, 4)) + + def test_geo_batch_judge_score(self): + # 测试 geo judger + 4 个实例池的评判分数是否正确 + from xtuner.v1.rl.judger.geo3k import GEO3KJudgerConfig + config = GEO3KJudgerConfig(judger_name="geo3k", num_ray_actors=4) + states, history_reward = construct_geo3k_dapo_judger_data(GEO_ROLLOUT_DATA_PATH) + router = config.build() + rollout_states = asyncio.run(self._judger_batch(router, states)) + rewards = [s.reward["score"] for s in rollout_states] + expected_avg_score = np.mean(history_reward) + self.assertEqual(round(np.mean(rewards), 4), round(expected_avg_score, 4)) + # 验证Router中确实有4个Worker实例在运行 + self.assertEqual(len(router.get_worker_status()), 4) + + def test_multi_judger_router(self): + import time + from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig + + gsm8k_config_1 = GSM8KJudgerConfig( + judger_name="openai/gsm8k_1", + num_ray_actors=2, + num_cpus_per_actor=1, + ) + gsm8k_config_2 = GSM8KJudgerConfig( + judger_name="openai/gsm8k_2", + num_ray_actors=8, + num_cpus_per_actor=2, + ) + + gsm8k_router_1 = gsm8k_config_1.build() + gsm8k_router_2 = gsm8k_config_2.build() + + states, history_reward = construct_gsm8k_judger_data(VERL_ROLLOUT_DATA_PATH) + gsm8k_results_1 = asyncio.run(self._judger_batch(gsm8k_router_1, states)) + gsm8k_results_2 = asyncio.run(self._judger_batch(gsm8k_router_2, states)) + + gsm8k_rewards_1 = [s.reward["score"] for s in gsm8k_results_1] + gsm8k_rewards_2 = [s.reward["score"] for s in gsm8k_results_2] + + expected_avg_score = np.mean(history_reward) + self.assertEqual(round(np.mean(gsm8k_rewards_1), 4), round(expected_avg_score, 4)) + self.assertEqual(round(np.mean(gsm8k_rewards_2), 4), round(expected_avg_score, 4)) + self.assertEqual(len(gsm8k_router_1.get_worker_status()), 2) + self.assertEqual(len(gsm8k_router_2.get_worker_status()), 8) + + def test_gsm8k_remote_judger(self): + # 测试输入remote_url时 + 1个实例 + 裸的NativeJudger的评判分数是否正确 + from xtuner.v1.utils.rl_test_utils import JudgerServer, GSM8KRemoteJudgerConfig + + server = JudgerServer(port=8018) + server.start() + try: + remote_judger_config = GSM8KRemoteJudgerConfig(judger_name="openai/gsm8k", reward_handler=server.url) + native_remote_judger = remote_judger_config.build() + res = asyncio.run(native_remote_judger.judge(FAKE_JUDGER_INPUT_ITEM)) + self.assertEqual(res.reward["score"], 1.0) + finally: + server.stop() + + def test_composed_judger_config(self): + from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig + + def reward_a(response, label, extra_info): + return {"score": 1.0, "source": "a"} + + def reward_b(response, label, extra_info): + return {"score": 0.25, "source": "b"} + + judger_config = ComposedJudgerConfig( + branches={ + "correctness": JudgerConfig(judger_name="correctness", reward_handler=reward_a), + "format": JudgerConfig(judger_name="format", reward_handler=reward_b), + }, + select_fn=lambda state, branches: ["correctness", "format"], + ) + + judger = judger_config.build() + rollout_state = asyncio.run(judger.judge(FAKE_JUDGER_INPUT_ITEM.model_copy(deep=True))) + + self.assertEqual(rollout_state.reward["correctness"], 1.0) + self.assertEqual(rollout_state.reward["format"], 0.25) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_mock_rollout.py b/tests/rl/test_mock_rollout.py new file mode 100644 index 0000000000..4ac1ec4b39 --- /dev/null +++ b/tests/rl/test_mock_rollout.py @@ -0,0 +1,193 @@ +import os +import asyncio +import unittest +import ray +from transformers import AutoTokenizer +import torch +import tempfile +import httpx +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.rollout.lmdeploy import LMDeployWorker +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.rollout.controller import RolloutController +from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult + +TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +resource_map = {"npu": "NPU", "cuda": "GPU"} + +class MockTimeoutRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + raise httpx.TimeoutException("Mocked timeout error") + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked timeout exception: {e.__class__.__name__}") + return result + + def _launch_server(self): + pass # Override + + +class MockRequestErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + raise httpx.RequestError("Mocked httpx request error", request=req) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked request error exception: {e.__class__.__name__}") + return result + + def _launch_server(self): + pass # Override + + +class MockClientErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(400, request=req) + raise httpx.HTTPStatusError("Mocked client error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked client exception: {e.__class__.__name__}") + return result + + def _launch_server(self): + pass # Override + + +class MockServerErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(500, request=req) + raise httpx.HTTPStatusError("Mocked server error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked server exception: {e.__class__.__name__}") + return result + + def _launch_server(self): + pass # Override + +class MockInvalidResponseRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + mock_rollout_state = RolloutState(message=TEST_TEXT_MESSAGES, status=Status.FAILED) + result = HttpRequestResult(response=mock_rollout_state) + return result + + async def _safe_handle_response(self, rollout_state, http_response) -> RolloutState: + mock_rollout_state = RolloutState(message=TEST_TEXT_MESSAGES, status=Status.FAILED) + return mock_rollout_state + + def _launch_server(self): + pass # Override + +@ray.remote +class MockTimeoutRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockTimeoutRolloutWorker) + +@ray.remote +class MockRequestErrorRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockRequestErrorRolloutWorker) + +@ray.remote +class MockClientErrorRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockClientErrorRolloutWorker) + +@ray.remote +class MockServerErrorRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockServerErrorRolloutWorker) + +@ray.remote +class MockInvalidResponseRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockInvalidResponseRolloutWorker) + +class TestMockRollout(unittest.TestCase): + @classmethod + def setUpClass(cls): + os.environ["XTUNER_USE_FA3"] = "1" + + @classmethod + def tearDownClass(cls): + del os.environ["XTUNER_USE_FA3"] + + def setUp(self): + current_dir = os.path.abspath(os.path.dirname(__file__)) + python_path = f"{current_dir}:{os.environ.get('PYTHONPATH', '')}" + + ray.init(num_cpus=80, ignore_reinit_error=True, runtime_env={"env_vars": {"PYTHONPATH": python_path}}) + self.global_batch_size = 3 + self.max_prompt_length = 4096 + self.max_response_length = 128 + self.max_concurrent = 3 + self.max_retry_times = 3 + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.rollout_cfg = RolloutConfig( + env="test_mock_rollout", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + tensor_parallel_size=1, + context_length=self.max_prompt_length + self.max_response_length, + max_retry_per_worker=2, + max_retry_per_sample=3, + worker_log_dir=self.worker_log_dir, + ) + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + async def _run_mock_test(self, mock_controller_cls, error_name, pg): + rollout_controller = mock_controller_cls.remote(self.rollout_cfg, pg) + input_state = RolloutState(message=TEST_TEXT_MESSAGES) + result_state = await rollout_controller.generate.remote(rollout_state=input_state) + self.assertEqual(result_state.status, Status.FAILED, f"Expected rollout to fail due to {error_name}, but it succeeded.") + self.assertIsNotNone(result_state.error_msg, f"Expected an error message for {error_name} case, but got None.") + if error_name == "server_error": + self.assertIn("Server error", result_state.error_msg, f"Expected error message to indicate a server error for {error_name} case, but got: {result_state.error_msg}") + elif error_name == "client_error": + self.assertIn("Client error", result_state.error_msg, f"Expected error message to indicate a client error for {error_name} case, but got: {result_state.error_msg}") + elif error_name in ["request_error", "timeout"]: + self.assertIn("Request failed", result_state.error_msg, f"Expected error message to indicate a request error for {error_name} case, but got: {result_state.error_msg}") + self.assertIn(str(self.rollout_cfg.max_retry_per_sample), result_state.error_msg, f"Expected error message to include max retry times for {error_name} case, but got: {result_state.error_msg}") + elif error_name == "invalid_response": + self.assertIn("Invalid rollout response", result_state.error_msg, f"Expected error message to indicate an invalid response for {error_name} case, but got: {result_state.error_msg}") + self.assertIn(str(self.rollout_cfg.max_retry_per_sample), result_state.error_msg, f"Expected error message to include max retry times for {error_name} case, but got: {result_state.error_msg}") + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_parallel_mock_rollout(self): + async def run_parallel(): + res_cfg_small = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=1, + num_cpus_per_worker=2, + ) + + pgs = [AutoAcceleratorWorkers.build_placement_group(res_cfg_small, name=f"pg_{i}") for i in range(5)] + await asyncio.gather(*[pg.ready() for pg in pgs]) + + tasks = [ + self._run_mock_test(MockTimeoutRolloutController, "timeout", pgs[0]), + self._run_mock_test(MockRequestErrorRolloutController, "request_error", pgs[1]), + self._run_mock_test(MockClientErrorRolloutController, "client_error", pgs[2]), + self._run_mock_test(MockServerErrorRolloutController, "server_error", pgs[3]), + self._run_mock_test(MockInvalidResponseRolloutController, "invalid_response", pgs[4]), + ] + await asyncio.gather(*tasks) + + asyncio.run(run_parallel()) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/rl/test_multi_task_agent_loop_manager.py b/tests/rl/test_multi_task_agent_loop_manager.py new file mode 100644 index 0000000000..931404ca5f --- /dev/null +++ b/tests/rl/test_multi_task_agent_loop_manager.py @@ -0,0 +1,188 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock + +from xtuner.v1.rl.agent_loop.agent_loop_manager import ( + AgentLoopManager, + AgentLoopManagerConfig, + TaskSpecConfig, + _TaskRunner, +) +from xtuner.v1.rl.agent_loop.producer import ProducerTimings +from xtuner.v1.data_proto import Status + + +class _FakeSampler: + def __init__(self, size: int = 1): + self._size = size + + def __len__(self) -> int: + return self._size + + def save(self, checkpoint_path): + return None + + def resume(self, checkpoint_path): + return None + + +class _FakeProduceStrategy: + def __init__(self, generate_times_s: list[float]): + self.generate_times_s = generate_times_s + self.called_batch_sizes: list[int] = [] + + async def produce_batch( + self, + agent_loop, + sampler, + replay_buffer, + batch_size: int, + task_name: str, + rollout_step: int = 0, + ) -> ProducerTimings: + self.called_batch_sizes.append(batch_size) + return ProducerTimings(generate_times_s=self.generate_times_s) + + +class _FakeReplayBuffer: + def __init__(self, rollout_states_by_task: dict[str, list[list[str]]], leftover_counts: dict[tuple[str, Status], int]): + self._rollout_states_by_task = rollout_states_by_task + self._leftover_counts = leftover_counts + + async def get(self, batch_size: int, task_name: str, group_status: Status): + assert group_status == Status.COMPLETED + return self._rollout_states_by_task.get(task_name, []) + + async def count(self, task_name: str, group_status: Status): + return self._leftover_counts.get((task_name, group_status), 0) + +def _fake_agent_loop(): + rollout_ctl = MagicMock() + rollout_ctl.continue_generation.remote = AsyncMock() + rollout_ctl.pause_generation.remote = AsyncMock() + rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + agent_loop = MagicMock() + agent_loop.rollout_ctl = rollout_ctl + return agent_loop + + +class TestMultiTaskAgentLoopManager(unittest.IsolatedAsyncioTestCase): + def test_manager_config_accepts_single_task_spec(self): + task = TaskSpecConfig.model_construct( + task_name="single_task", + agent_loop_config=MagicMock(), + produce_strategy_config=MagicMock(), + sampler_config=MagicMock(), + weight=1.0, + ) + + manager_config = AgentLoopManagerConfig(tasks=task) + + self.assertEqual(manager_config.tasks.task_name, "single_task") + + async def test_produce_batch_allocates_by_weight_and_returns_task_sorted_results(self): + strategy_a = _FakeProduceStrategy(generate_times_s=[2.0, 2.0]) + strategy_b = _FakeProduceStrategy(generate_times_s=[1.0, 1.0, 1.0]) + strategy_c = _FakeProduceStrategy(generate_times_s=[]) + replay_buffer = _FakeReplayBuffer( + rollout_states_by_task={ + "task_a": [["a-0"], ["a-1"]], + "task_b": [["b-0"], ["b-1"], ["b-2"]], + "task_c": [], + }, + leftover_counts={ + ("task_a", Status.COMPLETED): 1, + ("task_b", Status.ABORTED): 2, + }, + ) + + multi_task_manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_b", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_b, + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_a, + sampler=_FakeSampler(), + weight=2.0, + order=1, + ), + _TaskRunner( + task_name="task_c", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_c, + sampler=_FakeSampler(), + weight=0.0, + order=2, + ), + ], + replay_buffer=replay_buffer, + ) + + result = await multi_task_manager.produce_batch(batch_size=7, rollout_step=3) + + self.assertEqual(result.task_batch_sizes, {"task_a": 5, "task_b": 2, "task_c": 0}) + self.assertEqual(strategy_a.called_batch_sizes, [5]) + self.assertEqual(strategy_b.called_batch_sizes, [2]) + self.assertEqual(strategy_c.called_batch_sizes, []) + self.assertEqual(result.rollout_states, [["a-0"], ["a-1"], ["b-0"], ["b-1"], ["b-2"]]) + self.assertEqual(result.leftover_completed, 1) + self.assertEqual(result.leftover_aborted, 2) + self.assertEqual(result.leftover_expired, 0) + self.assertEqual(result.group_gen_count, 5) + self.assertAlmostEqual(result.group_gen_mean_s, 1.4) + self.assertIn("task_a", result.task_results) + self.assertIn("task_b", result.task_results) + self.assertIn("task_c", result.task_results) + + async def test_custom_get_task_batch_sizes_can_disable_tasks(self): + strategy_a = _FakeProduceStrategy(generate_times_s=[2.0]) + strategy_b = _FakeProduceStrategy(generate_times_s=[1.0, 1.0]) + replay_buffer = _FakeReplayBuffer( + rollout_states_by_task={ + "task_a": [["a-0"]], + "task_b": [["b-0"], ["b-1"]], + }, + leftover_counts={}, + ) + + class _CustomBatchManager(AgentLoopManager): + def get_task_batch_sizes(self, global_batch_size: int, rollout_step: int) -> dict[str, int]: + self.observed_rollout_step = rollout_step + return {"task_a": 0, "task_b": global_batch_size} + + multi_task_manager = _CustomBatchManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_a, + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + _TaskRunner( + task_name="task_b", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_b, + sampler=_FakeSampler(), + weight=1.0, + order=1, + ), + ], + replay_buffer=replay_buffer, + ) + + result = await multi_task_manager.produce_batch(batch_size=2, rollout_step=9) + + self.assertEqual(multi_task_manager.observed_rollout_step, 9) + self.assertEqual(result.task_batch_sizes, {"task_a": 0, "task_b": 2}) + self.assertEqual(strategy_a.called_batch_sizes, []) + self.assertEqual(strategy_b.called_batch_sizes, [2]) + self.assertEqual(result.rollout_states, [["b-0"], ["b-1"]]) diff --git a/tests/rl/test_producer.py b/tests/rl/test_producer.py new file mode 100644 index 0000000000..a9c20dfa08 --- /dev/null +++ b/tests/rl/test_producer.py @@ -0,0 +1,117 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch +from xtuner.v1.rl.agent_loop import SamplerConfig, SyncProduceStrategyConfig, AsyncProduceStrategyConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.data_proto.rl_data import RolloutState, Status + +class MockRolloutState: + def __init__(self, id, seq_staleness=1, status=Status.COMPLETED): + self.id = id + self.status = status + self.seq_staleness = seq_staleness + +class TestProducer(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + # 1. 模拟 DataloaderConfig 和 Dataloader + self.mock_dataloader_cfg = MagicMock() + self.mock_dataloader = MagicMock() + # 模拟 next(dataloader_iter) 返回 [RolloutState] + self.mock_dataloader.__iter__.return_value = iter([[MockRolloutState(i)] for i in range(100)]) + self.mock_dataloader_cfg.build.return_value = self.mock_dataloader + + # 2. 模拟 Tokenizer + self.mock_tokenizer = MagicMock() + + # 3. 准备 ReplayBuffer + replay_buffer_cfg = AsyncReplayBufferConfig() + self.replay_buffer = replay_buffer_cfg.build() + + async def test_sampler_with_replay_buffer(self): + task_name = "test_task" + sampler_cfg = SamplerConfig.model_construct(dataloader_cfg=self.mock_dataloader_cfg) + sampler = sampler_cfg.build(self.mock_tokenizer, self.replay_buffer) + + # 场景 A: ReplayBuffer 为空,从 Dataloader 拿 + data = await sampler.sample(task_name) + self.assertEqual(data[0].id, 0) + + # 场景 B: ReplayBuffer 有 ABORTED 数据,优先拿 + aborted_item = MockRolloutState(999, status=Status.ABORTED) + await self.replay_buffer.put([aborted_item], task_name) + + data = await sampler.sample(task_name, group_status=Status.ABORTED) + self.assertEqual(data[0].id, 999) + + async def test_sync_produce_strategy(self): + task_name = "test_task" + mock_agent_loop = MagicMock() + mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) + mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None) + mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + + async def mock_gen(rs): + await asyncio.sleep(0.01 * rs[0].id) + for r in rs: + r.status = Status.COMPLETED + return rs + mock_agent_loop.generate_group = mock_gen + + sampler_cfg = SamplerConfig.model_construct(dataloader_cfg=self.mock_dataloader_cfg) + produce_strategy_cfg = SyncProduceStrategyConfig() + + sampler = sampler_cfg.build(self.mock_tokenizer, self.replay_buffer) + strategy = produce_strategy_cfg.build() + + # 执行:生产 batch_size 为 2 的数据 + await strategy.produce_batch(mock_agent_loop, sampler, self.replay_buffer, batch_size=2, task_name=task_name) + + # 验证:ReplayBuffer 中应该有 2 条 COMPLETED 数据 + final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) + print(final_data[0][0].id, final_data[0][0].status) + print(final_data[1][0].id, final_data[1][0].status) + self.assertEqual(len(final_data), 2) + self.assertEqual(final_data[0][0].id, 0) + self.assertEqual(final_data[1][0].id, 1) + + async def test_async_produce_strategy(self): + # 这个async_produce_strategy的测试主要验证超发逻辑 + staleness 优先get的逻辑 + # 异步的其他功能如 partial_rollout, tail_batch不在这里进行验证 + mock_agent_loop = MagicMock() + mock_agent_loop.pause = AsyncMock() + mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) + task_name = "test_task" + call_count = 0 + async def mock_gen(rs, **kwargs): + nonlocal call_count + call_count += 1 + for r in rs: + if r.id == 999: + r.seq_staleness = 5 + else: + r.seq_staleness = call_count + r.status = Status.COMPLETED + print(r.id, r.seq_staleness, r.status) + return rs + + mock_agent_loop.generate_group = mock_gen + + sampler_cfg = SamplerConfig.model_construct(dataloader_cfg=self.mock_dataloader_cfg) + produce_strategy_cfg = AsyncProduceStrategyConfig(over_sample_threshold= 1) + sampler = sampler_cfg.build(self.mock_tokenizer, self.replay_buffer) + strategy = produce_strategy_cfg.build() + # 预处理 + aborted_item = MockRolloutState(999, status=Status.ABORTED) + await self.replay_buffer.put([aborted_item], task_name) + # 执行 + await strategy.produce_batch(mock_agent_loop, sampler, self.replay_buffer, batch_size=2, task_name=task_name) + + # 验证:ReplayBuffer 中应该有 4 条 COMPLETED 数据, + # NOTE(@duanyanhui): 目前还没实现暂停功能,所以4条都会推理完成,4条数据按照新鲜度顺序排列,999 是最旧的,0 是最新的 + final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) + self.assertEqual(len(final_data), 4) + self.assertEqual(final_data[0][0].id, 999) + self.assertEqual(final_data[1][0].id, 2) + self.assertEqual(final_data[2][0].id, 1) + self.assertEqual(final_data[3][0].id, 0) diff --git a/tests/rl/test_replay_buffer.py b/tests/rl/test_replay_buffer.py new file mode 100644 index 0000000000..d04e532a79 --- /dev/null +++ b/tests/rl/test_replay_buffer.py @@ -0,0 +1,129 @@ +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig, SyncReplayBufferConfig + + +class MockState: + def __init__(self, state_id, staleness=0, input_ids=None, status=Status.COMPLETED): + self.id = state_id + self.seq_staleness = staleness + self.status = status + self.input_ids = input_ids if input_ids is not None else [state_id] + + +class TestReplayBuffer(unittest.IsolatedAsyncioTestCase): + @staticmethod + def _get_sorted_input_ids(data_groups): + return sorted(tuple(state.input_ids) for group in data_groups for state in group) + + async def _run_roundtrip_input_ids_case(self, replay_buffer_cfg, put_groups, task_name, sample_size): + with TemporaryDirectory() as tmp_dir: + save_path = Path(tmp_dir) + original = replay_buffer_cfg.build() + for group in put_groups: + await original.put(group, task_name) + await original.save(save_path) + + old_sampled = await original.get(sample_size, task_name, Status.COMPLETED) + + resumed = replay_buffer_cfg.build() + await resumed.resume(save_path) + new_sampled = await resumed.get(sample_size, task_name, Status.COMPLETED) + + self.assertEqual(self._get_sorted_input_ids(old_sampled), self._get_sorted_input_ids(new_sampled)) + + async def test_basic_ordering_and_task_isolation(self): + fifo_cfg = SyncReplayBufferConfig() + fifo = fifo_cfg.build() + await fifo.put([MockState(1), MockState(2)], "task1") + await fifo.put([MockState(3)], "task1") + await fifo.put([MockState(200)], "task2") + + res_task1 = await fifo.get(2, "task1", Status.COMPLETED) + res_task2 = await fifo.get(1, "task2", Status.COMPLETED) + self.assertEqual([s.id for s in res_task1[0]], [1, 2]) + self.assertEqual([s.id for s in res_task1[1]], [3]) + self.assertEqual([s.id for s in res_task2[0]], [200]) + + staleness_cfg = AsyncReplayBufferConfig() + staleness = staleness_cfg.build() + await staleness.put([MockState("low", staleness=1)], "task") + await staleness.put([MockState("high", staleness=5)], "task") + sampled = await staleness.get(2, "task", Status.COMPLETED) + self.assertEqual(sampled[0][0].id, "high") + self.assertEqual(sampled[1][0].id, "low") + + async def test_save_resume_keeps_query_behavior_fifo(self): + replay_buffer_cfg = SyncReplayBufferConfig() + with TemporaryDirectory() as tmp_dir: + save_path = Path(tmp_dir) + buffer = replay_buffer_cfg.build() + await buffer.put([MockState("a1", status=Status.COMPLETED, input_ids=[11, 12])], "task_a") + await buffer.put([MockState("a2", status=Status.FAILED, input_ids=[21])], "task_a") + await buffer.put([MockState("b1", status=Status.COMPLETED, input_ids=[31])], "task_b") + await buffer.save(save_path) + + resumed = replay_buffer_cfg.build() + await resumed.resume(save_path) + + self.assertEqual(await resumed.count("task_a", Status.COMPLETED), 1) + self.assertEqual(await resumed.count("task_a", Status.FAILED), 1) + self.assertEqual(await resumed.count("task_b", Status.COMPLETED), 1) + self.assertEqual(await resumed.count("task_b", Status.FAILED), 0) + + completed = await resumed.get(5, "task_a", Status.COMPLETED) + failed = await resumed.get(5, "task_a", Status.FAILED) + self.assertEqual([s.id for s in completed[0]], ["a1"]) + self.assertEqual([s.id for s in failed[0]], ["a2"]) + + await resumed.put([MockState("a3", input_ids=[41])], "task_a") + next_completed = await resumed.get(1, "task_a", Status.COMPLETED) + self.assertEqual([s.id for s in next_completed[0]], ["a3"]) + + async def test_save_resume_keeps_query_behavior_staleness(self): + replay_buffer_cfg = AsyncReplayBufferConfig() + with TemporaryDirectory() as tmp_dir: + save_path = Path(tmp_dir) + buffer = replay_buffer_cfg.build() + await buffer.put([MockState("done_low", staleness=1, status=Status.COMPLETED, input_ids=[101])], "task") + await buffer.put([MockState("failed_high", staleness=10, status=Status.FAILED, input_ids=[201])], "task") + await buffer.put([MockState("done_mid", staleness=5, status=Status.COMPLETED, input_ids=[301, 302])], "task") + await buffer.save(save_path) + + resumed = replay_buffer_cfg.build() + await resumed.resume(save_path) + + self.assertEqual(await resumed.count("task", Status.COMPLETED), 2) + self.assertEqual(await resumed.count("task", Status.FAILED), 1) + + completed = await resumed.get(2, "task", Status.COMPLETED) + failed = await resumed.get(1, "task", Status.FAILED) + self.assertEqual(completed[0][0].id, "done_mid") + self.assertEqual(completed[1][0].id, "done_low") + self.assertEqual(failed[0][0].id, "failed_high") + + async def test_save_resume_sample_keeps_input_ids_fifo(self): + await self._run_roundtrip_input_ids_case( + replay_buffer_cfg=SyncReplayBufferConfig(), + put_groups=[ + [MockState(1, input_ids=[101, 102]), MockState(2, input_ids=[201])], + [MockState(3, input_ids=[301, 302, 303])], + ], + task_name="task", + sample_size=2, + ) + + async def test_save_resume_sample_keeps_input_ids_staleness(self): + await self._run_roundtrip_input_ids_case( + replay_buffer_cfg=AsyncReplayBufferConfig(), + put_groups=[ + [MockState("mid", staleness=3, input_ids=[301, 302])], + [MockState("high", staleness=5, input_ids=[501])], + [MockState("low", staleness=1, input_ids=[101, 102, 103])], + ], + task_name="task", + sample_size=3, + ) \ No newline at end of file diff --git a/tests/rl/test_rl_colocate_trainer_integration.py b/tests/rl/test_rl_colocate_trainer_integration.py new file mode 100644 index 0000000000..b47fc32f73 --- /dev/null +++ b/tests/rl/test_rl_colocate_trainer_integration.py @@ -0,0 +1,326 @@ +import os +import unittest +import shutil +import tempfile +import ray +from pathlib import Path + +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.train.trainer import LoadCheckpointConfig +from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.loss import CELossConfig +from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.agent_loop import ( + AgentLoopManagerConfig, + TaskSpecConfig, + SingleTurnAgentLoopConfig, + SyncProduceStrategyConfig, + SamplerConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.data_proto import SampleParams +from xtuner.v1.data_proto.sequence_context import SequenceContext +from transformers import AutoTokenizer +import torch + +QWEN3_PATH = os.environ["QWEN3_PATH"] +ALPACA_PATH = os.environ["ALPACA_PATH"] +ROLLOUT_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] + + +class TestRLColocateTrainerIntegration(unittest.TestCase): + """Integration test for RLColocateTrainer with checkpoint save/resume.""" + + def setUp(self): + ray.init(num_cpus=80, num_gpus=8, ignore_reinit_error=True) + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + ray.shutdown() + + def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxkeep=2, auto_resume=False): + """Build RLColocateTrainerConfig for testing.""" + model_path = QWEN3_PATH + data_path = ALPACA_PATH + + # Resources + resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, + ) + + # Rollout config + rollout_config = RolloutConfig( + env="test_rl", + device="GPU", + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=1, + expert_parallel_size=1, + gpu_memory_utilization=0.5, + context_length=1536, + ) + + # Judger + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + + # Train worker + lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) + fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) + model_cfg = get_model_config_from_hf(Path(model_path)) + if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None + if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None + + optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) + loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type="vanilla", + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode="chunk", + chunk_size=512, + ) + + # SFT configs for WorkerConfig + sft_dataset_config = [{ + "dataset": DatasetConfig(name='alpaca', anno_path=data_path), + "tokenize_fn": OpenaiTokenizeFunctionConfig( + chat_template='qwen3', + max_length=32768 + ) + }] + sft_dataloader_cfg = DataloaderConfig( + dataset_config_list=sft_dataset_config, + pack_max_length=32768, + pack_to_max_length=True, + num_workers=0, + ) + sft_global_batch_size = 8 + sft_loss_cfg = CELossConfig(mode="chunk", chunk_size=1024, loss_reduction="square") + + train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=1, + optimizer_steps=1, + pack_max_length=2048, + sft_dataloader_cfg=sft_dataloader_cfg, + sft_global_batch_size=sft_global_batch_size, + sft_loss_cfg=sft_loss_cfg, + ) + + # Agent loop manager + train_dataset = DatasetConfig(name="test_rl", anno_path=ROLLOUT_DATA_PATH) + tokenizer_config = RLTextTokenizeFnConfig(max_length=512) + train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] + dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=2048, + collator="fake_collator", + pack_level="none", + ) + sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=2, + ) + training_sample_params = SampleParams( + max_tokens=512, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ) + agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, + ) + produce_strategy_config = SyncProduceStrategyConfig() + agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ) + ], + ) + + # Eval agent loop manager (minimal) + eval_sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=1, + ) + eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams(max_tokens=512, top_k=1, temperature=0.0), + ) + eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ) + ], + ) + + # Evaluator + evaluator_config = EvaluatorConfig(compute_metric_func=None) + + return RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + rollout_steps=2, + global_batch_size=4, + enable_evaluate=False, + enable_initial_evaluate=False, + work_dir=work_dir, + checkpoint_interval=checkpoint_interval, + checkpoint_maxkeep=checkpoint_maxkeep, + auto_resume=auto_resume, + seed=42, + debug_rollout=False, + ) + + def test_rl_train_with_sft(self): + """Test train_controller save/resume with efficient_attn_ratio verification.""" + work_dir = Path(self.temp_dir) / "work_dir_sft" + work_dir.mkdir(parents=True, exist_ok=True) + + # Build trainer to get train_controller + trainer_cfg = self.build_trainer_config( + work_dir=str(work_dir), + checkpoint_interval=1, + checkpoint_maxkeep=2, + auto_resume=False, + ) + trainer = trainer_cfg.build() + train_controller = trainer.train_controller + + # Prepare synthetic data batches + tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True) + + # Create simple prompts and responses + prompts = ["What is 2+2?", "What is the capital of France?"] + responses = [ + ["4", "Four", "2+2=4", "The answer is 4"], + ["Paris", "The capital is Paris", "Paris, France", "It's Paris"] + ] + + data_batches = [] + for prompt, response_list in zip(prompts, responses): + prompt_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].flatten().tolist() + rewards = torch.tensor([1.0, 0.8, 0.9, 0.7], dtype=torch.float32) + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + for i, response in enumerate(response_list): + response_ids = tokenizer(response, return_tensors='pt')['input_ids'].flatten().tolist() + # Align with RLColocateTrainer._prepare_train_data(): + # - input_ids excludes last token (usually eos) of response_ids + # - shifted_labels aligns to input_ids length + input_ids = prompt_ids + response_ids[:-1] + shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + input_ids_tensor = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) + shifted_labels_tensor = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) + + adv_val = advantages[i].item() + # Controller._packing expects `advantage` as a list and will flatten it. + # Keep the length consistent with shifted_labels/input_ids. + advantage_list = [adv_val] * (len(prompt_ids) - 1) + [adv_val] * len(response_ids) + + data_batches.append(dict( + seq_ctx=SequenceContext.from_input_ids((input_ids_tensor,), device="cpu"), + shifted_labels=shifted_labels_tensor, + advantage=advantage_list, + )) + + # RLColocateTrainer initializes by offloading train workers to CPU. + # Align with RLColocateTrainer.fit() which onloads before training. + ray.get(train_controller.onload.remote(target="all")) + + # First fit and save + ray.get(train_controller.fit.remote(data_batches, pack_max_length=1024, rollout_idx=0)) + checkpoint_path = str(work_dir / "save_test") + ray.get(train_controller.save.remote(checkpoint_path, no_save_optimizer=True)) + + # Second fit and collect metrics + ray.get(train_controller.onload.remote(target="all")) + log_infos = ray.get(train_controller.fit.remote(data_batches, pack_max_length=1024, rollout_idx=1)) + efficient_attn_ratio_list = [] + for log_info in log_infos: + efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) + self.assertTrue(all([ratio > 0 for ratio in efficient_attn_ratio_list])) + + # Kill and rebuild + ray.kill(train_controller) + del trainer + ray.shutdown() + # Re-init Ray with enough resources for AcceleratorResourcesConfig(num_workers=8, num_cpus_per_worker=4). + ray.init(num_cpus=80, num_gpus=8, ignore_reinit_error=True) + + trainer_cfg = self.build_trainer_config( + work_dir=str(work_dir), + checkpoint_interval=1, + checkpoint_maxkeep=2, + auto_resume=False, + ) + trainer = trainer_cfg.build() + train_controller = trainer.train_controller + + # Resume and verify + load_checkpoint_cfg = LoadCheckpointConfig( + checkpoint_path=checkpoint_path, + load_optimizer_states=False, + load_optimizer_args=False + ) + ray.get(train_controller.resume.remote(load_checkpoint_cfg)) + + ray.get(train_controller.onload.remote(target="all")) + log_infos = ray.get(train_controller.fit.remote(data_batches, pack_max_length=1024, rollout_idx=1)) + new_efficient_attn_ratio_list = [] + for log_info in log_infos: + new_efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) + + efficient_attn_ratio_list.sort() + new_efficient_attn_ratio_list.sort() + self.assertEqual(efficient_attn_ratio_list, new_efficient_attn_ratio_list) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_rollout.py b/tests/rl/test_rollout.py new file mode 100644 index 0000000000..31d9e8e5d4 --- /dev/null +++ b/tests/rl/test_rollout.py @@ -0,0 +1,159 @@ +import asyncio +import os +import subprocess +import unittest +import tempfile +import ray +import torch +from transformers import AutoTokenizer +import tempfile +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.data_proto.rl_data import Status, SampleParams, RolloutState +from xtuner.v1.rl.rollout import RolloutController + +TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +MOE_MODEL_PATH = os.environ["QWEN3_MOE_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} +class TestRollout(unittest.IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=8, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + self.max_prompt_length = 512 + self.max_response_length = 1024 + self.context_length = self.max_prompt_length + self.max_response_length + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.data_path = TRAIN_DATA_PATH + self.model_path = MODEL_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.init_config() + + def tearDown(self): + ray.shutdown() + # When lmdeploy enable ep>1, it uses deep_ep. Buffer implicit destroy would cause some ray actor stucked. + # Use pkill cleen up ray::WorkerWrapper process after close ray cluster connection as workaround. + # TODO(chenchiyu): add excplicit deep_ep destroy in lmdeploy. + self._cleanup_lmdeploy_ray_worker_wrapper() + self.temp_dir.cleanup() + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_parallel_rollout(self): + resource_config = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, # 8 GB + ) + pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="tp_pg") + pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="ep_pg") + dense_model_path = MODEL_PATH + moe_model_path = MOE_MODEL_PATH + dist_port_base = 38000 + async def run_both(): + return await asyncio.gather( + self._run_rollout(model_path=dense_model_path, tp_size=4, ep_size=1, pg=pg1, dist_port_base=dist_port_base), + self._run_rollout(model_path=moe_model_path, tp_size=1, ep_size=4, pg=pg2, dist_port_base=dist_port_base + 1024 * 4), + return_exceptions=False + ) + + asyncio.run(run_both()) + + def _cleanup_lmdeploy_ray_worker_wrapper(self): + try: + result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) + if result.returncode != 0: + print(f"pkill command failed with return code {result.returncode}: {result.stderr}." + " Maybe no lmdeploy ray::RayWorkerWrapper processes found.") + except Exception as e: + print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") + + async def _run_rollout(self, model_path, tp_size, ep_size, pg, dist_port_base): + rollout_config = RolloutConfig( + env="test_rollout", + model_path=model_path, + model_name=os.path.basename(model_path).lower(), + tokenizer_path=model_path, + tensor_parallel_size=tp_size, + expert_parallel_size=ep_size, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + dist_port_base=dist_port_base, + enable_return_routed_experts=ep_size > 1, # ep_size > 1 默认打开r3 + ) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + result_refs = [] + + # Test Case 1: 文本输入 + 文本输出 + # TODO(@duanyanhui): test prompt in and prompt out with v1/chat/completion api + # sample_params1 = SampleParams(return_token_ids=False) + # input1 = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params1) + # result1_ref = rollout_controller.generate.remote(rollout_state=input1) + # result_refs.append(result1_ref) + + # Test Case 2: 文本输入 + Token 输出 + sample_params2 = SampleParams(return_token_ids=True) + input2 = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params2) + result2_ref = rollout_controller.generate.remote(rollout_state=input2) + result_refs.append(result2_ref) + + # Test Case 3: Token 输入 + Token 输出 + text_prompt = self.tokenizer.apply_chat_template(TEST_TEXT_MESSAGES, tokenize=False, add_generation_prompt=True) + input_tokens = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] + sample_params3 = SampleParams(return_token_ids=True) + input3 = RolloutState(message=TEST_TEXT_MESSAGES, tokens=input_tokens, sample_params=sample_params3) + result3_ref = rollout_controller.generate.remote(rollout_state=input3) + result_refs.append(result3_ref) + + try: + results = await asyncio.wait_for(asyncio.gather(*result_refs), timeout=300) + for i, result in enumerate(results): + case_id = f"Case {i+1}" + self.assertEqual(result.status, Status.COMPLETED, + msg=f"{case_id} failed: Expected status COMPLETED but got {result.status}") + self.assertEqual(result.finish_reason, 'stop', + msg=f"{case_id} failed: Expected finish_reason 'stop' but got {result.finish_reason}") + + if result.sample_params.return_token_ids: + self.assertGreater(len(result.response_ids), 0, + msg=f"{case_id} failed: response_ids should not be empty when return_token_ids is True") + + if result.sample_params.return_logprob: + self.assertEqual(len(result.logprobs), len(result.response_ids), + msg=f"{case_id} failed: logprobs length ({len(result.logprobs)}) " + f"does not match response_ids length ({len(result.response_ids)})") + + except asyncio.TimeoutError: + if tp_size > 1 and ep_size == 1: + self.fail("TP and Dense Rollout timed out!") + if ep_size > 1 and tp_size == 1: + self.fail("EP and MoE Rollout timed out!") + finally: + await asyncio.wait_for(rollout_controller.shutdown.remote(), timeout=300) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_rollout_utils.py b/tests/rl/test_rollout_utils.py new file mode 100644 index 0000000000..749dd16892 --- /dev/null +++ b/tests/rl/test_rollout_utils.py @@ -0,0 +1,98 @@ +import ray +import torch +import threading +import time +import unittest +import os +import tempfile +from types import SimpleNamespace +from unittest.mock import patch + +from xtuner.v1.data_proto.rl_data import Status, RolloutState, SampleParams +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.rollout.controller import RolloutController, WorkerInfo +from xtuner.v1.rl.rollout.utils import RolloutHealthChecker, SessionRouter +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers, asyncio_run + +MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH", "") +RESOURCE_MAP = {"npu": "NPU", "cuda": "GPU"} +TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] + +class TestRolloutControllerRecover(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def setUp(self): + ray.init(num_cpus=80, address="local", ignore_reinit_error=True) + self.model_path = MODEL_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + def init_rollout_controller(self): + resource_cfg = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=1, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, + ) + pg = AutoAcceleratorWorkers.build_placement_group(resource_cfg, name="recover_test_pg") + rollout_cfg = RolloutConfig( + env="test_rollout_utils", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + tensor_parallel_size=1, + expert_parallel_size=1, + worker_log_dir=self.temp_dir.name, + context_length=8192, + health_check_interval_seconds=10, + health_check_failure_threshold=1, + ) + controller = RolloutController(rollout_cfg, pg) + return controller + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_healthcheck_deactivate_and_recover(self): + controller = self.init_rollout_controller() + ranks = list(controller.rank2info.keys()) + rank0 = ranks[0] + actor0 = controller.rank2info[rank0].actor + ray.get(actor0.shutdown.remote()) + time.sleep(3) # wait for the actor to be fully killed + health_before_recover = ray.get(actor0.check_health.remote()) + url = controller.rank2info[rank0].url + self.assertFalse(health_before_recover) + + controller.health_checker.run_once() + + self.assertFalse(controller.rank2info[rank0].is_active) + rollout_state = RolloutState( + message=TEST_TEXT_MESSAGES, + sample_params=SampleParams(return_token_ids=True), + ) + out = asyncio_run(controller.generate(rollout_state)) + self.assertEqual(out.status, Status.FAILED) + + controller.recover_failed_workers() + + self.assertTrue(controller.rank2info[rank0].is_active) + self.assertEqual(url, controller.rank2info[rank0].url) + health_after_recover = ray.get(actor0.check_health.remote()) + self.assertTrue(health_after_recover) + out = asyncio_run(controller.generate(rollout_state)) + self.assertNotEqual(out.status, Status.FAILED) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/ray/test_update_weight.py b/tests/rl/test_update_weight.py similarity index 86% rename from tests/ray/test_update_weight.py rename to tests/rl/test_update_weight.py index fa008d3d71..d3b8a09480 100644 --- a/tests/ray/test_update_weight.py +++ b/tests/rl/test_update_weight.py @@ -3,17 +3,17 @@ import tempfile import ray -from xtuner.v1.ray.rollout import RolloutController -from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.data_proto.rl_data import SampleParams, RolloutState from xtuner.v1.config import ( AdamWConfig, FSDPConfig, LRConfig, ) -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.rl.base import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.grpo.loss import GRPOLossConfig as LossConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.rl.trainer import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker +from xtuner.v1.rl.loss import GRPOLossConfig as LossConfig from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense4BConfig TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}] @@ -121,10 +121,11 @@ def test_lmdeploy_update_weight_and_generate(self): self.pg, ) - res_baseline = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) + input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) + res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) # start update weight test - info_dict = ray.get(rollout_controller.get_rollout_info.remote()) + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) ray.get(train_controller.update_rollout_info.remote(info_dict)) # update weights @@ -136,7 +137,7 @@ def test_lmdeploy_update_weight_and_generate(self): ray.get(train_controller.offload.remote(["model"])) ray.get(rollout_controller.onload_kvcache.remote()) - res_update_weight = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) + res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) self.assertEqual(res_update_weight.response, res_baseline.response) ray.get(rollout_controller.shutdown.remote(), timeout=60) diff --git a/tests/ray/test_utils.py b/tests/rl/test_utils.py similarity index 97% rename from tests/ray/test_utils.py rename to tests/rl/test_utils.py index ca516469b8..127198440e 100644 --- a/tests/ray/test_utils.py +++ b/tests/rl/test_utils.py @@ -5,7 +5,7 @@ -from xtuner.v1.ray.utils import find_master_addr_and_port, get_accelerator_ids, get_ray_accelerator +from xtuner.v1.rl.utils.ray_utils import find_master_addr_and_port, get_accelerator_ids, get_ray_accelerator import parametrize diff --git a/tests/rl/test_vl_rollout.py b/tests/rl/test_vl_rollout.py new file mode 100644 index 0000000000..951ea75037 --- /dev/null +++ b/tests/rl/test_vl_rollout.py @@ -0,0 +1,160 @@ +import os +import subprocess +import unittest +import tempfile +import ray +import torch +from transformers import AutoTokenizer +import tempfile +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.datasets.rl_tokenize_fn import RLQwen3VLTokenizeFnConfig +import asyncio +from xtuner.v1.rl.rollout import RolloutController + + +MODEL_PATH=os.getenv("QWEN3_VL_DENSE_PATH") +MOE_MODEL_PATH=os.getenv("QWEN3_VL_MOE_PATH") +MEDIA_ROOT=os.getenv("GEO3K_MEDIA_ROOT") + +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} +class TestVLMRollout(unittest.IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=8, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + self.max_prompt_length = 1024 + self.max_response_length = 2048 + self.context_length = self.max_prompt_length + self.max_response_length + + tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + tokenize_fn = RLQwen3VLTokenizeFnConfig(processor_path=self.model_path, max_length=self.max_prompt_length) + self.tokenize_fn = tokenize_fn.build(tokenizer) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.model_path = MODEL_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.init_config() + + def tearDown(self): + ray.shutdown() + # When lmdeploy enable ep>1, it uses deep_ep. Buffer implicit destroy would cause some ray actor stucked. + # Use pkill cleen up ray::WorkerWrapper process after close ray cluster connection as workaround. + # TODO(chenchiyu): add excplicit deep_ep destroy in lmdeploy. + self._cleanup_lmdeploy_ray_worker_wrapper() + self.temp_dir.cleanup() + + def _cleanup_lmdeploy_ray_worker_wrapper(self): + try: + result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) + if result.returncode != 0: + print(f"pkill command failed with return code {result.returncode}: {result.stderr}." + " Maybe no lmdeploy ray::RayWorkerWrapper processes found.") + except Exception as e: + print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") + + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_parallel_rollout(self): + resource_config = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, # 8 GB + ) + pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="tp_pg") + pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="ep_pg") + dense_model_path = MODEL_PATH + moe_model_path = MOE_MODEL_PATH + dist_port_base = 38000 + async def run_both(): + return await asyncio.gather( + self._run_rollout(model_path=dense_model_path, tp_size=4, ep_size=1, pg=pg1, dist_port_base=dist_port_base), + # self._run_rollout(model_path=moe_model_path, tp_size=1, ep_size=4, pg=pg2, dist_port_base=dist_port_base + 1024 * 4), # TODO: lmdeploy 修复后启动 + return_exceptions=False + ) + + asyncio.run(run_both()) + + async def _run_rollout(self, model_path, tp_size, ep_size, pg, dist_port_base): + rollout_config = RolloutConfig( + env="test_rollout", + model_path=model_path, + model_name=os.path.basename(model_path).lower(), + tokenizer_path=model_path, + tensor_parallel_size=tp_size, + expert_parallel_size=ep_size, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + dist_port_base=dist_port_base, + enable_return_routed_experts=ep_size > 1, # ep_size > 1 默认打开r3 + ) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + result_refs = [] + + # Test Case 1: 纯文本 + rollout_state = self.tokenize_fn({'prompt':[{"role": "user", "content": "Hello!"}]}) + result1_ref = rollout_controller.generate.remote(rollout_state=rollout_state) + result_refs.append(result1_ref) + + # Test Case 2: 图片 + input_data = {"prompt": [{"content": [{"image_url": {"image_wh": [297, 265], "url": "images/test_0.jpg"}, "type": "image_url"}, {"text": "Chords $\\overline{A C}$ and $\\overline{D F}$ are equidistant from the center. If the radius of $\\odot G$ is 26 find $A C$ You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \\boxed{}.", "type": "text"}], "role": "user"}], "data_source": "hiyouga/geometry3k", "ability": "math", "reward_model": {"ground_truth": "48", "style": "rule"}} + rollout_state = self.tokenize_fn(input_data, media_root=MEDIA_ROOT) + rollout_state.tokens = rollout_state.prompt_ids + result2_ref = rollout_controller.generate.remote(rollout_state=rollout_state) + result_refs.append(result2_ref) + + try: + results = await asyncio.wait_for(asyncio.gather(*result_refs), timeout=300) + for i, result in enumerate(results): + case_id = f"Case {i+1}" + self.assertEqual(result.status, Status.COMPLETED, + msg=f"{case_id} failed: Expected status COMPLETED but got {result.status} and error_msg {result.error_msg}") + self.assertEqual(result.finish_reason, 'stop', + msg=f"{case_id} failed: Expected finish_reason 'stop' but got {result.finish_reason}") + + if result.sample_params.return_token_ids: + self.assertGreater(len(result.response_ids), 0, + msg=f"{case_id} failed: response_ids should not be empty when return_token_ids is True") + + if result.sample_params.return_logprob: + self.assertEqual(len(result.logprobs), len(result.response_ids), + msg=f"{case_id} failed: logprobs length ({len(result.logprobs)}) " + f"does not match response_ids length ({len(result.response_ids)})") + + except asyncio.TimeoutError: + if tp_size > 1 and ep_size == 1: + self.fail("TP and Dense Rollout timed out!") + if ep_size > 1 and tp_size == 1: + self.fail("EP and MoE Rollout timed out!") + finally: + await asyncio.wait_for(rollout_controller.shutdown.remote(), timeout=300) + + # @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + # def test_vl_resume_with_partial_rollout(self): + # # TODO: 后续实现 + # pass + + +if __name__ == "__main__": + unittest.main() diff --git a/xtuner/v1/data_proto/__init__.py b/xtuner/v1/data_proto/__init__.py index c30af9de46..a372717ba7 100644 --- a/xtuner/v1/data_proto/__init__.py +++ b/xtuner/v1/data_proto/__init__.py @@ -1,6 +1,22 @@ +from .rl_data import ( + RolloutFunctionCall, + RolloutState, + RolloutToolCall, + SampleParams, + Status, + update_expired_status, + update_seq_staleness, +) from .sequence_context import SequenceContext __all__ = [ + "RolloutFunctionCall", "SequenceContext", + "RolloutState", + "RolloutToolCall", + "SampleParams", + "Status", + "update_seq_staleness", + "update_expired_status", ] diff --git a/xtuner/v1/data_proto/messages/chat.py b/xtuner/v1/data_proto/messages/chat.py index 69a098d252..fcc4adc64a 100644 --- a/xtuner/v1/data_proto/messages/chat.py +++ b/xtuner/v1/data_proto/messages/chat.py @@ -6,10 +6,9 @@ from pydantic import BaseModel, ConfigDict from transformers import PreTrainedTokenizer -from xtuner.utils import IGNORE_INDEX from xtuner.v1.data_proto.messages.base import BaseMessages from xtuner.v1.data_proto.templates import ChatTemplate, HybridChatTemplate -from xtuner.v1.utils import get_logger +from xtuner.v1.utils import IGNORE_INDEX, get_logger logger = get_logger() diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index 31caf06bae..77bd012517 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -1,450 +1,289 @@ from __future__ import annotations -import copy -from typing import TYPE_CHECKING, Any, TypeAlias +import base64 +from enum import Enum +from typing import TYPE_CHECKING, Any, Literal, TypeAlias +import numpy as np import torch -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, Field -from typing_extensions import Annotated, NotRequired, Self, TypedDict - -from xtuner.v1.utils import StrEnum +from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator +from typing_extensions import NotRequired, TypedDict # ==================================== # ====== DataFlow 数据流 ============== # ==================================== +from xtuner.v1.utils.cache import CacheObj from xtuner.v1.utils.logger import get_logger if TYPE_CHECKING: - import ray - - RayObjectRef = ray.ObjectRef + from ray import ObjectRef as RayObjectRef else: RayObjectRef: TypeAlias = Any logger = get_logger() -class RolloutState(StrEnum): - """ - - 1. State Transitions from finish_reason and RolloutState: - - A new task starts as `INIT`. - - A successful generation (finish_reason 'stop' or 'length') becomes `COMPLETED`. - - A generation stopped by the dataflow (e.g., for partial rollout) becomes `ABORTED`. - - A generation that fails due to an inference server error becomes `FAILED`. - - A generation skipped due to client errors or timeout errors (e.g., invalid input) becomes `SKIPPED`. - - Data used for training is marked as `ARCHIVED`. - - Old data (rollout for morn than expiration step) in the replay buffer is marked as `EXPIRED`. - - 2. Dataflow Handling Based on RolloutState: - - `INIT`: Data is in progress; no special handling. - - `COMPLETED`: Data is valid for filtering, replay buffer insertion and training. - - `ABORTED`: Data may be partially valid; It's valid for replay buffer insertion but not for filtering and training. - - `FAILED`: Data is invalid; not used for filtering, replay buffer or training. - - `SKIPPED`: Data is invalid; not used for filtering, replay buffer or training. - - `ARCHIVED`: Data is stored for historical purposes; not used for training. - - `EXPIRED`: Data is removed from the replay buffer; not used for training. - """ - +class SampleParams(BaseModel): + model_config = ConfigDict(extra="forbid") + n: int = 1 + top_k: int = 0 + top_p: float = 1.0 + temperature: float = 1.0 + repetition_penalty: float = 1.0 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + min_tokens: int = 0 + max_tokens: int = 2048 + stops: list[str] = [] + stop_token_ids: list[int] = [] + skip_special_tokens: bool = True + stream: bool = False + return_logprob: bool = True + top_logprobs: int = 1 + return_token_ids: bool = True + include_stop_str_in_output: bool = True + no_stop_trim: bool = True + spaces_between_special_tokens: bool = False + return_routed_experts: bool = False + + +class Status(Enum): INIT = "init" COMPLETED = "completed" ABORTED = "aborted" + EXPIRED = "expired" FAILED = "failed" + FILTERED = "filtered" + # 归档,这个状态还是要保留,用不用再说,用于表示这个数据已经用于一次训练了,但保留在数据库里以备查询 ARCHIVED = "archived" - EXPIRED = "expired" - SKIPPED = "skipped" - - @staticmethod - def from_str(state_str: str) -> RolloutState: - for state in RolloutState: - if state.value == state_str: - return state - raise ValueError(f"Unknown ReplayState string: {state_str}") -class RLUIDItem(BaseModel): - """A unique identifier for tracking data items within the dataflow. +class MultimodalInfo(TypedDict): + # 使用TypedDict给出pixel_values的类型提示 + pixel_values: NotRequired[np.ndarray | RayObjectRef | None] + image_grid_thw: NotRequired[torch.Tensor] - Attributes: - env (str): The environment name. - root_id (int): The root ID for grouping related data items. - action_id (int): The ID for a specific action in prompt. - observation_id (int): The ID for a specific observation in response. - version (int): The version number of the data item. - """ +class RolloutFunctionCall(BaseModel): model_config = ConfigDict(extra="forbid") - env: str = "" - root_id: int = -1 - action_id: int = -1 - observation_id: int = -1 - version: int = 0 + name: str + arguments: Any = Field(default_factory=dict) + raw_arguments_text: str | None = None -class MultimodalTrainInfo(TypedDict): - pixel_values: NotRequired[torch.Tensor | RayObjectRef | None] # type: ignore[valid-type] - image_grid_thw: NotRequired[torch.Tensor] - position_ids: NotRequired[torch.Tensor] +class RolloutToolCall(BaseModel): + model_config = ConfigDict(extra="forbid") -class RLDatasetItem(BaseModel): - """Represents the data structure output from the dataset. + id: str + type: Literal["function"] = "function" + function: RolloutFunctionCall - Attributes: - messages (Optional[List[Dict[str, Any]]]): The message list for the prompt. - input_ids (Optional[List[int]]): The tokenized input IDs. - num_tokens (Optional[int]): The number of tokens in the input. - ability (Optional[str]): The ability or category of the data. - reward_model (Optional[Dict[str, Any]]): Data required by the reward model, like ground truth. - data_source (Optional[Dict[str, Any]]): The source of the data, used for weighting rewards. - extra_info (Dict[str, Any]): Additional user-defined information. - """ +class RolloutState(CacheObj, BaseModel): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - messages: list[dict[str, Any]] | None = None - input_ids: list[int] | None = None - num_tokens: int | None = None - ability: str | None = None - reward_model: dict[str, Any] | None = None - data_source: dict[str, Any] | None = None - extra_info: dict[str, Any] = dict() - multimodal_train_info: MultimodalTrainInfo | None = None + # --- 数据 --- + message_uid: int | None = None # 通过计算原始的message的哈希值得到的id,一组的数据为同一个prompt_id + message: list[dict[str, Any]] # dataset输出,需要在AgentLoop中转换成input_ids + prompt_ids: list[int] | None = None # 原始 prompt的token ids + data_source: dict[str, Any] | str | None = None + mm_info: MultimodalInfo | None = None + reward_model: dict[str, Any] | None = None -class RolloutExtraInfo(TypedDict): - routed_experts: NotRequired[list[int] | RayObjectRef] # type: ignore[valid-type] - partial_rollout_input_ids: NotRequired[list[int]] - - -class RLRolloutResponseItem(BaseModel): - """Represents the data structure output from the rollout process. - - Attributes: - response (Optional[str]): The generated text response from the model. - response_ids (Optional[List[int]]): The token IDs of the generated response. - num_return_tokens (Optional[int]): The number of tokens in the response. - finish_reason (Optional[str]): The reason why the generation finished (e.g., 'stop', 'length'). - logprobs (Optional[List[float]]): The log probabilities of the generated tokens. - extra_info (Dict[str, Any]): Additional user-defined information. - """ + # --- InferEngine 输入 --- + session_uid: int | None = None + tokens: list[int] | None = None # 每一次推理引擎的实际输入 + tools: list | None = None + tool_choice: str | dict[str, Any] | None = None + sample_params: SampleParams = SampleParams() - model_config = ConfigDict(extra="forbid") + # --- InferEngine 输出 --- + # 每一次推理引擎的实际输出, 在rollout worker中被覆盖写 response: str | None = None + tool_calls: list[RolloutToolCall] | None = None response_ids: list[int] | None = None logprobs: list[float] | None = None - num_return_tokens: int | None = None - versioned_response: list[str] = Field(default_factory=list) - versioned_response_ids: list[list[int]] = Field(default_factory=list) - versioned_logprobs: list[list[float]] = Field(default_factory=list) - versioned_num_return_tokens: list[int] = Field(default_factory=list) - finish_reason: str | None = None # "stop", "length", "abort", "failed", "skipped" - extra_info: RolloutExtraInfo = Field(default_factory=dict) - state: RolloutState = RolloutState.INIT - - def _update_by_append(self, other: Self) -> None: - other_ids_copy = copy.deepcopy(other.response_ids) - other_logprobs_copy = copy.deepcopy(other.logprobs) - other_response_copy = copy.deepcopy(other.response) - if other_response_copy is not None: - assert self.response is not None, "response must not be None when updating partial data." - self.response += other_response_copy - self.versioned_response.append(other_response_copy) - - if other_ids_copy is not None: - assert self.response_ids is not None, "response_ids must not be None when updating partial data." - self.response_ids.extend(other_ids_copy.copy()) - self.versioned_response_ids.append(other_ids_copy) - self.versioned_num_return_tokens.append(len(other_ids_copy)) - - if other_logprobs_copy is not None: - assert self.logprobs is not None, "logprobs must not be None when updating partial data." - self.logprobs.extend(other_logprobs_copy.copy()) - self.versioned_logprobs.append(other_logprobs_copy) - - self.num_return_tokens = len(self.response_ids) if self.response_ids is not None else 0 - self.finish_reason = other.finish_reason - self.extra_info.update(other.extra_info) - self.state = other.state - return - - def update(self, other: Self) -> None: - """Updates this RLRolloutResponseItem with data from another one. - - If partial_rollout is True, concat other response to this RLRolloutResponseItem's response. - """ - if not isinstance(other, RLRolloutResponseItem): - raise TypeError("Can only update with another RLRolloutResponseItem instance.") - - if other.response_ids is None and other.logprobs is None and other.response is None: - self.finish_reason = other.finish_reason - self.state = other.state - self.extra_info.update(other.extra_info) - return - - if self.response_ids is None: - assert self.response is None and self.logprobs is None, ( - "Inconsistent state: if response_ids is None, response and logprobs must also be None." - ) - self.response = "" - self.response_ids = [] - self.logprobs = [] - self.num_return_tokens = 0 - else: - assert self.response is not None and self.logprobs is not None, ( - "Inconsistent state: if response_ids is not None, response and logprobs must also be not None." - ) - - self._update_by_append(other) - - -class RLJudgerResponseItem(BaseModel): - """Represents the data structure output from the judger. - - Attributes: - uid (Optional[int]): A unique ID to identify which input the result corresponds to. - reward (Dict[str, Any]): A dictionary of reward scores, e.g., {"judger_type": reward_score, "weighted_scores": score}. - extra_info (Dict[str, Any]): Additional user-defined information. - """ - - model_config = ConfigDict(extra="forbid") + routed_experts: list[int] | RayObjectRef | None = None + finish_reason: str | None = None + # response_mask: 记录response_ids中哪个token算loss, 与response_ids长度相同,每轮rollout在 agent_loop.generate 中覆盖写 + response_mask: list[int] | None = None + # response_rollout_steps:记录 response_ids 中每个 token 是在哪个 rollout_step 生成的,与 response_ids 长度相同,每轮rollout在agent_loop中后处理中覆盖写 + response_rollout_steps: list[int] | None = None + # 记录该样本过期程度,即最先生成的token与当前的训练步数的差值,数值越大表示越过期,在 agent_loop 中后处理中覆盖写 + seq_staleness: int = 0 + + # --- Judger 输出 --- + reward: dict[str, Any] | None = None + + # --- 状态 --- uid: int | None = None - reward: dict[str, Any] = Field(default_factory=lambda: {"score": 0.0, "val": 0.0}) - extra_info: dict[str, Any] = dict() - - -class RLAgentDataItem(BaseModel): - # todo: define agent output data structure - model_config = ConfigDict(extra="forbid") - extra_info: dict[str, Any] = dict() - - -class RLEnvDataItem(BaseModel): - """Contains the internal data structures of the environment, stored as an - observation. - - Attributes: - rollout (RLRolloutResponseItem): Data from the rollout stage. - judger (RLJudgerResponseItem): Data from the judger stage. - agent (RLAgentDataItem): Data from the agent stage. - extra_info (Dict[str, Any]): Additional user-defined information. - """ - - model_config = ConfigDict(extra="forbid") - rollout: RLRolloutResponseItem = RLRolloutResponseItem() - judger: RLJudgerResponseItem = RLJudgerResponseItem() - agent: RLAgentDataItem = RLAgentDataItem() - extra_info: dict[str, Any] = dict() - - -class RLExtraDataItem(BaseModel): - """Reserved for data that does not belong to a specific stage of the - dataflow. - - Attributes: - retry_times (int): The number of times the data processing has been retried. - extra_info (Dict[str, Any]): Additional user-defined information. - """ - - model_config = ConfigDict(extra="forbid") - retry_times: int = 0 - extra_info: dict[str, Any] = dict() - - -class RLDataFlowItem(BaseModel): - """The core data structure that flows through the dataflow and environment. - - It encapsulates all information related to a single data point, including its - unique ID, the original data, environment outputs, and extra metadata. - - Attributes: - uid (RLUIDItem): The unique identifier for the data item. - data (RLDatasetItem): The original data from the dataset. - env (RLEnvDataItem): The collected outputs from the environment stages. - extra_info (RLExtraDataItem): Additional reserved information. - """ - - model_config = ConfigDict(extra="forbid") - uid: RLUIDItem = RLUIDItem() - data: RLDatasetItem = RLDatasetItem() - env: RLEnvDataItem = RLEnvDataItem() - extra_info: RLExtraDataItem = RLExtraDataItem() - - -def is_valid_for_replaybuffer(group_data_items: list[RLDataFlowItem]) -> bool: - """Checks if a group of data items is valid for insertion into the replay - buffer. + task_name: str | None = None + status: Status = Status.INIT + error_msg: str | None = None + position_ids: torch.Tensor | None = None + extra_fields: dict[str, Any] = {} + + @field_serializer("routed_experts") + def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | str | None: + """序列化 routed_experts 字段: + + - None -> None + - list[int] -> list[int](原样保留) + - RayObjectRef -> base64 编码的字符串(通过 ray.cloudpickle 序列化) + """ + import ray + + if value is None: + return None + if isinstance(value, ray.ObjectRef): + data = ray.cloudpickle.dumps(value) + return base64.b64encode(data).decode("utf-8") + return value + + @field_validator("routed_experts", mode="before") + @classmethod + def _deserialize_routed_experts(cls, value: Any) -> list[int] | RayObjectRef | None: + """反序列化 routed_experts 字段: + + - None -> None + - list[int] -> list[int](原样保留) + - str(base64 编码)-> RayObjectRef(通过 ray.cloudpickle 反序列化) + - RayObjectRef -> RayObjectRef(原样保留) + """ + import ray + + if value is None: + return None + if isinstance(value, ray.ObjectRef): + return value + if isinstance(value, str): + data = base64.b64decode(value) + return ray.cloudpickle.loads(data) + if isinstance(value, list): + return value + return value + + @field_serializer("mm_info") + def _serialize_mm_info(self, value: MultimodalInfo | None) -> MultimodalInfo | None: + # TODO: Not currently needed + return None + + +def update_status_from_finish_reason(finish_reason: str | None) -> Status: + """Updates the internal status based on the inference engine's finish + reason. + + State Transition Logic: + ------------------------------------------------------------- + | Finish Reason (Input) | Internal Status (Output) | + | :----------------------------- | :----------------------- | + | `stop`, `length`, `tool_calls` | `Status.COMPLETED` | + | `abort` | `Status.ABORTED` | + | `error` or `None` | `Status.FAILED` | + | *Others* | *Raises ValueError* | + ------------------------------------------------------------- Args: - group_data_items: A list of RLDataFlowItem objects. + finish_reason (str | None): The raw finish reason string returned by + the inference engine (e.g., vLLM, LMDeploy). - Returns: - True if the group is valid, False otherwise. - - NOTE: Why this check is needed: - - For system fault tolerance, this check is performed at rollout / dataflow - time, but we still do it here to ensure replay buffer data integrity. - - 'skipped' or 'failed' states indicate that the rollout process did not - complete successfully or was intentionally bypassed. - - 'aborted' states may still contain useful data for the replay buffer, - as the rollout was started but not finished. - - 'completed' states are valid and should be included in the replay buffer. + Raises: + ValueError: If the ``finish_reason`` is unknown and cannot be mapped. """ - is_skipped = any(item.env.rollout.state == RolloutState.SKIPPED for item in group_data_items) - is_failed = any(item.env.rollout.state == RolloutState.FAILED for item in group_data_items) - if is_skipped or is_failed: - logger.warning( - "Invalid dataflow group found during replay buffer insertion, skipped: {is_skipped}, failed: {is_failed}." - ) - return False - return True - - -def is_valid_for_training(group_data_items: list[RLDataFlowItem]) -> bool: - """Checks if a group of data items is valid for a training step. + if finish_reason is None: + logger.error("finish_reason is None, setting status to FAILED.") + return Status.FAILED + + reason = finish_reason.lower() + if reason in ("stop", "length", "tool_calls"): + return Status.COMPLETED + elif reason == "abort": + return Status.ABORTED + elif reason == "error": + logger.warning("finish_reason is 'error', setting status to FAILED.") + return Status.FAILED + else: + logger.error(f"finish_reason '{finish_reason}' is unknown, setting status to FAILED.") + return Status.FAILED + + +def update_group_status(rollout_states: list[RolloutState]) -> Status: + """Updates the group status based on the individual rollout states. + + Group Status Logic: + ------------------------------------------------------------- + | Individual Rollout States | Group Status (Output) | + | :----------------------------- | :----------------------- | + | All `Status.COMPLETED` | `Status.COMPLETED` | + | Any `Status.FAILED` | `Status.FAILED` | + | Any `Status.ABORTED` | `Status.ABORTED` | + | Any `Status.EXPIRED` | `Status.EXPIRED` | + | Any `Status.FILTERED` | `Status.FILTERED` | + | *Others* | *Determined by priority*| + ------------------------------------------------------------- + + Priority Order (from highest to lowest): + 1. FAILED + 2. ABORTED + 3. EXPIRED + 4. FILTERED + 5. COMPLETED Args: - group_data_items: A list of RLDataFlowItem objects. + rollout_states (list[RolloutState]): A list of individual rollout states. Returns: - True if the group is valid, False otherwise. - - NOTE: Why this check is needed: - - For system fault tolerance, this check is performed at rollout / dataflow - time, but we still do it here to ensure training data integrity. - - 'skipped'/'failed': These items are fundamentally broken or incomplete and - should not be used for training. - - 'aborted': These items represent rollouts that were stopped - prematurely. Using such partial data could lead the model to learn - undesirable behaviors (e.g., stopping generation too early). - - Empty response/response_ids: The model's generated response is the core - of the training data for RL algorithms like PPO. If the response is - missing, there is nothing to compute rewards on or to train the model with. + Status: The aggregated group status based on the individual states. """ - is_abort = any(item.env.rollout.state == RolloutState.ABORTED for item in group_data_items) - is_skipped = any(item.env.rollout.state == RolloutState.SKIPPED for item in group_data_items) - is_failed = any(item.env.rollout.state == RolloutState.FAILED for item in group_data_items) - if is_skipped or is_failed or is_abort: - logger.debug( - f"Invalid dataflow group found during training, rollout state skipped: {is_skipped}, failed: {is_failed}, aborted: {is_abort}." - ) - return False - for item in group_data_items: - rollout_info = item.env.rollout - response_valid = True if rollout_info.response is not None and len(rollout_info.response) > 0 else False - ids_valid = True if rollout_info.response_ids is not None and len(rollout_info.response_ids) > 0 else False - if not ids_valid: - # NOTE: `response_ids` is the critical field for token-in-token-out mode, so we ensure it's not empty. - logger.warning( - "Invalid dataflow item found during training: no response or response_ids and skip this item." + if all(state.status == Status.COMPLETED for state in rollout_states): + return Status.COMPLETED + elif any(state.status == Status.FAILED for state in rollout_states): + return Status.FAILED + elif any(state.status == Status.ABORTED for state in rollout_states): + return Status.ABORTED + elif any(state.status == Status.EXPIRED for state in rollout_states): + return Status.EXPIRED + elif any(state.status == Status.FILTERED for state in rollout_states): + return Status.FILTERED + else: + # If there are other statuses, we can determine the group status based on a defined priority order. + # For now, we will default to COMPLETED if none of the above conditions are met. + return Status.COMPLETED + + +def update_seq_staleness(rollout_state: RolloutState, rollout_step: int) -> RolloutState: + """计算 response_rollout_steps 列表,表示 rollout_state.response_ids 中的每个 token + 是在哪个 rollout_step 生成的。""" + response_len = len(rollout_state.response_ids or []) + response_rollout_steps = [rollout_step] * response_len + rollout_state.response_rollout_steps = (rollout_state.response_rollout_steps or []) + response_rollout_steps + + cur_rollout_steps = min(rollout_state.response_rollout_steps, default=rollout_step) + rollout_state.seq_staleness = rollout_step - cur_rollout_steps + return rollout_state + + +def update_expired_status(samples: list[RolloutState], tail_batch_stale_threshold: int = 0) -> list[RolloutState]: + if tail_batch_stale_threshold <= 0: + return samples + is_group_expired = False + + # 1. 检查组内是否存过期的样本 + for sample in samples: + if sample.status == Status.ABORTED and sample.seq_staleness >= tail_batch_stale_threshold: + logger.debug( + f"Sample {sample.uid} (seq_staleness: {sample.seq_staleness}) exceeded threshold ({tail_batch_stale_threshold}). Triggering group expiration." ) - return False - if not response_valid: - # NOTE: check valid response string for judger inputs - logger.warning("Invalid dataflow item found during training: empty response string and skip this item.") - return False - return True - + is_group_expired = True + break # 一旦发现过期,直接跳出,无需检查剩余样本 -def update_rollout_item(group_data_items, target_value): - """Update a list of RLDataFlowItem objects by merging another - RLRolloutResponseItem into each item's env.rollout attribute. + # 2. 如果存在过期样本,将组内所有样本置为过期 + if is_group_expired: + # NOTE: 当一组数据中有一个样本被标记为过期后,这组数据中就可能出现未超过过期阈值但状态是 aborted 的样本。 + # 这些样本在后续的生成过程中也不应该被继续生成了,所以直接把它们都标记为过期, 才能在preprocess中将之前的response清掉。 + for sample in samples: + sample.status = Status.EXPIRED - Args: - group_data_items (List[RLDataFlowItem]): List of data items to update. - target_value (List[RLRolloutResponseItem]): The rollout response item to merge into each data item. - - Returns: - List[RLDataFlowItem]: The updated list of data items. - - Example: - >>> # Suppose you want to update the rollout response for each item - >>> items = [RLDataFlowItem(), RLDataFlowItem()] - >>> rollout_response = RLRolloutResponseItem(response="new response", response_ids=[1,2,3]) - >>> update_rollout_item(items, rollout_response) - # Now each item's env.rollout has been updated with the new response and response_ids - """ - - for idx, item in enumerate(group_data_items): - item.env.rollout.update(target_value[idx]) - - return group_data_items - - -def update_dataflow_item(group_data_items, target_key, target_value): - """Update a list of RLDataFlowItem objects by setting a nested attribute - for each item. - - Args: - group_data_items (List[RLDataFlowItem]): List of data items to update. - target_key (str): Dot-separated path to the attribute to update (e.g., 'env.rollout.response'). - target_value (List[Any]): List of values to set, one for each data item. - - Returns: - List[RLDataFlowItem]: The updated list of data items. - - Example: - >>> # Suppose you want to update the 'response' field in env.rollout for each item - >>> items = [RLDataFlowItem(), RLDataFlowItem()] - >>> responses = ["hello", "world"] - >>> update_dataflow_item(items, "env.rollout.response", responses) - # Now items[0].env.rollout.response == "hello", items[1].env.rollout.response == "world" - """ - - group_length = len(group_data_items) - assert group_length == len(target_value) - - keys = target_key.split(".") - - for i in range(group_length): - parent_obj = group_data_items[i] - for key in keys[:-1]: - parent_obj = getattr(parent_obj, key) - setattr(parent_obj, keys[-1], target_value[i]) - - return group_data_items - - -# ============================================== -# ====== Rollout API Server 数据流 ============== -# ============================================== - - -class SampleParams(BaseModel): - model_config = ConfigDict(extra="forbid") - n: Annotated[int, Parameter(help="Number of samples to generate.")] = 1 - top_k: Annotated[ - int, Parameter(help="The number of highest probability vocabulary tokens to keep for top-k-filtering.") - ] = 0 - top_p: Annotated[float, Parameter(help="The cumulative probability for nucleus sampling.")] = 1.0 - temperature: Annotated[float, Parameter(help="The value used to module the next token probabilities.")] = 1.0 - repetition_penalty: Annotated[float, Parameter(help="The parameter for repetition penalty.")] = 1.0 - presence_penalty: Annotated[float, Parameter(help="The parameter for presence penalty.")] = 0.0 - frequency_penalty: Annotated[float, Parameter(help="The parameter for frequency penalty.")] = 0.0 - min_tokens: Annotated[int, Parameter(help="Minimum number of tokens to generate.")] = 0 - max_tokens: Annotated[int, Parameter(help="Maximum number of tokens to generate.")] = 2048 - stops: Annotated[list[str], Parameter(help="List of stop sequences.")] = [] - stop_token_ids: Annotated[list[int], Parameter(help="List of stop token IDs.")] = [] - skip_special_tokens: Annotated[bool, Parameter(help="Whether to skip special tokens.")] = True - - -class RolloutExtraParams(TypedDict): - stream: bool - return_logprob: bool - top_logprobs: int - return_token_ids: bool - include_stop_str_in_output: bool - no_stop_trim: bool - skip_special_tokens: bool - spaces_between_special_tokens: bool - - -# 说明: 这里没定义API server情况数据格式,因为直接使用openai server的格式 -class RLRolloutRequestItem(BaseModel): - model_config = ConfigDict(extra="forbid") - messages: list[dict[str, Any]] - tools: list = Field(default_factory=list) - tool_choice: str = "auto" - sample_params: SampleParams = Field(default_factory=SampleParams) - extra_params: dict[str, Any] = Field(default_factory=dict) + return samples diff --git a/xtuner/v1/datasets/__init__.py b/xtuner/v1/datasets/__init__.py index 36ecf3ebc3..e7badef5f3 100644 --- a/xtuner/v1/datasets/__init__.py +++ b/xtuner/v1/datasets/__init__.py @@ -19,10 +19,10 @@ from .packing import ExpandSoftPackDataset, HardPackDataset, MLLMPretrainHybridPackDataset, _LegacySoftPackDataset from .pt_tokenize_fn import PretrainTokenizeFunction, PretrainTokenizeFunctionConfig from .resume import get_dataloader_state, load_dataloader_state -from .rl_tokenize_fn import RLTokenizeFnConfig +from .rl_tokenize_fn import RLTextTokenizeFnConfig from .sampler import LengthGroupedSampler, ParallelSampler from .sft_tokenize_fn import OpenaiTokenizeFunction, OpenaiTokenizeFunctionConfig -from .utils import CachableTokenizeFunction, CacheObj, calculate_file_sha256, calculate_xxhash, tokenizer_hash +from .utils import CachableTokenizeFunction, calculate_file_sha256, calculate_xxhash, tokenizer_hash from .vlm_jsonl import VLMJsonlDataset @@ -32,7 +32,6 @@ __all__ = [ "JsonlDataset", "CachableTokenizeFunction", - "CacheObj", "calculate_file_sha256", "calculate_xxhash", "tokenizer_hash", @@ -47,6 +46,7 @@ "build_datasets", "build_dataloader", "sft_llm_collator", + "fake_collator", "intern_s1_vl_sft_collator", "qwen3_vl_sft_collator", "FtdpTokenizeFunction", @@ -56,7 +56,6 @@ "VLMJsonlDataset", "FTDPTokenizeFnConfig", "InternS1VLTokenizeFnConfig", - "fake_collator", "RLTokenizeFnConfig", "get_dataloader_state", "load_dataloader_state", diff --git a/xtuner/v1/datasets/_hardcode_patch.py b/xtuner/v1/datasets/_hardcode_patch.py index de223e7a3d..5deb833303 100644 --- a/xtuner/v1/datasets/_hardcode_patch.py +++ b/xtuner/v1/datasets/_hardcode_patch.py @@ -26,9 +26,10 @@ from xtuner.v1.utils import get_logger from .ftdp import FtdpTokenizeFunction -from .mllm_tokenize_fn import Qwen3VLTokenizeFunction + +# from .rl_tokenize_fn.rl_tokenize_fn import InternS1VLTokenizeFunction +from .mllm_tokenize_fn import InternS1VLTokenizeFunction, Qwen3VLTokenizeFunction from .pt_tokenize_fn import PretrainTokenizeFunction -from .rl_tokenize_fn.rl_tokenize_fn import InternS1VLTokenizeFunction from .sft_tokenize_fn import OpenaiTokenizeFunction diff --git a/xtuner/v1/datasets/collator.py b/xtuner/v1/datasets/collator.py index b19d67f342..7b9df4ba78 100644 --- a/xtuner/v1/datasets/collator.py +++ b/xtuner/v1/datasets/collator.py @@ -4,6 +4,7 @@ from typing_extensions import TypedDict from xtuner.v1.data_proto import SequenceContext +from xtuner.v1.data_proto.rl_data import RolloutState from xtuner.v1.utils import IGNORE_INDEX, get_logger from xtuner.v1.utils.pad import pad_to_max_length @@ -18,7 +19,7 @@ class ColateItem(TypedDict): shifted_labels: torch.Tensor -def fake_collator(instances: list[DataItem], **kwargs): +def fake_collator(instances: list[RolloutState], **kwargs): return instances diff --git a/xtuner/v1/datasets/config.py b/xtuner/v1/datasets/config.py index 99da59b284..eeccf5b532 100644 --- a/xtuner/v1/datasets/config.py +++ b/xtuner/v1/datasets/config.py @@ -299,7 +299,7 @@ def build_collator(self): elif self.collator == "qwen3_vl_sft_collator": return qwen3_vl_sft_collator elif self.collator == "fake_collator": - return fake_collator # for RL + return fake_collator else: collator = pydoc.locate(self.collator) if collator is None: diff --git a/xtuner/v1/datasets/jsonl.py b/xtuner/v1/datasets/jsonl.py index f774e6c254..a0cc3f2b5f 100644 --- a/xtuner/v1/datasets/jsonl.py +++ b/xtuner/v1/datasets/jsonl.py @@ -23,9 +23,9 @@ from tqdm import tqdm from xtuner.v1.datasets.data_item import CacheItem -from xtuner.v1.utils import SharedMemory, get_logger +from xtuner.v1.utils import CacheDict, CacheObj, SharedMemory, get_logger -from .utils import CachableTokenizeFunction, CacheObj, calculate_xxhash +from .utils import CachableTokenizeFunction, calculate_xxhash T = TypeVar("T") @@ -439,11 +439,15 @@ def count_offsets(self, cache_dir=None): @staticmethod def _tokenize_by_offset( data: bytes, - tokenize_fn: Callable[[dict], CacheObj], + tokenize_fn: Callable[[dict], CacheDict | CacheObj], ) -> dict: line = data.decode() tokenized = tokenize_fn(json.loads(line)) - return {"num_tokens": tokenized["num_tokens"]} + if isinstance(tokenized, CacheObj): + num_tokens = tokenized.num_tokens + else: + num_tokens = tokenized["num_tokens"] + return {"num_tokens": num_tokens} def count_tokens(self, offsets, cache_dir=None): self.tokenize_fn.set_state("cache") diff --git a/xtuner/v1/datasets/rl_tokenize_fn/__init__.py b/xtuner/v1/datasets/rl_tokenize_fn/__init__.py index 2ecf6f3f61..83eb5f8b7f 100644 --- a/xtuner/v1/datasets/rl_tokenize_fn/__init__.py +++ b/xtuner/v1/datasets/rl_tokenize_fn/__init__.py @@ -1,6 +1,5 @@ -from .rl_tokenize_fn import RLTokenizeFnConfig +from .qwen3_vl_tokenize_fn import RLQwen3VLTokenizeFnConfig +from .text_tokenize_fn import RLTextTokenizeFnConfig -__all__ = [ - "RLTokenizeFnConfig", -] +__all__ = ["RLTextTokenizeFnConfig", "RLQwen3VLTokenizeFnConfig"] diff --git a/xtuner/v1/datasets/rl_tokenize_fn/qwen3_vl_tokenize_fn.py b/xtuner/v1/datasets/rl_tokenize_fn/qwen3_vl_tokenize_fn.py new file mode 100644 index 0000000000..7bdca2d71c --- /dev/null +++ b/xtuner/v1/datasets/rl_tokenize_fn/qwen3_vl_tokenize_fn.py @@ -0,0 +1,97 @@ +from typing import cast + +from xtuner.v1.data_proto import RolloutState + +from ...data_proto.rl_data import MultimodalInfo +from ..mllm_tokenize_fn.qwen3_vl_tokenize_fn import Qwen3VLTokenizeFnConfig, Qwen3VLTokenizeFunction, QwenVL3DataItem +from ..utils import replace_image_context_and_collect_media_data + + +def remove_consecutive_img_context_tokens(tokens: list[int], img_context_id: int) -> list[int]: + if not tokens: + return tokens + + new_tokens = [tokens[0]] + for i in range(1, len(tokens)): + if tokens[i] == img_context_id and tokens[i - 1] == img_context_id: + continue # 跳过连续的 img_context_id + else: + new_tokens.append(tokens[i]) + return new_tokens + + +class RLQwen3VLTokenizeFunction(Qwen3VLTokenizeFunction): + def __init__(self, *args, ignore_multimodal_info: bool = False, **kwargs): + self.ignore_multimodal_info = ignore_multimodal_info + super().__init__(*args, **kwargs) + + # TODO: tool call + def __call__(self, item: dict, media_root: str = "", **kwargs) -> RolloutState: + extra_info = item.get("extra_info", {}) + message = item["prompt"] + + data = super().__call__({"messages": message}, media_root=media_root) + + if self.state == "cache": + return RolloutState(message=message, num_tokens=data["num_tokens"]) + else: + data = cast(QwenVL3DataItem, data) + image_data, _ = replace_image_context_and_collect_media_data(message, media_root, True) + if image_data: + extra_info["image_data"] = image_data + + # 因为 sft tokenizer fn 可能并没有完全和 apply_chat_template 中的 jinja 模块对齐,特别是 system prompt + # 为了确保一致,必须要通过 tokenizer_fn 得到 prompt_token_ids + # raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + # prompt_token_ids = self.tokenizer(raw_prompt, add_special_tokens=False)["input_ids"] + prompt_token_ids = remove_consecutive_img_context_tokens(data["input_ids"], self.img_context_token_id) + raw_prompt = self.tokenizer.decode(prompt_token_ids) # Just for logging + extra_info["raw_prompt"] = raw_prompt + # 训练时的 prompt token ids,包含连续的 img_context_token_id + extra_info["train_prompt_ids"] = data["input_ids"] + + mm_info = None + if not self.ignore_multimodal_info: + mm_info = MultimodalInfo() + if "pixel_values" in data: + mm_info["pixel_values"] = data["pixel_values"].numpy() # for ray put into shared memory + if "image_grid_thw" in data: + mm_info["image_grid_thw"] = data["image_grid_thw"] + return RolloutState( + message=message, + num_tokens=data["num_tokens"], + prompt_ids=prompt_token_ids, + position_ids=data["position_ids"], + data_source=item.get("data_source", "default"), + reward_model=item.get("reward_model", {}), + mm_info=mm_info, + extra_fields=extra_info, + ) + + +class RLQwen3VLTokenizeFnConfig(Qwen3VLTokenizeFnConfig): + ignore_multimodal_info: bool = False # eval is True + + def build( + self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs + ) -> RLQwen3VLTokenizeFunction: + return RLQwen3VLTokenizeFunction( + tokenizer, + self.processor_path, + anno_name, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + oss_loader_cfg=self.oss_loader_cfg, + video_min_total_pixels=self.video_min_total_pixels, + video_max_total_pixels=self.video_max_total_pixels, + video_min_frames=self.video_min_frames, + video_max_frames=self.video_max_frames, + rand_video_max_frames=self.rand_video_max_frames, + fps=self.fps, + enable_3d_rope=self.enable_3d_rope, + add_vision_id=self.add_vision_id, + max_length=self.max_length, + system_message=self.system_message, + tokenizer_hash=tokenizer_hash, + ignore_multimodal_info=self.ignore_multimodal_info, + ) diff --git a/xtuner/v1/datasets/rl_tokenize_fn/rl_tokenize_fn.py b/xtuner/v1/datasets/rl_tokenize_fn/rl_tokenize_fn.py deleted file mode 100644 index 47b3fc5ce7..0000000000 --- a/xtuner/v1/datasets/rl_tokenize_fn/rl_tokenize_fn.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import cast - -from pydantic import BaseModel, ConfigDict - -from transformers import PreTrainedTokenizer -from xtuner.v1.data_proto.rl_data import RLDatasetItem -from xtuner.v1.utils import get_logger - -from ..data_item import OmniDataItem -from ..mllm_tokenize_fn.intern_s1_vl_tokenize_fn import InternS1VLTokenizeFunction -from ..mllm_tokenize_fn.qwen3_vl_tokenize_fn import Qwen3VLTokenizeFunction -from ..utils import CachableTokenizeFunction, replace_image_context_and_collect_media_data - - -logger = get_logger() - - -def remove_consecutive_twos(tokens, img_context_id): - if not tokens: - return tokens - - new_tokens = [tokens[0]] - for i in range(1, len(tokens)): - if tokens[i] == img_context_id and tokens[i - 1] == img_context_id: - continue # 跳过连续的 img_context_id - else: - new_tokens.append(tokens[i]) - return new_tokens - - -class RLTokenizeFn(CachableTokenizeFunction[RLDatasetItem]): - def __init__( - self, - tokenizer_fn: CachableTokenizeFunction | None, - tokenizer: PreTrainedTokenizer, - max_length: int | None = None, - ignore_multimodal_info: bool = False, - ): - super().__init__(tokenizer) - self.tokenizer_fn = tokenizer_fn - self.max_length = max_length - - self.img_context_id = None - self.ignore_multimodal_info = ignore_multimodal_info - self.model_name = "default" - if self.tokenizer_fn: - if isinstance(self.tokenizer_fn, Qwen3VLTokenizeFunction): - self.model_name = "qwen3_vl" - elif isinstance(self.tokenizer_fn, InternS1VLTokenizeFunction): - self.model_name = "intern_s1_vl" - else: - raise ValueError(f"Unsupported tokenizer_fn type: {type(self.tokenizer_fn)}") - self.img_context_id = tokenizer.convert_tokens_to_ids(self.tokenizer_fn.chat_template.image_context_token) - - def __call__(self, item: dict, **kwargs) -> RLDatasetItem: - """example: - item = { - "data_source": data_source, - "prompt": [ - { - "role": "user", - "content": question, - } - ], - "ability": "math", - "reward_model": {"style": "rule", "ground_truth": solution}, - "extra_info": { - "split": split, - "index": idx, - "answer": answer_raw, - "question": question_raw, - }, - } - """ - - extra_info = item.get("extra_info", {}) - messages = item["prompt"] - if self.tokenizer_fn is None: - # pure text - raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - data = self.tokenizer(raw_prompt, add_special_tokens=False) - prompt_token_ids = data["input_ids"] - num_tokens = len(data["input_ids"]) - else: - # mllm - self.tokenizer_fn.state = self.state - data = self.tokenizer_fn({"messages": messages}, **kwargs) - data = cast(OmniDataItem, data) - num_tokens = data["num_tokens"] - - media_root = kwargs.get("media_root", "") - if self.model_name == "qwen3_vl": - image_data, _ = replace_image_context_and_collect_media_data(messages, media_root, True) - elif self.model_name == "intern_s1_vl": - image_data, _ = replace_image_context_and_collect_media_data(messages, media_root, False) - else: - raise ValueError(f"Unsupported model_name: {self.model_name}") - if image_data: - extra_info["image_data"] = image_data - - # 不能用下面的逻辑得到 rollout 的 prompt_token_ids - # 因为 sft tokenizer fn 可能并没有完全和 apply_chat_template 中的 jinja 模块对齐,特别是 system prompt - # 为了确保一致,必须要通过 tokenizer_fn 得到 prompt_token_ids - # raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - # prompt_token_ids = self.tokenizer(raw_prompt, add_special_tokens=False)["input_ids"] - if self.state != "cache": - prompt_token_ids = remove_consecutive_twos(data["input_ids"], self.img_context_id) - else: - prompt_token_ids = [1] # Just a placeholder - raw_prompt = self.tokenizer.decode(prompt_token_ids) # Just for logging - - multimodal_train_info = {} - extra_info["raw_prompt"] = raw_prompt - - if self.state == "cache": - if self.max_length is not None and num_tokens > self.max_length: - num_tokens = 0 # will be filtered out by the dataset filter - else: - if self.max_length is not None: - assert num_tokens <= self.max_length, f"num_tokens {num_tokens} > max_length {self.max_length}" - if not self.ignore_multimodal_info: - if "pixel_values" in data: - multimodal_train_info["pixel_values"] = data["pixel_values"] - if "image_grid_thw" in data: - multimodal_train_info["image_grid_thw"] = data["image_grid_thw"] # qwen3-vl - if "position_ids" in data: - multimodal_train_info["position_ids"] = data["position_ids"] # qwen3-vl - - # 在多模态场景下,训练和 rollout 的 prompt ids 是不一样的 - # 为了统一训练处理逻辑,额外保存 train_prompt_ids - extra_info["train_prompt_ids"] = data["input_ids"] - - rl_out_data = { - "messages": messages, - "input_ids": prompt_token_ids, - "num_tokens": num_tokens, - "reward_model": item["reward_model"], - "ability": item.get("ability", None), - "data_source": {item.get("data_source"): 1.0}, - "extra_info": extra_info, - "multimodal_train_info": multimodal_train_info, - } - return rl_out_data # type: ignore - - def hash(self) -> str: - raise ValueError("不应该触发这个方法, 因为 RLTokenizeFn 不需要缓存。") - - -class RLTokenizeFnConfig(BaseModel): - model_config = ConfigDict(title="Base RL dataset config for xtuner", extra="forbid") - tokenize_fn_cfg: BaseModel | None = None - max_length: int | None = None - ignore_multimodal_info: bool = False # eval is True - - def build( - self, tokenizer: PreTrainedTokenizer, tokenizer_hash: str | None = None, anno_name: str | None = None, **kwargs - ) -> RLTokenizeFn: - tokenizer_fn = None - if self.tokenize_fn_cfg: - tokenizer_fn = self.tokenize_fn_cfg.build( - tokenizer=tokenizer, - tokenizer_hash=tokenizer_hash, - anno_name=anno_name, - **kwargs, - ) - return RLTokenizeFn( - tokenizer_fn, - tokenizer=tokenizer, - max_length=self.max_length, - ignore_multimodal_info=self.ignore_multimodal_info, - ) diff --git a/xtuner/v1/datasets/rl_tokenize_fn/text_tokenize_fn.py b/xtuner/v1/datasets/rl_tokenize_fn/text_tokenize_fn.py new file mode 100644 index 0000000000..dd4d0011c1 --- /dev/null +++ b/xtuner/v1/datasets/rl_tokenize_fn/text_tokenize_fn.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pydantic import BaseModel, ConfigDict + +from transformers import PreTrainedTokenizer +from xtuner.v1.data_proto import RolloutState +from xtuner.v1.utils import get_logger + +from ..utils import CachableTokenizeFunction + + +logger = get_logger() + + +class RLTextTokenizeFn(CachableTokenizeFunction[RolloutState]): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + max_length: int | None = None, + tools_schema: list | None = None, + ): + super().__init__(tokenizer) + self.max_length = max_length + self.tools_schema = tools_schema if tools_schema is not None else [] + + def __call__(self, item: dict, **kwargs) -> RolloutState: + """example: + item = { + "data_source": data_source, + "prompt": [ + { + "role": "user", + "content": question, + } + ], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": { + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + }, + } + """ + + extra_info = item.get("extra_info", {}) + message = item["prompt"] + + raw_prompt = self.tokenizer.apply_chat_template( + message, tools=self.tools_schema, add_generation_prompt=True, tokenize=False + ) + extra_info["raw_prompt"] = raw_prompt + data = self.tokenizer(raw_prompt, add_special_tokens=False) + prompt_token_ids = data["input_ids"] + num_tokens = len(data["input_ids"]) + + if self.state == "cache": + if self.max_length is not None and num_tokens > self.max_length: + num_tokens = 0 # will be filtered out by the dataset filter + else: + if self.max_length is not None: + assert num_tokens <= self.max_length, f"num_tokens {num_tokens} > max_length {self.max_length}" + + rollout_state = RolloutState( + prompt_ids=prompt_token_ids, + message=message, + data_source=item.get("data_source", "default"), + reward_model=item.get("reward_model", {}), + num_tokens=num_tokens, + extra_fields=extra_info, + ) + return rollout_state + + def hash(self) -> str: + raise ValueError("不应该触发这个方法, 因为 RLTokenizeFn 不需要缓存。") + + +class RLTextTokenizeFnConfig(BaseModel): + model_config = ConfigDict(title="Text RL dataset config for xtuner", extra="forbid") + max_length: int | None = None + tools_schema: list | None = None + + def build(self, tokenizer: PreTrainedTokenizer, **kwargs) -> RLTextTokenizeFn: + return RLTextTokenizeFn(tokenizer=tokenizer, max_length=self.max_length, tools_schema=self.tools_schema) diff --git a/xtuner/v1/datasets/utils.py b/xtuner/v1/datasets/utils.py index 4c3999a317..5d89e87b39 100644 --- a/xtuner/v1/datasets/utils.py +++ b/xtuner/v1/datasets/utils.py @@ -9,7 +9,6 @@ import numpy as np import xxhash from PIL import Image -from typing_extensions import TypedDict from .data_item import CacheItem @@ -20,10 +19,6 @@ from transformers import PreTrainedTokenizer -class CacheObj(TypedDict, total=False): - num_tokens: int - - class CachableTokenizeFunction(ABC, Generic[T]): def __init__(self, tokenizer, *args, **kwargs): self.tokenizer = tokenizer diff --git a/xtuner/v1/ray/__init__.py b/xtuner/v1/ray/__init__.py deleted file mode 100644 index 1d2ccfb3c9..0000000000 --- a/xtuner/v1/ray/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .base import AcceleratorResourcesConfig, AutoAcceleratorWorkers, SingleAcceleratorWorker -from .utils import ( - find_master_addr_and_port, - get_accelerator_ids, - get_ray_accelerator, - load_function, -) diff --git a/xtuner/v1/ray/base/__init__.py b/xtuner/v1/ray/base/__init__.py deleted file mode 100644 index c66e30c7fa..0000000000 --- a/xtuner/v1/ray/base/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .accelerator import AcceleratorResourcesConfig, AutoAcceleratorWorkers, SingleAcceleratorWorker -from .cpu import AutoCPUWorkers, BaseCPUWorker, CPUResourcesConfig diff --git a/xtuner/v1/ray/base/cpu.py b/xtuner/v1/ray/base/cpu.py deleted file mode 100644 index ca08f5d859..0000000000 --- a/xtuner/v1/ray/base/cpu.py +++ /dev/null @@ -1,191 +0,0 @@ -from typing import Any, Dict, TypeVar - -import ray -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, field_validator -from ray.util.placement_group import VALID_PLACEMENT_GROUP_STRATEGIES, PlacementGroup, placement_group -from typing_extensions import Annotated - - -PG_READY_TIMEOUT = 30 # seconds -T = TypeVar("T") - - -class CPUResourcesConfig(BaseModel): - """Configuration for CPU resources in a placement group for XTuner. - - This class provide specific configuration options for CPU-based workers in Ray placement groups. - - Args: - num_cpus_per_worker (float): Number of CPUs to allocate per worker in the - placement group. Defaults to 8. - cpu_memory_per_worker (int): Amount of CPU memory (in bytes) to allocate - for each worker in the placement group. - num_workers (int): Total number of workers in the placement group. - """ - - model_config = ConfigDict(extra="forbid") - num_workers: Annotated[int, Parameter(help="Number of workers in the placement group.")] = 1 - num_cpus_per_worker: Annotated[float, Parameter(help="Number of CPUs to allocate for the placement group.")] = 1 - cpu_memory_per_worker: Annotated[ - int, Parameter(help="Amount of memory (in bytes) to allocate for the placement group.") - ] = 1024**3 # 1 GB - pg_pack_strategy: Annotated[ - str, - Parameter(help="Placement group packing strategy, options: " + ", ".join(VALID_PLACEMENT_GROUP_STRATEGIES)), - ] = "SPREAD" - - @field_validator("pg_pack_strategy") - @classmethod - def check_pg_pack_strategy(cls, v): - if v not in VALID_PLACEMENT_GROUP_STRATEGIES: - raise ValueError(f"pg_pack_strategy must be one of {VALID_PLACEMENT_GROUP_STRATEGIES}") - return v - - def model_post_init(self, __context: Any) -> None: - assert ray.is_initialized(), "Ray must be initialized before creating CPUResourcesConfig." - available_resources = ray.available_resources() - available_cpus = available_resources.get("CPU", 0) - available_memory = available_resources.get("memory", 0) - # TODO: manage single controller's cpu resource to replace "10" here - needed_cpus = (self.num_cpus_per_worker * self.num_workers) + 10 - assert needed_cpus <= available_cpus, ( - f"Not enough available CPUs in Ray cluster, available_cpus is {available_cpus} but xtuner needs {needed_cpus}." - ) - needed_memory = self.cpu_memory_per_worker * self.num_workers + 10 * 1024**3 - assert needed_memory <= available_memory, ( - f"Not enough available memory in Ray cluster, available_memory is {available_memory} but xtuner needs {needed_memory}." - ) - # TODO: check all resources sum in cluster to avoid over allocation - - @classmethod - def from_total( - cls, total_cpus: float | int, total_memory: int, num_workers: int, pg_pack_strategy: str = "SPREAD" - ): - """Create a CPUResourcesConfig from total CPU and memory resources. - - Args: - total_cpus (float | int): Total number of CPUs to allocate across all workers. - total_memory (int): Total amount of memory (in bytes) to allocate across all workers. - num_workers (int): Number of workers in the placement group. - - Returns: - CPUResourcesConfig: The created CPUResourcesConfig object. - """ - assert num_workers > 0, "Number of workers must be positive." - return cls( - num_workers=num_workers, - num_cpus_per_worker=total_cpus / num_workers, - cpu_memory_per_worker=total_memory / num_workers, - pg_pack_strategy=pg_pack_strategy, - ) - - -class BaseCPUWorker: - """The BaseCPUWorker class serves as a foundational structure for CPU-based - workers within the XTuner framework. - - This class is designed to be extended by specific CPU worker implementations. - It provides a constructor that accepts a configuration object, allowing - subclasses to initialize with custom settings. - - Args: - config: The configuration object for the CPU worker. - num_cpus (float | int): The number of CPUs allocated to this worker. - Defaults to 1. - """ - - def __init__(self, config, num_cpus: float | int = 1): - self.config = config - self.num_cpus = num_cpus - - -class AutoCPUWorkers: - """A utility class for automatically creating and managing cpu actors - within a Ray PlacementGroup.""" - - @staticmethod - def build_placement_group(resources_config: CPUResourcesConfig): - """Build a Ray PlacementGroup based on the provided resource - configuration. - - Args: - resources_config (CPUResourcesConfig): The configuration - specifying the resources for each worker bundle. - - Returns: - PlacementGroup: The created Ray PlacementGroup. - """ - bundles = [ - { - "CPU": resources_config.num_cpus_per_worker, - "memory": resources_config.cpu_memory_per_worker, - } - ] * resources_config.num_workers - - pg = placement_group(bundles=bundles, strategy=resources_config.pg_pack_strategy) - - ray.get(pg.ready(), timeout=PG_READY_TIMEOUT) - return pg - - @staticmethod - def get_pg_options(pg: PlacementGroup, num_cpus: int | float = -1) -> Dict: - """Provide a dictionary of resource requests for Ray tasks or actors - with specific cpu requirements. - - Args: - pg (PlacementGroup): The placement group to get options for. - num_cpus (float): The number of CPUs to request. If set to -1, - the default CPU allocation from the placement group bundle - will be used. Defaults to -1. - - Returns: - Dict: A dictionary of Ray resource options for `task.options()`. - """ - assert len(pg.bundle_specs) > 0, "Placement group has no bundles defined." - default_cpu = pg.bundle_specs[0].get("CPU", 1) - return {"num_cpus": num_cpus if num_cpus >= 0 else default_cpu} - - @classmethod - def from_config(cls, worker_cls, worker_config, cpu_config: CPUResourcesConfig): - """Create workers and a placement group from configuration objects. - - Args: - worker_cls: The class of the worker to instantiate. - worker_config: The configuration for each worker instance. - cpu_config (CPUResourcesConfig): The configuration - for the cpu resources. - - Returns: - List[T]: List of created worker instances. - """ - pg = AutoCPUWorkers.build_placement_group(cpu_config) - workers_list = cls.from_placement_group(worker_cls, worker_config, pg) - - return workers_list, pg - - @classmethod - def from_placement_group(cls, worker_cls, worker_config, pg: PlacementGroup, num_workers: int = -1): - """Create workers from an existing placement group. - - Args: - worker_cls: The class of the worker to instantiate. - worker_config: The configuration for each worker instance. - pg (PlacementGroup): The existing placement group to use. - num_workers (int): The number of workers to create. Defaults to -1, - the number of bundles in the placement group will be used. - - Returns: - List[T]: List of created worker instances. - """ - pg_options = cls.get_pg_options(pg) - - num_workers = num_workers if num_workers > 0 else pg.bundle_count - workers_list = [] - for _ in range(num_workers): - worker = worker_cls.options(placement_group=pg, **pg_options).remote( - worker_config, num_cpus=pg_options.get("num_cpus", 1) - ) # type: ignore[attr-defined] - workers_list.append(worker) - - return workers_list diff --git a/xtuner/v1/ray/config/__init__.py b/xtuner/v1/ray/config/__init__.py deleted file mode 100644 index 22f86fcb56..0000000000 --- a/xtuner/v1/ray/config/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .worker import ( - RolloutConfig, - TrainingWorkerConfig, -) diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py deleted file mode 100644 index f394b019cd..0000000000 --- a/xtuner/v1/ray/config/worker.py +++ /dev/null @@ -1,331 +0,0 @@ -import json -import os -import socket -from pathlib import Path -from typing import Any, List, Literal, Optional, Union - -from cyclopts import Group, Parameter -from pydantic import BaseModel, ConfigDict, PrivateAttr -from typing_extensions import Annotated - -from xtuner.v1.utils import get_logger - - -worker_group = Group("worker", help="Types of workers available.") -train_group = Group("Training", sort_key=90, help="Training worker configuration.") -infer_group = Group("inference", help="Inference worker configuration.") - - -class TrainingWorkerConfig(BaseModel): - """Configuration for the TrainingWorker.""" - - model_config = ConfigDict(extra="forbid") - type: Literal["train"] = "train" - train_model_path: Annotated[str, Parameter(group=train_group, help="Path to the training model.")] - - -class RolloutConfig(BaseModel): - """Rollout worker configuration for XTuner. - - This class defines comprehensive configuration parameters for rollout workers in XTuner, - supporting multiple inference backends with distributed computing and optimization features. - - Args: - env (str): Environment variables for the rollout worker. Defaults to "". - backend (str): Backend framework ('vllm', 'lmdeploy', etc.). Defaults to "lmdeploy". - model_path (str | Path): Path to the inference model. - model_name (str): Model name for the backend engine. - tokenizer_path (str): Path to the model tokenizer. Defaults to "". - api_key (Optional[Union[List[str], str]]): API keys for rollout service. Supports single key or list of keys. Defaults to None. - api_port (Optional[int]): Port number for the rollout API server. If not set, it will find an available port starting from 8000. Defaults to 8000. - gpus_per_node (int): Number of GPUs per node. Defaults to 8. - dtype (str): Model data type ('bfloat16', 'float16', 'int8'). Defaults to "bfloat16". - gpu_memory_utilization (float): GPU memory utilization ratio. Defaults to 0.85. - random_seed (int): Random seed for reproducible generation. Defaults to 1024. - rollout_cross_node_comm (bool): Enable cross-node communication. Defaults to False. - rollout_max_batch_size_per_instance (int): Maximum batch size for the rollout worker. If not set, it will be determined automatically based on `context_length`. Defaults to 512. - allow_over_concurrency_ratio (float): Factor to allow over-concurrency in HTTP requests for the rollout worker to improve GPU utilization. Defaults to 1.2. - tensor_parallel_size (int): GPUs per inference engine (tensor parallelism). Defaults to 1. - expert_parallel_size (int): Experts per inference engine (expert parallelism). Defaults to 1. - enable_chunked_prefill (bool): Enable chunked prefill for memory efficiency. Defaults to False. - chunked_prefill_size (int): Chunk size for prefill operations. Defaults to 128. - skip_load_weights (bool): Skip weight loading for rollout worker. Defaults to False. - rollout_timeout (float): Timeout duration in seconds for rollout requests. Defaults to 3600.0. - context_length (int): Context length for the rollout worker. - launch_server_method (Literal["ray", "multiprocessing"]): Server launch method. Defaults to "ray". - system_prompt (Optional[str]): System prompt to guide generation behavior. Defaults to None. - extra_rollout_config (Optional[dict]): Backend-specific configurations using engine prefixes - (e.g., 'vllm_enable_chunked_prefill', 'lmdeploy_max_batch_size'). Defaults to empty dict. - - **Examples:** - - Example configuration with LMDeploy backend:: - - config = RolloutConfig( - env="test_env", - model_path="Qwen/Qwen3-8B", - model_name="Qwen3-8B", - tensor_parallel_size=2, - gpu_memory_utilization=0.6, - gpus_per_node=8, - backend="lmdeploy", - ) - """ - - model_config = ConfigDict(extra="forbid") - - # base config - env: Annotated[ - str, - Parameter(group=infer_group, help="Environment variables to set for the rollout."), - ] = "" - device: Annotated[str, Parameter(group=infer_group, help="Device to be used for the rollout worker.")] = "GPU" - model_path: Annotated[str | Path, Parameter(group=infer_group, help="Path to the SGLang model.")] - model_name: Annotated[ - str | None, Parameter(group=infer_group, help="Name of the model to be used in the LMDeploy.") - ] = None - tokenizer_path: Annotated[ - str | None, Parameter(group=infer_group, help="Path to the tokenizer for the model.") - ] = None - api_key: Annotated[ - Optional[Union[List[str], str]], - Parameter( - group=infer_group, - help="API keys for the rollout service. Can be a single key or a list of keys.", - ), - ] = None - api_port: Annotated[ - int, - Parameter(group=infer_group, help="Port number for the rollout API server. If not set, 8000 will be used."), - ] = 8000 - gpus_per_node: Annotated[int, Parameter(group=infer_group, help="Number of GPUs allocated per node.")] = 8 - dtype: Annotated[ - str, - Parameter(group=infer_group, help="Data type for the model, e.g., 'bfloat16', 'float16', 'int8'."), - ] = "bfloat16" - gpu_memory_utilization: Annotated[ - float, Parameter(group=infer_group, help="GPU memory utilization for the rollout worker.") - ] = 0.85 - random_seed: Annotated[int, Parameter(group=infer_group, help="Random seed for the rollout worker.")] = 1024 - # distributed config - rollout_cross_node_comm: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to enable cross-node communication for the rollout worker.", - ), - ] = False - dist_port_base: Annotated[ - int, - Parameter( - group=infer_group, - help="Base port number for distributed communication among rollout workers.", - ), - ] = 35000 - rollout_max_batch_size_per_instance: Annotated[ - Optional[int], - Parameter( - group=infer_group, - help="Maximum batch size for the rollout worker. If not set, it will be determined automatically based on the model and GPU memory.", - ), - ] = None - allow_over_concurrency_ratio: Annotated[ - float, - Parameter( - group=infer_group, - help="Factor to allow over concurrency in the http request for rollout worker to improve GPU utilization.", - ), - ] = 1.2 - tensor_parallel_size: Annotated[ - int, - Parameter( - group=infer_group, - help="Number of GPUs allocated for each inference engine in the rollout worker.", - ), - ] = 1 - expert_parallel_size: Annotated[ - int, - Parameter( - group=infer_group, - help="Number of experts allocated for each inference engine in the rollout worker.", - ), - ] = 1 - # optimization config - enable_chunked_prefill: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to enable chunked prefill for the rollout worker.", - ), - ] = False - chunked_prefill_size: Annotated[ - int, - Parameter( - group=infer_group, - help="Chunked prefill size for the rollout worker.", - ), - ] = 128 - skip_load_weights: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to skip loading weights for the rollout worker.", - ), - ] = False - enable_return_routed_experts: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to enable returning routed experts for the rollout worker.", - ), - ] = False - launch_server_method: Annotated[ - Literal["ray", "multiprocessing"], - Parameter( - group=infer_group, - help="Method to launch the rollout server, either 'ray' or 'multiprocessing'.", - ), - ] = "ray" - rollout_timeout: Annotated[ - float, - Parameter( - group=infer_group, - help="Timeout duration (in seconds) for rollout requests.", - ), - ] = 1200.0 - context_length: Annotated[ - Optional[int], - Parameter( - group=infer_group, - help="Context length for the rollout worker.", - ), - ] = None - enable_float8: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to enable float8 quantization for the rollout worker.", - ), - ] = False - extra_rollout_config: Annotated[ - dict, - Parameter( - group=infer_group, - help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.', - ), - ] = {} - max_retry_per_worker: Annotated[ - Optional[int], - Parameter( - group=infer_group, - help="Maximum number of retries per rollout worker before deactivation.", - ), - ] = None - max_retry_per_sample: Annotated[ - int, - Parameter( - group=infer_group, - help="Maximum number of retries per sample before marking it as failed.", - ), - ] = 1 - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - _logged_server_urls_per_engine: bool = PrivateAttr(default=False) - - @property - def rollout_backend(self) -> str: - backend = "" - if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": - backend = "sglang" - elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": - backend = "vllm" - elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": - backend = "lmdeploy" - - assert backend in ["sglang", "vllm", "lmdeploy"], ( - f"Unsupported rollout backend: {backend}. Please set XTUNER_USE_SGLANG, XTUNER_USE_VLLM, or XTUNER_USE_LMDEPLOY to 1." - ) - return backend - - @property - def server_urls_per_engine(self) -> int: - # server_urls_per_engine is introduced for lmdeploy ep settings - # for now only lmdeploy pytorch backend with ep > 1 requires multiple server urls per engine - if self.rollout_backend == "lmdeploy" and self.expert_parallel_size > 1: - # when expert parallelism is used, lmdeploy requires `expert_parallel_size` server instances per engine - if not self._logged_server_urls_per_engine: - self._logged_server_urls_per_engine = True - get_logger().info( - f"Setting server_urls_per_engine={self.expert_parallel_size} due to expert parallelism in LMDeploy." - ) - return self.expert_parallel_size - else: - return 1 - - def model_post_init(self, __context: Any) -> None: - if self.model_name is None: - model_name_from_config = None - config_json_path = Path(self.model_path) / "config.json" - try: - with open(config_json_path, encoding="utf-8") as f: - config_data = json.load(f) - model_name_from_config = config_data.get("model_type") - except (json.JSONDecodeError, OSError): - pass - self.model_name = model_name_from_config or Path(self.model_path).name - - if self.tokenizer_path is None: - self.tokenizer_path = str(self.model_path) - - port = self.api_port - while True: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind(("localhost", port)) - break - except OSError: - port += 1 - self.api_port = port - - if self.device == "NPU": - self.gpus_per_node = 16 - - if self.rollout_backend == "sglang": - self.launch_server_method = "multiprocessing" - self.rollout_cross_node_comm = False - else: - self.launch_server_method = "ray" - self.rollout_cross_node_comm = True - - if self.rollout_max_batch_size_per_instance is None: - assert self.context_length is not None, ( - "context_length must be set if rollout_max_batch_size_per_instance is not provided." - ) - # TODO(@duanyanhui): Provide better suggestions for different models/input-output lengths - if self.context_length <= 4096: - self.rollout_max_batch_size_per_instance = 1024 - elif self.context_length <= 8192: - self.rollout_max_batch_size_per_instance = 512 - else: - self.rollout_max_batch_size_per_instance = 128 - - if self.max_retry_per_worker is None: - self.max_retry_per_worker = self.rollout_max_batch_size_per_instance - - self.worker_log_dir.mkdir(parents=True, exist_ok=True) - - -if __name__ == "__main__": - from cyclopts import App, Group, Parameter - - app = App() - - @app.default - def test_command(*, config: RolloutConfig): - """A test command to verify the command line interface. - - Args: - config: The rollout configuration. - """ - print("This is a test command.") - - app() diff --git a/xtuner/v1/ray/dataflow/__init__.py b/xtuner/v1/ray/dataflow/__init__.py deleted file mode 100644 index 3f2fbaf630..0000000000 --- a/xtuner/v1/ray/dataflow/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .flow import DataFlow, DataFlowConfig, DataFlowProxy -from .replay_buffer import ReplayBuffer, ReplayBufferConfig diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py deleted file mode 100644 index 6c7fb74039..0000000000 --- a/xtuner/v1/ray/dataflow/flow.py +++ /dev/null @@ -1,558 +0,0 @@ -import asyncio -import math -import time -from pathlib import Path -from typing import Any, Dict, List, Optional, TypedDict - -import httpx -import ray -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict -from ray.actor import ActorProxy -from tqdm.auto import tqdm -from typing_extensions import Annotated - -from xtuner.v1.data_proto.rl_data import MultimodalTrainInfo, RLDataFlowItem, RolloutState -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.ray.rollout.controller import SampleParams -from xtuner.v1.ray.utils import create_task -from xtuner.v1.utils import get_logger, ray_method - -from .replay_buffer import ReplayBuffer, ReplayBufferConfig, determine_group_state - - -class DataFlowResult(TypedDict): - data_groups: List[List[RLDataFlowItem]] - mm_train_infos: List[MultimodalTrainInfo] - metrics: Dict[str, Any] - - -class DataFlowConfig(BaseModel): - """Data flow configuration for XTuner. - - Simple configuration for managing concurrent data generation workflows - in reinforcement learning training. - - Args: - env (str): Environment identifier. Defaults to "". - max_concurrent (int): Maximum concurrent tasks. Defaults to 8. - prompt_repeat_k (int): Times to repeat each prompt. Defaults to 1. - global_batch_size (int): Target samples to collect. Defaults to 8. - max_retry_times (int): Maximum retry attempts. Defaults to 1. - enable_partial_rollout (int): Enable async mode (1) or disable (0). Defaults to 0. - sample_params (SampleParams): Model sampling parameters. Defaults to SampleParams(). - - **Examples:** - - Example configuration for dataflow:: - - config = DataFlowConfig( - env="test_env", - max_concurrent=256, - global_batch_size=1024, - prompt_repeat_k=8, - sample_params=SampleParams(max_tokens=2048), - ) - """ - - model_config = ConfigDict(extra="forbid") - - env: Annotated[ - str, - Parameter(help="Environment name to set for the dataflow."), - ] = "" - max_concurrent: Annotated[ - Optional[int], - Parameter(help="Maximum number of concurrent tasks."), - ] = None - max_retry_times: Annotated[ - int, - Parameter(help="Maximum number of retry task for failed samples."), - ] = 3 - prompt_repeat_k: Annotated[ - int, - Parameter(help="Number of times to repeat each prompt."), - ] = 1 - global_batch_size: Annotated[ - int, - Parameter(help="Target number of samples to collect before stopping."), - ] = 8 - sample_params: Annotated[SampleParams, Parameter(help="Parameters for sampling from the model.")] = SampleParams() - extra_params: Annotated[Dict, Parameter(help="Extra parameters for rollout.")] = {} - # async params - staleness_threshold: Annotated[ - float, - Parameter( - help="The maximum allowed threshold of stale (expired) samples in a training batch. Must be between 0.0 and 1.0." - ), - ] = 0.0 - enable_partial_rollout: Annotated[ - bool, - Parameter(help="Whether to enable partial rollout for asynchronous data generation."), - ] = False - tail_batch_candidate_steps: Annotated[ - int, - Parameter( - help="Number of rollout steps after which a sample becomes a candidate for the tail batch. Set to 0 to disable." - ), - ] = 0 - tail_batch_trigger_size: Annotated[ - Optional[int], - Parameter( - help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable." - ), - ] = None - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - - def model_post_init(self, __context: Any) -> None: - self.worker_log_dir.mkdir(parents=True, exist_ok=True) - if self.tail_batch_trigger_size is None: - self.tail_batch_trigger_size = self.global_batch_size - - -class RawDataFlow: - """A Ray actor that manages the data flow for reinforcement learning. - - This class is responsible for sampling prompts, interacting with the environment or to generate responses, - processing the results, and storing them in a replay buffer. It orchestrates the asynchronous generation of - training data. - """ - - def __init__( - self, - env: str, - dataflow_cfg: DataFlowConfig, - replay_buffer_cfg: ReplayBufferConfig, - environment: SingleTurnEnvironment, - ): - """Initializes the DataFlow actor. - - Args: - env (str): The name of the environment. - dataflow_cfg (DataFlowConfig): Configuration for the data flow. - replay_buffer_cfg (ReplayBufferConfig): Configuration for the - replay buffer. - environment (EnvController): The environment controller actor. - postprocessor (Optional[Callable]): An optional function to - post-process the generated samples. - """ - self.logger = get_logger(log_dir=dataflow_cfg.worker_log_dir, tag="DataFlow") - self.env = env - self.config = dataflow_cfg - replay_buffer_cfg.worker_log_dir = self.config.worker_log_dir - self.replay_buffer = ReplayBuffer.remote(replay_buffer_cfg) # type: ignore[attr-defined] - self.replay_buffer.setup_storage_config.remote( # type: ignore[attr-defined] - enable_partial_rollout=self.config.enable_partial_rollout, - tail_batch_candidate_steps=self.config.tail_batch_candidate_steps, - tail_batch_trigger_size=self.config.tail_batch_trigger_size, - ) - self.staleness_threshold = self.config.staleness_threshold - self.env_controller = environment - self.finished_samples_count = 0 - self.skipped_sample_count = 0 - self.failed_sample_count = 0 - self.filtered_samples_count = 0 - self.tb_metrics: Dict[str, Any] = {} - self.target_batch_size = self.config.global_batch_size - rollout_info = ray.get(self.env_controller.get_rollout_info.remote()) # type: ignore[attr-defined] - self.worker_url_list = list(rollout_info["server_url_dict"].values()) - self.logger.info(f"DataFlow connected to active rollout workers url: {self.worker_url_list}") - rollout_config = rollout_info["rollout_config"] - max_concurrent = int( - ( - rollout_config.rollout_max_batch_size_per_instance - * len(self.worker_url_list) - / self.config.prompt_repeat_k - ) - * rollout_config.allow_over_concurrency_ratio - ) - - if self.config.max_concurrent is None: - self.config.max_concurrent = max_concurrent - self.logger.info( - f"Set Dataflow max_concurrent to {self.config.max_concurrent} based on rollout max batch size and number of workers." - ) - else: - self.logger.warning( - f"Dataflow max_concurrent is set to {self.config.max_concurrent}, we proposed to set max_concurrent to {max_concurrent} based on rollout_max_batch_size_per_instance." - ) - self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}") - self.cleanup_task_time = 5 * 60 # 5 minutes - - def _reset_internal_states( - self, - global_batch_size: Optional[int] = None, - sample_params: Optional[SampleParams] = None, - extra_params: Optional[Dict] = None, - staleness_threshold: Optional[float] = None, - ): - """Resets all internal state variables of DataFlow.""" - self.skipped_sample_count = 0 - self.failed_sample_count = 0 - self.filtered_samples_count = 0 - self.tb_metrics = {} - if global_batch_size and global_batch_size > 0: - self.target_batch_size = global_batch_size - else: - self.target_batch_size = self.config.global_batch_size - - if staleness_threshold is not None: - self.staleness_threshold = staleness_threshold - else: - self.staleness_threshold = self.config.staleness_threshold - - self.sample_from_expired_storage, self.finished_samples_count = ray.get( - self.replay_buffer.get_prerun_state.remote() - ) - ray.get(self.env_controller.restart.remote()) # type: ignore[attr-defined] - self.sample_params = sample_params if sample_params else self.config.sample_params - self.extra_params = extra_params if extra_params else self.config.extra_params - logger_msg = ( - f"DataFlow states for new generations: target_batch_size={self.target_batch_size}, " - f"sample_params: {self.sample_params}, extra_params: {self.extra_params}, " - f"sample_from_expired_storage={self.sample_from_expired_storage}, finished_samples_count={self.finished_samples_count}, " - ) - self.logger.info(logger_msg) - - @ray_method - def get_train_dataset_length(self): - """Gets the length of the training dataset from the replay buffer.""" - return ray.get(self.replay_buffer.get_train_dataset_length.remote()) - - @ray_method - async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowItem]] = None): - """A single worker task to generate and process a group of samples. - - This task performs the following steps: - 1. Samples a prompt from the replay buffer (or uses a sample for retry). - 2. Calls the environment controller or rollout controller to generate a response. - 3. Post-processes the generated samples use default postprocessor and custom postprocessor. - 4. Adds the filtered samples to the replay buffer. - - Args: - group_samples_for_retry (Optional[List[RLDataFlowItem]]): A group - of samples to retry if a previous attempt failed. Defaults to - None. - - Returns: - Optional[List[RLDataFlowItem]]: The group of samples if the task - fails and needs to be retried, otherwise None. - """ - task_start_time = time.perf_counter() - # step 1: sample - # TODO(@duanyanhui): More fine-grained control over group data generation: - # Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency. - group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined] - self.env, self.config.prompt_repeat_k - ) - assert len(group_data_items) > 0, "Sampled empty group data items from replay buffer." - action_id = group_data_items[0].uid.action_id - # step 2: env generate - group_data_items = await self.env_controller.run.remote( # type: ignore[attr-defined] - group_data_items, - sample_params=self.sample_params, - extra_params=self.extra_params, - ) - - # Step 3: Determine the sample's state and act accordingly. - group_state = determine_group_state(group_data_items) - self.logger.debug(f"Determined replay state for {action_id}: {group_state}") - if group_state == RolloutState.COMPLETED: - group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined] - if len(group_data_items) > 0: - await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined] - else: - self.filtered_samples_count += 1 - self.logger.debug(f"Worker task completed successfully for {action_id}.") - elif group_state == RolloutState.ABORTED: - await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined] - self.logger.debug(f"Adding aborted sample {action_id} to aborted storage") - elif group_state == RolloutState.SKIPPED: - self.skipped_sample_count += 1 - self.logger.info(f"Total skipped samples count: {self.skipped_sample_count}") - elif group_state == RolloutState.FAILED: - self.failed_sample_count += 1 - self.logger.info(f"Total failed samples count: {self.failed_sample_count}") - else: - self.logger.error(f"Unexpected group state '{group_state}' for action_id {action_id}.") - - return time.perf_counter() - task_start_time - - async def concurrent_task_runner(self): - """Orchestrates the concurrent execution of worker tasks. - - This method manages a pool of asynchronous worker tasks to collect a - target number of samples (`self.target_batch_size`). It dynamically - adjusts the number of concurrent tasks based on progress and a - staleness threshold, ensuring efficient data generation. - - The process is as follows: - 1. Initializes a set of worker tasks, potentially over-provisioning - based on `self.config.staleness_threshold` to account for - variability in task completion times. - 2. Enters a main loop that continues until `target_batch_size` - samples are collected. - 3. Inside the loop, it periodically checks the number of pending - tasks and launches new ones if the current number is insufficient - to meet the target, maintaining a steady flow of data generation. - 4. Uses `asyncio.wait` with a short timeout to efficiently monitor - for completed tasks without blocking execution. - 5. A progress bar (`tqdm`) is updated as samples are collected. - 6. Once `target_batch_size` is reached, it sends a pause/abort - signal to all rollout workers to prevent them from starting new - computations. - 7. It then waits for any remaining in-flight tasks to complete, with - a configurable timeout to prevent indefinite hanging. Tasks that - do not finish within the timeout are forcefully cancelled. - """ - waiting_tasks = set() - dataflow_start_time = time.perf_counter() - task_completion_times = [] - with tqdm(total=self.target_batch_size, desc="rollout_controller for training samples", miniters=10) as pbar: - last_pbar_n = self.finished_samples_count - init_finished_samples_count = self.finished_samples_count - - if self.sample_from_expired_storage: - # 如果是从过期的存储中采样数据,需要禁用staleness_threshold - data_concurrency = self.target_batch_size - self.finished_samples_count - self.logger.info( - f"Sampling from expired storage, starting {data_concurrency} worker tasks from expired samples." - ) - else: - data_concurrency = math.ceil( - (1 + self.staleness_threshold) * (self.target_batch_size - self.finished_samples_count) - ) - self.logger.info( - f"Starting dataflow concurrent task runner with data_concurrency: {data_concurrency}, target_batch_size: {self.target_batch_size}, finished_samples_count: {self.finished_samples_count}, staleness_threshold: {self.staleness_threshold}" - ) - - for _ in range(data_concurrency): - task = create_task(self.worker_task()) - waiting_tasks.add(task) - - while ( - self.finished_samples_count < self.target_batch_size - and self.failed_sample_count < self.target_batch_size - and self.skipped_sample_count < self.target_batch_size - ): - done_tasks, pending_tasks = await asyncio.wait( - waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED - ) - - for done_task in done_tasks: - task_time = done_task.result() - task_completion_times.append(task_time) - - self.finished_samples_count = await self.replay_buffer.get_completed_samples_count.remote() - pbar.update(self.finished_samples_count - last_pbar_n) - last_pbar_n = self.finished_samples_count - - waiting_tasks = pending_tasks - - while ( - len(waiting_tasks) + self.finished_samples_count < data_concurrency + init_finished_samples_count - ): - # 当存在被filter掉的样本时,需要补数据 - task = create_task(self.worker_task()) - waiting_tasks.add(task) - - pbar.n = self.finished_samples_count - pbar.refresh() - - if self.finished_samples_count >= self.target_batch_size: - self.logger.info( - f"Target batch size {self.target_batch_size} reached with finished_samples_count: {self.finished_samples_count}." - ) - elif self.skipped_sample_count >= self.target_batch_size: - self.logger.info( - f"Stopping data generation as skipped samples {self.skipped_sample_count} reached target batch size {self.target_batch_size}." - ) - elif self.failed_sample_count >= self.target_batch_size: - self.logger.info( - f"Stopping data generation as failed samples {self.failed_sample_count} reached target batch size {self.target_batch_size}." - ) - generation_time = time.perf_counter() - dataflow_start_time - pause_start_time = time.perf_counter() - - if len(waiting_tasks) > 0: - self.logger.info(f"Start pausing env controller for remaining worker tasks {len(waiting_tasks)}.") - await self.pause() - cleanup_start_time = time.perf_counter() - while len(waiting_tasks) > 0: - elapsed_time = time.perf_counter() - cleanup_start_time - if elapsed_time > self.cleanup_task_time: - self.logger.warning( - f"Cleanup timeout of {self.cleanup_task_time}s reached. " - f"Forcefully cancelling {len(waiting_tasks)} remaining tasks." - ) - for task in waiting_tasks: - task.cancel() - # Wait for cancellations to complete - await asyncio.gather(*waiting_tasks, return_exceptions=True) - break # Exit the cleanup loop - # NOTE: Keep sending pause requests because the inference engine only marks currently running requests as aborted. - # When a waiting request starts running, it still needs another pause request to be marked as aborted. - _, pending_tasks = await asyncio.wait(waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED) - if len(pending_tasks) > 0: - await self.pause() - await asyncio.sleep(1) - self.logger.debug( - f"Waiting for {len(pending_tasks)} remaining worker tasks to complete after pausing env controller." - ) - waiting_tasks = pending_tasks - self.logger.info("All worker tasks have completed after pausing env controller.") - - pause_time = time.perf_counter() - pause_start_time - dataflow_time = time.perf_counter() - dataflow_start_time - self.logger.info( - f"dataflow task finished, generation_time: {generation_time:.2f}s, pause_time: {pause_time:.2f}s, total_time: {dataflow_time:.2f}s" - ) - self.tb_metrics["time/generation_time"] = generation_time - self.tb_metrics["time/pause_time"] = pause_time - - task_completion_dict = self._log_task_completion_stats(task_completion_times, "Task Completion Time Stats:\n") - for k, v in task_completion_dict.items(): - self.tb_metrics[f"task_time/{k}"] = v - - @ray_method - async def pause(self, timeout: float = 60.0): - """Asynchronously sends abort requests to all rollout workers.""" - if not self.worker_url_list: - self.logger.info("No active rollout workers to pause.") - return - - async with httpx.AsyncClient() as client: - tasks = [self._send_abort_request(client, url, timeout=timeout) for url in self.worker_url_list] - results = await asyncio.gather(*tasks) - - failed_workers = [url for url, success in results if not success] - succeeded_count = len(self.worker_url_list) - len(failed_workers) - - if failed_workers: - self.logger.warning( - f"Abort requests completed. Succeeded: {succeeded_count}, " - f"Failed: {len(failed_workers)}. Failed workers: {failed_workers}" - ) - else: - self.logger.info(f"All {succeeded_count} abort requests sent successfully.") - - @ray_method - async def run( - self, - num: Optional[int] = None, - sample_params: Optional[SampleParams] = None, - extra_params: Optional[Dict] = None, - staleness_threshold: Optional[float] = None, - ) -> DataFlowResult: - """Starts the data generation process. - - This method resets the internal state and runs the concurrent task - runner to collect a new batch of samples from the environment. - - Args: - num (Optional[int]): The target number of samples to collect for this run. - Overrides the existing global_batch_size in DataFlowConfig if provided. - sample_params (Optional[SampleParams]): Parameters for model sampling. - Overrides the existing sample_params in DataFlowConfig if provided. - extra_params (Optional[Dict]): Additional parameters for rollout. - Overrides the existing extra_params in DataFlowConfig if provided. - enable_partial_rollout (Optional[bool]): Whether to enable partial rollout mode. - This is primarily intended for unit testing, allowing the dataflow to pause - and resume partway through a rollout for checkpointing and recovery tests.Returns: - Returns: - List[RLDataFlowItem]: A list of collected training samples. - """ - self._reset_internal_states( - global_batch_size=num, - sample_params=sample_params, - extra_params=extra_params, - staleness_threshold=staleness_threshold, - ) - self.logging_replaybuffer_state("DataFlow run started. ") - await self.concurrent_task_runner() - self.logging_replaybuffer_state("DataFlow run completed. ") - - get_start_time = time.perf_counter() - return_samples = await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined] - self.logger.info( - f"Getting {self.target_batch_size} samples from replay buffer took {time.perf_counter() - get_start_time:.2f}s" - ) - self.tb_metrics["time/get_samples_time"] = time.perf_counter() - get_start_time - dataflow_result = DataFlowResult( - data_groups=return_samples[0], - mm_train_infos=return_samples[1], - metrics=self.tb_metrics, - ) - return dataflow_result - - def logging_replaybuffer_state(self, logging_msg: Optional[str] = None): - status = self.get_replaybuffer_status() - logging_msg = logging_msg if logging_msg else "" - logging_msg += f"ReplayBuffer Status: {status}" - logging_msg += f", finished_samples_count: {self.finished_samples_count}, " - logging_msg += f"skipped_samples_count: {self.skipped_sample_count}, " - logging_msg += f"failed_samples_count: {self.failed_sample_count}, " - logging_msg += f"filtered_samples_count: {self.filtered_samples_count}, " - self.logger.info(logging_msg) - - def get_replaybuffer_status(self): - return ray.get(self.replay_buffer.status.remote()) - - async def _send_abort_request(self, client, url, timeout): - worker_url = f"{url}/abort_request" - try: - response = await client.post(worker_url, json={"abort_all": True}, timeout=timeout) - response.raise_for_status() - self.logger.debug(f"Successfully sent abort request to {url}") - return url, True - except Exception as e: - self.logger.error(f"Failed to send abort request to {url}: {e}") - return url, False - - def _log_task_completion_stats(self, task_times: List[float], logger_msg: Optional[str] = None): - if not task_times: - self.logger.info("No task completion times to report.") - return {} - - import numpy as np - - stats_dict = { - "p50": np.percentile(task_times, 50), - "p90": np.percentile(task_times, 90), - "p95": np.percentile(task_times, 95), - "p99": np.percentile(task_times, 99), - "max": np.max(task_times), - "avg": np.mean(task_times), - "std": np.std(task_times), - } - stats_dict["p99_p50_ratio"] = stats_dict["p99"] / stats_dict["p50"] if stats_dict["p50"] > 0 else float("inf") - - task_completions_report = ( - f" - Avg Time: {stats_dict['avg']:.2f}s, Std: {stats_dict['std']:.2f}s\n" - f" - P50 (Median): {stats_dict['p50']:.2f}s, P90: {stats_dict['p90']:.2f}s, P95: {stats_dict['p95']:.2f}s, P99: {stats_dict['p99']:.2f}s\n" - f" - Max Time: {stats_dict['max']:.2f}s, Ratio (P99 / P50): {stats_dict['p99_p50_ratio']:.2f}\n" - ) - logger_msg = logger_msg if logger_msg else "" - logger_msg += task_completions_report - self.logger.info(logger_msg) - return stats_dict - - def save(self, save_path: Path | str): - """Saves the replay buffer to the specified path. - - Args: - save_path (str): The path to the checkpoint file to save to. - """ - ray.get(self.replay_buffer.save.remote(save_path)) - - def resume(self, resume_path: Path | str): - """Resumes the replay buffer from the specified path. - - Args: - resume_path (str): The path to the checkpoint file to resume from. - """ - ray.get(self.replay_buffer.resume.remote(resume_path)) - - -DataFlow = ray.remote(RawDataFlow) -DataFlowProxy = ActorProxy[RawDataFlow] diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py deleted file mode 100644 index 4f524f3ddf..0000000000 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ /dev/null @@ -1,970 +0,0 @@ -import itertools -import time -from collections import defaultdict -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from uuid import uuid4 - -import numpy -import ray -import torch -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, Field -from ray import ObjectRef -from typing_extensions import Annotated - -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast -from xtuner.v1.data_proto.rl_data import ( - MultimodalTrainInfo, - RLDataFlowItem, - RLDatasetItem, - RLEnvDataItem, - RLExtraDataItem, - RLUIDItem, - RolloutState, - is_valid_for_replaybuffer, -) -from xtuner.v1.datasets.config import DataloaderConfig -from xtuner.v1.utils import get_logger -from xtuner.v1.utils.device import get_device - - -DEVICE = get_device() -logger = get_logger() - - -@dataclass -class ReplayMeta: - """ReplayMeta aggregates all versions of data related to a single prompt in - the replay buffer. - - Attributes: - env (str): Name or identifier of the environment. - root_id (int): Identifier for grouping related prompts (e.g., for GRPO or multi-turn scenarios). - action_id (int): Unique identifier for the prompt. If the prompt changes (such as in a multi-turn scenario), a new action_id is assigned. - action_ref (ObjectRef): Ray object reference to the prompt data (corresponds to RLDatasetItem in RLDataFlowItem). - observation_ids (List[int]): IDs for different responses to the same prompt. Each response has a unique observation_id. - observation_refs (List[ObjectRef]): Ray object references to environment data for each observation (corresponds to RLEnvDataItem in RLDataFlowItem). - observation_versions (List[int]): Version numbers for each observation, supporting async rollout. - state (str): Overall state of the prompt (e.g., "paused" for partial rollout, or other rollout states). - extra_info (Dict[str, Any]): Additional metadata or information. - """ - - env: str = "" - root_id: int = 0 - action_id: int = 0 # same prompt share the same action_id - action_ref: ObjectRef = None - observation_ids: List[int] = field(default_factory=list) - observation_refs: List[ObjectRef] = field(default_factory=list) - observation_versions: List[int] = field(default_factory=list) # 目前发数据为按组下发,暂时用不到 - state: RolloutState = RolloutState.INIT - version: int = 0 # version for partial rollout - extra_info: Dict[str, Any] = field(default_factory=dict) - - -def determine_group_state(group_data_items: List[RLDataFlowItem]) -> RolloutState: - """Determines the processing strategy for a group of rollout samples based - on their state.""" - # TODO(@duanyanhui): remove this function when send one request instead of group requests. - if not group_data_items: - return RolloutState.SKIPPED - group_states = {item.env.rollout.state for item in group_data_items} - if RolloutState.SKIPPED in group_states: - return RolloutState.SKIPPED - elif RolloutState.FAILED in group_states: - return RolloutState.FAILED - elif RolloutState.ABORTED in group_states: - return RolloutState.ABORTED - elif all(state == RolloutState.COMPLETED for state in group_states): - return RolloutState.COMPLETED - else: - raise ValueError(f"Unknown group states: {group_states}") - - -def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> ReplayMeta: - assert len(grouped_dataitem) > 0 - - env_str = grouped_dataitem[0].uid.env - root_id = grouped_dataitem[0].uid.root_id - action_id = grouped_dataitem[0].uid.action_id - # !!! 注意:这里放的是第一个dataitem的data,因为一组数据的data是一样的 !!! - data = grouped_dataitem[0].data - # 现在是按组发送,那么一组里的dataitem的version是一样的,如果一组中的数据在某次rollout step中没有生成的数据,version也还是会+1 - group_version = grouped_dataitem[0].uid.version - observation_ids = [] - observation_refs = [] - - for item in grouped_dataitem: - observation_ids.append(item.uid.observation_id) - observation_refs.append(ray.put(item.env)) - - group_state = determine_group_state(grouped_dataitem) - logger.debug( - f"Mapping data items to ReplayMeta {action_id} with group_state: {group_state}, group_version: {group_version}" - ) - - replay_meta = ReplayMeta( - env=env_str, - root_id=root_id, - action_id=action_id, - action_ref=ray.put(data), - observation_ids=observation_ids, - observation_refs=observation_refs, - state=group_state, - version=group_version, - extra_info={}, - ) - return replay_meta - - -def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta) -> List[RLDataFlowItem]: - env_str = replay_meta.env - root_id = replay_meta.root_id - action_id = replay_meta.action_id - data_ref = ray.get(replay_meta.action_ref) - group_data_item = [] - for obs_id, obs_ref in zip(replay_meta.observation_ids, replay_meta.observation_refs): - env_data = ray.get(obs_ref) - # NOTE: This mapping function used by both dump and get. ObjectRefs are kept during dump (for training continuity) - # but released during get (via del replaymeta) when no longer needed. So we do not free them manually here. - # ray._private.internal_api.free(obs_ref) - - item = RLDataFlowItem( - uid=RLUIDItem( - env=env_str, root_id=root_id, action_id=action_id, observation_id=obs_id, version=replay_meta.version - ), - data=data_ref, - env=env_data, - extra_info=RLExtraDataItem(), - ) - group_data_item.append(item) - return group_data_item - - -class ReplayBufferConfig(BaseModel): - """Replay buffer configuration for XTuner. - - This class defines configuration parameters for the replay buffer system in XTuner, - managing dataset handling, data loading, text processing, and post-processing - operations for reinforcement learning experience replay. - - Args: - dataset_cfg (List): Configuration for datasets used to sample initial prompts. - dataloader_cfg (DataloaderConfig): Configuration for the PyTorch DataLoader - that iterates over the dataset. - tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast): Tokenizer for - processing text data, including support for partial rollouts. - postprocessor_func (Optional[Callable]): Optional function to filter or - modify data groups after generation. Defaults to None. - replay_ratio (float): Ratio of samples to replay from the buffer versus - sampling new data. Defaults to 0. - replay_weights (dict): Weights for different states in the replay buffer - to control sampling priorities. Defaults to empty dict. - - **Examples:** - - Example configuration for ReplayBuffer with GSM8K dataset config and base dataloader config:: - - from transformers import AutoTokenizer - - config = ReplayBufferConfig( - dataset_cfg=[{ - "dataset": DatasetConfig(name="gsm8k", anno_path="path/to/data"), - "tokenize_fn": RLTokenizeFnConfig(max_length=512) - }], - dataloader_cfg=DataloaderConfig(collator='fake_collator'), - tokenizer=AutoTokenizer.from_pretrained("model_path"), - postprocessor_func=None, - ) - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - - dataset_cfg: Annotated[List, Parameter(help="The dataset object to sample initial prompts from.")] - - dataloader_cfg: Annotated[ - Optional[DataloaderConfig], Parameter(help="The PyTorch DataLoader for iterating over the dataset.") - ] = None - - tokenizer: Annotated[ - Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str], - Field(exclude=True), - Parameter(help="The tokenizer for processing text data, e.g., for partial rollouts."), - ] - postprocessor_func: Annotated[ - Optional[Callable], - Field(exclude=True), - Parameter(help="An optional function to filter or modify data groups after they are generated."), - ] = None - replay_ratio: Annotated[ - float, - Parameter(help="Ratio of samples to replay from the buffer."), - ] = 0 - replay_weights: Annotated[ - dict, - Parameter(help="Weights for different states in the replay buffer."), - ] = {} - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - - -class DatasetSampler: - """Sampler for drawing new prompts from the configured dataset. - - This class is responsible for building a dataloader from the provided dataset configurations and sampling fresh - data prompts upon request. - """ - - def __init__(self, dataset_cfg, dataloader_cfg, tokenizer): - """Initializes the DatasetSampler. - - Args: - dataset_cfg (List): Configuration for the datasets to sample from. - dataloader_cfg (Optional[DataloaderConfig]): Configuration for the - PyTorch DataLoader. - tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str]): - The tokenizer for processing text data. Can be a path or an object. - """ - self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - if isinstance(tokenizer, str): - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) - else: - self.tokenizer = tokenizer - if dataloader_cfg is not None: - self.dataloader_cfg = dataloader_cfg - self.dataloader_cfg.dataset_config_list = dataset_cfg - else: - self.dataloader_cfg = DataloaderConfig( - dataset_config_list=dataset_cfg, - collator="fake_collator", - pack_level="none", - num_workers=1, - ) - self.dataloader = self.dataloader_cfg.build( - tokenizer=self.tokenizer, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1 - ) - self.dataloader_iter = iter(self.dataloader) - self.cur_epoch = 0 - self.reduced_consumed_samples = 0 - self.logger = get_logger() - - def sample(self, env: str, prompt_repeat_k: int) -> List[RLDataFlowItem]: - """Samples a new prompt from the dataset and prepares it as a group. - - This method fetches the next item from the dataloader, assigns new - unique IDs (root_id, action_id, observation_id), and formats it into - a list of RLDataFlowItem objects, repeated `prompt_repeat_k` times. - - Args: - env (str): The environment name to be associated with the new samples. - prompt_repeat_k (int): The number of times to repeat the sampled - prompt in the returned group. - - Returns: - List[RLDataFlowItem]: A list of newly created data items for a rollout. - """ - root_id = uuid4().int - action_id = uuid4().int - group_data_item: List[RLDataFlowItem] = [RLDataFlowItem() for _ in range(prompt_repeat_k)] - try: - data = next(self.dataloader_iter)[0] - except StopIteration: - self.cur_epoch += 1 - self.dataloader.set_epoch(self.cur_epoch) - self.dataloader_iter = iter(self.dataloader) - data = next(self.dataloader_iter)[0] - self.reduced_consumed_samples += 1 - - multimodal_train_info = data.pop("multimodal_train_info", {}) - if "pixel_values" in multimodal_train_info: - multimodal_train_info["pixel_values"] = ray.put(multimodal_train_info["pixel_values"]) - data["multimodal_train_info"] = multimodal_train_info - - for data_item in group_data_item: - data_item.uid = RLUIDItem( - env=env, - root_id=root_id, - action_id=action_id, - observation_id=uuid4().int, - ) - data_item.data = RLDatasetItem(**data) - data_item.extra_info = RLExtraDataItem(retry_times=0) - self.logger.debug(f"Sampling new prompt with action_id: {action_id} in env: {env}") - return group_data_item - - def resume(self, dataloader_path): - dataloader_state = torch.load(dataloader_path, map_location=DEVICE) - self.dataloader.load_state_dict(dataloader_state) - self.dataloader_iter = iter(self.dataloader) - self.reduced_consumed_samples = dataloader_state["sampler"]["step"] - self.cur_epoch = dataloader_state["sampler"]["epoch"] - - -class ReplayBufferStorage: - """Handles the storage of experiences for the replay buffer.""" - - def __init__(self, replay_buffer_cfg): - """Initializes the data structures for storing replay data.""" - self.enable_partial_rollout: bool = False - self.tail_batch_candidate_steps: int = 0 - self.tail_batch_trigger_size: int = 0 - - self._completed_actions: Dict[int, List[int]] = defaultdict(list) - self._aborted_actions: Dict[int, List[int]] = defaultdict(list) - self._expired_actions: List[int] = [] - self._actions: Dict[int, ReplayMeta] = {} - self._root2actions: Dict[int, List[int]] = {} - self._observations: Dict[int, ReplayMeta] = {} - self._observations2states: Dict[int, str] = {} - self._states: Dict[str, List[int]] = defaultdict(list) - self._action2observations: Dict[int, List[int]] = defaultdict(list) - self._multimodal_train_infos: Dict[int, Dict[str, Any]] = {} - self.logger = get_logger(log_dir=replay_buffer_cfg.worker_log_dir, tag="ReplayBuffer") - self.sample_from_aborted_count = 0 - self.sample_from_expired_count = 0 - - def add(self, grouped_dataitem: List[RLDataFlowItem]): - """Adds a group of data items to the storage. - - Args: - grouped_dataitem (List[RLDataFlowItem]): A list of data items - belonging to the same group. - """ - if ( - grouped_dataitem is None - or len(grouped_dataitem) == 0 - or is_valid_for_replaybuffer(grouped_dataitem) is False - ): - return - - replay_meta = mapping_dataitem_to_replaymeta(grouped_dataitem) - root_id = replay_meta.root_id - action_id = replay_meta.action_id - - # 1. 更新版本 - if root_id in self._root2actions: - # TODO: 考虑到非共卡的情况,version是否更新需要根据是否update_weights来判断 - replay_meta.version += 1 - self._root2actions[root_id].append(action_id) - self.logger.debug( - f"Existing root_id: {root_id} with action_id {action_id} found. Incrementing version to {replay_meta.version}." - ) - else: - self._root2actions[root_id] = [action_id] - self._actions[action_id] = replay_meta - - # 2. 根据rollout_state更新completed/aborted/expired相关映射 - self._check_rollout_state_and_insert(replay_meta) - - # 3. 更新observations相关映射 - for observation_id in replay_meta.observation_ids: - self._observations[observation_id] = replay_meta - self._observations2states[observation_id] = replay_meta.state - if observation_id not in self._action2observations[action_id]: - self._action2observations[action_id].append(observation_id) - if observation_id not in self._states[replay_meta.state]: - self._states[replay_meta.state].append(observation_id) - - def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[MultimodalTrainInfo | None]]: - """Retrieves a batch of finished sample groups from the buffer. - - Args: - global_batch_size (int): The number of sample groups to retrieve. - - Raises: - ValueError: If there are not enough finished samples in the buffer - to meet the `global_batch_size`. - - Returns: - List[List[RLDataFlowItem]]: A list of sample groups. Each inner - list contains a group of data items that were generated from the - same initial prompt, repeated `repeat_prompt_k` times. - """ - samples = [] - multimodal_train_infos = [] - target_batch_size = min(global_batch_size, self.completed_samples_count) - self.logger.info(f"Retrieving {target_batch_size} completed samples from the replay buffer.") - task_time = [] - for _ in range(target_batch_size): - task_start_time = time.perf_counter() - action_id = self._pop_highest_version_action(self._completed_actions) - if action_id is None: - self.logger.warning("Get action_id None from completed_actions and skip this iteration.") - continue - replay_meta = self._actions.pop(action_id) - group_samples = mapping_replaymeta_to_dataitem(replay_meta) - # 将这条数据彻底清除,不用再记录root_id对应的action_ids了 - self._clear_meta_for_root(replay_meta) - multimodal_train_info = None - # TODO: 是否需要额外返回不重复的 multimodal_train_infos? - for data_item in group_samples: - if hasattr(data_item.data, "multimodal_train_info"): - multimodal_train_info = data_item.data.multimodal_train_info - del data_item.data.multimodal_train_info - if "partial_rollout_input_ids" in data_item.env.rollout.extra_info: - del data_item.env.rollout.extra_info["partial_rollout_input_ids"] - samples.append(group_samples) - multimodal_train_infos.append(multimodal_train_info) - task_end_time = time.perf_counter() - task_time.append(task_end_time - task_start_time) - # 检查completed_samples中是否还有剩余的数据,并且检查其是否过期 - avg_time = sum(task_time) / len(task_time) if len(task_time) > 0 else 0 - self.logger.info( - f"Remaining completed samples in buffer: {self.completed_samples_count}, task_time: {sum(task_time)}s, avg_time: {avg_time}s" - ) - self._check_completed_samples_expired() - self._check_completed_samples_aborted() - return samples, multimodal_train_infos - - def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]: - if sample_from_expired_states and self.expired_samples_count > 0: - self.sample_from_expired_count += 1 - return self._sample_from_expired_storage() - if self.aborted_samples_count > 0: - self.sample_from_aborted_count += 1 - return self._sample_from_aborted_storage() - return [] - - def clear(self): - attrs_to_clear = [ - "_aborted_actions", - "_completed_actions", - "_expired_actions", - "_actions", - "_root2actions", - "_observations", - "_observations2states", - "_states", - "_action2observations", - ] - for attr in attrs_to_clear: - getattr(self, attr).clear() - self.sample_from_aborted_count = 0 - self.sample_from_expired_count = 0 - - def resolve_ray_objects(self, data_item: RLDataFlowItem): - """Resolves ray.ObjectRefs in a RLDataFlowItem to their actual values. - - Args: - data_item (RLDataFlowItem): The data item containing ray.ObjectRefs. - Returns: - RLDataFlowItem: The data item with ray.ObjectRefs resolved. - """ - - # Resolve data.multimodal_train_info - if hasattr(data_item.data, "multimodal_train_info"): - multimodal_info = data_item.data.multimodal_train_info - if multimodal_info and "pixel_values" in multimodal_info: - pixel_values_ref = multimodal_info["pixel_values"] - if isinstance(pixel_values_ref, ObjectRef): - multimodal_info["pixel_values"] = ray.get(pixel_values_ref) - data_item.data.multimodal_train_info = multimodal_info - # Resolve rollout.extra_info.router_experts - if "routed_experts" in data_item.env.rollout.extra_info: - if isinstance(data_item.env.rollout.extra_info["routed_experts"], ObjectRef): - data_item.env.rollout.extra_info["routed_experts"] = ray.get( - data_item.env.rollout.extra_info["routed_experts"] - ) - self.logger.info("Resolved routed_experts ObjectRef in rollout.extra_info") - - def convert_to_ray_objref(self, data_item: RLDataFlowItem): - """Converts large tensors in RLDataFlowItem to ray.ObjectRefs. - - Args: - data_item (RLDataFlowItem): The data item containing large tensors. - Returns: - RLDataFlowItem: The data item with large tensors converted to ray.ObjectRefs. - """ - # convert data.multimodal_train_info to ray.ObjectRef - if hasattr(data_item.data, "multimodal_train_info"): - multimodal_info = data_item.data.multimodal_train_info - if multimodal_info and "pixel_values" in multimodal_info: - # 一组数据共享同一个data_item.data,所以只需要转换一次 - if not isinstance(multimodal_info["pixel_values"], ray.ObjectRef): - pixel_values_ref = ray.put(multimodal_info["pixel_values"]) - del multimodal_info["pixel_values"] - data_item.data.multimodal_train_info["pixel_values"] = pixel_values_ref # type: ignore[index] - # convert rollout.extra_info.router_experts to ray.ObjectRef - if "routed_experts" in data_item.env.rollout.extra_info: - routed_experts_ref = ray.put(data_item.env.rollout.extra_info["routed_experts"]) - del data_item.env.rollout.extra_info["routed_experts"] - data_item.env.rollout.extra_info["routed_experts"] = routed_experts_ref - - def has_objectref(self, item: RLDataFlowItem) -> bool: - def check(obj): - if isinstance(obj, ray.ObjectRef): - return True - if isinstance(obj, BaseModel): - return any(check(getattr(obj, f)) for f in obj.model_fields) - if isinstance(obj, (list, tuple, set)): - return any(check(x) for x in obj) - if isinstance(obj, dict): - return any(check(v) for v in obj.values()) - if isinstance(obj, (str, int, float, bool, type(None), torch.Tensor, numpy.ndarray)): - return False - # 如果不满足以上类型,抛出错误,防止意想不到的问题 - raise TypeError( - f"Unsupported type: {type(obj)} in {obj} " - f"Expected ray.ObjectRef, BaseModel, list/tuple/set, dict, or primitive types." - ) - - return check(item) - - def dump(self, file_path: Path): - """Dumps the entire state of the replay buffer storage to a single - file, resolving all ray.ObjectRefs to their actual values. - - Args: - file_path (str): The path to the file where the state will be - saved. - """ - all_data_items = [mapping_replaymeta_to_dataitem(replay_meta) for replay_meta in self._actions.values()] - - for data_items in all_data_items: - for item in data_items: - self.resolve_ray_objects(item) - res = self.has_objectref(item) - assert not res, "ReplayBufferStorage.dump found unresolved ray.ObjectRef in RLDataFlowItem" - - state = { - "_completed_actions": self._completed_actions, - "_aborted_actions": self._aborted_actions, - "_expired_actions": self._expired_actions, - "_actions": all_data_items, - "_root2actions": dict(self._root2actions), - "_observations2states": self._observations2states, - "_states": dict(self._states), - "_action2observations": dict(self._action2observations), - } - - torch.save(state, file_path) - self.logger.info(f"ReplayBufferStorage state dumped to {file_path}") - - def resume(self, file_path: Path): - """Resumes the replay buffer storage from a single file. - - Args: - file_path (str): The path to the file from which to restore the - state. - """ - - self.clear() - - state = torch.load(file_path, map_location="cpu", weights_only=False) - - self._completed_actions = state["_completed_actions"] - self._aborted_actions = state["_aborted_actions"] - self._expired_actions = state["_expired_actions"] - self._root2actions = defaultdict(list, state["_root2actions"]) - self._observations2states = state["_observations2states"] - self._states = defaultdict(list, state["_states"]) - self._action2observations = defaultdict(list, state["_action2observations"]) - - dump_actions = state["_actions"] - # 重建 _actions 和 _observations: 与replaymeta相关 - for group_dataitem in dump_actions: - for data_item in group_dataitem: - self.convert_to_ray_objref(data_item) - replay_meta = mapping_dataitem_to_replaymeta(group_dataitem) - action_id = replay_meta.action_id - self._actions[action_id] = replay_meta - for observation_id in self._action2observations[action_id]: - self._observations[observation_id] = replay_meta - - self.logger.info(f"ReplayBufferStorage state successfully resumed from {file_path}") - self.logger.info( - f"ReplayBuffer Storage status: completed_samples_count={self.completed_samples_count}, aborted_samples_count={self.aborted_samples_count}, expired_samples_count={self.expired_samples_count}" - ) - - @property - def completed_samples_count(self) -> int: - return sum(len(bucket) for bucket in self._completed_actions.values()) - - @property - def aborted_samples_count(self): - return sum(len(bucket) for bucket in self._aborted_actions.values()) - - @property - def expired_samples_count(self): - return len(self._expired_actions) - - def _sample_from_expired_storage(self) -> List[RLDataFlowItem]: - """Samples an item from the expired storage for re-rollout. - - This method takes an action_id from the expired queue, retrieves its - original prompt data, cleans up all its previous rollout outputs, and - prepares it as a new sample group with a fresh action_id and reset - version (0) to be sent for a new generation attempt. - - Returns: - List[RLDataFlowItem]: A list of data items ready for a new rollout. - """ - assert len(self._expired_actions) > 0 - action_id = self._expired_actions.pop() - replay_meta = self._actions.pop(action_id) - group_samples = mapping_replaymeta_to_dataitem(replay_meta) - # 把这条数据上次的记录全部删掉,重新开始rollout,root2actions也要清除 - self._clear_meta_for_root(replay_meta) - - for sample in group_samples: - assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!" - if "routed_experts" in sample.env.rollout.extra_info: - ray._private.internal_api.free(sample.env.rollout.extra_info["routed_experts"]) - del sample.env.rollout.extra_info["routed_experts"] - del sample.env - sample.env = RLEnvDataItem() # 重置env数据 - sample.uid.action_id = action_id - sample.uid.version = 0 - - self.logger.debug( - f"Sampling expired action_id: {action_id} from replay buffer, remain expired samples: {len(self._expired_actions)}" - ) - return group_samples - - def _sample_from_aborted_storage(self) -> List[RLDataFlowItem]: - """Samples an item from the aborted storage for re-rollout. - - This method retrieves an action with the highest version (oldest version) from the - aborted buckets. It then cleans up its previous (aborted) rollout - outputs and prepares it as a new sample group with a fresh action_id. - The original version number is preserved to track its retry history. - - Returns: - List[RLDataFlowItem]: A list of data items ready for a new rollout. - """ - assert self.aborted_samples_count > 0 - action_id = self._pop_highest_version_action(self._aborted_actions) - # 通过self.aborted_samples_count判断过这里不会返回None - replay_meta = self._actions.pop(action_id) # type: ignore[arg-type] - replay_meta_version = replay_meta.version - group_samples = mapping_replaymeta_to_dataitem(replay_meta) - # 把这条数据上次rollout产生的输出的记录都删掉,上次的数据已经记录在了RLEnvDataItem中了 - self._clear_meta_for_actions(replay_meta) - - sample_action_id = uuid4().int - for sample in group_samples: - assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!" - if not self.enable_partial_rollout: - # 清除上次的response_ids等env数据 - if "routed_experts" in sample.env.rollout.extra_info: - ray._private.internal_api.free(sample.env.rollout.extra_info["routed_experts"]) - del sample.env.rollout.extra_info["routed_experts"] - del sample.env - sample.env = RLEnvDataItem() - sample.uid.version = 0 - sample.uid.action_id = action_id if action_id is not None else sample_action_id - else: - # 将异步的逻辑尽量放在replay buffer中处理,尽量不在env/rollout中进行处理 - history_response_ids = list(itertools.chain.from_iterable(sample.env.rollout.versioned_response_ids)) - sample.env.rollout.extra_info["partial_rollout_input_ids"] = ( - sample.data.input_ids + history_response_ids - ) - self.logger.debug( - f"partial rollout enabled, {sample_action_id} pass response_ids {len(history_response_ids)} to input_ids {len(sample.data.input_ids)} to data extra info when sampling." - ) - sample.uid.version = replay_meta_version - sample.uid.action_id = int(sample_action_id) - - self.logger.debug( - f"Sampling aborted action_id: {sample_action_id}, root_id: {group_samples[0].uid.root_id} from replay buffer, remain aborted samples: {self.aborted_samples_count}" - ) - return group_samples - - def _pop_highest_version_action(self, buckets: Dict[int, List[int]]) -> Optional[int]: - if not buckets: - return None - - highest_version = sorted(buckets.keys())[-1] - action_list = buckets[highest_version] - action_id = action_list.pop() - if not action_list: - del buckets[highest_version] - - return action_id - - def _check_completed_samples_expired(self): - """Moves samples from completed buckets to the expired list if they are - too old after get target completed samples from replay buffer. - - This method iterates through the `_completed_actions` buckets. If a - bucket's version index is greater than or equal to the configured - `tail_batch_candidate_steps`, all action_ids within that bucket are - moved to the `_expired_actions` list, marking them as expired. - """ - if self.tail_batch_candidate_steps <= 0: - return - - expired_versions = [v for v in self._completed_actions if v >= self.tail_batch_candidate_steps] - - for version in expired_versions: - bucket = self._completed_actions.pop(version) - self._expired_actions.extend(bucket) - self.logger.info( - f"Moved {len(bucket)} completed samples with version {version} to expired samples due to exceeding tail_batch_candidate_steps." - ) - - def _check_completed_samples_aborted(self): - if self.enable_partial_rollout: - return - - for version, bucket in self._completed_actions.items(): - self._aborted_actions[0].extend(bucket) - self.logger.info( - f"Moved {len(bucket)} completed samples with version {version} to aborted samples due to partial rollout disabled." - ) - self._completed_actions.clear() - - def _clear_meta_for_actions(self, replay_meta: ReplayMeta): - """Completely removes an action and all its associated data from the - storage. - - This is the single source of truth for deleting an action. - """ - action_id = replay_meta.action_id - - for observation_id in replay_meta.observation_ids: - self._observations.pop(observation_id, None) - state = self._observations2states.pop(observation_id, None) - if state and observation_id in self._states.get(state, []): - self._states[state].remove(observation_id) - - self._action2observations.pop(action_id, None) - del replay_meta - - def _clear_meta_for_root(self, replay_meta: ReplayMeta): - """Clears all actions and associated metadata linked to the same - root_id. - - This function is called after a sample group is successfully retrieved - for training. It ensures that all historical versions of a prompt - (linked by root_id) are purged from the storage to prevent them from - being re-sampled or replayed. - - Args: - replay_meta (ReplayMeta): The metadata of the action that was just - retrieved. The root_id from this object will be used to find - and clear all related actions. - """ - root_id = replay_meta.root_id - if root_id in self._root2actions: - for action_id in self._root2actions[root_id]: - new_replay_meta = self._actions.pop(action_id, None) - if new_replay_meta: - self._clear_meta_for_actions(new_replay_meta) - del self._root2actions[root_id] - del replay_meta - - def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): - """Checks the rollout state of a ReplayMeta object and inserts its - action_id into the appropriate state bucket. - - This method acts as a router, directing action_ids to different storage - lists (_aborted_actions, _completed_actions, _expired_actions) based on - their final rollout state and version. It also handles the logic for - when an aborted sample becomes expired due to too many retries. - - Args: - replay_meta (ReplayMeta): The metadata object containing the final - state and version of a rollout action. - """ - state = replay_meta.state - root_id = replay_meta.root_id - action_id = replay_meta.action_id - - if state == RolloutState.ABORTED: - if self.tail_batch_candidate_steps > 0 and replay_meta.version >= self.tail_batch_candidate_steps: - # 过期的数据需要重置状态 - self._expired_actions.append(action_id) - self.logger.debug( - f"Add expired sample with action_id: {action_id} to _expired_actions because version: {replay_meta.version} >= tail_batch_candidate_steps: {self.tail_batch_candidate_steps}." - ) - else: - self._aborted_actions[replay_meta.version].append(action_id) - self.logger.debug( - f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _aborted_actions." - ) - elif state == RolloutState.COMPLETED: - self._completed_actions[replay_meta.version].append(action_id) - self.logger.debug(f"Add sample with root_id: {root_id}, action_id: {action_id} to finished_actions.") - else: - raise AssertionError( - f"Unsupported rollout state {state} for action_id {action_id} in ReplayBufferStorage." - ) - - -@ray.remote -class ReplayBuffer: - """A Ray actor that manages experience replay for reinforcement - learning.""" - - def __init__( - self, - config: ReplayBufferConfig, - ): - """Initializes the ReplayBuffer actor. - - Args: - config (ReplayBufferConfig): The configuration object. - """ - self.config = config - self.storage = ReplayBufferStorage(config) - self.sampler = DatasetSampler(config.dataset_cfg, config.dataloader_cfg, config.tokenizer) - self.post_processor_func = config.postprocessor_func - self.sample_from_expired_states = False - self.sample_from_dataset_count = 0 - self.logger = get_logger(log_dir=config.worker_log_dir, tag="ReplayBuffer") - - def setup_storage_config( - self, enable_partial_rollout: bool, tail_batch_candidate_steps: int, tail_batch_trigger_size: int - ): - """Sets up the storage configuration for the replay buffer. - - Args: - enable_partial_rollout (bool): Flag to enable partial rollouts. - tail_batch_candidate_steps (int): Number of steps to consider for - tail batch sampling. - tail_batch_trigger_size (int): Threshold size to trigger tail batch - sampling. - """ - self.storage.enable_partial_rollout = enable_partial_rollout - self.storage.tail_batch_candidate_steps = tail_batch_candidate_steps - self.storage.tail_batch_trigger_size = tail_batch_trigger_size - - def get_prerun_state(self) -> Tuple[bool, int]: - if ( - self.storage.tail_batch_trigger_size > 0 - and self.storage.expired_samples_count > self.storage.tail_batch_trigger_size - ): - self.sample_from_expired_states = True - self.logger.info( - f"Enable sampling from expired states. Expired samples: {self.storage.expired_samples_count}, threshold: {self.storage.tail_batch_trigger_size}" - ) - else: - self.sample_from_expired_states = False - return self.sample_from_expired_states, self.storage.completed_samples_count - - def get_train_dataset_length(self): - """Returns the length of the training dataloader.""" - return len(self.sampler.dataloader) - - def post_processor(self, group_samples): - """Applies a post-processing function to a group of samples. - - Args: - group_samples: A list of samples to process. - - Returns: - The processed group of samples. - """ - if self.post_processor_func: - group_samples = self.post_processor_func(group_samples) - return group_samples - return group_samples - - def sample(self, env, prompt_repeat_k) -> List[RLDataFlowItem]: - """Samples a batch of experiences from the replay buffer. - - Args: - env: The environment name. - enable_partial_rollout (int): Flag to enable partial rollouts. - prompt_repeat_k (int): Number of times to repeat a prompt. - - Returns: - A list of sampled data items. - """ - storage_samples = self.storage.sample(self.sample_from_expired_states) - if storage_samples: - return storage_samples - else: - self.sample_from_dataset_count += 1 - return self.sampler.sample(env, prompt_repeat_k) - - def get_samples( - self, - global_batch_size: int, - ): - """Gets a batch of finished samples from the storage. - - Args: - global_batch_size (int): The number of sample groups to retrieve. - - Returns: - A list of sample groups. - """ - return self.storage.get(global_batch_size) - - def add(self, grouped_dataitem: List[RLDataFlowItem]): - """Adds a group of data items to the replay buffer storage. - - Args: - grouped_dataitem (List[RLDataFlowItem]): A list of data items - from the same group. - """ - self.storage.add(grouped_dataitem) - - def status(self): - return { - "remain_completed_samples_count": self.storage.completed_samples_count, - "remain_aborted_samples_count": self.storage.aborted_samples_count, - "remain_expired_samples_count": self.storage.expired_samples_count, - "sample_from_dataset_count": self.sample_from_dataset_count, - "sample_from_aborted_count": self.storage.sample_from_aborted_count, - "sample_from_expired_count": self.storage.sample_from_expired_count, - } - - def save(self, file_path: Path | str): - """Saves the replay buffer's storage to a file. - - Args: - file_path (str): The path to the file for saving the data. - """ - if isinstance(file_path, str): - file_path = Path(file_path) - - # save dataloader - dataloader_path = file_path / "dataloader" - dataloader_state = self.sampler.dataloader.get_state_dict(self.sampler.reduced_consumed_samples) - torch.save(dataloader_state, dataloader_path) - - # save storage - rb_storage_path = file_path / "replay_buffer_storage.pth" - self.storage.dump(rb_storage_path) - - def resume(self, file_path: Path | str): - """Resumes the replay buffer's storage from a file. - - Args: - file_path (str): The path to the file from which to restore the - state. - """ - if isinstance(file_path, str): - file_path = Path(file_path) - dataloader_path = file_path / "dataloader" - if dataloader_path.exists(): - self.sampler.resume(dataloader_path) - self.sample_from_dataset_count = self.sampler.reduced_consumed_samples - self.logger.info( - f"Dataloader state successfully resumed from {dataloader_path} and set to epoch {self.sampler.cur_epoch} and step {self.sampler.reduced_consumed_samples}." - ) - else: - self.logger.warning(f"Dataloader state file {dataloader_path} does not exist. Skipping dataloader resume.") - # resume storage - rb_storage_path = file_path / "replay_buffer_storage.pth" - if rb_storage_path.exists(): - self.storage.resume(rb_storage_path) - else: - self.logger.warning( - f"ReplayBufferStorage state file {rb_storage_path} does not exist. Skipping storage resume." - ) - - def get_completed_samples_count(self) -> int: - """Returns the count of completed samples in the replay buffer. - - Returns: - int: The number of completed samples. - """ - return self.storage.completed_samples_count - - def clear(self): - """Clears the replay buffer storage.""" - self.storage.clear() diff --git a/xtuner/v1/ray/environment/__init__.py b/xtuner/v1/ray/environment/__init__.py deleted file mode 100644 index 02112e66a7..0000000000 --- a/xtuner/v1/ray/environment/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .base_env import BaseEnvironment -from .single_turn_env import SingleTurnEnvironment, SingleTurnEnvironmentProxy diff --git a/xtuner/v1/ray/environment/base_env.py b/xtuner/v1/ray/environment/base_env.py deleted file mode 100644 index c565f8e958..0000000000 --- a/xtuner/v1/ray/environment/base_env.py +++ /dev/null @@ -1,234 +0,0 @@ -import os -from abc import ABC, abstractmethod -from typing import Any, List - -import ray - -from xtuner.v1.data_proto.rl_data import RLDataFlowItem -from xtuner.v1.utils import ray_method - - -class BaseEnvironment(ABC): - """The BaseEnvironment class provides a foundational structure for managing - rollout and judger controllers for single-turn generation or multi-turn - generation. - - This class is responsible for initializing the necessary controllers based on the provided - configurations and placement group. It defines abstract methods for generation and - execution, which must be implemented by subclasses. - - Args: - environment (str): The name or identifier of the environment. - rollout_pg (Any): The placement group for scheduling rollout Ray actors. - rollout_cfg (Any, optional): The configuration for the rollout controller. Defaults to None. - judger_pg (Any): The placement group for scheduling judger Ray actors. - Defaults to None indicates using the rollout_pg. - judger_cfg (Any, optional): The configuration for the judger controller. Defaults to None. - """ - - def __init__( - self, - environment: str, - rollout_pg: Any, - rollout_cfg: Any, - judger_pg: Any = None, - judger_cfg: Any = None, - rollout_controller=None, - judger_controller=None, - ): - # judger_pg = judger_pg if judger_pg else rollout_pg - self.environment = environment - if rollout_controller: - self.rollout_controller = rollout_controller - else: - self.rollout_controller = self.init_rollout_controller(rollout_cfg, rollout_pg) - if judger_controller: - self.judger_controller = judger_controller - else: - self.judger_controller = self.init_judger_controller(judger_cfg, judger_pg) - - def init_rollout_controller(self, rollout_cfg: Any, placement_group: Any): - """Initializes the rollout controller with the appropriate worker - backend. - - Based on the `rollout_cfg`, this method selects and initializes the corresponding - rollout worker (e.g., `LMDeployWorker` or `vLLMWorker`). It then creates and - returns a `RolloutController` to manage these workers. - - Args: - rollout_cfg (Any): The configuration for the rollout controller. - placement_group (Any): The placement group for scheduling Ray actors. - - Returns: - The initialized rollout controller, or None if `rollout_cfg` is not provided. - - Raises: - NotImplementedError: If the specified rollout backend is not supported. - """ - - rollout_controller = None - if rollout_cfg is None: - return rollout_controller - - from xtuner.v1.ray.rollout.controller import RolloutController - - rollout_controller = ( - ray.remote(RolloutController) - .options(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000))) - .remote(rollout_cfg, placement_group) - ) # type: ignore[attr-defined] - return rollout_controller - - def init_judger_controller(self, judger_cfg: Any, placement_group: Any): - """Initializes the judger controller. - - If a `judger_cfg` is provided, this method creates and returns a `JudgerController` - to handle evaluation and judging tasks. - - Args: - judger_cfg (Any): The configuration for the judger controller. - placement_group (Any): The placement group for scheduling Ray actors. - - Returns: - The initialized judger controller, or None if `judger_cfg` is not provided. - """ - judger_controller = None - if judger_cfg: - from xtuner.v1.ray.judger.controller import JudgerController - - judger_controller = JudgerController.remote(judger_cfg, placement_group) # type: ignore[attr-defined] - return judger_controller - - @abstractmethod - @ray_method - async def generate( - self, data: List[RLDataFlowItem], sample_params: Any, extra_params: Any - ) -> List[RLDataFlowItem]: - """Generates responses from the model for the given data using the - inference engine. This method is primarily used for single-step - inference. - - Args: - data: The input data, which can be a single prompt, RLTextDataItem, or a list of RLTextDataItem. - sample_params: Sampling parameters for the generation process. - - Returns: - A list of generated samples, each populated with 'response_str' and 'state' - """ - pass - - @abstractmethod - @ray_method - async def run(self, data: List[RLDataFlowItem], sample_params: Any, extra_params: Any) -> List[RLDataFlowItem]: - """Executes a full cycle of generation and interpretation, such as - generating a response and then evaluating it with a judger. This method - can be extended to support complex interactions like multi-turn - dialogues. - - Args: - data: The input data for the generation process. - sample_params: Sampling parameters for generation. - - Returns: - A list of generated samples - """ - pass - - def _call_rollout_func(self, method_name: str, block: bool): - """A helper function to dynamically call a method on the rollout - controller. - - Args: - method_name (str): The name of the method to call. - block (bool): Whether to block until the call completes. - - Returns: - The result of the method call. - """ - import ray - - assert self.rollout_controller, "Rollout controller is not initialized." - if block: - return ray.get(getattr(self.rollout_controller, method_name).remote()) - return getattr(self.rollout_controller, method_name).remote() - - @ray_method - def pause(self, block=True) -> None: - """Pauses the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("pause", block) - - @ray_method - def shutdown(self, block=True) -> None: - """Shuts down the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("shutdown", block) - - @ray_method - def restart(self, block=True) -> None: - """Restarts the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("restart", block) - - @ray_method - def get_rollout_info(self, block=True) -> dict[str, Any]: - """Gets information about the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("get_rollout_info", block) - - @ray_method - def onload_weights(self, block=True) -> None: - """Loads weights onto the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("onload_weights", block) - - @ray_method - def onload_kvcache(self, block=True) -> str: - """Loads the KV cache onto the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("onload_kvcache", block) - - @ray_method - def offload(self, block=True) -> str: - """Offloads weights and the KV cache from the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("offload", block) - - @ray_method - def update_active_workers(self, block=True) -> None: - """Checks the status of active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("update_active_workers", block) - - @ray_method - def get_rollout_stats(self, block=True) -> dict[str, Any]: - """Gets statistics from the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("get_rollout_stats", block) diff --git a/xtuner/v1/ray/environment/single_turn_env.py b/xtuner/v1/ray/environment/single_turn_env.py deleted file mode 100644 index 56e0a14aee..0000000000 --- a/xtuner/v1/ray/environment/single_turn_env.py +++ /dev/null @@ -1,179 +0,0 @@ -import asyncio -import copy -import os -from pathlib import Path -from typing import List, cast - -import ray -from ray.actor import ActorClass, ActorProxy - -from xtuner.v1.data_proto.rl_data import ( - RLDataFlowItem, - RLJudgerResponseItem, - RLRolloutResponseItem, - is_valid_for_training, - update_dataflow_item, - update_rollout_item, -) -from xtuner.v1.ray.environment.base_env import BaseEnvironment -from xtuner.v1.utils import get_logger, ray_method - - -class RawSingleTurnEnvironment(BaseEnvironment): - """A single-turn environment for handling generation and evaluation tasks. - - This class extends `BaseEnvironment` to provide a concrete implementation for - single-turn interactions. It manages the rollout process for generating responses - and can coordinate with a judger for evaluation. - - Args: - environment (str): The name of the environment. - rollout_pg: The placement group for scheduling rollout Ray actors. - rollout_cfg (optional): Configuration for the rollout controller. Defaults to None. - judger_pg (Any): The placement group for scheduling judger Ray actors. - Defaults to None indicates using the rollout_pg. - judger_cfg (optional): Configuration for the judger controller. Defaults to None. - rollout_controller (optional): An instance of the rollout controller. Defaults to None. - judger_controller (optional): An instance of the judger controller. Defaults to None. - """ - - def __init__( - self, - environment: str, - rollout_pg, - rollout_cfg=None, - judger_pg=None, - judger_cfg=None, - rollout_controller=None, - judger_controller=None, - ): - super().__init__( - environment, rollout_pg, rollout_cfg, judger_pg, judger_cfg, rollout_controller, judger_controller - ) - if rollout_cfg: - worker_log_dir = rollout_cfg.worker_log_dir - elif judger_cfg: - worker_log_dir = judger_cfg.worker_log_dir - else: - worker_log_dir = Path.cwd() / "work_dir" - self.logger = get_logger(log_dir=worker_log_dir, tag="SingleTurnEnv") - if rollout_cfg and rollout_cfg.enable_return_routed_experts: - self.logger.info("!!! Enable `return routed experts` in rollout controller. !!!") - self.rollout_timeout = rollout_cfg.rollout_timeout if rollout_cfg else 1200.0 - self.judger_timeout = judger_cfg.judger_timeout if judger_cfg else 1200.0 - # The timeout for the environment to wait for the rollout controller's response. - # This should be longer than the controller's internal timeout (`rollout_timeout`) - # to account for potential queuing delays and other overheads. - self.timeout_multiplier = 2.0 - - async def generate( # type: ignore[override] - self, group_data_items: List[RLDataFlowItem], sample_params=None, extra_params=None - ) -> List[RLDataFlowItem]: - """Generate responses for a batch of RLTextDataItem using the rollout - controller. - - Each item in `group_data_items` will be sent to the rollout controller for response generation - with the provided sampling parameters. The generated response string and state will be - added to each RLTextDataItem in-place as `response_str` and `state` fields. - - Args: - group_data_items (List[RLTextDataItem]): - A list of RLTextDataItem objects containing the prompts/messages for generation. - sample_params: Sampling parameters for the generation process. The type should match - the rollout controller's expected sampling parameter type (e.g., SampleParams or dict). - - Returns: - List[RLTextDataItem]: - The same list of RLTextDataItem, with each item enriched with the generated response - and state from the rollout controller. - """ - if self.rollout_controller: - response_future = [] - for sample in group_data_items: - sample.data.extra_info["root_id"] = sample.uid.root_id - sample.data.extra_info["action_id"] = sample.uid.action_id - update_sample_params = sample_params - - if "partial_rollout_input_ids" in sample.env.rollout.extra_info: - input_ids_length = len(sample.data.input_ids) if sample.data.input_ids is not None else 0 - current_partial_length = len(sample.env.rollout.extra_info["partial_rollout_input_ids"]) - rollout_extra_info = copy.deepcopy(sample.data.extra_info) - rollout_extra_info["partial_rollout_input_ids"] = sample.env.rollout.extra_info[ - "partial_rollout_input_ids" - ] - assert sample_params is not None, "sample_params should not be None when using partial rollout." - update_sample_params = copy.deepcopy(sample_params) - update_sample_params.max_tokens = sample_params.max_tokens - ( - current_partial_length - input_ids_length - ) - self.logger.debug( - f"root_id: {sample.uid.root_id}, action_id {sample.uid.action_id} pass current_partial_length {current_partial_length}, input_ids_length {input_ids_length} to rollout and set max_tokens to {update_sample_params.max_tokens}" - ) - else: - rollout_extra_info = sample.data.extra_info - - if "routed_experts" in sample.env.rollout.extra_info: - rollout_extra_info["routed_experts"] = sample.env.rollout.extra_info["routed_experts"] - - fut = self.rollout_controller.rollout.remote( - prompt=sample.data.messages, - input_ids=sample.data.input_ids, - sample_params=update_sample_params, - extra_params=extra_params, - extra_info=rollout_extra_info, - ) - response_future.append(fut) - try: - rollout_responses = await asyncio.wait_for( - asyncio.gather(*response_future), timeout=self.rollout_timeout * self.timeout_multiplier - ) - except asyncio.TimeoutError: - self.logger.error("Get rollout controller response timeout and return the failed response.") - rollout_responses = [RLRolloutResponseItem(state="skipped") for _ in group_data_items] - group_data_items = update_rollout_item(group_data_items, rollout_responses) - return group_data_items - - @ray_method - async def run( # type: ignore[override] - self, group_data_items: List[RLDataFlowItem], sample_params=None, extra_params=None - ) -> List[RLDataFlowItem]: - """Runs a full generation and judger cycle. - - This method first generates responses using the `generate` method and then, - if a judger controller is available, it uses the judger to evaluate the - generated responses. - - Args: - data: The input data for the cycle. Can be a list of strings, - a single `RLTextDataItem`, or a list of `RLTextDataItem`. - sample_params: Sampling parameters for the generation process. - - Returns: - The data enriched with generated responses and evaluation results. - The format of the return value matches the format of the input `data`. - """ - group_data_items = await self.generate(group_data_items, sample_params, extra_params) # type: ignore[assignment] - continue_judger = is_valid_for_training(group_data_items) - if self.judger_controller and continue_judger: - try: - judger_responses: List[RLJudgerResponseItem] = await asyncio.wait_for( - self.judger_controller.run.remote(group_data_items), - timeout=self.judger_timeout * self.timeout_multiplier, - ) - except asyncio.TimeoutError: - self.logger.error("Get judger controller response timeout and return the failed response.") - judger_responses = [ - RLJudgerResponseItem( - extra_info={"state": "failed"}, - ) - for _ in group_data_items - ] - group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_responses) - return group_data_items - - -SingleTurnEnvironment = cast( - ActorClass[RawSingleTurnEnvironment], - ray.remote(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)))(RawSingleTurnEnvironment), -) -SingleTurnEnvironmentProxy = ActorProxy[RawSingleTurnEnvironment] diff --git a/xtuner/v1/ray/evaluator.py b/xtuner/v1/ray/evaluator.py deleted file mode 100644 index b793f89a51..0000000000 --- a/xtuner/v1/ray/evaluator.py +++ /dev/null @@ -1,286 +0,0 @@ -import asyncio -from pathlib import Path -from typing import Callable, List, Optional, Sized, TypeVar, Union -from uuid import uuid4 - -import ray -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, Field -from ray.actor import ActorProxy -from tqdm.auto import tqdm -from typing_extensions import Annotated - -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast -from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLDatasetItem, RLUIDItem, SampleParams -from xtuner.v1.datasets import build_dataloader, build_datasets -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfigList -from xtuner.v1.ray.environment import BaseEnvironment -from xtuner.v1.ray.utils import create_task -from xtuner.v1.utils import get_logger -from xtuner.v1.utils.type_helper import ray_method - - -T = TypeVar("T") -Ret = TypeVar("Ret") - - -class EvaluatorConfig(BaseModel): - """Configuration for the Evaluator in XTuner. - - This class defines the configuration parameters for model evaluation in XTuner, including four main aspects: - - - Dataset configuration: Specifies the evaluation dataset and tokenizer for text processing - - - Evaluator control logic: Manages concurrency levels and retry mechanisms for robust evaluation - - - Evaluation scheduling: Controls evaluation step intervals and sample size (either by ratio or absolute count) - - - Custom metric computation: Supports user-defined functions for specialized metric calculations - - Args: - dataset_cfg (DatasetConfigList): Configuration for the evaluation dataset. - tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast): Tokenizer used for text processing. - evaluate_step (int): Step interval for triggering evaluation. Defaults to 1. - eval_sample_ratio (float): Ratio of samples to evaluate from the generated samples. If > 0, overrides eval_sample_num. Defaults to 0 (use all samples). - eval_sample_num (int): Number of samples to evaluate from the generated samples. Used if eval_sample_ratio is 0. Defaults to 0 (use all samples). - max_concurrent (int): Maximum number of concurrent evaluation tasks. Defaults to 8. - max_retry_times (int): Maximum number of retry attempts for failed evaluation tasks. Defaults to 2. - compute_metric_func (Optional[Callable]): Optional function to compute or filter metrics for generated data groups. If None, uses default metric computation. - - **Examples:** - - Example configuration for evaluator with GSM8K dataset:: - - from transformers import AutoTokenizer - - config = EvaluatorConfig( - dataset_cfg=[{ - "dataset": DatasetConfig(name="gsm8k", anno_path="test_data.json"), - "tokenize_fn": RLTokenizeFnConfig(max_length=512) - }], - tokenizer=AutoTokenizer.from_pretrained("model_path"), - max_concurrent=32, - eval_sample_ratio=0.8, # Use 80% of samples - evaluate_step=10, - compute_metric_func=custom_accuracy_metric - ) - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - - enable_evaluate: Annotated[ - bool, - Parameter(help="Flag to enable or disable evaluation during training."), - ] = False - enable_initial_evaluate: Annotated[ - bool, - Parameter(help="Flag to enable or disable initial evaluation before training starts."), - ] = False - dataset_cfg: Annotated[ - DatasetConfigList, - Parameter(help="Configuration for the dataset."), - ] - dataloader_cfg: Annotated[ - Optional[DataloaderConfig], Parameter(help="The PyTorch DataLoader for iterating over the dataset.") - ] = None - - tokenizer: Annotated[ - Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str], - Field(exclude=True), - Parameter(help="Tokenizer for text processing."), - ] - max_concurrent: Annotated[ - int, - Parameter(help="Maximum number of concurrent tasks."), - ] = 512 - eval_sample_ratio: Annotated[ - float, - Parameter(help="Ratio of samples to evaluate from the generated samples."), - ] = 0 - eval_sample_num: Annotated[ - int, - Parameter(help="Number of samples to evaluate from the generated samples."), - ] = 0 - max_retry_times: Annotated[int, Parameter(help="Maximum number of retry attempts for failed tasks.")] = 2 - evaluate_step: Annotated[int, Parameter(help="Step interval for evaluation.")] = 1 - compute_metric_func: Annotated[ - Optional[Callable], - Field(exclude=True), - Parameter(help="An optional function to filter or modify data groups after they are generated."), - ] = None - sample_params: Annotated[ - SampleParams, - Parameter(help="Sampling parameters for evaluation."), - ] = SampleParams() - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - - -class RawEvaluator: - """A Ray actor for evaluating a model's performance on a given dataset. - - The Evaluator generates responses using an environment controller or rollout controller, then it use default or - custom computes metrics function to compute scores for generated samples. It returns the evaluation scores and - generated samples. - """ - - def __init__(self, config: EvaluatorConfig, env_controller: BaseEnvironment): - """Initialize the Evaluator. - - Args: - config (EvaluatorConfig): The configuration for the evaluator. - env_controller (EnvController): The environment controller used for - generating responses. - """ - self.config = config - self.sample_params = self.config.sample_params - self.dataset = ( - build_datasets(config.dataset_cfg, config.tokenizer) - if isinstance(config.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) - else build_datasets( - config.dataset_cfg, AutoTokenizer.from_pretrained(config.tokenizer, trust_remote_code=True) - ) - ) - - if config.dataloader_cfg is not None: - self.dataloader_cfg = config.dataloader_cfg - else: - self.dataloader_cfg = DataloaderConfig( - collator="fake_collator", - pack_level="none", - ) - self.dataloader = build_dataloader( - dataloader_config=self.dataloader_cfg, - datasets=self.dataset, - global_batch_size=1, - micro_batch_size=1, - seed=1, - ) - assert isinstance(self.dataloader, Sized) - - self.env_controller = env_controller - self.return_list: List[RLDataFlowItem] = [] - if self.config.eval_sample_ratio > 0: - self.eval_batch_size = int(len(self.dataloader) * self.config.eval_sample_ratio) - elif self.config.eval_sample_num > 0: - self.eval_batch_size = self.config.eval_sample_num - else: - self.eval_batch_size = len(self.dataloader) - if self.config.compute_metric_func is not None: - self.compute_metric = self.config.compute_metric_func - else: - self.compute_metric = self.default_compute_metric - self.logger = get_logger(log_dir=config.worker_log_dir, tag="Evaluator") - - def default_compute_metric(self, samples): - """Default metric computation function. - - Calculates accuracy based on whether the reward is positive. - - Args: - samples (list): A list of RLDataFlowItem samples. - - Returns: - dict: A dictionary containing the accuracy score. - """ - return {"accuracy": sum(s.env.judger.reward["score"] > 0 for s in samples) / len(samples)} - - async def eval_worker_task(self, sample: RLDataFlowItem): - """A single worker task to evaluate one sample. - - This task calls the environment controller to run the model on a - sample. If it fails, it returns the sample with an incremented - retry count. - - Args: - sample (RLDataFlowItem): The data item to evaluate. - - Returns: - RLDataFlowItem or None: The sample with retry information if it - failed, or None if it succeeded or failed without a sample. - """ - group_sample = await self.env_controller.run.remote([sample], sample_params=self.sample_params) # type: ignore[attr-defined] - self.return_list.append(group_sample[0]) - - async def concurrent_eval_task_runner(self): - """Runs evaluation tasks concurrently to generate a batch of samples. - - This method orchestrates the evaluation process by creating and managing - a pool of asynchronous worker tasks. It continuously fetches data from - the dataloader and submits evaluation tasks until the desired number of - samples (`self.eval_batch_size`) has been successfully processed. - """ - waiting_tasks = set() - self.logger.info(f"Start to generate {self.eval_batch_size} samples for evaluate") - self.logger.info(f"Evaluate sample parameters set to {self.sample_params}.") - data_iter = iter(self.dataloader) - with tqdm(total=self.eval_batch_size, desc="Rollout for eval samples") as pbar: - update_step = max(1, int(self.eval_batch_size * 0.1)) - next_update_threshold = update_step - while len(self.return_list) < self.eval_batch_size: - if len(self.return_list) >= next_update_threshold: - pbar.n = len(self.return_list) - pbar.refresh() - next_update_threshold += update_step - while len(waiting_tasks) < self.config.max_concurrent: - if len(self.return_list) + len(waiting_tasks) >= self.eval_batch_size: - break - try: - data = next(data_iter) - except StopIteration: - data_iter = iter(self.dataloader) - data = next(data_iter) - self.logger.warning("Restarting the evaluation dataset.") - uid = RLUIDItem(action_id=uuid4().int, observation_id=uuid4().int) - data_item = RLDataFlowItem(data=RLDatasetItem(**data[0]), uid=uid) - task = create_task(self.eval_worker_task(data_item)) - waiting_tasks.add(task) - - if len(waiting_tasks) == 0: - break - - _, pending_tasks = await asyncio.wait(waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED) - waiting_tasks = pending_tasks - - pbar.n = len(self.return_list) - pbar.refresh() - - self.logger.info("Target batch size reached.") - if waiting_tasks: - await asyncio.wait_for(asyncio.gather(*waiting_tasks, return_exceptions=True), timeout=10) - - rollout_stats = await self.env_controller.get_rollout_stats.remote() # type: ignore[attr-defined] - self.logger.info(rollout_stats) - - @ray_method - async def run(self, return_samples=False): - """Run the full evaluation process. - - This method resets the state, runs the concurrent task runner, - computes the final metrics, and returns the results. - - Args: - sample_params (Optional[SampleParams]): Sampling parameters for - generation. Defaults to a greedy strategy. - return_samples (bool): Whether to return the generated samples - along with the scores. Defaults to False. - - Returns: - dict or Tuple[dict, list]: The evaluation scores, and optionally - the generated samples. - """ - self.return_list = [] - await self.env_controller.restart.remote() # type: ignore[attr-defined] - await self.concurrent_eval_task_runner() - if len(self.return_list) == 0: - self.logger.warning("No valid samples were generated during evaluation.") - return {} if not return_samples else ({}, []) - scores = self.compute_metric(self.return_list) - # To match the training format : each group's data is a list - self.eval_samples = [[sample] for sample in self.return_list] - if return_samples: - return scores, self.eval_samples - return scores - - -Evaluator = ray.remote(RawEvaluator) -EvaluatorProxy = ActorProxy[RawEvaluator] diff --git a/xtuner/v1/ray/judger/__init__.py b/xtuner/v1/ray/judger/__init__.py deleted file mode 100644 index 6a194b0245..0000000000 --- a/xtuner/v1/ray/judger/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .controller import JudgerConfig, JudgerController diff --git a/xtuner/v1/ray/judger/controller.py b/xtuner/v1/ray/judger/controller.py deleted file mode 100644 index 1cdcb64ab1..0000000000 --- a/xtuner/v1/ray/judger/controller.py +++ /dev/null @@ -1,270 +0,0 @@ -import asyncio -import random -from pathlib import Path -from typing import List, Optional - -import ray -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, computed_field -from ray.util.placement_group import PlacementGroup, placement_group -from typing_extensions import Annotated - -from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem - -from .native import NativeJudgerConfig - - -PG_READY_TIMEOUT = 30 - - -class JudgerConfig(BaseModel): - """Judger configuration for XTuner. - - Configuration for the judging system managing batch processing and custom judger - implementations for model evaluation and reward computation. - - Args: - enable_batch_reward (bool): Enable calculate reward within the data group of repeat_prompt_k. Defaults to False. - - reward_judger_configs (Dict[str, BaseModel]): Dictionary mapping judger names - to their configuration objects. We provided the example GSM8KJudgerConfig - for GSM8K mathematical reasoning tasks (see ``xtuner/v1/ray/judger/gsm8k.py``). Defaults to empty dict. - - **Examples:** - - Example configuration for single judger:: - - config = JudgerConfig( - enable_batch_reward=False, - reward_judger_configs={ - "gsm8k": GSM8KJudgerConfig(...) - } - ) - - Example configuration for multiple judgers:: - - config = JudgerConfig( - reward_judger_configs={ - "gsm8k": GSM8KJudgerConfig(...), - "math_qa": MathQAJudgerConfig(...), - "custom_eval": CustomJudgerConfig(...) - } - ) - - .. note:: - You should ensure each dataset item specifies data_source with dictionary mapping judger names to their weight ratios - - Example dataset item:: - - data_item = { - "data_source": {"gsm8k": 0.7, "math_qa": 0.3}, - "response_str": "...", - "reward_model": {"ground_truth": "..."} - } - """ - - model_config = ConfigDict(extra="forbid") - - enable_batch_reward: Annotated[ - bool, Parameter(help="Whether to enable batch reward calculation for multiple samples at once.") - ] = False - enable_weighted_judgers: Annotated[ - bool, Parameter(help="Whether to enable weighted reward calculation on multi judgers.") - ] = False - reward_judger_configs: Annotated[ - List[NativeJudgerConfig], - Parameter(help="A custom Python function for computing reward given model output and label."), - ] = [] - judger_timeout: Annotated[float, Parameter(help="Timeout for each judger request in seconds.")] = 1200.0 - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - - @computed_field - def total_bundles_needed(self) -> list[dict]: - judger_total_bundles = [ - {"CPU": cfg.num_cpus_per_actor, "memory": cfg.num_cpus_per_actor * 1024**3} - for cfg in self.reward_judger_configs - for _ in range(cfg.num_ray_actors) - ] - return judger_total_bundles - - @computed_field - def total_cpus_needed(self) -> int: - judger_total_cpus = sum(cfg.num_cpus_per_actor * cfg.num_ray_actors for cfg in self.reward_judger_configs) - return judger_total_cpus - - @computed_field - def total_memory_needed(self) -> int: - judger_total_memory = sum( - cfg.num_cpus_per_actor * 1024**3 * cfg.num_ray_actors for cfg in self.reward_judger_configs - ) - return judger_total_memory - - -@ray.remote -class JudgerController: - """Controller for judging model outputs and calculating rewards.""" - - def __init__(self, judger_config: JudgerConfig, pg: Optional[PlacementGroup] = None): - """Initialize the JudgerController. - - Args: - judger_config (JudgerConfig): The configuration for the judger. - placement_group: The Ray placement group for resource allocation. - Defaults to None. - """ - self.judger_config = judger_config - # note: placement_group is used to control the placement of Ray tasks. - # It will be implemented when gpu judger is needed - if pg is None: - assert len(self.judger_config.reward_judger_configs) == 1, ( - "If no placement group is provided, there should be only one judger config." - ) - defaule_placement_group = placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK") - ray.get([defaule_placement_group.ready()], timeout=PG_READY_TIMEOUT) - self.pg = defaule_placement_group - else: - assert len(pg.bundle_specs) >= sum( - config.num_ray_actors for config in self.judger_config.reward_judger_configs - ), "The provided placement group does not have enough bundles for all judger actors." - self.pg = pg - self.reward_judger: List[List[ray.actor.ActorHandle]] = [] - self.reward_judger_names: List[str] = [] - self.judger_instance_count = 0 - - for idx, config in enumerate(self.judger_config.reward_judger_configs): - # start_bundle_idx用于指定从placement group的哪个bundle开始分配资源 - judger = config.build_actor(pg=self.pg, start_bundle_idx=self.judger_instance_count) - # 同一类judger可能会有多个实例(例如多个Ray actor),同一类的judger作为一行 - self.reward_judger.append(judger) - self.reward_judger_names.append(config.judger_name) - self.judger_instance_count += len(judger) - self.enable_weighted_judgers = ( - False if len(self.reward_judger) == 1 else self.judger_config.enable_weighted_judgers - ) - - async def _call_single_reward_judger( - self, judger: List[ray.actor.ActorHandle], group_data_item: List[RLDataFlowItem] - ): - """Call a single custom reward judger to calculate rewards. - - Args: - judger (NativeJudger): An instance of a custom judger. - responses (List[str]): A list of model-generated responses. - labels (List[str]): A list of ground-truth labels. - - Returns: - List[RLJudgerResponseItem]: A list of RLJudgerResponseItem containing - calculated rewards for each sample. - """ - tasks = [] - judger_input_data = ( - [group_data_item] if self.judger_config.enable_batch_reward else [[item] for item in group_data_item] - ) - - if self.judger_config.enable_batch_reward: - # Randomly pick a judger instance for batch evaluation to balance the load. - tasks.append(random.choice(judger).judge.remote(group_data_item)) - else: - tasks.extend([judger[idx % len(judger)].judge.remote(item) for idx, item in enumerate(judger_input_data)]) - return tasks - - async def _call_custom_reward_judger( - self, - active_judgers: List[List[ray.actor.ActorHandle]], - active_reward_judger_names: List[str], - group_data_item: List[RLDataFlowItem], - ) -> List[RLJudgerResponseItem]: - """Call custom reward judgers to calculate rewards. - - Args: - active_judgers (Dict[str, NativeJudger]): A dictionary of active - judgers. - responses (List[str]): A list of model-generated responses. - labels (List[str]): A list of ground-truth labels. - - Returns: - Dict[str, List[float]]: A dictionary where keys are judger names - and values are lists of calculated rewards for each sample. - """ - active_judgers_len = len(active_judgers) - task_len_list = [0] - all_tasks = [] - assert active_judgers_len == len(active_reward_judger_names), ( - f"Expected {active_judgers_len} active judgers, but got {len(active_reward_judger_names)}" - ) - for judger in active_judgers: - tasks = await self._call_single_reward_judger(judger, group_data_item) - all_tasks.extend(tasks) - task_len_list.append(task_len_list[-1] + len(tasks)) - - all_results = await asyncio.gather(*all_tasks) - - assert len(all_results) == len(group_data_item) * len(active_judgers), ( - f"Expected {len(group_data_item) * len(active_judgers)} results, but got {len(all_results)}" - ) - - active_judger_results = {} - for i in range(active_judgers_len): - active_judger_results[active_reward_judger_names[i]] = all_results[task_len_list[i] : task_len_list[i + 1]] - - # 为每个样本创建一个 RLJudgerResponseItem,不同judger的结果放在同一个item中 - uid_list = [item.uid.observation_id for item in group_data_item] - judger_response_items_dict = {uid: RLJudgerResponseItem(uid=uid) for uid in uid_list} - for judger_name, results in active_judger_results.items(): - for result in results: - for data in result: - return_uid = data.uid - judger_response_items_dict[return_uid].reward.update(data.reward) - judger_response_items_dict[return_uid].reward.update({judger_name: data.reward}) - judger_response_items_dict[return_uid].extra_info.update(data.extra_info) - return list(judger_response_items_dict.values()) - - async def run( - self, group_data_item: RLDataFlowItem | List[RLDataFlowItem] - ) -> RLJudgerResponseItem | List[RLJudgerResponseItem]: - """Run the judging process for a group of data items. - - Args: - group_data_item (List[RLTextDataItem]): A list of RLTextDataItem, - each containing the response and other relevant information. - - Returns: - List[float]: A list of final calculated rewards for each data item. - """ - input_type_is_list = True - if not isinstance(group_data_item, list): - input_type_is_list = False - group_data_item = [group_data_item] - - if self.enable_weighted_judgers: - data_source = group_data_item[0].data.data_source - # 如果要使用多个judger并且进行加权打分,则必须在数据集中指定data_source的分数 - assert data_source, "No data source found for the given datasets when multiple judgers are provided." - active_reward_judger = [] - active_reward_judger_names = [] - for idx, judger in enumerate(self.reward_judger): - judger_name = self.reward_judger_names[idx] - if judger_name in data_source: - active_reward_judger.append(judger) - active_reward_judger_names.append(judger_name) - assert active_reward_judger, ( - f"No active reward judger in {self.reward_judger_names} found for the given data source {data_source}." - ) - judger_response_item = await self._call_custom_reward_judger( - active_reward_judger, active_reward_judger_names, group_data_item - ) - - # NOTE: 只计算score的加权和 - for item in judger_response_item: - final_reward = 0 - for name, weight in data_source.items(): - if name in item.reward: - final_reward += item.reward[name]["score"] * weight - item.reward["weighted_score"] = final_reward - else: - judger_response_item = await self._call_custom_reward_judger( - self.reward_judger, self.reward_judger_names, group_data_item - ) - if input_type_is_list is False: - return judger_response_item[0] - return judger_response_item diff --git a/xtuner/v1/ray/judger/native.py b/xtuner/v1/ray/judger/native.py deleted file mode 100644 index 6bb4917041..0000000000 --- a/xtuner/v1/ray/judger/native.py +++ /dev/null @@ -1,266 +0,0 @@ -import inspect -from typing import Any, Callable, List, Optional - -import httpx -import ray -from pydantic import BaseModel, ConfigDict, Field -from ray.util.placement_group import PlacementGroup - -from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem -from xtuner.v1.utils import get_logger - - -class NativeJudgerConfig(BaseModel): - """Configuration class for NativeJudger. - - This class defines the configuration options for initializing a NativeJudger, - including resource allocation (number of Ray actors and CPUs per actor), - reward function or remote judging service, optional pre/post-processing functions, - request timeout, and any extra information needed for judging. - - Attributes: - judger_name (str): Name identifier for the judger. - num_ray_actors (int): Number of Ray actor instances to launch. - num_cpus_per_actor (int): Number of CPUs allocated per actor. - reward_func (Optional[Callable]): Local reward function for judging. - Exactly one of reward_func or remote_url must be provided. - remote_url (Optional[str]): Remote service URL for judging. - Exactly one of reward_func or remote_url must be provided. - preprocess_func (Optional[Callable]): Function to preprocess input data before judging. - postprocess_func (Optional[Callable]): Function to postprocess the judging result. - request_timeout (float): Timeout (in seconds) for remote requests. - extra_info (dict): Additional information to be passed to the judger or reward function. - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - judger_name: str - num_ray_actors: int = 1 - num_cpus_per_actor: int = 1 - cpu_memory_per_actor: int = 1024**3 - reward_func: Optional[Callable] = Field(default=None, exclude=True) - remote_url: Optional[str] = None - preprocess_func: Optional[Callable] = Field(default=None, exclude=True) - postprocess_func: Optional[Callable] = Field(default=None, exclude=True) - request_timeout: float = 30.0 - extra_info: dict = Field(default={}, exclude=True) - - def build_actor(self, pg: PlacementGroup, start_bundle_idx: int) -> List[ray.actor.ActorClass]: - """Create and launch Ray actor instances for the GSM8K judger. - - This method instantiates multiple NativeJudger Ray actors according to `num_ray_actors`, - assigning each to a specific bundle in the provided placement group for resource isolation. - Each actor is initialized with the judger's configuration and reward function. - - Args: - pg: The Ray PlacementGroup used to allocate resources for the actors. - start_bundle_idx: The starting bundle index in the placement group for actor placement. - - Returns: - List[ActorClass]: A list of Ray actor handles representing the launched judger workers. - """ - workers_list = [] - for idx in range(self.num_ray_actors): - bundle_idx = start_bundle_idx + idx - pg_options = {"num_cpus": self.num_cpus_per_actor, "memory": self.cpu_memory_per_actor} - assert pg.bundle_specs[bundle_idx].get("CPU", 1) >= self.num_cpus_per_actor, ( - f"Placement group bundle {bundle_idx} does not have enough CPU resources." - ) - assert pg.bundle_specs[bundle_idx].get("memory", 0) >= self.cpu_memory_per_actor, ( - f"Placement group bundle {bundle_idx} does not have enough memory resources." - ) - worker = ( - ray.remote(NativeJudger) - .options( - placement_group=pg, - placement_group_bundle_index=bundle_idx, - **pg_options, - ) - .remote( - judger_name=self.judger_name, - reward_func=self.reward_func, - remote_url=self.remote_url, - preprocess_func=self.preprocess_func, - postprocess_func=self.postprocess_func, - request_timeout=self.request_timeout, - extra_info=self.extra_info, - ) - ) - workers_list.append(worker) - return workers_list - - -class NativeJudger: - """Base class for judgers, providing a standard interface for executing a - judging process, which can be either a local function or a remote service. - - The judger orchestrates a three-step pipeline: - 1. Pre-process the input data. - 2. Execute the core logic (local function or remote HTTP call). - 3. Post-process the result. - """ - - def __init__( - self, - judger_name: str = "native_judger", - reward_func: Optional[Callable] = None, - remote_url: Optional[str] = None, - preprocess_func: Optional[Callable] = None, - postprocess_func: Optional[Callable] = None, - request_timeout: float = 30.0, - extra_info: dict = {}, - ): - """Initialize the NativeJudger. - - Args: - reward_func (Optional[Callable]): A local function to compute the - reward. Exactly one of `reward_func` or `remote_url` must be - provided. Defaults to None. - remote_url (Optional[str]): The URL of a remote service for - judging. Exactly one of `reward_func` or `remote_url` must be - provided. Defaults to None. - preprocess_func (Optional[Callable]): A function to preprocess the - input data before judger execution. Defaults to None. - postprocess_func (Optional[Callable]): A function to postprocess - the judger result. Defaults to None. - request_timeout (float): Timeout for remote requests in seconds. - Defaults to 30.0. - extra_info (dict): Extra information to be passed to the reward - function. Defaults to {}. - - Raises: - ValueError: If both or neither of `reward_func` and `remote_url` - are provided. - """ - if (reward_func is None and remote_url is None) or (reward_func is not None and remote_url is not None): - raise ValueError("Exactly one of 'reward_func' or 'remote_url' must be provided.") - self.judger_name = judger_name - self.extra_info = extra_info - self.reward_func = reward_func - self.remote_url = remote_url - - self.preprocess_func = preprocess_func or self._default_preprocess - self.postprocess_func = postprocess_func or self._default_postprocess - - self.http_client = None - self.execute_func = None - - if self.reward_func: - self.execute_func = self._local_executor - elif self.remote_url: - self.http_client = httpx.AsyncClient(timeout=request_timeout) - self.execute_func = self._remote_executor - - def _default_preprocess(self, data_item: List[RLDataFlowItem], extra_info: dict) -> Any: - """Default preprocessing function. - - Args: - data_item (RLDataFlowItem | List[RLDataFlowItem]): The data item(s) to preprocess. - - Returns: - Any: A dictionary containing the responses, labels, and extra info. - """ - - assert len(data_item) == 1, "Default preprocess only supports single data item." - # TODO: Support batch reward calculation via API server - response = data_item[0].env.rollout.response - assert data_item[0].data.reward_model is not None - label = data_item[0].data.reward_model["ground_truth"] - return { - "response": response, - "label": label, - "extra_info": extra_info, - } - - def _default_postprocess(self, result: Any) -> List[RLJudgerResponseItem]: - ## 将结果包装成 RLJudgerResponseItem - """Default postprocessing function. - - Args: - result (Any): The result from the execution step. - - Returns: - Any: The result, unchanged. - """ - if not isinstance(result, list): - result = [result] - # todo: 支持多个judger结果的返回 - judger_response_item = [RLJudgerResponseItem(reward=result[i]) for i in range(len(result))] - return judger_response_item - - async def _local_executor(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]: - """Executes the reward function locally. - - Args: - responses (str | List[str]): The model's response(s). - labels (str | List[str]): The ground-truth label(s). - - Returns: - Any: The postprocessed result of the reward function. - """ - assert self.reward_func is not None, "reward_func cannot be None for local execution." - # 记录每个judger请求的uid, 方便后续结果合并 - uid_list = [item.uid.observation_id for item in data_item] - kwargs = self.preprocess_func(data_item, self.extra_info) - if inspect.iscoroutinefunction(self.reward_func): - json_result = await self.reward_func(**kwargs) - else: - json_result = self.reward_func(**kwargs) - - # transform json to RLJudgerResponseItem - result = self.postprocess_func(json_result) - for i in range(len(result)): - result[i].uid = uid_list[i] - return result - - async def _remote_executor(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]: - """Executes the reward function by calling a remote service. - - Args: - responses (str | List[str]): The model's response(s). - labels (str | List[str]): The ground-truth label(s). - - Returns: - Any: The postprocessed result from the remote service, or None if - an error occurs. - """ - assert self.remote_url is not None and self.http_client is not None, ( - "remote_url cannot be None for remote execution." - ) - payload = self.preprocess_func(data_item, self.extra_info) - try: - response = await self.http_client.post(self.remote_url, json=payload) - response.raise_for_status() - json_result = response.json() - # 重要,必须加 - json_result["uid"] = data_item[0].uid.observation_id - # transform json to RLJudgerResponseItem - return self.postprocess_func(json_result) - except httpx.RequestError as exc: - get_logger().error(f"An error occurred while requesting {exc.request.url}: {exc}") - return [] - - async def judge(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]: - """The main public method to run the judging pipeline. - - Args: - responses (str | List[str]): The model's response(s) to be judged. - labels (str | List[str]): The ground-truth label(s). - - Returns: - Any: The final result after the full - preprocess-execute-postprocess pipeline. - - Raises: - RuntimeError: If the judger is not properly initialized. - """ - if self.execute_func is None: - raise RuntimeError("Judger is not properly initialized.") - return await self.execute_func(data_item) - - def get_judger_name(self) -> str: - """Get the name of the judger. - - Returns: - str: The name of the judger. - """ - return self.judger_name diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py deleted file mode 100644 index 52075b36a0..0000000000 --- a/xtuner/v1/ray/rollout/controller.py +++ /dev/null @@ -1,589 +0,0 @@ -import asyncio -import os -import socket -import threading -import time -from collections import OrderedDict -from dataclasses import dataclass -from itertools import cycle -from typing import Any, Dict, List, Optional, Union -from uuid import uuid4 - -import ray -import uvicorn -from fastapi import FastAPI -from ray.util.placement_group import PlacementGroup - -from transformers import AutoTokenizer -from xtuner.v1.data_proto.rl_data import RLRolloutRequestItem, RLRolloutResponseItem, RolloutExtraParams, SampleParams -from xtuner.v1.ray.base import AutoAcceleratorWorkers -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.utils import get_logger - -from .worker import RolloutWorker - - -ROLLOUT_RAY_GET_TIMEOUT = os.getenv("XTUNER_ROLLOUT_RAY_GET_TIMEOUT", 5 * 3600) # default 5 hours - - -@dataclass -class WorkerInfo: - """A data class to hold all state information for a single worker.""" - - actor: RolloutWorker - rank: int = -1 - is_active: bool = True - failure_count: int = 0 - running_count: int = 0 - success_count: int = 0 - - -class SessionRouter: - def __init__( - self, - worker_status: Dict[Any, bool], # worker: worker_status - max_sessions: int = 10000, - max_idle_seconds: Optional[float] = 3600.0, - ): - assert len(worker_status) > 0 - self._workers = list(worker_status.items()) - self._max_sessions = max_sessions - self._max_idle = max_idle_seconds - - # OrderedDict: key=session_id -> value=(worker, last_used_ts) - self._map: OrderedDict[int, tuple[Any, float]] = OrderedDict() - self._worker_cycler = cycle(self._workers) - self._lock = asyncio.Lock() - self.logger = get_logger() - - def _now(self) -> float: - return time.time() - - def _evict_expired(self): - if self._max_idle is None: - return - now = self._now() - - to_delete = [] - for sid, (_, last_used) in self._map.items(): - if now - last_used > self._max_idle: - to_delete.append(sid) - else: - break - for sid in to_delete: - self._map.pop(sid, None) - - def _evict_lru_to_capacity(self): - while len(self._map) > self._max_sessions: - self._map.popitem(last=False) - - def update_active_workers(self, worker_status: Dict[Any, bool]): - self._workers = list(worker_status.items()) - self.logger.debug(f"SessionRouter update active workers: {self._workers}") - self._worker_cycler = cycle(self._workers) - - async def get_worker(self, session_id: int) -> Any: - async with self._lock: - self._evict_expired() - - if session_id in self._map: - worker, _ = self._map.pop(session_id) - self._map[session_id] = (worker, self._now()) - if worker[1]: # worker is healthy - return worker[0] - - worker = next(self._worker_cycler) - while worker[1] is False: - worker = next(self._worker_cycler) - self._map[session_id] = (worker, self._now()) - - self._evict_lru_to_capacity() - return worker[0] - - -class RolloutController: - """Controller for managing and coordinating multiple RolloutWorker - actors.""" - - def __init__( - self, - infer_config: RolloutConfig, - placement_group: PlacementGroup, - ): - """Initialize the RolloutController. - - Args: - infer_config (RolloutConfig): The configuration for the rollout. - placement_group (PlacementGroup): The placement group for the - RolloutWorker actors. - """ - self.config = infer_config - self.num_gpus_per_engine = ( - self.config.expert_parallel_size - if self.config.expert_parallel_size > 1 - else self.config.tensor_parallel_size - ) - self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController") - self.num_workers = 0 - self.workers_info: Dict[str, WorkerInfo] = {} # url -> WorkerInfo - self.active_rollout_workers: List[RolloutWorker] = [] - self.tokenizer = AutoTokenizer.from_pretrained(infer_config.tokenizer_path, trust_remote_code=True) - self.workers, self.rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( - self._get_worker_cls(), infer_config, placement_group - ) - self.engine_rank_mesh_array, self.worker_server_urls_map = self.init_workers() - self.start_api_server() - # todo(@duanyanhui): add router to replace native round robin - self.router = SessionRouter(self._get_worker_status_for_router()) - self.sample_params = SampleParams().dict() - self.extra_params = dict( - RolloutExtraParams( - stream=False, - include_stop_str_in_output=True, - no_stop_trim=True, - return_logprob=True, - return_token_ids=True, - skip_special_tokens=False, - spaces_between_special_tokens=False, - top_logprobs=1, - ) - ) - self.print_params_flag = True - # The timeout for the environment to wait for the rollout controller's response. - # This should be longer than the controller's internal timeout (`rollout_timeout`) - # to account for potential queuing delays and other overheads. - self.timeout_multiplier = 2.0 - - def _get_worker_status_for_router(self) -> Dict[RolloutWorker, bool]: - """Helper to generate the status dict required by the SessionRouter.""" - return {info.actor: info.is_active for info in self.workers_info.values()} - - def _get_worker_cls(self): - if os.environ.get("XTUNER_USE_LMDEPLOY") == "1": - from .lmdeploy import LMDeployWorker - - return ray.remote(LMDeployWorker) - elif os.environ.get("XTUNER_USE_VLLM") == "1": - from .vllm import vLLMWorker - - return ray.remote(vLLMWorker) - elif os.environ.get("XTUNER_USE_SGLANG") == "1": - from .sglang import SGLangWorker - - return ray.remote(SGLangWorker) - else: - raise NotImplementedError( - "Rollout backend is not supported." - "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM" - " or XTUNER_USE_SGLANG environment variable." - ) - - def _get_active_worker_to_url_map(self): - """Get a mapping of active workers to their server URLs.""" - return {info.actor: url for url, info in self.workers_info.items()} - - def _is_port_in_use(self, host: str, port: int) -> bool: - """Check if a port is in use on the given host.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind((host, port)) - return False - except OSError: - return True - - def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map): - """Update the list of active rollout workers and their server URLs. - - When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with - tp_rank=0 in each engine is responsible for receiving input data. Other tp_ranks do not accept input. - Therefore, this function updates active_rollout_workers and worker_server_urls_map to keep only the tp_rank=0 - workers and their corresponding URLs. - """ - if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node: - return active_rollout_workers, worker_server_urls_map - else: - active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node - active_rank = list(worker_server_urls_map.keys())[::active_worker_interval] - active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] - return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) - - def get_rollout_info(self): - """Get information about the current rollout setup. - - Returns: - dict: A dictionary containing the engine mesh list, server URL - dictionary, and the rollout configuration. - """ - worker_server_urls_status = {url: info.is_active for url, info in self.workers_info.items()} - return dict( - engine_rank_mesh_array=self.engine_rank_mesh_array, - server_url_dict=self.worker_server_urls_map, - rollout_config=self.config, - worker_server_urls_status=worker_server_urls_status, - ) - - def init_workers(self): - """Initializes and configures the pool of RolloutWorker actors. - - This method configures distributed inference engines by grouping - workers, where each group forms a tensor-parallel inference engine. It - determines the `active_workers` to act as the head of each engine, - constructs the `engine_rank_mesh_array` to define engine topology, acquires - necessary distributed communication ports, and finally launches servers - on the `active_workers` to get their addresses. - - Returns: - Tuple[List, Dict]: A tuple where the first element is - `engine_rank_mesh_array`, a list of lists containing the ranks of workers - in each engine, and the second element is `worker_server_urls_map`, - a dictionary mapping the ID of each active worker to its - corresponding server URL. - """ - active_servers_count, nodes_per_engine = self._get_active_servers_count(self.config, len(self.workers)) - interval = len(self.workers) // active_servers_count - active_rollout_workers = self.workers[::interval] - self.num_workers = len(active_rollout_workers) - server_urls_per_engine = self.config.server_urls_per_engine - - set_bundle_idxs_objectref = [] - engine_rank_mesh_array = [] - activate_worker_idx = 0 - for active_worker in active_rollout_workers: - head_rank, _ = self.rank_bundle_idx_list[activate_worker_idx] - engine_workers_meta = self.rank_bundle_idx_list[head_rank : head_rank + interval] - engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] # meta: (rank, bundle_idx) - set_bundle_idxs_objectref.append(active_worker.set_engine_bundle_idxs.remote(engine_bundle_idxs)) # type: ignore[attr-defined] - engine_rank_mesh_array.append([meta[0] for meta in engine_workers_meta]) - activate_worker_idx += interval - ray.get(set_bundle_idxs_objectref) - # set engine mesh list for each worker - ray.get( - [worker.set_engine_rank_mesh_array.remote(engine_rank_mesh_array) for worker in active_rollout_workers] - ) # type: ignore[attr-defined] - # init dist_init_addr for each worker according to parallel settings - init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined] - dist_init_addrs = self._update_dist_init_addr( - nodes_per_engine, server_urls_per_engine, init_dist_init_addrs, self.num_gpus_per_engine - ) - # launch rollout servers - worker_server_urls_map = dict( # rank -> url - ray.get([worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)]) - ) - active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map( - active_rollout_workers, worker_server_urls_map - ) - self.workers_info = {} - for i in range(len(active_rollout_workers)): - rank = list(worker_server_urls_map.keys())[i] - url = worker_server_urls_map[rank] - self.workers_info[url] = WorkerInfo(rank=rank, actor=active_rollout_workers[i]) - self.logger.info(f"Rollout worker server URLs: {list(self.workers_info.keys())}") - return engine_rank_mesh_array, worker_server_urls_map - - def _deactivate_worker(self, url: str): - """A helper function to deactivate a worker, update all related states, - and shut it down.""" - worker_info = self.workers_info.get(url) - if not worker_info or not worker_info.is_active: - return - - self.logger.warning(f"Deactivating rollout worker {worker_info.actor} with URL {url} due to failures.") - worker_info.is_active = False - self.router.update_active_workers(self._get_worker_status_for_router()) - - ray.get(worker_info.actor.offload.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - ray.get(worker_info.actor.shutdown.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - - def update_active_workers(self): - """Check the health of all active rollout workers. - - Returns: - List[bool]: A list of booleans indicating the health status of - each active rollout worker. - """ - active_workers = [(url, info) for url, info in self.workers_info.items() if info.is_active] - if not active_workers: - return - - urls, infos = zip(*active_workers) - actors = [info.actor for info in infos] - - health_statuses = ray.get([actor.check_health.remote() for actor in actors], timeout=ROLLOUT_RAY_GET_TIMEOUT) - - for url, is_healthy in zip(urls, health_statuses): - if not is_healthy: - self.logger.warning(f"Rollout worker {url} is unhealthy.") - self._deactivate_worker(url) - - def deactivate_worker_by_url(self, url: str): - """Deactivates a worker identified by its URL after it exceeds the - maximum retry count.""" - worker_info = self.workers_info.get(url) - if not worker_info or not worker_info.is_active: - return - - worker_info.failure_count += 1 - if ( - self.config.max_retry_per_worker is not None - and worker_info.failure_count < self.config.max_retry_per_worker - ): - self.logger.warning( - f"Rollout worker {url} failed {worker_info.failure_count} times, but not deactivated yet." - ) - return - - self._deactivate_worker(url) - - async def rollout( - self, - prompt: Union[str, List[Dict[str, Any]]] | None = None, - input_ids: Optional[List[int]] | None = None, - tools: List = [], - tool_choice: str = "auto", - sample_params: Optional[SampleParams] = None, - extra_params: dict = dict(), - format: str = "openai", - session_id: Optional[int] = None, - extra_info: dict = dict(), - ) -> RLRolloutResponseItem: - # 这个函数接受标准的openapi chat create接口,所以不需要再额外定义输入的形式 - """Perform a rollout using one of the workers in a round-robin fashion. - - Args: - prompt (List[str]): The prompt to send to the model. - tools (List, optional): A list of tools the model can call. - Defaults to []. - tool_choice (str, optional): The tool choice strategy. - Defaults to "auto". - sample_params (Optional[SampleParams], optional): The sampling - parameters for generation. If None, the default `sample_params` - of the controller will be used. Defaults to None. - extra_params (dict, optional): Extra parameters for the worker. - Defaults to dict(). - format (str, optional): The format of the response. - Defaults to "openai". - - Returns: - The response from the rollout worker. - """ - session_id = session_id if session_id else uuid4().int - worker = await self.router.get_worker(session_id) - # update sample params and extra params - self.sample_params.update(sample_params.dict() if sample_params else {}) - self.extra_params.update(extra_params if extra_params else {}) - if self.print_params_flag: - self.logger.info(f"Rollout with sample params: {self.sample_params}, extra params: {self.extra_params}") - self.print_params_flag = False - assert prompt is not None or input_ids is not None, "Either prompt or input_ids must be provided." - active_worker_to_url_map = self._get_active_worker_to_url_map() - server_url = active_worker_to_url_map.get(worker) - self.workers_info[server_url].running_count += 1 - response_ref = worker.rollout.remote( # type: ignore[attr-defined] - prompt=prompt, - input_ids=input_ids, - tools=tools, - tool_choice=tool_choice, - sample_params=self.sample_params, - extra_params=self.extra_params, - format=format, - extra_info=extra_info, - ) - try: - selected_worker_info = self.workers_info[server_url] - response = await asyncio.wait_for( - response_ref, timeout=self.config.rollout_timeout * self.timeout_multiplier - ) - selected_worker_info.success_count += 1 - if response.state == "failed" or response.state == "skipped": - selected_worker_info.failure_count += 1 - self.logger.error(f"Get failed/skipped response from rollout worker {worker}, deactivate it.") - self.deactivate_worker_by_url(server_url) - return response - except asyncio.TimeoutError: - selected_worker_info.failure_count += 1 - self.logger.error(f"Get response from rollout worker {worker} timeout and return skip this sample.") - self.deactivate_worker_by_url(server_url) - return RLRolloutResponseItem(state="skipped") - - def get_rollout_stats(self) -> str: - """Get statistics about the rollout workers. - - Returns: - str: A formatted string containing statistics about each rollout - """ - log_parts = ["Rollout Worker Stats:"] - for url, info in self.workers_info.items(): - log_parts.append( - f" - URL: {url} | Rank: {info.rank} | Active: {info.is_active} | " - f"Running: {info.running_count} | Success: {info.success_count} | " - f"Failures: {info.failure_count}" - ) - log_msg = "\n".join(log_parts) - return log_msg - - def start_api_server(self, host: str = "0.0.0.0", port: int = 8000): - """Starts the API server to expose the rollout functionality.""" - app = FastAPI() - port = self.config.api_port if self.config.api_port else port - - original_port = port - while self._is_port_in_use(host, port): - self.logger.warning(f"Port {port} is in use, trying port {port + 1}") - port += 1 - - if original_port != port: - self.logger.info(f"API server will use port {port} instead of the originally configured {original_port}.") - - @app.post("/v1/chat/completions") - async def chat_completions(request: RLRolloutRequestItem) -> RLRolloutResponseItem: - response = await self.rollout( - prompt=request.messages, - tools=request.tools, - tool_choice=request.tool_choice, - sample_params=request.sample_params, - extra_params=request.extra_params, - ) - return response - - config = uvicorn.Config(app, host=host, port=port) - server = uvicorn.Server(config) - server_thread = threading.Thread(target=server.run, daemon=True) - server_thread.start() - - # internal functions - def _update_dist_init_addr(self, nodes_per_engine, server_urls_per_engine, dist_init_addrs, tp_size): - """Update the distributed initialization addresses for workers. - - This is used to group workers that belong to the same inference engine. - - Args: - nodes_per_engine (int): The number of nodes per inference engine. - server_urls_per_engine (int): The number of server urls per inference engine. - dist_init_addrs (list): The list of initial addresses. - tp_size (int): The tensor parallel size. - - Returns: - list: The updated list of distributed initialization addresses. - """ - # lmdeploy pytorch ep: server_urls_per_engine > 1 - # sglang cross node engine: nodes_per_engine > 1 - assert server_urls_per_engine == 1 or nodes_per_engine == 1 - if nodes_per_engine > 1: - index = list(range(0, self.num_workers + 1, tp_size)) + [self.num_workers] - for i in range(1, len(index)): - dist_init_addrs[index[i - 1] : index[i]] = [dist_init_addrs[index[i - 1]]] * (index[i] - index[i - 1]) - if server_urls_per_engine > 1: - activate_servers = len(dist_init_addrs) - for i in range(0, activate_servers, server_urls_per_engine): - dist_init_addrs[i : i + server_urls_per_engine] = [dist_init_addrs[i]] * server_urls_per_engine - return dist_init_addrs - - def _get_active_servers_count(self, infer_config: RolloutConfig, gpu_nums: int): - """Calculate the number of active servers and nodes per engine. - - This calculation depends on the inference backend and parallelism settings. - - Args: - infer_config (RolloutConfig): The rollout configuration. - gpu_nums (int): The total number of GPUs available. - - Returns: - Tuple[int, int]: A tuple containing the number of active servers - and the number of nodes per engine. - """ - # NOTE:Since different inference engines have different launch methods, - # the number of nodes contained in each engine is not consistent. - # For example: sglang requires starting an inference engine for each node, - # while lmdeploy and vllm does not. Therefore, we calculate the number - # of active servers based on the configuration. - support_cross_node_comm = infer_config.rollout_cross_node_comm - gpus_per_node = infer_config.gpus_per_node - nodes_per_engine = ( - 1 - if support_cross_node_comm or self.num_gpus_per_engine < gpus_per_node - else self.num_gpus_per_engine // gpus_per_node - ) - - active_servers_count = int( - (gpu_nums // self.num_gpus_per_engine) * nodes_per_engine * infer_config.server_urls_per_engine - ) - return active_servers_count, nodes_per_engine - - def _broadcast_to_active_workers(self, method_name: str, block: bool): - """Helper function to call a method on all active workers. - - Args: - method_name (str): The name of the method to call. - block (bool): Whether to block until the call completes. - - Returns: - A list of futures if `block` is False, otherwise a list of results. - """ - futures = [] - for info in self.workers_info.values(): - if info.is_active: - futures.append(getattr(info.actor, method_name).remote()) - else: - self.logger.warning(f"Skipping {method_name} for inactive worker {info.actor}.") - - if not block: - return futures - - results = ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT) - return results - - def pause(self, block=True): - """Pauses all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("pause", block) - - def restart(self, block=True): - """Restarts all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("restart", block) - - def reset_prefix_cache(self, block=True): - """Resets the prefix cache on all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("reset_prefix_cache", block) - - def offload(self, block=True): - """Offloads model weights and KV cache on all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("offload", block) - - def onload_weights(self, block=True): - """Onloads model weights on all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("onload_weights", block) - - def onload_kvcache(self, block=True): - """Onloads KV cache on all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("onload_kvcache", block) - - def shutdown(self, block=True): - """Shuts down all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("shutdown", block) diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py deleted file mode 100644 index ed3324abda..0000000000 --- a/xtuner/v1/ray/rollout/worker.py +++ /dev/null @@ -1,854 +0,0 @@ -import asyncio -import copy -import json -import multiprocessing -import os -import time -import traceback -import uuid -from abc import abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union - -import httpx -import numpy as np -import ray -import requests # type: ignore[import-untyped] -import torch -from packaging.version import Version -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -from transformers import AutoTokenizer -from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem, RolloutState -from xtuner.v1.ray import find_master_addr_and_port -from xtuner.v1.ray.base import AutoAcceleratorWorkers, SingleAcceleratorWorker -from xtuner.v1.ray.config import RolloutConfig -from xtuner.v1.utils import get_logger -from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult - - -def get_eos_token(model_path: str) -> int | List[int]: - from xtuner.v1.utils.logger import get_logger - - logger = get_logger() - generation_config_path = os.path.join(model_path, "generation_config.json") - if not os.path.exists(generation_config_path): - logger.warning( - f"Config {generation_config_path} does not exist and thus cannot get eos_token. You must provide eos_token manually." - ) - return [] - with open(generation_config_path) as f: - generation_config = json.load(f) - eos_token_id = generation_config.get("eos_token_id") - return eos_token_id - - -class RolloutWorker(SingleAcceleratorWorker): - """Base class for a rollout worker that runs an inference server. - - This class manages the lifecycle of a distributed inference server, including initialization, launching, and - handling generation requests. It is designed to be subclassed for specific inference backends like LMDeploy, vLLM - or SGLang. - """ - - def __init__( - self, - config: RolloutConfig, - rank: int, - master_addr: str, - master_port: int, - world_size: int, - accelerator: str = "GPU", - ): - """Initialize the RolloutWorker. - - Args: - config (RolloutConfig): The configuration for the rollout. - rank (int): The rank of this worker in the distributed setup. - master_addr (str): The address of the Ray master node. - master_port (int): The port of the Ray master node. - world_size (int): The total number of workers. - accelerator (str): The type of accelerator to use. - Defaults to "GPU". - """ - self.config = config - self.rank = rank - self.master_addr = master_addr # ray master - self.master_port = master_port - self.world_size = world_size - self.accelerator = accelerator - self.server_func: Callable - self.endpoints: dict[str, str] = dict() - self.engine_rank_mesh_array: list[list[int]] - # http_concurrency is calculated based on the max batch size per engine and the total number of engines - assert config.rollout_max_batch_size_per_instance, ( - "rollout_max_batch_size_per_instance must be set in RolloutConfig" - ) - http_concurrency = config.rollout_max_batch_size_per_instance * config.allow_over_concurrency_ratio - limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) - self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout) - self.paused = False - self.server_task = None - self.engine_bundle_idxs: list[int] = [] - self.server_process: Optional[multiprocessing.Process] = None - self.logger = get_logger(log_dir=config.worker_log_dir, tag="RolloutWorker") - self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True) - self.check_flag = True # only print once - self.enable_return_routed_experts = self.config.enable_return_routed_experts - if self.rank == 0: - self.logger.info(f"RolloutConfig:\n{self.config.model_dump_json(indent=2)}") - eos_token = get_eos_token(self.config.model_path) - self.logger.info(f"Using eos_token: {eos_token} for model at {self.config.model_path}") - self.eos_token: List[int] = [eos_token] if isinstance(eos_token, int) else eos_token - self.receive_abort_request = asyncio.Event() - self.abort_timeout = 5.0 - - def init_dist_port(self): - """Initialize distributed communication ports. - - This method acquires three free ports for the distributed setup: - one for the inference server, one for NCCL, and one for Ray's - distributed communication. - - Returns: - str: The distributed initialization address (host:port). - """ - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=ray.util.get_current_placement_group(), - placement_group_capture_child_tasks=True, - placement_group_bundle_index=self.engine_bundle_idxs[0], - ) - - local_rank = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) - interval = 1024 - start_port = self.config.dist_port_base + local_rank * interval - end_port = start_port + interval - self.host, self.ports = ray.get( - find_master_addr_and_port.options(scheduling_strategy=scheduling_strategy).remote( - nums=3, - start_port=start_port, - end_port=end_port, - ) - ) - - self.dist_port = self.ports[0] - self.server_port = self.ports[1] - self.nccl_port = self.ports[2] - self.dist_init_addr = f"{self.host}:{self.dist_port}" - self.server_url = f"http://{self.host}:{self.server_port}" - return self.dist_init_addr - - def init(self, dist_init_addr: str = ""): - """Initialize the worker and launch the server. - - Args: - dist_init_addr (str): The distributed initialization address. - If not provided, the one generated by `init_dist_port` is used. - - Returns: - Tuple[int, str]: A tuple containing the worker's rank and its - server URL. - """ - self.dist_init_addr = dist_init_addr if dist_init_addr else self.dist_init_addr - self.receive_abort_request.clear() - self.launch_server() - return (self.rank, self.server_url) - - def set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]): - self.engine_rank_mesh_array = engine_rank_mesh_array - - def set_engine_bundle_idxs(self, engine_bundle_idxs: list[int]): - """Set the bundle indices for the inference engine. - - This is used by some backends (like LMDeploy with Ray executor) to - know which bundles in the placement group belong to this engine. - - Args: - engine_bundle_idxs (list[int]): A list of bundle indices. - """ - self.engine_bundle_idxs = engine_bundle_idxs - - def launch_server(self): - """Launch the inference server as a separate process or Ray task. - - It waits for the server to become healthy before returning. - - Raises: - TimeoutError: If the server fails to start within the specified - timeout. - Exception: If the server task terminates unexpectedly. - """ - server_configs = self._transform_rollout_config_to_server_configs() - timeout = 3600.0 # Increased timeout to 5 minutes for downloading large models - start_time = time.perf_counter() - last_log_time = start_time - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {server_configs.api_key}", - } - - self.logger.info(f"Launch server task on server_url: {self.server_url}") - - # note(@duanyanhui): launch server as multiprocessing for sglang temporarily - if self.config.launch_server_method == "multiprocessing": - mp_ctx = multiprocessing.get_context("spawn") - process = mp_ctx.Process(target=self.server_func, args=(server_configs,)) - process.start() - self.server_process = process - time.sleep(60) # Wait for the server to start - with requests.Session() as session: - while time.perf_counter() - start_time < timeout: - try: - response = session.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers - ) - if response.status_code == 200: - return - except requests.RequestException as e: - self.logger.error( - f"can't connect to server url {self.server_url}/{self.endpoints['health_generate']} because {e}" - ) - - current_time = time.perf_counter() - if current_time - last_log_time >= 15: - self.logger.info( - f"Waiting for server to start, Elapsed time: {current_time - start_time:.2f}s" - ) - last_log_time = current_time - - time.sleep(5) - process.terminate() - raise TimeoutError("Server failed to start within the timeout period.") - else: - # launch the server as ray task - # so that the lmdeploy backend could get externl pg - current_pg = ray.util.get_current_placement_group() - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=current_pg, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=self.engine_bundle_idxs[0], - ) - assert ray.is_initialized() - ray_kwargs = ( - {"runtime_env": server_configs.ray_runtime_env} if hasattr(server_configs, "ray_runtime_env") else {} - ) - self.server_task = ( - ray.remote(self.server_func) - .options( - scheduling_strategy=scheduling_strategy, - **AutoAcceleratorWorkers.get_pg_options(current_pg), - **ray_kwargs, - ) - .remote(server_configs) - ) - - with requests.Session() as session: - while time.perf_counter() - start_time < timeout: - try: - response = session.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers - ) - if response.status_code == 200: - return - except requests.RequestException: - pass - - try: - ray.get(self.server_task, timeout=0.1) - raise Exception("Server task terminated unexpectedly.") - except ray.exceptions.GetTimeoutError: - pass - except Exception as e: - raise e - - current_time = time.perf_counter() - if current_time - last_log_time >= 15: - self.logger.info( - f"Waiting for server to start... Elapsed time: {current_time - start_time:.2f}s" - ) - last_log_time = current_time - - ray.cancel(self.server_task) - raise TimeoutError("Server failed to start within the timeout period.") - - def _adapt_input_to_openai_spec(self, prompts, tools, tool_choice): - openai_prompts = [] - openai_tools = [] - # transform claude spec to openai spec - # 1. transform system prompt: concat provided system_prompt to input prompt - system_prompt = self.config.system_prompt - if system_prompt: - system_prompt_json = {"role": "system", "content": f"{system_prompt}"} - prompts.insert(0, system_prompt_json) - # 2. transform multi-modal usage - for prompt in prompts: - content = prompt["content"] - openai_content = [] - for item in content: - if item["type"] == "image": - if item["source"]["type"] == "base64": - openai_url = f"data:{item['source']['media_type']};base64,{item['source']['data']}" - if item["source"]["type"] == "url": - openai_url = item["source"]["url"] - new_prompt = {"type": "image_url", "image_url": {"url": openai_url}} - openai_content.append(new_prompt) - elif item["type"] == "text": - openai_content.append(item) - new_prompt = copy.deepcopy(prompt) - new_prompt["content"] = openai_content - openai_prompts.append(new_prompt) - # 3. transform tool use - for tool in tools: - openai_tool = { - "type": "function", - "function": { - "name": tool["name"], - "description": tool["description"], - "parameters": tool["input_schema"], - }, - } - openai_tools.append(openai_tool) - return openai_prompts, openai_tools - - def _check_infer_engine_version(self, return_token_ids: bool): - # TODO(@duanyanhui): remove this check when all backends support return_token_ids - if self.check_flag: - if os.environ.get("XTUNER_USE_VLLM", "0") == "1": - if return_token_ids: - self.logger.error( - "VLLM backend does not support return_token_ids or generate with input_ids as input in Xtuner now" - ) - elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": - import lmdeploy - - lmdeploy_version = lmdeploy.__version__ - if return_token_ids and Version(lmdeploy_version) < Version("0.10.2"): - self.logger.error( - f"You should use lmdeploy >= v0.10.2 to support return_token_ids, but current version is {lmdeploy_version}" - ) - self.check_flag = False - - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - if self.receive_abort_request.is_set(): - self.logger.debug(f"Request to {url} was cancelled before sending due to an abort signal.") - return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) - req = self.client.build_request( - "POST", - url, - headers=headers, - json=payload, - ) - r = await self.client.send(req) - r.raise_for_status() - return HttpRequestResult(response=r) - - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - return result - - async def rollout_task( - self, - prompts: Union[str, List[Dict[str, Any]]] | None, - input_ids: List[int] | None, - tools: List, - tool_choice: str, - sample_params: dict, - extra_params: dict, - format: str, - extra_info: dict, - ) -> RLRolloutResponseItem: - uid = extra_info.get("action_id", str(uuid.uuid4())) - action_id = extra_info.get("action_id", str(uuid.uuid4())) - root_id = extra_info.get("action_id", str(uuid.uuid4())) - response = None - cur_retry_times = 0 - - if format == "openai": - openai_prompts, openai_tools = prompts, tools - else: - openai_prompts, openai_tools = self._adapt_input_to_openai_spec(prompts, tools, tool_choice) - - if "return_token_ids" in extra_params and extra_params["return_token_ids"]: - endpoint_url = f"{self.server_url}/{self.endpoints['generate']}" - else: - endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}" - - while True: - # 当拼接后的response_ids长度已经达到了max_tokens时,则不需要发送数据,直接返回 - if extra_info.get("partial_rollout_input_ids", None) is not None: - if sample_params["max_tokens"] == 0: - self.logger.info( - f"Request {uid} reached max context length {self.config.context_length}, no need to rollout more." - ) - return RLRolloutResponseItem( - response=None, - response_ids=None, - logprobs=None, - num_return_tokens=0, - finish_reason="length", - state=RolloutState.COMPLETED, - ) - if extra_info["partial_rollout_input_ids"][-1] in self.eos_token: - self.logger.info( - f"Request {uid} already ends with eos token {extra_info['partial_rollout_input_ids'][-1]}, no need to rollout more" - ) - return RLRolloutResponseItem( - response=None, - response_ids=None, - logprobs=None, - num_return_tokens=0, - finish_reason="stop", - state=RolloutState.COMPLETED, - ) - - http_result = await self._create_request( - endpoint_url, - openai_prompts, - input_ids, - openai_tools, - tool_choice, - sample_params=sample_params, - extra_params=extra_params, - extra_info=extra_info, - ) - # Case 1: Request was successful - if http_result.response is not None: # 推理完成:completed状态:finish_reason为abort/stop/length, 退出 - response = await self._handle_non_stream_response( - root_id, action_id, sample_params, extra_params, http_result.response, extra_info - ) - if response.state == RolloutState.SKIPPED: - # retry - cur_retry_times += 1 - if cur_retry_times < self.config.max_retry_per_sample: - self.logger.warning( - f"Invalid rollout response for request {uid}, retrying {cur_retry_times}/{self.config.max_retry_per_sample}." - ) - await asyncio.sleep(0.1) - continue - else: - return RLRolloutResponseItem(state=RolloutState.SKIPPED) - return response - - # Case2: Return aborted error if receive abort signal - if http_result.error_type == HttpRequestErrorType.REQUEST_ABORTED: - return RLRolloutResponseItem(finish_reason="abort", state=RolloutState.ABORTED) - - # Case 3: A fatal, non-retryable error occurred - elif http_result.is_unknown_error: - raise RuntimeError( - f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}" - ) - - # Case 4: A retryable error occurred, and we still have retries left - elif http_result.is_retryable and cur_retry_times < self.config.max_retry_per_sample: - cur_retry_times += 1 - self.logger.warning( - f"Retrying rollout request {uid} to {http_result.url} due to {http_result.error_type} with {http_result.error_msg}. " - f"Retry {cur_retry_times}/{self.config.max_retry_per_sample}." - ) - await asyncio.sleep(0.1) - continue - - elif http_result.is_retryable and cur_retry_times >= self.config.max_retry_per_sample: - self.logger.warning( - f"rollout request {uid} to {http_result.url} was skipped due to max retries reached" - ) - return RLRolloutResponseItem(state=RolloutState.SKIPPED) - elif http_result.is_client_error: - self.logger.warning( - f"rollout request {uid} to {http_result.url} was skipped due to client error {http_result.error_type} with {http_result.error_msg}" - ) - return RLRolloutResponseItem(state=RolloutState.SKIPPED) - elif http_result.is_server_error: - self.logger.warning( - f"rollout request {uid} to {http_result.url} failed due to server error {http_result.error_type} with {http_result.error_msg}" - ) - return RLRolloutResponseItem(state=RolloutState.FAILED) - else: - raise RuntimeError( - f"Unhandled error case for rollout request {uid} to {http_result.url}: {http_result.exception}" - ) - - async def _handle_stream_response(self, uid, sample_params, extra_params, response) -> RLRolloutResponseItem: - last_trajectory = "" - last_token_ids = [] - last_logprobs = [] - finish_reason = "" - async for chunk in response.aiter_lines(): - if not chunk.startswith("data:"): - continue - try: - chunk_data_str = chunk[len("data:") :].strip() - if self.paused or chunk_data_str == "[DONE]": - finish_reason = "paused" if self.paused else finish_reason - break - if not (chunk_data_str.startswith("{") and chunk_data_str.endswith("}")): - continue - - chunk_data = json.loads(chunk_data_str) - - if "return_token_ids" in extra_params and extra_params["return_token_ids"]: - last_trajectory = last_trajectory + chunk_data.get("text", "") - finish_reason = chunk_data["meta_info"].get("finish_reason") - if finish_reason is not None: - finish_reason = finish_reason["type"] - - output_token_logprobs = chunk_data["meta_info"].get("output_token_logprobs") - if output_token_logprobs is not None: - for token_logprob in output_token_logprobs: - last_logprobs.append(token_logprob[0]) - last_token_ids.append(token_logprob[1]) - else: - delta_content = chunk_data["choices"][0]["delta"].get("content") - last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory - last_token_id = chunk_data["choices"][0]["delta"].get("gen_tokens") - if last_token_id is not None: - last_token_ids.extend(last_token_id) - finish_reason = chunk_data["choices"][0].get("finish_reason") - logprobs_content = chunk_data["choices"][0]["logprobs"] - if logprobs_content is not None: - for content_item in logprobs_content["content"]: - last_logprobs.append(content_item["logprob"]) - - except json.JSONDecodeError as e: - self.logger.error(f"JSON decode error for chunk in request {uid}: {chunk}, error: {e}") - continue - except Exception as e: - self.logger.error(f"Error processing chunk for {uid}: {chunk}, error: {e}") - return RLRolloutResponseItem( - response="", - finish_reason="failed", - ) - - assert finish_reason in ["stop", "length", "tool_call", "abort"], f"Unexpected finish_reason: {finish_reason}" - rollout_response = RLRolloutResponseItem( - response=last_trajectory, - response_ids=last_token_ids if len(last_token_ids) > 0 else None, - num_return_tokens=len(last_token_ids) if len(last_token_ids) > 0 else None, - finish_reason=finish_reason, - logprobs=last_logprobs, - ) - return rollout_response - - async def _handle_non_stream_response( - self, root_id, action_id, sample_params, extra_params, response, input_extra_info - ) -> RLRolloutResponseItem: - response = response.json() - uid = action_id - if "return_token_ids" in extra_params and extra_params["return_token_ids"]: - last_logprobs: list[float] = [] - try: - extra_info = {} - finish_reason = response["meta_info"]["finish_reason"]["type"] - if finish_reason == "abort" and self.receive_abort_request.is_set() is False: - self.receive_abort_request.set() - self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}") - if "output_token_logprobs" in response["meta_info"]: - if response["meta_info"]["output_token_logprobs"] is None: - last_token_ids = [] - last_logprobs = [] - else: - last_token_ids = [item[1] for item in response["meta_info"]["output_token_logprobs"]] - last_logprobs = [item[0] for item in response["meta_info"]["output_token_logprobs"]] - assert len(last_token_ids) <= sample_params["max_tokens"], ( - f"Generation length exceeds the limit: generated length is {len(last_token_ids)}, limit is {sample_params['max_tokens']}" - ) - else: - num_return_tokens = response["meta_info"].get("completion_tokens", 0) - last_token_ids = response["output_ids"][-num_return_tokens:] if num_return_tokens > 0 else [] - - if self.enable_return_routed_experts: - assert "routed_experts" in response["meta_info"], ( - "enable_return_routed_experts is True, but routed_experts is not in meta_info" - ) - exist_history_routed_experts = ( - "routed_experts" in input_extra_info and input_extra_info["routed_experts"] is not None - ) - routed_experts = response["meta_info"]["routed_experts"] # token[layer[expert]] - if routed_experts is not None and not exist_history_routed_experts: - # 不存在历史专家,先把当前专家存起来 - if isinstance(routed_experts, str): - import base64 - - data = base64.b64decode(routed_experts) - routed_experts = ray.cloudpickle.loads(data) - else: - routed_experts = torch.tensor(routed_experts) # n,layer,expert - routed_experts = ray.put(routed_experts) - extra_info["routed_experts"] = routed_experts - elif routed_experts is not None and exist_history_routed_experts: - # 存在历史专家,则不进行put 操作,直接进行concat - if isinstance(routed_experts, str): - import base64 - - data = base64.b64decode(routed_experts) - routed_experts = ray.cloudpickle.loads(data) - cur_routed_experts = await routed_experts # n,layer,expert - ray._private.internal_api.free(routed_experts) - else: - routed_experts = torch.tensor(routed_experts) # n,layer,expert - cur_routed_experts = routed_experts - - history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert - ray._private.internal_api.free(input_extra_info["routed_experts"]) - del input_extra_info["routed_experts"] - - assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[ - 0 - ] - 1 <= cur_routed_experts.shape[0], ( - f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts.shape}" - ) - init_cur_roued_experts = cur_routed_experts.shape[0] - cur_routed_experts = cur_routed_experts[history_routed_experts.shape[0] :, :, :] - concat_routed_experts = np.concatenate((history_routed_experts, cur_routed_experts), axis=0) - prompt_tokens = response["meta_info"].get("prompt_tokens", 0) - response_tokens = response["meta_info"].get("completion_tokens", 0) - assert concat_routed_experts.shape[0] == prompt_tokens + response_tokens - 1, ( - f"Routed experts shape {concat_routed_experts.shape[0]} does not match total tokens {prompt_tokens + response_tokens - 1}" - ) - self.logger.debug( - f"[{root_id}/{action_id}] Partial Rollout Stats: " - f"Tokens(prompt={prompt_tokens}, response={response_tokens}, total={prompt_tokens + response_tokens}) | " - f"Experts(exist={history_routed_experts.shape}, init_cur={init_cur_roued_experts}, cur={cur_routed_experts.shape}, concat={concat_routed_experts.shape})" - ) - extra_info["routed_experts"] = ray.put(concat_routed_experts) - else: - assert finish_reason == "abort", ( - f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}" - ) - # NOTE: When set return_token_ids = True, the response must contain valid token_ids/logprobs. - # If not, we consider it as an invalid response and retry it. - # NOTE: !!! When finish_reason is abort, some queries may not return token_ids or logprobs. !!! - if finish_reason != "abort" and (len(last_token_ids) == 0 or len(last_logprobs) == 0): - self.logger.error(f"Invalid rollout response for request {uid}: {response}") - return RLRolloutResponseItem(state=RolloutState.SKIPPED) - else: - rollout_response = RLRolloutResponseItem( - response=response["text"], - response_ids=last_token_ids, - num_return_tokens=len(last_token_ids), - finish_reason=finish_reason, - logprobs=last_logprobs, - extra_info=extra_info, - state=RolloutState.ABORTED if finish_reason == "abort" else RolloutState.COMPLETED, - ) - # self.logger.info(f"Rollout response for request {uid}: finish_reason={finish_reason}, num_return_tokens={len(last_token_ids)}") - return rollout_response - except KeyError as e: - error_msg = f"Missing expected key {e} in response {response} for {uid}" - raise RuntimeError(error_msg) - except IndexError as e: - error_msg = f"Index error {e} while processing response {response} for {uid}" - raise RuntimeError(error_msg) - except AssertionError as e: - error_msg = f"AssertionError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except json.JSONDecodeError as e: - error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except TypeError as e: - error_msg = f"TypeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except Exception as e: - error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" - raise RuntimeError(error_msg) - else: - # v1/chat/completions API response - try: - last_trajectory = response["choices"][0]["message"]["content"] - finish_reason = response["choices"][0]["finish_reason"] - rollout_response = RLRolloutResponseItem( - response=last_trajectory, - finish_reason=finish_reason, - num_return_tokens=response["usage"]["completion_tokens"], - ) - return rollout_response - except KeyError as e: - error_msg = f"Missing expected key {e} in response {response} for {uid}" - raise RuntimeError(error_msg) - except IndexError as e: - error_msg = f"Index error {e} while processing response {response} for {uid}" - raise RuntimeError(error_msg) - except AssertionError as e: - error_msg = f"AssertionError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except json.JSONDecodeError as e: - error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except TypeError as e: - error_msg = f"TypeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except Exception as e: - error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" - raise RuntimeError(error_msg) - - async def rollout( - self, - prompt: Union[str, List[Dict[str, Any]]] | None = None, - input_ids: Optional[List[int]] | None = None, - tools: List = [], - tool_choice: str = "auto", - sample_params: dict = dict(), - extra_params: dict = dict(), - format: str = "openai", - extra_info: dict = dict(), - ) -> RLRolloutResponseItem: - """Public method to initiate a rollout. - - Args: - prompt (str): The input prompt for generation. - sample_params (dict): Parameters for sampling. - - Returns: - The result of the `rollout_task`. - """ - return await self.rollout_task( - prompt, input_ids, tools, tool_choice, sample_params, extra_params, format=format, extra_info=extra_info - ) - - def pause(self): - """Pause the worker's generation process.""" - self.paused = True - - def restart(self): - """Resume the worker's generation process.""" - self.receive_abort_request.clear() - - def check_health(self) -> bool: - """Check the health of the worker's server. - - Returns: - bool: True if the server is healthy, False otherwise. - """ - try: - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {self.config.api_key}", - } - response = requests.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers, timeout=5.0 - ) - return response.status_code == 200 - except requests.RequestException as e: - self.logger.error(f"Health check failed for server {self.server_url}: {e}") - return False - - def shutdown(self): - """Shut down the worker, its server task, and any child processes.""" - if self.server_task is not None: - ray.cancel(self.server_task) - return - - if self.server_process is not None: - import psutil - - parent = psutil.Process(self.server_process.pid) - children = parent.children(recursive=True) - for child in children: - child.terminate() - gone, alive = psutil.wait_procs(children, timeout=5) - for child in alive: - child.kill() - parent.terminate() - parent.wait(timeout=5) - self.logger.debug(f"Worker {self.rank} server process and its children terminated.") - return - - @abstractmethod - async def _create_request( - self, - url: str, - prompt: Union[str, List[Dict[str, Any]]] | None, - input_ids: List[int] | None, - tools: List, - tool_choice: str, - sample_params: dict, - extra_params: dict, - extra_info: dict, - ): - """Abstract method to create a generation request. - - Must be implemented by subclasses. - """ - raise NotImplementedError("_create_request must be implemented in subclass") - - @abstractmethod - def _transform_rollout_config_to_server_configs(self): - """Abstract method to transform rollout config to server configs. - - Must be implemented by subclasses. - """ - raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") - - @abstractmethod - def _transform_sample_params(self, sample_params: Dict): - """Abstract method to transform rollout config to server configs. - - Must be implemented by subclasses. - """ - raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") - - @abstractmethod - def get_logprobs(self, input_ids, sampling_params): - """Abstract method to get log probabilities. - - Must be implemented by subclasses. - """ - raise NotImplementedError("get_logprobs must be implemented in subclass") - - @abstractmethod - def update_weights(self): - """Abstract method to update model weights. - - Must be implemented by subclasses. - """ - raise NotImplementedError("update_weights must be implemented in subclass") - - @abstractmethod - def reset_prefix_cache(self): - """Abstract method to reset the prefix cache. - - Must be implemented by subclasses. - """ - raise NotImplementedError("reset_prefix_cache must be implemented in subclass") - - @abstractmethod - def offload(self): - """Abstract method to offload the model and KVcache. - - Must be implemented by subclasses. - """ - raise NotImplementedError("reset_prefix_cache must be implemented in subclass") - - @abstractmethod - def onload_weights(self): - """Abstract method to onload the model weights. - - Must be implemented by subclasses. - """ - pass - - @abstractmethod - def onload_kvcache(self): - """Abstract method to onload the KV cache. - - Must be implemented by subclasses. - """ - pass - - @abstractmethod - def pause_generation(self): - """Abstract method to pause the generation process. - - Must be implemented by subclasses. - """ - raise NotImplementedError("pause_generation must be implemented in subclass") - - @abstractmethod - def continue_generation(self): - """Abstract method to continue the generation process. - - Must be implemented by subclasses. - """ - raise NotImplementedError("continue_generation must be implemented in subclass") diff --git a/xtuner/v1/rl/agent_loop/__init__.py b/xtuner/v1/rl/agent_loop/__init__.py new file mode 100644 index 0000000000..64c919ea3e --- /dev/null +++ b/xtuner/v1/rl/agent_loop/__init__.py @@ -0,0 +1,55 @@ +from xtuner.v1.rl.judger import ComposedJudgerConfig, Judger, JudgerConfig + +from .agent_loop import ( + AgentLoop, + AgentLoopActor, + AgentLoopConfig, + AgentLoopSpec, + RayAgentLoop, + RayAgentLoopProxy, + RouterAgentLoop, +) +from .agent_loop_manager import ( + AgentLoopManager, + AgentLoopManagerConfig, + ProduceBatchResult, + TaskSpecConfig, +) +from .producer import ( + AsyncProduceStrategy, + AsyncProduceStrategyConfig, + ProduceStrategy, + ProduceStrategyConfig, + SyncProduceStrategy, + SyncProduceStrategyConfig, +) +from .sampler import Sampler, SamplerConfig +from .single_turn_agent_loop import SingleTurnAgentLoop, SingleTurnAgentLoopConfig + + +__all__ = [ + "AgentLoopConfig", + "SingleTurnAgentLoopConfig", + "AgentLoop", + "AgentLoopSpec", + "AgentLoopActor", + "RouterAgentLoop", + "RayAgentLoop", + "RayAgentLoopProxy", + "SingleTurnAgentLoop", + "Judger", + "JudgerConfig", + "ComposedJudgerConfig", + "AgentLoopManagerConfig", + "AgentLoopManager", + "TaskSpecConfig", + "ProduceBatchResult", + "ProduceStrategyConfig", + "SyncProduceStrategyConfig", + "AsyncProduceStrategyConfig", + "ProduceStrategy", + "SyncProduceStrategy", + "AsyncProduceStrategy", + "SamplerConfig", + "Sampler", +] diff --git a/xtuner/v1/rl/agent_loop/agent_loop.py b/xtuner/v1/rl/agent_loop/agent_loop.py new file mode 100644 index 0000000000..97cad2bfab --- /dev/null +++ b/xtuner/v1/rl/agent_loop/agent_loop.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from typing import TypeAlias, cast + +from pydantic import BaseModel, ConfigDict, Field, model_validator +from ray.actor import ActorClass, ActorProxy +from ray.util.placement_group import PlacementGroup + +from xtuner.v1.data_proto import RolloutState, SampleParams +from xtuner.v1.rl.judger import Judger +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.utils import CPUActorLauncher, create_task +from xtuner.v1.utils import get_logger, ray_method +from xtuner.v1.utils.processing_utils import load_processor, load_tokenizer + + +class AgentLoopConfig(ABC, BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + hf_checkpoint: str + sample_params: SampleParams + num_ray_actors: int = Field( + default=0, + ge=0, + description="Number of AgentLoop Ray actor instances. 0 means local mode.", + ) + num_cpus: float = Field(default=1, gt=0, description="CPU cores required by the AgentLoop actor itself.") + cpu_memory: int = Field(default=1024**3, gt=0, description="CPU memory in bytes required by AgentLoop.") + + @model_validator(mode="after") + def _validate_ray_actor_config(self) -> AgentLoopConfig: + if self.num_ray_actors == 0 and (self.num_cpus != 1 or self.cpu_memory != 1024**3): + logger = get_logger() + logger.warning("num_cpus and cpu_memory are ignored when AgentLoop runs in local mode.") + return self + + def build(self, rollout_controller, judger: Judger | None = None, logger=None) -> AgentLoopSpec: + if self.num_ray_actors == 0: + return self.build_local( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + if self.num_ray_actors > 1: + return self._build_router( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + return self._build_ray_actor( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + + @abstractmethod + def build_local( + self, + rollout_controller, + judger: Judger | None = None, + logger=None, + ) -> AgentLoop: ... + + def _build_ray_actor( + self, + rollout_controller: RolloutController, + pg: PlacementGroup | None = None, + judger: Judger | None = None, + logger=None, + ) -> RayAgentLoopProxy: + return cast( + "RayAgentLoopProxy", + CPUActorLauncher.build_actor( + AgentLoopActor, + self, + rollout_controller, + judger, + logger, + pg=pg, + bundle_idx=0, + actor_num_cpus=self.num_cpus, + actor_memory=self.cpu_memory, + capture_child_tasks=True, + ), + ) + + def _build_ray_actors( + self, + rollout_controller: RolloutController, + num_actors: int, + pg: PlacementGroup | None = None, + judger: Judger | None = None, + logger=None, + start_bundle_idx: int = 0, + ) -> list[RayAgentLoopProxy]: + return cast( + list["RayAgentLoopProxy"], + CPUActorLauncher.build_actors( + AgentLoopActor, + self, + rollout_controller, + judger, + logger, + pg=pg, + start_bundle_idx=start_bundle_idx, + num_workers=num_actors, + actor_num_cpus_per_worker=self.num_cpus, + actor_memory_per_worker=self.cpu_memory, + capture_child_tasks=True, + ), + ) + + def _build_router( + self, + rollout_controller: RolloutController, + pg: PlacementGroup | None = None, + judger: Judger | None = None, + logger=None, + start_bundle_idx: int = 0, + ) -> RouterAgentLoop: + return RouterAgentLoop( + workers=self._build_ray_actors( + rollout_controller=rollout_controller, + num_actors=self.num_ray_actors, + pg=pg, + judger=judger, + logger=logger, + start_bundle_idx=start_bundle_idx, + ), + rollout_ctl=rollout_controller, + ) + + +class AgentLoop(ABC): + def __init__( + self, + rollout_ctl: RolloutController, + sample_params: SampleParams, + hf_checkpoint: str, + judger: Judger | None = None, + logger=None, + ) -> None: + self.rollout_ctl = rollout_ctl + self.hf_checkpoint = hf_checkpoint + self.tokenizer = load_tokenizer(hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(hf_checkpoint, trust_remote_code=True) + self.sample_params = sample_params + self.judger = judger + if logger is None: + self.logger = get_logger() + else: + self.logger = logger + + @abstractmethod + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: ... + + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + pending_tasks = [] + for state in rollout_state: + state.sample_params = self.sample_params + task = create_task(self.generate_sample(state, **kwargs)) + pending_tasks.append(task) + generated_samples = asyncio.gather(*pending_tasks) + group_samples = await generated_samples + return group_samples + + +class RouterAgentLoop: + def __init__(self, workers: list[RayAgentLoopProxy], rollout_ctl: RolloutController): + self.workers = workers + self.rollout_ctl = rollout_ctl + self._worker_loads = dict.fromkeys(workers, 0) + self._rr_index = 0 + self._lock = asyncio.Lock() + + async def _pick_worker(self) -> RayAgentLoopProxy: + async with self._lock: + min_load = min(self._worker_loads.values()) + candidates = [worker for worker in self.workers if self._worker_loads[worker] == min_load] + worker = candidates[self._rr_index % len(candidates)] + self._rr_index = (self._rr_index + 1) % len(self.workers) + self._worker_loads[worker] += 1 + return worker + + async def _release_worker(self, worker: RayAgentLoopProxy) -> None: + async with self._lock: + self._worker_loads[worker] -= 1 + + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + worker = await self._pick_worker() + try: + return await worker.generate_sample.remote(rollout_state, **kwargs) + finally: + await self._release_worker(worker) + + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + worker = await self._pick_worker() + try: + return await worker.generate_group.remote(rollout_state, **kwargs) + finally: + await self._release_worker(worker) + + def get_worker_status(self) -> dict[str, int]: + return {str(worker): load for worker, load in self._worker_loads.items()} + + +async def get_agent_loop_rollout_ctl(agent_loop: AgentLoopSpec) -> RolloutController: + rollout_ctl = getattr(agent_loop, "rollout_ctl", None) + if rollout_ctl is not None: + return rollout_ctl + + get_rollout_ctl = getattr(agent_loop, "get_rollout_ctl", None) + if get_rollout_ctl is None or not hasattr(get_rollout_ctl, "remote"): + raise AttributeError(f"Agent loop {type(agent_loop)} does not expose rollout_ctl or get_rollout_ctl().") + return await get_rollout_ctl.remote() + + +class AgentLoopActor: + def __init__( + self, + agent_loop_config: AgentLoopConfig, + rollout_controller: RolloutController, + judger: Judger | None = None, + logger=None, + ): + self.agent_loop = agent_loop_config.build_local( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + + @ray_method + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + return await self.agent_loop.generate_sample(rollout_state, **kwargs) + + @ray_method + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + return await self.agent_loop.generate_group(rollout_state, **kwargs) + + @ray_method + async def get_rollout_ctl(self): + return self.agent_loop.rollout_ctl + + +RayAgentLoop = cast(ActorClass[AgentLoopActor], CPUActorLauncher.to_actor_class(AgentLoopActor)) +RayAgentLoopProxy: TypeAlias = ActorProxy[AgentLoopActor] +AgentLoopSpec: TypeAlias = AgentLoop | RayAgentLoopProxy | RouterAgentLoop diff --git a/xtuner/v1/rl/agent_loop/agent_loop_manager.py b/xtuner/v1/rl/agent_loop/agent_loop_manager.py new file mode 100644 index 0000000000..c091f16961 --- /dev/null +++ b/xtuner/v1/rl/agent_loop/agent_loop_manager.py @@ -0,0 +1,400 @@ +import asyncio +import math +import time +from dataclasses import dataclass +from pathlib import Path + +from pydantic import BaseModel, ConfigDict, Field + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto import RolloutState, Status +from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger +from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.rl.rollout import RolloutController, continue_generation, pause_generation +from xtuner.v1.rl.utils import asyncio_run +from xtuner.v1.utils import get_logger + +from .agent_loop import AgentLoopConfig, AgentLoopSpec, get_agent_loop_rollout_ctl +from .producer import ProducerTimings, ProduceStrategy, ProduceStrategyConfig, SyncProduceStrategyConfig +from .sampler import Sampler, SamplerConfig + + +@dataclass +class ProduceBatchResult: + """Result of a single ``produce_batch`` call. + + Args: + rollout_states (list[list[RolloutState]]): Completed rollout groups retrieved from the replay buffer for training. + group_gen_count (int | None): Number of generate-group calls finished in this batch (None if no generations ran). + group_gen_mean_s (float | None): Mean wall-clock time per generate-group call, in seconds. + group_gen_p50_s (float | None): Median (p50) generate-group time, in seconds. + group_gen_p99_s (float | None): 99th percentile generate-group time, in seconds. + group_gen_p99_p50_ratio (float | None): Ratio of p99 to p50, indicating tail-latency skew. + group_gen_pause_time_s (float | None): Time spent in pause/cleanup phase (async strategy only), in seconds. + leftover_completed (int): Number of completed groups remaining in the replay buffer after this batch. + leftover_aborted (int): Number of aborted groups remaining in the replay buffer. + leftover_expired (int): Number of expired groups remaining in the replay buffer. + """ + + rollout_states: list[list[RolloutState]] + # per-group generation timing stats (all None if no generations occurred) + group_gen_count: int | None = None + group_gen_mean_s: float | None = None + group_gen_p50_s: float | None = None + group_gen_p99_s: float | None = None + group_gen_p99_p50_ratio: float | None = None + group_gen_pause_time_s: float | None = None + # leftover samples remaining in replay buffer after batch retrieval + leftover_completed: int = 0 + leftover_aborted: int = 0 + leftover_expired: int = 0 + task_batch_sizes: dict[str, int] | None = None + task_results: dict[str, "ProduceBatchResult"] | None = None + + +@dataclass(frozen=True) +class _TaskRunner: + task_name: str + agent_loop: AgentLoopSpec + produce_strategy: ProduceStrategy + sampler: Sampler + weight: float = 1.0 + order: int = 0 + + +class _TaskSamplerView: + def __init__(self, samplers: list[Sampler]): + self._samplers = samplers + + def __len__(self) -> int: + return sum(len(sampler) for sampler in self._samplers) + + +def _fill_produce_timing_stats(result: ProduceBatchResult, stats: ProducerTimings) -> None: + if not stats.generate_times_s: + return + sorted_times = sorted(stats.generate_times_s) + n = len(sorted_times) + mean_s = sum(sorted_times) / n + p50_s = sorted_times[n // 2] + p99_s = sorted_times[int(n * 0.99)] + ratio = p99_s / p50_s if p50_s > 0 else float("inf") + result.group_gen_count = n + result.group_gen_mean_s = mean_s + result.group_gen_p50_s = p50_s + result.group_gen_p99_s = p99_s + result.group_gen_p99_p50_ratio = ratio + result.group_gen_pause_time_s = stats.pause_time_s + + +async def _produce_single_task_batch( + task_runner: _TaskRunner, + replay_buffer: ReplayBuffer, + batch_size: int, + rollout_step: int, + logger, + manager_name: str, +) -> ProduceBatchResult: + start = time.perf_counter() + logger.info(f"[{manager_name}][{task_runner.task_name}] produce_batch start batch={batch_size}") + stats: ProducerTimings = await task_runner.produce_strategy.produce_batch( + task_runner.agent_loop, + task_runner.sampler, + replay_buffer, + batch_size, + task_runner.task_name, + rollout_step, + ) + logger.info( + f"[{manager_name}][{task_runner.task_name}] produce scheduler done elapsed={time.perf_counter() - start:.3f}, and start replay_buffer.get" + ) + + result = ProduceBatchResult(rollout_states=[]) + _fill_produce_timing_stats(result, stats) + + start = time.perf_counter() + batch_rollout_states: list[list[RolloutState]] = await replay_buffer.get( + batch_size, task_runner.task_name, Status.COMPLETED + ) + logger.info( + f"[{manager_name}][{task_runner.task_name}] replay_buffer.get done completed_groups={len(batch_rollout_states)} elapsed={time.perf_counter() - start:.3f}" + ) + result.rollout_states = batch_rollout_states + completed_sample_count, aborted_sample_count, expired_sample_count = await asyncio.gather( + replay_buffer.count(task_name=task_runner.task_name, group_status=Status.COMPLETED), + replay_buffer.count(task_name=task_runner.task_name, group_status=Status.ABORTED), + replay_buffer.count(task_name=task_runner.task_name, group_status=Status.EXPIRED), + ) + result.leftover_completed = completed_sample_count + result.leftover_aborted = aborted_sample_count + result.leftover_expired = expired_sample_count + return result + + +class TaskSpecConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + task_name: str + weight: float = Field(default=1.0, ge=0.0) + agent_loop_config: AgentLoopConfig + judger_config: JudgerConfig | ComposedJudgerConfig | None = None + produce_strategy_config: ProduceStrategyConfig = SyncProduceStrategyConfig() + sampler_config: SamplerConfig + + +class AgentLoopManagerConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + tasks: list[TaskSpecConfig] | TaskSpecConfig + + def build( + self, + rollout_controller: RolloutController, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + replay_buffer: ReplayBuffer, + logger=None, + ) -> "AgentLoopManager": + tasks = self.tasks if isinstance(self.tasks, list) else [self.tasks] + if not tasks: + raise ValueError("AgentLoopManagerConfig requires at least one task config.") + + seen_task_names: set[str] = set() + task_runners: list[_TaskRunner] = [] + for order, task_cfg in enumerate(tasks): + if task_cfg.task_name in seen_task_names: + raise ValueError(f"Duplicate task_name found in AgentLoopManagerConfig: {task_cfg.task_name}") + seen_task_names.add(task_cfg.task_name) + + agent_loop = task_cfg.agent_loop_config.build( + rollout_controller=rollout_controller, + judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None, + logger=logger, + ) + produce_strategy = task_cfg.produce_strategy_config.build() + sampler = task_cfg.sampler_config.build(tokenizer=tokenizer, replay_buffer=replay_buffer) + task_runners.append( + _TaskRunner( + task_name=task_cfg.task_name, + agent_loop=agent_loop, + produce_strategy=produce_strategy, + sampler=sampler, + weight=task_cfg.weight, + order=order, + ) + ) + + return AgentLoopManager( + task_runners=task_runners, + replay_buffer=replay_buffer, + logger=logger, + ) + + +class AgentLoopManager: + _TASK_CHECKPOINT_DIR = "tasks" + + def __init__( + self, + task_runners: list[_TaskRunner], + replay_buffer: ReplayBuffer, + logger=None, + ): + if not task_runners: + raise ValueError("AgentLoopManager requires at least one task runner.") + if sum(task.weight for task in task_runners) <= 0: + raise ValueError("At least one task weight must be positive for AgentLoopManager.") + + self.task_runners = task_runners + self.replay_buffer = replay_buffer + self.data_sampler = ( + task_runners[0].sampler + if len(task_runners) == 1 + else _TaskSamplerView([task.sampler for task in task_runners]) + ) + self.name = task_runners[0].task_name if len(task_runners) == 1 else "multi_task" + if logger is None: + self.logger = get_logger() + else: + self.logger = logger + + def get_task_batch_sizes(self, global_batch_size: int, rollout_step: int) -> dict[str, int]: + """Return the per-task batch sizes for the current rollout step. + + Subclasses may override this method to implement custom dynamic batch allocation policies. Returning 0 for a + task effectively disables that task for the current produce_batch call. + """ + if global_batch_size < 0: + raise ValueError(f"global_batch_size must be non-negative, got {global_batch_size}") + + total_weight = sum(task.weight for task in self.task_runners) + if total_weight <= 0: + raise ValueError("Sum of task weights must be positive.") + if global_batch_size == 0: + return {task.task_name: 0 for task in self.task_runners} + + raw_allocations = [global_batch_size * task.weight / total_weight for task in self.task_runners] + floor_allocations = [math.floor(raw) for raw in raw_allocations] + remaining = global_batch_size - sum(floor_allocations) + + task_batch_sizes = {task.task_name: floor_allocations[idx] for idx, task in enumerate(self.task_runners)} + if remaining <= 0: + return task_batch_sizes + + ranked_tasks = sorted( + enumerate(self.task_runners), + key=lambda item: ( + -(raw_allocations[item[0]] - floor_allocations[item[0]]), + item[1].order, + ), + ) + for idx, task in ranked_tasks[:remaining]: + task_batch_sizes[task.task_name] += 1 + return task_batch_sizes + + def _validate_task_batch_sizes(self, task_batch_sizes: dict[str, int], global_batch_size: int) -> None: + expected_task_names = {task.task_name for task in self.task_runners} + actual_task_names = set(task_batch_sizes.keys()) + if actual_task_names != expected_task_names: + missing_task_names = expected_task_names - actual_task_names + extra_task_names = actual_task_names - expected_task_names + raise ValueError( + "Invalid task batch sizes returned by get_task_batch_sizes: " + f"missing={sorted(missing_task_names)}, extra={sorted(extra_task_names)}" + ) + + negative_batch_sizes = { + task_name: task_batch_size + for task_name, task_batch_size in task_batch_sizes.items() + if task_batch_size < 0 + } + if negative_batch_sizes: + raise ValueError(f"Task batch sizes must be non-negative, got {negative_batch_sizes}") + + total_batch_size = sum(task_batch_sizes.values()) + if total_batch_size != global_batch_size: + raise ValueError( + "Task batch sizes must sum to the requested global batch size, " + f"got total={total_batch_size}, expected={global_batch_size}" + ) + + @staticmethod + def _aggregate_task_results( + ordered_tasks: list[_TaskRunner], task_results: dict[str, ProduceBatchResult] + ) -> ProduceBatchResult: + rollout_states: list[list[RolloutState]] = [] + leftover_completed = 0 + leftover_aborted = 0 + leftover_expired = 0 + total_group_count = 0 + weighted_group_mean_sum = 0.0 + weighted_group_p50_sum = 0.0 + weighted_group_p99_sum = 0.0 + weighted_group_ratio_sum = 0.0 + total_pause_time_s = 0.0 + + for task in ordered_tasks: + result = task_results[task.task_name] + rollout_states.extend(result.rollout_states) + leftover_completed += result.leftover_completed + leftover_aborted += result.leftover_aborted + leftover_expired += result.leftover_expired + if result.group_gen_count is not None and result.group_gen_mean_s is not None: + total_group_count += result.group_gen_count + weighted_group_mean_sum += result.group_gen_count * result.group_gen_mean_s + weighted_group_p50_sum += result.group_gen_count * (result.group_gen_p50_s or 0.0) + weighted_group_p99_sum += result.group_gen_count * (result.group_gen_p99_s or 0.0) + weighted_group_ratio_sum += result.group_gen_count * (result.group_gen_p99_p50_ratio or 0.0) + total_pause_time_s += result.group_gen_pause_time_s or 0.0 + + aggregated = ProduceBatchResult( + rollout_states=rollout_states, + leftover_completed=leftover_completed, + leftover_aborted=leftover_aborted, + leftover_expired=leftover_expired, + task_results={task.task_name: task_results[task.task_name] for task in ordered_tasks}, + ) + if total_group_count > 0: + aggregated.group_gen_count = total_group_count + aggregated.group_gen_mean_s = weighted_group_mean_sum / total_group_count + aggregated.group_gen_p50_s = weighted_group_p50_sum / total_group_count + aggregated.group_gen_p99_s = weighted_group_p99_sum / total_group_count + aggregated.group_gen_p99_p50_ratio = weighted_group_ratio_sum / total_group_count + aggregated.group_gen_pause_time_s = total_pause_time_s + return aggregated + + async def produce_batch(self, batch_size: int, rollout_step: int = 0) -> ProduceBatchResult: + start = time.perf_counter() + self.logger.info(f"[AgentLoopManager][{self.name}] produce_batch start batch={batch_size}") + + if len(self.task_runners) == 1: + task = self.task_runners[0] + rollout_ctl = await get_agent_loop_rollout_ctl(task.agent_loop) + await continue_generation(rollout_ctl) + try: + return await _produce_single_task_batch( + task_runner=task, + replay_buffer=self.replay_buffer, + batch_size=batch_size, + rollout_step=rollout_step, + logger=self.logger, + manager_name="AgentLoopManager", + ) + finally: + await pause_generation(rollout_ctl) + + task_batch_sizes = self.get_task_batch_sizes(batch_size, rollout_step) + self._validate_task_batch_sizes(task_batch_sizes, batch_size) + active_tasks = [task for task in self.task_runners if task_batch_sizes[task.task_name] > 0] + + results: list[ProduceBatchResult] = [] + if active_tasks: + rollout_ctl = await get_agent_loop_rollout_ctl(active_tasks[0].agent_loop) + await continue_generation(rollout_ctl) + try: + results = await asyncio.gather( + *[ + _produce_single_task_batch( + task_runner=task, + replay_buffer=self.replay_buffer, + batch_size=task_batch_sizes[task.task_name], + rollout_step=rollout_step, + logger=self.logger, + manager_name="AgentLoopManager", + ) + for task in active_tasks + ] + ) + finally: + await pause_generation(rollout_ctl) + + task_results = {task.task_name: result for task, result in zip(active_tasks, results)} + for task in self.task_runners: + if task.task_name not in task_results: + task_results[task.task_name] = ProduceBatchResult(rollout_states=[]) + + ordered_tasks = sorted(self.task_runners, key=lambda task: (task.task_name, task.order)) + aggregated = self._aggregate_task_results(ordered_tasks, task_results) + aggregated.task_batch_sizes = {task.task_name: task_batch_sizes[task.task_name] for task in ordered_tasks} + + self.logger.info( + f"[AgentLoopManager][{self.name}] produce_batch done elapsed={time.perf_counter() - start:.3f}, completed_groups={len(aggregated.rollout_states)}" + ) + return aggregated + + def _task_checkpoint_path(self, checkpoint_path: Path | str, task_name: str) -> Path: + checkpoint_path = Path(checkpoint_path) + return checkpoint_path / self._TASK_CHECKPOINT_DIR / task_name + + def save(self, checkpoint_path: Path | str) -> None: + """Save all task sampler states and the shared replay buffer.""" + for task in self.task_runners: + task_checkpoint_path = self._task_checkpoint_path(checkpoint_path, task.task_name) + task_checkpoint_path.mkdir(parents=True, exist_ok=True) + task.sampler.save(task_checkpoint_path) + asyncio_run(self.replay_buffer.save(checkpoint_path)) + + def resume(self, checkpoint_path: Path | str) -> None: + """Resume all task sampler states and the shared replay buffer.""" + for task in self.task_runners: + task.sampler.resume(self._task_checkpoint_path(checkpoint_path, task.task_name)) + asyncio_run(self.replay_buffer.resume(checkpoint_path)) diff --git a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py new file mode 100644 index 0000000000..7774c3a711 --- /dev/null +++ b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py @@ -0,0 +1,157 @@ +import copy +import json +import re +from typing import cast + +from pydantic import BaseModel, ConfigDict + +from xtuner.v1.data_proto import RolloutState, SampleParams +from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig +from xtuner.v1.rl.judger import Judger +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.utils import get_logger + + +logger = get_logger() + + +class GSM8KToolAgentLoopConfig(AgentLoopConfig): + max_turns: int + + def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "GSM8KToolAgentLoop": + return GSM8KToolAgentLoop( + max_turns=self.max_turns, + rollout_ctl=rollout_controller, + hf_checkpoint=self.hf_checkpoint, + sample_params=self.sample_params, + judger=judger, + ) + + +class FunctionCall(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str + arguments: dict + + +class GSM8KToolAgentLoop(AgentLoop): + def __init__( + self, + max_turns: int, + rollout_ctl: RolloutController, + hf_checkpoint: str, + sample_params: SampleParams, + judger: Judger | None = None, + ): + super().__init__( + rollout_ctl=rollout_ctl, hf_checkpoint=hf_checkpoint, sample_params=sample_params, judger=judger + ) + self.max_turns = max_turns + self.tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL) + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + def calc_gsm8k_reward(self, answer: dict, ground_truth: str) -> float: + from xtuner.v1.rl.judger.gsm8k import compute_reward + + extra_info = {"score": 1.0, "format_score": 0} + actual_answer = answer.get("answer", "") + if not actual_answer.startswith("#### "): + actual_answer = "#### " + actual_answer + return compute_reward(actual_answer, ground_truth, extra_info) + + def extract_tool_calls(self, rollout_state: RolloutState) -> tuple[str, list[FunctionCall]]: + text = self.tokenizer.decode(rollout_state.response_ids) + if self.tool_call_start_token not in text or self.tool_call_end_token not in text: + return text, [] + + matches = self.tool_call_pattern.findall(text) + function_calls = [] + for match in matches: + try: + function_call = json.loads(match) + name, arguments = function_call["name"], function_call["arguments"] + function_calls.append(FunctionCall(name=name, arguments=arguments)) + except Exception as e: + logger.error(f"Error parsing tool call JSON: {e}") + continue + + content = self.tool_call_pattern.sub("", text) + return content, function_calls + + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + # Respect state passed from preprocess for partial rollout continuation. + base_sample_params = copy.deepcopy(rollout_state.sample_params or self.sample_params) + final_response_mask: list[int] = [] + final_response_ids: list[int] = [] + final_logprobs: list[float] = [] + + max_len = base_sample_params.max_tokens + cur_turn_tokens = list(rollout_state.tokens or rollout_state.prompt_ids or []) + remaining_max_tokens = max_len - len(final_response_ids) + cur_turn = 0 + while True: + if cur_turn >= self.max_turns or len(final_response_ids) >= max_len or remaining_max_tokens <= 0: + break + + rollout_state.tokens = cur_turn_tokens + rollout_state.sample_params = copy.deepcopy(base_sample_params) + rollout_state.sample_params.max_tokens = remaining_max_tokens + + rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] + cur_turn += 1 + response_ids = cast(list[int], rollout_state.response_ids) + cur_turn_tokens.extend(response_ids) + + # 处理 rollout_controller 的输出 + final_response_ids.extend(response_ids) + final_logprobs.extend(cast(list[float], rollout_state.logprobs)) + final_response_mask.extend([1] * len(response_ids)) + # TODO: 处理 routed_experts, 要注意这里涉及到是否要解引用的问题 + + if len(final_response_ids) >= max_len: + break + + _, function_calls = self.extract_tool_calls(rollout_state) + if not function_calls: + break + + tool_messages = [] + for function_call in function_calls: + tool_name = function_call.name + tool_args = function_call.arguments + if tool_name == "calc_gsm8k_reward": + answer = tool_args + ground_truth = cast(dict, rollout_state.reward_model).get("ground_truth", "") + function_results = self.calc_gsm8k_reward(answer, ground_truth) + tool_message = { + "role": "tool", + "content": json.dumps({"result": function_results}, ensure_ascii=False), + } + tool_messages.append(tool_message) + + # 处理工具调用的输出 + tools_response_ids = self.tokenizer.apply_chat_template(tool_messages, remove_system_prompt=True) + final_response_ids.extend(tools_response_ids) + final_logprobs.extend([0.0] * len(tools_response_ids)) + final_response_mask.extend([0] * len(tools_response_ids)) + + # 处理下一轮生成的输入 + cur_turn_tokens.extend(tools_response_ids) + remaining_max_tokens = max_len - len(final_response_ids) + + final_response_ids = final_response_ids[:max_len] + final_response_mask = final_response_mask[:max_len] + final_logprobs = final_logprobs[:max_len] + + rollout_state.response_ids = final_response_ids + rollout_state.response_mask = final_response_mask + rollout_state.logprobs = final_logprobs + rollout_state.response = self.tokenizer.decode(rollout_state.response_ids) + assert len(rollout_state.response_ids) == len(rollout_state.response_mask) == len(rollout_state.logprobs), ( + f"{len(rollout_state.response_ids)} vs {len(rollout_state.response_mask)} vs {len(rollout_state.logprobs)}" + ) + if self.judger is not None: + rollout_state = await self.judger.judge(rollout_state) + return rollout_state diff --git a/xtuner/v1/rl/agent_loop/producer.py b/xtuner/v1/rl/agent_loop/producer.py new file mode 100644 index 0000000000..ddf3e8e78f --- /dev/null +++ b/xtuner/v1/rl/agent_loop/producer.py @@ -0,0 +1,325 @@ +import asyncio +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Protocol, runtime_checkable + +import ray +from pydantic import BaseModel, ConfigDict + +from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_expired_status +from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.rl.rollout.utils import pause_generation +from xtuner.v1.rl.utils import create_task +from xtuner.v1.utils import get_logger + +from .agent_loop import AgentLoopSpec, get_agent_loop_rollout_ctl +from .sampler import Sampler + + +@dataclass +class ProducerTimings: + """记录一轮 batch 生成过程中每个 group 的生成耗时统计信息。 + + Args: + generate_times_s (list[float]): 每个 group 的生成耗时(秒),长度等于本轮生成 group 的数量。 + pause_time_s (float): 结束时等待所有 pending 任务收尾的总耗时(秒)。 + """ + + generate_times_s: list[float] = field(default_factory=list) + pause_time_s: float = 0.0 + + +logger = get_logger() + + +async def _timed_generate_group( + agent_loop: AgentLoopSpec, rollout_state: list[RolloutState], **kwargs +) -> tuple[list[RolloutState], float]: + start = time.perf_counter() + if isinstance(agent_loop, ray.actor.ActorHandle): + result = await agent_loop.generate_group.remote(rollout_state, **kwargs) + else: + result = await agent_loop.generate_group(rollout_state, **kwargs) + return result, time.perf_counter() - start + + +def default_is_valid_sample_fn(samples: list[RolloutState]) -> bool: + return all(sample.status == Status.COMPLETED for sample in samples) + + +def default_should_continue_fn(completed_count: int, batch_size: int, **kwargs) -> bool: + return completed_count < batch_size + + +@runtime_checkable +class IsValidSampleFn(Protocol): + def __call__(self, samples: list[RolloutState]) -> bool: ... + + +@runtime_checkable +class ShouldContinueFn(Protocol): + def __call__(self, completed_count: int, batch_size: int, **kwargs) -> bool: ... + + +class ProduceStrategyConfig(ABC, BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn + should_continue_fn: ShouldContinueFn = default_should_continue_fn + + @abstractmethod + def build(self) -> "ProduceStrategy": ... + + +class SyncProduceStrategyConfig(ProduceStrategyConfig): + def build(self) -> "SyncProduceStrategy": + return SyncProduceStrategy( + is_valid_sample_fn=self.is_valid_sample_fn, should_continue_fn=self.should_continue_fn + ) + + +class AsyncProduceStrategyConfig(ProduceStrategyConfig): + over_sample_threshold: float = 0.0 + enable_partial_rollout: bool = False + tail_batch_stale_threshold: int = 0 + tail_batch_trigger_size: int = 0 + + def build(self) -> "AsyncProduceStrategy": + return AsyncProduceStrategy( + over_sample_threshold=self.over_sample_threshold, + enable_partial_rollout=self.enable_partial_rollout, + tail_batch_stale_threshold=self.tail_batch_stale_threshold, + tail_batch_trigger_size=self.tail_batch_trigger_size, + is_valid_sample_fn=self.is_valid_sample_fn, + should_continue_fn=self.should_continue_fn, + ) + + +class ProduceStrategy(ABC): + def __init__( + self, + is_valid_sample_fn: IsValidSampleFn, + should_continue_fn: ShouldContinueFn, + ): + self.is_valid_sample_fn = is_valid_sample_fn + self.should_continue_fn = should_continue_fn + + @abstractmethod + async def produce_batch( + self, + agent_loop: AgentLoopSpec, + sampler: Sampler, + replay_buffer: ReplayBuffer, + batch_size: int, + task_name: str, + rollout_step: int = 0, + ) -> "ProducerTimings": ... + + +class SyncProduceStrategy(ProduceStrategy): + async def produce_batch( + self, + agent_loop: AgentLoopSpec, + sampler: Sampler, + replay_buffer: ReplayBuffer, + batch_size: int, + task_name: str, + rollout_step: int = 0, + ) -> ProducerTimings: + pending_tasks = set() + generate_times: list[float] = [] + completed_sample_count = await replay_buffer.count(task_name=task_name, group_status=Status.COMPLETED) + assert completed_sample_count == 0, "SyncProduceStrategy assumes no completed samples at the start." + + for _ in range(batch_size): + rollout_state = await sampler.sample(task_name=task_name) + task = create_task(_timed_generate_group(agent_loop, rollout_state)) + pending_tasks.add(task) + + logger.info(f"Started {len(pending_tasks)} initial tasks for SyncProduceStrategy.") + + while self.should_continue_fn(completed_sample_count, batch_size): + if not pending_tasks: + logger.warning("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + # 如果要过滤,在这个地方处理,然后加入到 replay buffer + # 如果被过滤的数据就放到 put_to_filtered pool 中 + for task in done_tasks: + items, elapsed = task.result() + generate_times.append(elapsed) + if self.is_valid_sample_fn(items): + completed_sample_count += 1 + logger.info(f"Collected {completed_sample_count}/{batch_size} valid samples for task {task_name}.") + await replay_buffer.put(items, task_name) + + while len(pending_tasks) + completed_sample_count < batch_size and self.should_continue_fn( + completed_sample_count, batch_size + ): + rollout_state = await sampler.sample(task_name=task_name) + task = create_task(_timed_generate_group(agent_loop, rollout_state)) + pending_tasks.add(task) + + return ProducerTimings(generate_times_s=generate_times, pause_time_s=0.0) + + +class AsyncProduceStrategy(ProduceStrategy): + def __init__( + self, + over_sample_threshold: float, + enable_partial_rollout: bool, + tail_batch_trigger_size: int, + tail_batch_stale_threshold: int, + is_valid_sample_fn: IsValidSampleFn, + should_continue_fn: ShouldContinueFn, + ): + super().__init__(is_valid_sample_fn, should_continue_fn) + self.over_sample_threshold = over_sample_threshold + self.enable_partial_rollout = enable_partial_rollout + self.tail_batch_stale_threshold = tail_batch_stale_threshold + self.tail_batch_trigger_size = tail_batch_trigger_size + + async def _process_leftover_samples(self, replay_buffer: ReplayBuffer, task_name: str): + previously_completed_count = await replay_buffer.count(task_name=task_name, group_status=Status.COMPLETED) + if (not self.enable_partial_rollout or self.tail_batch_stale_threshold > 0) and previously_completed_count > 0: + previously_completed = await replay_buffer.get( + batch_size=previously_completed_count, + task_name=task_name, + group_status=Status.COMPLETED, + ) + for group in previously_completed: + for sample in group: + if self.tail_batch_stale_threshold > 0 and sample.seq_staleness >= self.tail_batch_stale_threshold: + sample.status = Status.EXPIRED + elif not self.enable_partial_rollout: + sample.status = Status.ABORTED + await replay_buffer.put(group, task_name) + + async def _cleanup_pending_tasks( + self, pending_tasks: set, agent_loop: AgentLoopSpec, replay_buffer: ReplayBuffer, task_name: str + ) -> float: + pause_start = time.perf_counter() + rollout_ctl = await get_agent_loop_rollout_ctl(agent_loop) + await pause_generation(rollout_ctl) + while len(pending_tasks) > 0: + done_task, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + + for task in done_task: + paused_items, _ = task.result() + paused_items = update_expired_status( + paused_items, tail_batch_stale_threshold=self.tail_batch_stale_threshold + ) + for item in paused_items: + logger.debug( + f"[{self.__class__.__name__}] Task {task_name} | Collecting aborted sample (uid: {item.uid}, status: {item.status}, length: {len(item.response_ids or [])}) after pausing generation." + ) + await replay_buffer.put(paused_items, task_name) + if len(pending_tasks) > 0: + await pause_generation(rollout_ctl) + await asyncio.sleep(1) + return time.perf_counter() - pause_start + + async def produce_batch( + self, + agent_loop: AgentLoopSpec, + sampler: Sampler, + replay_buffer: ReplayBuffer, + batch_size: int, + task_name: str, + rollout_step: int = 0, + ) -> ProducerTimings: + # 1. 处理上一轮遗留的 completed 样本 + await self._process_leftover_samples(replay_buffer, task_name) + + # 2. 计算当前并发需求 + previously_completed_count = await replay_buffer.count(task_name=task_name, group_status=Status.COMPLETED) + data_concurrency = int((1 + self.over_sample_threshold) * batch_size) - previously_completed_count + expired_sample_count = await replay_buffer.count(task_name=task_name, group_status=Status.EXPIRED) + sample_from_expired_storage = False + + if self.tail_batch_trigger_size > 0 and expired_sample_count >= self.tail_batch_trigger_size: + logger.info( + f"Tail batch trigger condition met: {expired_sample_count} expired samples (threshold: {self.tail_batch_trigger_size}). Enabling tail batch mode." + ) + sample_from_expired_storage = True + data_concurrency = batch_size - previously_completed_count + + logger.info( + f"[{self.__class__.__name__}] Task {task_name} | Starting produce: data_concurrency: {data_concurrency}, previously_completed: {previously_completed_count}, expired_sample_count: {expired_sample_count}, rollout_step: {rollout_step}" + ) + + # 3. 初始下发任务 + pending_tasks = set() + generate_times: list[float] = [] + for _ in range(data_concurrency): + if sample_from_expired_storage and expired_sample_count > 0: + group_status = Status.EXPIRED + expired_sample_count -= 1 + else: + group_status = Status.ABORTED + rollout_state = await sampler.sample(task_name=task_name, group_status=group_status) + task = create_task( + _timed_generate_group( + agent_loop, + rollout_state, + enable_partial_rollout=self.enable_partial_rollout, + rollout_step=rollout_step, + ) + ) + pending_tasks.add(task) + + # 4. 循环收集样本 + completed_sample_count = previously_completed_count + while self.should_continue_fn(completed_sample_count, batch_size): + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + + for task in done_tasks: + running_items, elapsed = task.result() + generate_times.append(elapsed) + if self.is_valid_sample_fn(running_items): + completed_sample_count += 1 + running_items = update_expired_status( + running_items, tail_batch_stale_threshold=self.tail_batch_stale_threshold + ) + await replay_buffer.put(running_items, task_name) + logger.debug( + f"[{self.__class__.__name__}] Task {task_name} | Collected {completed_sample_count}/{batch_size} valid samples." + ) + + # 动态补充任务 + while len( + pending_tasks + ) + completed_sample_count < data_concurrency + previously_completed_count and self.should_continue_fn( + completed_sample_count, batch_size + ): + if sample_from_expired_storage and expired_sample_count > 0: + group_status = Status.EXPIRED + expired_sample_count -= 1 + else: + group_status = Status.ABORTED + rollout_state = await sampler.sample(task_name=task_name, group_status=group_status) + task = create_task( + _timed_generate_group( + agent_loop, + rollout_state, + enable_partial_rollout=self.enable_partial_rollout, + rollout_step=rollout_step, + ) + ) + pending_tasks.add(task) + + # 5. 清理正在执行的任务 + pause_time_s = 0.0 + if len(pending_tasks) > 0: + pause_time_s = await self._cleanup_pending_tasks(pending_tasks, agent_loop, replay_buffer, task_name) + + return ProducerTimings(generate_times_s=generate_times, pause_time_s=pause_time_s) diff --git a/xtuner/v1/rl/agent_loop/sampler.py b/xtuner/v1/rl/agent_loop/sampler.py new file mode 100644 index 0000000000..181fc320df --- /dev/null +++ b/xtuner/v1/rl/agent_loop/sampler.py @@ -0,0 +1,121 @@ +import copy +from pathlib import Path +from typing import Iterator, Optional, cast +from uuid import uuid4 + +import ray +import torch +from pydantic import BaseModel, ConfigDict + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.datasets.config import DataloaderConfig +from xtuner.v1.datasets.dataloader import Dataloader +from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.utils.logger import get_logger + + +logger = get_logger(__name__) + + +class SamplerConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + dataloader_cfg: DataloaderConfig + prompt_repeat_k: int = 1 + + def build( + self, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str, replay_buffer: ReplayBuffer + ) -> "Sampler": + if isinstance(tokenizer, str): + tokenizer_obj = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + else: + tokenizer_obj = tokenizer + dataloader = self.dataloader_cfg.build( + tokenizer=tokenizer_obj, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1 + ) + return Sampler(dataloader=dataloader, prompt_repeat_k=self.prompt_repeat_k, replay_buffer=replay_buffer) + + +# TODO: The best solution is to put it in the fake_collator, +# but it will cause a deadlock problem, so it is temporarily placed here. +# The best solution should be to start the dataloader using spawn. +def put_to_ray(data: RolloutState) -> RolloutState: + if hasattr(data, "mm_info") and data.mm_info is not None: + pixel_values = data.mm_info.get("pixel_values", None) + if pixel_values is not None: + data.mm_info["pixel_values"] = ray.put(pixel_values) + return data + + +class _DatasetSampler: + def __init__(self, dataloader: Dataloader, prompt_repeat_k: int): + self.dataloader = dataloader + self.dataloader_iter: Optional[Iterator] = None + self.cur_epoch = 0 + self.prompt_repeat_k = prompt_repeat_k + self._consumed_samples: int = 0 + + def __len__(self) -> int: + return len(self.dataloader) + + def sample_from_dataloader(self) -> list[RolloutState]: + if self.dataloader_iter is None: + self.dataloader_iter = iter(self.dataloader) + assert self.dataloader_iter is not None + try: + data = next(self.dataloader_iter)[0] + data = put_to_ray(data) + + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] + data = put_to_ray(data) + + group_data = [] + for _ in range(self.prompt_repeat_k): + new_data = copy.deepcopy(data) + new_data.uid = uuid4().int + group_data.append(new_data) + self._consumed_samples += 1 + return cast(list[RolloutState], group_data) + + +class Sampler(_DatasetSampler): + _DATALOADER_FILE = "dataloader" + + def __init__( + self, + dataloader: Dataloader, + prompt_repeat_k: int, + replay_buffer: ReplayBuffer, + ): + super().__init__(dataloader, prompt_repeat_k) + self.replay_buffer = replay_buffer + + async def sample(self, task_name: str, group_status: Status | None = None) -> list[RolloutState]: + if group_status is not None: + buffer_data = await self.replay_buffer.get(1, task_name=task_name, group_status=group_status) + if buffer_data: + return buffer_data[0] + return self.sample_from_dataloader() + + def save(self, checkpoint_path: Path | str) -> None: + """Save the sampler's dataloader state to checkpoint.""" + checkpoint_path = Path(checkpoint_path) + dataloader_state = self.dataloader.get_state_dict(self._consumed_samples) + torch.save(dataloader_state, checkpoint_path / self._DATALOADER_FILE) + + def resume(self, checkpoint_path: Path | str) -> None: + """Resume the sampler's dataloader state from checkpoint.""" + checkpoint_path = Path(checkpoint_path) + dataloader_path = checkpoint_path / self._DATALOADER_FILE + if not dataloader_path.exists(): + logger.warning(f"Dataloader state {dataloader_path} not found, skipping resume.") + return + state = torch.load(dataloader_path, map_location="cpu") + self.dataloader.load_state_dict(state) + self.dataloader_iter = iter(self.dataloader) + self._consumed_samples = state["sampler"]["step"] + self.cur_epoch = state["sampler"]["epoch"] diff --git a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py new file mode 100644 index 0000000000..db24c35f91 --- /dev/null +++ b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py @@ -0,0 +1,51 @@ +from xtuner.v1.data_proto import RolloutState, SampleParams, Status +from xtuner.v1.rl.judger import Judger +from xtuner.v1.rl.rollout import RolloutController + +from .agent_loop import AgentLoop, AgentLoopConfig +from .utils import PartialRolloutHandler + + +class SingleTurnAgentLoopConfig(AgentLoopConfig): + def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "SingleTurnAgentLoop": + return SingleTurnAgentLoop( + rollout_ctl=rollout_controller, + sample_params=self.sample_params, + hf_checkpoint=self.hf_checkpoint, + judger=judger, + logger=logger, + ) + + +class SingleTurnAgentLoop(AgentLoop): + def __init__( + self, + rollout_ctl: RolloutController, + sample_params: SampleParams, + hf_checkpoint: str, + judger: Judger | None = None, + logger=None, + ): + super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger) + self.max_tokens = self.sample_params.max_tokens + self.partial_rollout_handler = PartialRolloutHandler(max_tokens=self.max_tokens) + + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + enable_partial_rollout = kwargs.get("enable_partial_rollout", False) + rollout_step = kwargs.get("rollout_step", 0) + + # rollout state 预处理, enable_partial_rollout = True 会在这里拼接 token 和修正 max_token + rollout_state = self.partial_rollout_handler.preprocess(rollout_state, enable_partial_rollout) + if not rollout_state.tokens: + rollout_state.tokens = rollout_state.prompt_ids + + # 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上 + rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] + # rollout state 后处理: 合并 partial rollout 的历史上下文, 更新 seq_staleness + rollout_state = self.partial_rollout_handler.postprocess(rollout_state, rollout_step) + # 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分 + if rollout_state.status != Status.COMPLETED: + return rollout_state + if self.judger is not None: + rollout_state = await self.judger.judge(rollout_state) + return rollout_state diff --git a/xtuner/v1/rl/agent_loop/utils.py b/xtuner/v1/rl/agent_loop/utils.py new file mode 100644 index 0000000000..9e5dbfe8a2 --- /dev/null +++ b/xtuner/v1/rl/agent_loop/utils.py @@ -0,0 +1,116 @@ +import time + +import ray + +from xtuner.v1.data_proto import RolloutState, Status, update_seq_staleness +from xtuner.v1.utils import get_logger + + +logger = get_logger() + + +def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef) -> list[int]: + if isinstance(routed_experts, ray.ObjectRef): + routed_experts = ray.get(routed_experts) + if hasattr(routed_experts, "tolist"): + routed_experts = routed_experts.tolist() + assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}" + return routed_experts + + +class PartialRolloutHandler: + """Handle preprocessing and postprocessing for partial rollout + continuation.""" + + def __init__(self, max_tokens: int) -> None: + self.max_tokens = max_tokens + + def preprocess(self, rollout_state: RolloutState, enable_partial_rollout: bool = False) -> RolloutState: + # for partial rollout + if not enable_partial_rollout or not rollout_state.response_ids or rollout_state.status == Status.COMPLETED: + return rollout_state + + # If status is EXPIRED, reset tokens, sample_params and responses for fresh generation + if rollout_state.status == Status.EXPIRED: + rollout_state.tokens = rollout_state.prompt_ids + rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": self.max_tokens}) + rollout_state.response_ids = [] + rollout_state.response = "" + rollout_state.logprobs = [] + rollout_state.response_mask = [] + rollout_state.response_rollout_steps = [] + return rollout_state + + # Set up token and length variable + response_ids = rollout_state.response_ids + prompt_ids = list(rollout_state.prompt_ids or []) + response_len = len(response_ids) + prompt_len = len(prompt_ids) + + rollout_state.tokens = prompt_ids + response_ids # concatenate for partial rollout continuation + remaining_tokens = self.max_tokens - response_len # compute remaining max_tokens budget + rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": remaining_tokens}) + + logger.debug( + f"[PartialRolloutHandler] Sample {rollout_state.uid} continue rollout | Remaining tokens allowed: {remaining_tokens} | Status: {rollout_state.status} | Prompt len: {prompt_len} | Response len: {response_len} | Staleness: {rollout_state.seq_staleness} | Total tokens: {len(rollout_state.tokens)}" + ) + # TODO: handle routed_experts + rollout_state.extra_fields["history_response_dict"] = { + "response_ids": rollout_state.tokens[prompt_len:] if rollout_state.tokens else [], + "response": rollout_state.response or "", + "logprobs": rollout_state.logprobs or [], + "response_mask": rollout_state.response_mask or [], + "routed_experts": rollout_state.routed_experts, + } + return rollout_state + + def postprocess(self, rollout_state: RolloutState, rollout_step: int) -> RolloutState: + # Update seq_staleness + rollout_state = update_seq_staleness(rollout_state, rollout_step) + + # Concatenate history response fields + history_dict = rollout_state.extra_fields.pop("history_response_dict", None) + if not history_dict: + return rollout_state + + rollout_state.response_ids = history_dict.get("response_ids", []) + (rollout_state.response_ids or []) + rollout_state.response = history_dict.get("response", "") + (rollout_state.response or "") + rollout_state.logprobs = history_dict.get("logprobs", []) + (rollout_state.logprobs or []) + rollout_state.response_mask = history_dict.get("response_mask", []) + (rollout_state.response_mask or []) + history_routed_experts_ref = history_dict.get("routed_experts") + cur_routed_experts_ref = rollout_state.routed_experts + if history_routed_experts_ref is not None and cur_routed_experts_ref is not None: + start_time = time.time() + history_routed_experts = _resolve_routed_experts(history_routed_experts_ref) + cur_routed_experts = _resolve_routed_experts(cur_routed_experts_ref) + cur_routed_experts_len = len(cur_routed_experts) + history_routed_experts_len = len(history_routed_experts) + assert history_routed_experts_len - 1 <= cur_routed_experts_len, ( + f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}" + ) + cur_routed_experts = cur_routed_experts[history_routed_experts_len:] + concat_routed_experts = history_routed_experts + cur_routed_experts + + prompt_ids = rollout_state.prompt_ids or [] + response_ids = rollout_state.response_ids or [] + expect_tokens_num = len(prompt_ids) + len(response_ids) - 1 + assert len(concat_routed_experts) == expect_tokens_num, ( + f"After concatenation, routed_experts len: {len(concat_routed_experts)}, expected tokens num: {expect_tokens_num}" + ) + logger.info( + f"[PartialRolloutHandler] Postprocess rollout {rollout_state.uid}: " + f"concat routed_experts len={len(concat_routed_experts)} " + f"(history={history_routed_experts_len}, new={cur_routed_experts_len}), " + f"prompt={len(prompt_ids)}, response={len(response_ids)}" + ) + rollout_state.routed_experts = ray.put(concat_routed_experts) + end_time = time.time() + logger.info( + f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" + ) + elif history_routed_experts_ref is None and cur_routed_experts_ref is not None: + rollout_state.routed_experts = cur_routed_experts_ref + elif history_routed_experts_ref is not None and cur_routed_experts_ref is None: + rollout_state.routed_experts = history_routed_experts_ref + + return rollout_state diff --git a/xtuner/v1/rl/base.py b/xtuner/v1/rl/base.py deleted file mode 100644 index a299434804..0000000000 --- a/xtuner/v1/rl/base.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Literal - -from cyclopts import Parameter -from pydantic import BaseModel, Field -from typing_extensions import Annotated - - -class BaseTrainerConfig(BaseModel): - type: Annotated[ - Literal["xtuner", "lmdeploy", "sglang", "vllm"], - Parameter(group="Worker Types", description="Type of the worker."), - ] = Field(..., discriminator="type") diff --git a/xtuner/v1/rl/base/__init__.py b/xtuner/v1/rl/base/__init__.py deleted file mode 100644 index 7141d58260..0000000000 --- a/xtuner/v1/rl/base/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from .controller import TrainingController, TrainingControllerProxy -from .loss import BaseRLLossConfig, BaseRLLossContext, BaseRLLossKwargs, compute_kl_loss_weight -from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem - - -__all__ = [ - "TrainingController", - "TrainingControllerProxy", - "TrainingWorkerClass", - "TrainingWorkerProxy", - "TrainingWorker", - "WorkerConfig", - "BaseRLLossConfig", - "BaseRLLossKwargs", - "BaseRLLossContext", - "compute_kl_loss_weight", - "WorkerLogItem", -] diff --git a/xtuner/v1/rl/config/__init__.py b/xtuner/v1/rl/config/__init__.py deleted file mode 100644 index 8313e21fc7..0000000000 --- a/xtuner/v1/rl/config/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .trainer import GRPOTrainerConfig diff --git a/xtuner/v1/rl/config/loss.py b/xtuner/v1/rl/config/loss.py deleted file mode 100644 index 1cf4af1230..0000000000 --- a/xtuner/v1/rl/config/loss.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Literal - -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict -from typing_extensions import Annotated - - -class BaseLossConfig(BaseModel): - """Base configuration for loss function.""" - - model_config = ConfigDict(extra="forbid") - type: Annotated[ - Literal["grpo", "ppo"], - Parameter(group="Loss Types", help="Type of the loss function."), - ] diff --git a/xtuner/v1/rl/config/trainer.py b/xtuner/v1/rl/config/trainer.py deleted file mode 100644 index 058dacf0a0..0000000000 --- a/xtuner/v1/rl/config/trainer.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional - -from cyclopts import Group, Parameter -from pydantic import BaseModel, ConfigDict, Field -from typing_extensions import Annotated - -from xtuner.v1.engine.config import EngineConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig - - -grpo_group = Group("GRPO", sort_key=1, help="GRPO Trainer Configuration") - -actor_worker_group = Group("Actor Workers", sort_key=90, help="Configuration for the rollout worker.") -actor_resources_group = Group("Actor Resources", sort_key=90, help="Configuration for the actor resources.") -rollout_worker_group = Group("Rollout Workers", sort_key=90, help="Configuration for the rollout worker.") -rollout_resources_group = Group("Rollout Resources", sort_key=90, help="Configuration for the rollout resources.") - - -class GRPOTrainerConfig(BaseModel): - """Configuration for the GRPO Ray Trainer.""" - - model_config = ConfigDict(extra="forbid") - actor: Annotated[ - EngineConfig, - Parameter(group=actor_worker_group, help="Configuration for the rollout worker."), - ] - - critic: Annotated[ - EngineConfig, - Parameter(group=actor_worker_group, help="Configuration for the rollout worker."), - ] - - actor_resources: Annotated[ - AcceleratorResourcesConfig, Parameter(group=actor_resources_group, help="Resources allocated for the actor.") - ] - - rollout: Annotated[ - RolloutConfig, - Parameter(group=rollout_worker_group, help="Configuration for the rollout worker."), - # Discriminator('type') - ] - rollout_resources: Annotated[ - Optional[AcceleratorResourcesConfig], - Parameter(group=rollout_resources_group, help="Resources allocated for the rollout."), - ] = None - - enrionment: Annotated[str, Parameter(group=grpo_group, help="Environment for the GRPO training.")] = "default" - - global_batch_size: Annotated[int, Parameter(group=grpo_group, help="Batch size for training.")] = Field( - 32, help="Batch size for training." - ) - - micro_batch_size: Annotated[int, Parameter(group=grpo_group, help="Micro batch size for training.")] = Field( - 8, help="Micro batch size for training." - ) - - num_mini_batches: Annotated[int, Parameter(group=grpo_group, help="Number of mini-batches for training.")] = Field( - 4, help="Number of mini-batches for training." - ) - - total_steps: Annotated[int, Parameter(group=grpo_group, help="Total number of training steps.")] = Field( - 100000, help="Total number of training steps." - ) diff --git a/xtuner/v1/rl/evaluator.py b/xtuner/v1/rl/evaluator.py new file mode 100644 index 0000000000..870a993030 --- /dev/null +++ b/xtuner/v1/rl/evaluator.py @@ -0,0 +1,80 @@ +from collections.abc import Mapping +from typing import Annotated, Protocol, cast, runtime_checkable + +from cyclopts import Parameter +from pydantic import BaseModel, ConfigDict, Field + +from xtuner.v1.data_proto import RolloutState + + +@runtime_checkable +class ComputeMetricProtocol(Protocol): + def __call__(self, samples: list[RolloutState]) -> dict[str, float]: ... + + +def default_compute_metric_func(samples: list[RolloutState]) -> dict[str, float]: + if not samples: + return {"accuracy": 0.0} + + positive = 0 + for s in samples: + reward = s.reward + assert isinstance(reward, Mapping) + score = reward["score"] + if score > 0: + positive += 1 + return {"accuracy": positive / len(samples)} + + +class Evaluator: + def __init__( + self, + compute_metric_func: ComputeMetricProtocol | None = None, + eval_batch_size: int = 0, + ): + self.compute_metric_func = compute_metric_func or default_compute_metric_func + self.eval_batch_size = eval_batch_size + + def run(self, samples: list[RolloutState] | list[list[RolloutState]]) -> dict[str, float]: + # 将 list[list[RolloutState]] 转换为 list[RolloutState] + if samples and isinstance(samples[0], list): + flat_samples = [sample for batch in cast(list[list[RolloutState]], samples) for sample in batch] + else: + flat_samples = cast(list[RolloutState], samples) + return self.compute_metric_func(flat_samples) + + +class EvaluatorConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + eval_sample_ratio: Annotated[ + float, + Parameter(help="Ratio of samples to evaluate from the generated samples."), + ] = 0 + eval_sample_num: Annotated[ + int, + Parameter(help="Number of samples to evaluate from the generated samples."), + ] = 0 + + compute_metric_func: Annotated[ + ComputeMetricProtocol | None, + Field(exclude=True), + Parameter(help="An optional metric computation function."), + ] = None + + def build(self, total_eval_samples: int = 0) -> "Evaluator": + if self.eval_sample_num > 0: + eval_batch_size = self.eval_sample_num + else: + assert total_eval_samples > 0, ( + "Total eval samples must be greater than 0 if eval sample num is not provided" + ) + if self.eval_sample_ratio > 0: + eval_batch_size = int(total_eval_samples * self.eval_sample_ratio) + else: + eval_batch_size = total_eval_samples + + return Evaluator( + compute_metric_func=self.compute_metric_func, + eval_batch_size=eval_batch_size, + ) diff --git a/xtuner/v1/rl/gateway/__init__.py b/xtuner/v1/rl/gateway/__init__.py new file mode 100644 index 0000000000..26779e3a9e --- /dev/null +++ b/xtuner/v1/rl/gateway/__init__.py @@ -0,0 +1,13 @@ +from .backend.local_backend import LocalRolloutBackend +from .config import GatewayConfig +from .server import build_gateway_app, build_local_gateway_app, serve_gateway, serve_gateway_in_thread + + +__all__ = [ + "GatewayConfig", + "LocalRolloutBackend", + "build_gateway_app", + "build_local_gateway_app", + "serve_gateway", + "serve_gateway_in_thread", +] diff --git a/xtuner/v1/rl/gateway/adapters/__init__.py b/xtuner/v1/rl/gateway/adapters/__init__.py new file mode 100644 index 0000000000..4c68e142e8 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/__init__.py @@ -0,0 +1,43 @@ +from .anthropic import ( + AnthropicChatAdapter, + AnthropicChatAdapterError, + AnthropicCountTokensRequest, + AnthropicCountTokensResponse, + AnthropicMessagesRequest, + AnthropicMessagesResponse, +) +from .base import BaseChatAPIAdapter +from .openai import ( + ChatCompletionRequest, + ChatCompletionResponse, + OpenAIChatAdapter, + OpenAIChatAdapterError, +) +from .responses import ResponsesRequest, ResponsesResponse +from .trace import ( + DEFAULT_CHAT_TRACE_KEY, + ChatTraceRecord, + ChatTraceStore, + build_api_key_trace_key, +) + + +__all__ = [ + "AnthropicChatAdapter", + "AnthropicChatAdapterError", + "AnthropicCountTokensRequest", + "AnthropicCountTokensResponse", + "AnthropicMessagesRequest", + "AnthropicMessagesResponse", + "ChatCompletionRequest", + "ChatCompletionResponse", + "OpenAIChatAdapter", + "OpenAIChatAdapterError", + "ResponsesRequest", + "ResponsesResponse", + "BaseChatAPIAdapter", + "DEFAULT_CHAT_TRACE_KEY", + "ChatTraceRecord", + "ChatTraceStore", + "build_api_key_trace_key", +] diff --git a/xtuner/v1/rl/gateway/adapters/anthropic.py b/xtuner/v1/rl/gateway/adapters/anthropic.py new file mode 100644 index 0000000000..aa478543e1 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/anthropic.py @@ -0,0 +1,598 @@ +import json +from collections.abc import AsyncIterator +from typing import Any, Literal +from uuid import uuid4 + +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from ..core.models import ( + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalMessage, + CanonicalReasoning, + CanonicalReasoningBlock, + CanonicalReasoningStep, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResult, + CanonicalToolResultBlock, +) +from .base import BaseChatAPIAdapter +from .streaming import build_sse_response, encode_sse_event +from .trace import ChatTraceStore, normalize_trace_payload + + +class AnthropicTextContent(BaseModel): + model_config = ConfigDict(extra="allow") + + type: str = "text" + text: str + + +AnthropicContentBlock = dict[str, Any] + + +class AnthropicMessage(BaseModel): + model_config = ConfigDict(extra="allow") + + role: Literal["user", "assistant"] + content: str | list[AnthropicContentBlock] + + +class AnthropicMessagesRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + session_uid: int | None = None + model: str | None = None + system: str | list[dict[str, Any]] | None = None + messages: list[AnthropicMessage] + max_tokens: int + stream: bool = False + temperature: float | None = None + top_p: float | None = None + stop_sequences: list[str] | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + + +class AnthropicCountTokensRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + model: str | None = None + system: str | list[dict[str, Any]] | None = None + messages: list[AnthropicMessage] + tools: list[dict[str, Any]] | None = None + + +class AnthropicCountTokensResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + + +class AnthropicUsage(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + output_tokens: int + + +class AnthropicMessagesResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + content: list[dict[str, Any]] + model: str + stop_reason: str | None = None + stop_sequence: str | None = None + usage: AnthropicUsage + + +class AnthropicChatAdapterError(RuntimeError): + def __init__(self, message: str, error_type: str, request_id: str | None = None): + super().__init__(message) + self.message = message + self.error_type = error_type + self.request_id = request_id + + +class AnthropicChatAdapter(BaseChatAPIAdapter[AnthropicMessagesRequest, AnthropicMessagesResponse]): + def __init__( + self, + generate_handler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None, + default_model_name: str | None = None, + context_length: int | None = None, + capture_folder: str | None = None, + trace_store: ChatTraceStore | None = None, + ): + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + super().__init__(generate_handler, tokenizer=tokenizer, capture_folder=capture_folder, trace_store=trace_store) + self._default_model_name = default_model_name + self._context_length = context_length + + async def messages( + self, + request: AnthropicMessagesRequest, + *, + api_key: str | None = None, + ) -> AnthropicMessagesResponse | StreamingResponse: + if request.stream: + response = await self.handle_request(request, api_key=api_key) + return build_sse_response(self.iter_stream_events(response)) + return await self.handle_request(request, api_key=api_key) + + async def count_tokens(self, request: AnthropicCountTokensRequest) -> AnthropicCountTokensResponse: + internal_messages = self._build_internal_messages(request) + tokenizer_tools = self._normalize_tools_for_backend(request.tools) + if self._tokenizer is None: + return AnthropicCountTokensResponse(input_tokens=0) + raw_prompt_ids = self._tokenizer.apply_chat_template( + internal_messages, + tools=tokenizer_tools, + tokenize=True, + add_generation_prompt=True, + ) + prompt_ids = raw_prompt_ids.get("input_ids") if hasattr(raw_prompt_ids, "get") else list(raw_prompt_ids) + return AnthropicCountTokensResponse(input_tokens=len(prompt_ids)) + + def validate_request(self, request: AnthropicMessagesRequest) -> None: + return None + + def request_to_canonical_request(self, request: AnthropicMessagesRequest) -> CanonicalGenerateRequest: + messages: list[CanonicalMessage] = [] + if request.system: + messages.append(self._anthropic_system_to_canonical_message(request.system)) + messages.extend(self._anthropic_messages_to_canonical_messages(request.messages)) + return CanonicalGenerateRequest( + request_id=f"anthropic_req_{uuid4().hex}", + model=request.model or self._default_model_name or "rollout-controller", + messages=messages, + tools=self._anthropic_tools_to_canonical(request.tools), + tool_choice=self._anthropic_tool_choice_to_canonical(request.tool_choice), + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens, + stop=list(request.stop_sequences or []), + stream=False, + metadata={ + key: value + for key, value in { + "source_protocol": "anthropic_messages", + "client_stream": bool(request.stream), + "session_uid": request.session_uid, + }.items() + if value is not None + }, + ) + + def normalize_request(self, request: AnthropicMessagesRequest) -> dict[str, Any]: + return normalize_trace_payload(request.model_dump(mode="python", exclude_none=True)) + + def normalize_response(self, response: AnthropicMessagesResponse) -> dict[str, Any]: + return normalize_trace_payload(response.model_dump(mode="python", exclude_none=True)) + + async def iter_stream_events( + self, + response: AnthropicMessagesResponse, + ) -> AsyncIterator[str]: + yield encode_sse_event( + { + "type": "message_start", + "message": { + "id": response.id, + "type": response.type, + "role": response.role, + "content": [], + "model": response.model, + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": response.usage.input_tokens, + "output_tokens": 0, + }, + }, + }, + event="message_start", + ) + + for index, block in enumerate(response.content): + block_type = block.get("type") + start_block: dict[str, Any] + delta: dict[str, Any] + if block_type == "reasoning": + start_block = {"type": "thinking", "thinking": ""} + delta = {"type": "thinking_delta", "thinking": str(block.get("text", ""))} + elif block_type == "tool_use": + start_block = { + "type": "tool_use", + "id": block.get("id"), + "name": block.get("name"), + "input": {}, + } + delta = { + "type": "input_json_delta", + "partial_json": json.dumps(block.get("input", {}), ensure_ascii=False), + } + else: + start_block = {"type": "text", "text": ""} + delta = {"type": "text_delta", "text": str(block.get("text", ""))} + + yield encode_sse_event( + { + "type": "content_block_start", + "index": index, + "content_block": start_block, + }, + event="content_block_start", + ) + yield encode_sse_event( + { + "type": "content_block_delta", + "index": index, + "delta": delta, + }, + event="content_block_delta", + ) + yield encode_sse_event( + { + "type": "content_block_stop", + "index": index, + }, + event="content_block_stop", + ) + + yield encode_sse_event( + { + "type": "message_delta", + "delta": { + "stop_reason": self._stream_stop_reason(response.stop_reason), + "stop_sequence": response.stop_sequence, + }, + "usage": { + "output_tokens": response.usage.output_tokens, + }, + }, + event="message_delta", + ) + yield encode_sse_event({"type": "message_stop"}, event="message_stop") + + def canonical_response_to_protocol_response( + self, + canonical_response: CanonicalGenerateResponse, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + content = self._canonical_response_to_anthropic_blocks(canonical_response) + stop_reason = canonical_response.finish_reason or "stop" + if any(block.get("type") == "tool_use" for block in content): + stop_reason = "tool_use" + return AnthropicMessagesResponse( + id=f"msg_{canonical_response.request_id}", + content=content, + model=canonical_response.model or self._default_model_name or "rollout-controller", + stop_reason=stop_reason, + usage=AnthropicUsage( + input_tokens=canonical_response.usage.prompt_tokens, + output_tokens=canonical_response.usage.completion_tokens, + ), + ) + + def _build_internal_messages(self, request: AnthropicCountTokensRequest) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + if request.system: + if isinstance(request.system, str): + system_text = request.system + else: + system_text = self._join_text_blocks(request.system, context="system") + messages.append({"role": "system", "content": system_text}) + + for message in request.messages: + if isinstance(message.content, str): + messages.append({"role": message.role, "content": message.content}) + else: + messages.extend(self._convert_content_blocks_to_backend_messages(message.role, message.content)) + return messages + + def _join_text_blocks(self, blocks: list[dict[str, Any]], context: str) -> str: + unsupported_types = [str(block.get("type")) for block in blocks if block.get("type") != "text"] + if unsupported_types: + unsupported_str = ", ".join(sorted(set(unsupported_types))) + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type(s) in {context}: {unsupported_str}", + "invalid_request_error", + ) + return "\n".join(str(block.get("text", "")) for block in blocks) + + def _convert_content_blocks_to_backend_messages( + self, + role: str, + blocks: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + backend_messages: list[dict[str, Any]] = [] + text_chunks: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + + def flush_text_chunks() -> None: + if text_chunks: + backend_messages.append({"role": role, "content": "\n".join(text_chunks)}) + text_chunks.clear() + + for block in blocks: + block_type = block.get("type") + if block_type == "text": + text_value = str(block.get("text", "")) + if role == "assistant": + text_value = self._sanitize_assistant_text(text_value) + text_chunks.append(text_value) + elif block_type == "tool_use": + tool_calls.append( + { + "id": block.get("id") or f"toolu_{uuid4().hex}", + "type": "function", + "function": { + "name": str(block.get("name", "")), + "arguments": normalize_trace_payload(block.get("input", {})), + }, + } + ) + elif block_type == "tool_result": + flush_text_chunks() + backend_messages.append( + { + "role": "tool", + "content": self._serialize_tool_result_content(block.get("content")), + "tool_call_id": block.get("tool_use_id"), + } + ) + else: + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type in messages[{role}]: {block_type}", + "invalid_request_error", + ) + + if tool_calls: + backend_messages.append( + { + "role": role, + "content": "\n".join(text_chunks), + "tool_calls": tool_calls, + } + ) + text_chunks.clear() + flush_text_chunks() + return backend_messages + + def _serialize_tool_result_content(self, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + if all(isinstance(item, dict) and item.get("type") == "text" for item in content): + return "\n".join(str(item.get("text", "")) for item in content) + return json.dumps(content, ensure_ascii=False) + if isinstance(content, dict): + return json.dumps(content, ensure_ascii=False) + return str(content) + + def _normalize_tools_for_backend(self, tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + if not tools: + return None + normalized_tools = [] + for tool in tools: + if tool.get("type") == "function": + normalized_tools.append(normalize_trace_payload(tool)) + else: + normalized_tools.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool["input_schema"], + }, + } + ) + return normalize_trace_payload(normalized_tools) + + def _sanitize_assistant_text(self, text: str) -> str: + cleaned = text.replace("<|im_end|>", "") + cleaned = cleaned.replace("", "") + cleaned = cleaned.replace("", "") + return cleaned.strip() + + def _anthropic_system_to_canonical_message( + self, + system: str | list[dict[str, Any]], + ) -> CanonicalMessage: + if isinstance(system, str): + content = [CanonicalTextBlock(text=system)] if system else [] + else: + content = [] + for block in system: + if block.get("type") != "text": + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type(s) in system: {block.get('type')}", + "invalid_request_error", + ) + text = str(block.get("text", "")) + if text: + content.append(CanonicalTextBlock(text=text)) + return CanonicalMessage( + role="system", + content=content, + metadata={"source_protocol": "anthropic_messages"}, + ) + + def _anthropic_messages_to_canonical_messages( + self, + messages: list[AnthropicMessage], + ) -> list[CanonicalMessage]: + canonical_messages = [] + for message in messages: + if isinstance(message.content, str): + content_blocks = [CanonicalTextBlock(text=message.content)] if message.content else [] + else: + content_blocks = self._anthropic_content_blocks_to_canonical(message.content) + canonical_messages.append( + CanonicalMessage( + role=message.role, + content=content_blocks, + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + return canonical_messages + + def _anthropic_content_blocks_to_canonical( + self, + blocks: list[dict[str, Any]], + ) -> list[Any]: + canonical_blocks: list[Any] = [] + for block in blocks: + block_type = block.get("type") + if block_type == "text": + canonical_blocks.append(CanonicalTextBlock(text=str(block.get("text", "")))) + elif block_type == "tool_use": + canonical_blocks.append( + CanonicalToolCallBlock( + tool_call=CanonicalToolCall( + id=str(block.get("id") or f"toolu_{uuid4().hex}"), + name=str(block.get("name", "")), + arguments=normalize_trace_payload(block.get("input", {})), + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + ) + elif block_type == "tool_result": + content = block.get("content") + canonical_blocks.append( + CanonicalToolResultBlock( + tool_result=CanonicalToolResult( + tool_call_id=str(block.get("tool_use_id") or ""), + output=content, + output_text=self._serialize_tool_result_content(content), + is_error=bool(block.get("is_error", False)), + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + ) + elif block_type in {"reasoning", "thinking"}: + reasoning_text = str(block.get("text", "")) + canonical_blocks.append( + CanonicalReasoningBlock( + reasoning=CanonicalReasoning( + steps=[CanonicalReasoningStep(text=reasoning_text)] if reasoning_text else [], + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + ) + else: + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type in canonical mapping: {block_type}", + "invalid_request_error", + ) + return canonical_blocks + + def _anthropic_tools_to_canonical( + self, + tools: list[dict[str, Any]] | None, + ) -> list[CanonicalToolDefinition]: + if not tools: + return [] + canonical_tools = [] + for tool in tools: + if tool.get("type") == "function": + function_spec = tool.get("function", {}) + name = function_spec.get("name") + description = function_spec.get("description") + parameters = function_spec.get("parameters", {}) + else: + name = tool.get("name") + description = tool.get("description") + parameters = tool.get("input_schema", {}) + canonical_tools.append( + CanonicalToolDefinition( + name=str(name or ""), + description=description, + parameters_json_schema=parameters, + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + return canonical_tools + + def _anthropic_tool_choice_to_canonical( + self, + tool_choice: str | dict[str, Any] | None, + ) -> CanonicalToolChoice | None: + if tool_choice is None: + return None + if isinstance(tool_choice, str): + mapped_type = "required" if tool_choice == "any" else tool_choice + return CanonicalToolChoice(type=mapped_type) + choice_type = tool_choice.get("type") + if choice_type == "tool": + return CanonicalToolChoice( + type="specific", + tool_name=tool_choice.get("name"), + metadata={"source_protocol": "anthropic_messages"}, + ) + mapped_type = "required" if choice_type == "any" else str(choice_type or "auto") + return CanonicalToolChoice( + type=mapped_type, + metadata={"source_protocol": "anthropic_messages"}, + ) + + def _canonical_response_to_anthropic_blocks( + self, + response: CanonicalGenerateResponse, + ) -> list[dict[str, Any]]: + blocks: list[dict[str, Any]] = [] + for block in response.output.content: + if isinstance(block, CanonicalTextBlock): + if block.text: + blocks.append({"type": "text", "text": block.text}) + elif isinstance(block, CanonicalToolCallBlock): + blocks.append( + { + "type": "tool_use", + "id": block.tool_call.id, + "name": block.tool_call.name, + "input": block.tool_call.arguments if block.tool_call.arguments is not None else {}, + } + ) + elif isinstance(block, CanonicalToolResultBlock): + tool_result_content: Any = block.tool_result.output + if tool_result_content is None: + tool_result_content = block.tool_result.output_text or "" + blocks.append( + { + "type": "tool_result", + "tool_use_id": block.tool_result.tool_call_id, + "content": tool_result_content, + "is_error": block.tool_result.is_error, + } + ) + elif isinstance(block, CanonicalReasoningBlock): + reasoning_text = self._reasoning_to_text(block.reasoning) + if reasoning_text: + blocks.append({"type": "thinking", "thinking": reasoning_text}) + return blocks or [{"type": "text", "text": ""}] + + def _reasoning_to_text(self, reasoning: CanonicalReasoning) -> str: + return "\n".join(step.text for step in reasoning.steps if step.text).strip() + + def _stream_stop_reason(self, stop_reason: str | None) -> str | None: + if stop_reason == "stop": + return "end_turn" + if stop_reason == "length": + return "max_tokens" + return stop_reason diff --git a/xtuner/v1/rl/gateway/adapters/base.py b/xtuner/v1/rl/gateway/adapters/base.py new file mode 100644 index 0000000000..f0ff991438 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/base.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import json +import logging +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from typing import Any, Generic, TypeVar + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import Status + +from ..core.models import ( + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalReasoningBlock, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolResultBlock, +) +from .capture import append_gateway_capture_record, render_blocks_as_text +from .trace import ( + ChatTraceRecord, + ChatTraceStore, + build_api_key_trace_key, + normalize_trace_payload, + snapshot_routed_experts, +) + + +GenerateHandler = Callable[[CanonicalGenerateRequest], Awaitable[CanonicalGenerateResponse]] +RequestT = TypeVar("RequestT") +ResponseT = TypeVar("ResponseT") +logger = logging.getLogger(__name__) + + +def coerce_content_to_text(content: Any) -> str | None: + """Coerce arbitrary content (str, list of blocks, dict) to a plain + string.""" + if content is None: + return None + if isinstance(content, str): + return content + if isinstance(content, list): + text_chunks = [] + for item in content: + if isinstance(item, dict) and item.get("type") in {"text", "input_text", "output_text"}: + text_chunks.append(str(item.get("text", ""))) + joined = "\n".join(chunk for chunk in text_chunks if chunk) + return joined or None + if isinstance(content, dict) and "text" in content: + return str(content["text"]) + return str(content) + + +def stringify_tool_arguments(tool_call: CanonicalToolCall) -> str: + if tool_call.raw_arguments_text is not None: + return tool_call.raw_arguments_text + if isinstance(tool_call.arguments, str): + return tool_call.arguments + return json.dumps(tool_call.arguments if tool_call.arguments is not None else {}, ensure_ascii=False) + + +class BaseChatAPIAdapter(ABC, Generic[RequestT, ResponseT]): + def __init__( + self, + generate_handler: GenerateHandler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None, + *, + capture_folder: str | None = None, + trace_store: ChatTraceStore | None = None, + trace_store_max_entries: int = 10000, + ): + self._generate_handler = generate_handler + self._tokenizer = tokenizer + self._capture_folder = capture_folder + self._trace_store = trace_store or ChatTraceStore(max_entries=trace_store_max_entries) + + async def handle_request(self, request: RequestT, *, api_key: str | None = None) -> ResponseT: + self.validate_request(request) + canonical_request = self.request_to_canonical_request(request) + canonical_response = await self._generate_handler(canonical_request) + response = self.canonical_response_to_protocol_response(canonical_response, request) + record_trace_key = build_api_key_trace_key(api_key) + self._trace_store.append( + self._build_trace_record( + record_trace_key, + request, + response, + canonical_response, + ) + ) + self._write_capture_record( + request=request, + response=response, + canonical_response=canonical_response, + api_key=api_key, + ) + return response + + def get_trace_records(self, trace_key: str) -> list[ChatTraceRecord]: + return self._trace_store.get(trace_key) + + def pop_trace_records(self, trace_key: str) -> list[ChatTraceRecord]: + return self._trace_store.pop(trace_key) + + def clear_trace_records(self, trace_key: str) -> None: + self._trace_store.clear(trace_key) + + def _build_trace_record( + self, + trace_key: str, + request: RequestT, + response: ResponseT, + canonical_response: CanonicalGenerateResponse, + ) -> ChatTraceRecord: + request_snapshot = self.normalize_request(request) + response_snapshot = self.normalize_response(response) + rollout_trace = self._get_rollout_trace(canonical_response) + status = rollout_trace.get("status", Status.COMPLETED.value) + output_text = rollout_trace.get("output_text") or render_blocks_as_text( + self._build_output_message_list(canonical_response) + ) + return ChatTraceRecord( + trace_key=trace_key, + request_snapshot=request_snapshot, + response_snapshot=response_snapshot, + prompt_ids=list(rollout_trace.get("prompt_ids") or []), + response_ids=list(rollout_trace.get("response_ids") or []), + input_text=rollout_trace.get("input_text", ""), + output_text=output_text, + logprobs=rollout_trace.get("logprobs"), + routed_experts=snapshot_routed_experts(rollout_trace.get("routed_experts")), + finish_reason=rollout_trace.get("rollout_finish_reason") or canonical_response.finish_reason, + status=Status(status) if isinstance(status, str) else status, + request_id=canonical_response.request_id, + ) + + def _write_capture_record( + self, + request: RequestT, + response: ResponseT, + canonical_response: CanonicalGenerateResponse, + api_key: str | None = None, + ) -> None: + if self._capture_folder is None: + return + rollout_trace = self._get_rollout_trace(canonical_response) + try: + response_snapshot = self.normalize_response(response) + response_finish_reason = ( + response_snapshot.get("stop_reason") + or response_snapshot.get("finish_reason") + or canonical_response.finish_reason + ) + output_messages = self._build_output_message_list(canonical_response) + append_gateway_capture_record( + self._capture_folder, + { + "protocol": self.__class__.__name__, + "request_id": canonical_response.request_id, + "session_uid": rollout_trace.get("session_uid"), + "status": rollout_trace.get("status", Status.COMPLETED.value), + "finish_reason": response_finish_reason, + "rollout_finish_reason": rollout_trace.get("rollout_finish_reason"), + "prompt_tokens": canonical_response.usage.prompt_tokens, + "completion_tokens": canonical_response.usage.completion_tokens, + "request": self.normalize_request(request), + "response": response_snapshot, + "internal_messages": rollout_trace.get("internal_messages"), + "rollout_tools": rollout_trace.get("rollout_tools"), + "rollout_tool_choice": rollout_trace.get("rollout_tool_choice"), + "rollout_sample_params": rollout_trace.get("rollout_sample_params"), + "output_messages": output_messages, + "input_text": rollout_trace.get("input_text", ""), + "output_text": render_blocks_as_text(output_messages), + }, + api_key=api_key, + ) + except Exception: + logger.warning(f"Failed to write gateway capture record to {self._capture_folder}", exc_info=True) + return + + def _get_rollout_trace(self, canonical_response: CanonicalGenerateResponse) -> dict[str, Any]: + trace_payload = canonical_response.metadata.get("rollout_trace", {}) + if not isinstance(trace_payload, dict): + return {} + return trace_payload + + def _build_output_message_list( + self, + canonical_response: CanonicalGenerateResponse, + ) -> list[dict[str, Any]]: + content: list[dict[str, Any]] = [] + for block in canonical_response.output.content: + if isinstance(block, CanonicalTextBlock): + content.append({"type": "text", "text": block.text}) + elif isinstance(block, CanonicalReasoningBlock): + reasoning_text = "\n".join(step.text for step in block.reasoning.steps if step.text).strip() + if reasoning_text: + content.append({"type": "reasoning", "text": reasoning_text}) + elif isinstance(block, CanonicalToolCallBlock): + content.append( + { + "type": "tool_use", + "id": block.tool_call.id, + "name": block.tool_call.name, + "input": normalize_trace_payload(block.tool_call.arguments), + } + ) + elif isinstance(block, CanonicalToolResultBlock): + tool_result_content = block.tool_result.output + if tool_result_content is None: + tool_result_content = block.tool_result.output_text or "" + content.append( + { + "type": "tool_result", + "tool_use_id": block.tool_result.tool_call_id, + "content": normalize_trace_payload(tool_result_content), + } + ) + return [{"role": "assistant", "content": content or ""}] + + @abstractmethod + def validate_request(self, request: RequestT) -> None: + raise NotImplementedError + + @abstractmethod + def request_to_canonical_request(self, request: RequestT) -> CanonicalGenerateRequest: + raise NotImplementedError + + @abstractmethod + def normalize_request(self, request: RequestT) -> dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def normalize_response(self, response: ResponseT) -> dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def canonical_response_to_protocol_response( + self, + canonical_response: CanonicalGenerateResponse, + request: RequestT, + ) -> ResponseT: + raise NotImplementedError diff --git a/xtuner/v1/rl/gateway/adapters/capture.py b/xtuner/v1/rl/gateway/adapters/capture.py new file mode 100644 index 0000000000..00f6817415 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/capture.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json +import threading +from datetime import datetime, timezone +from hashlib import sha256 +from pathlib import Path +from typing import Any + + +_CAPTURE_LOCK = threading.RLock() +_NO_API_KEY_CAPTURE_FILE_NAME = "api_key_none.jsonl" + + +def resolve_capture_output_path(folder: str | Path, api_key: str | None = None) -> Path: + if not api_key: + return Path(folder) / _NO_API_KEY_CAPTURE_FILE_NAME + api_key_hash = sha256(api_key.encode("utf-8")).hexdigest()[:16] + return Path(folder) / f"api_key_{api_key_hash}.jsonl" + + +def append_gateway_capture_record(folder: str | Path, record: dict[str, Any], api_key: str | None = None) -> None: + capture_path = resolve_capture_output_path(folder, api_key=api_key) + capture_path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "type": "gateway_turn", + "timestamp": datetime.now(timezone.utc).isoformat(), + **record, + } + with _CAPTURE_LOCK: + with capture_path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, ensure_ascii=False) + "\n") + + +def render_blocks_as_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, list): + rendered_parts = [render_blocks_as_text(item) for item in value] + return "\n".join(part for part in rendered_parts if part) + if isinstance(value, dict): + block_type = value.get("type") + if block_type == "text": + return str(value.get("text", "")) + if block_type == "tool_use": + name = value.get("name", "") + input_payload = json.dumps(value.get("input", {}), ensure_ascii=False, sort_keys=True) + return f"{input_payload}" + if block_type == "tool_result": + tool_use_id = value.get("tool_use_id", "") + content = render_blocks_as_text(value.get("content")) + return f"{content}" + if "content" in value: + return render_blocks_as_text(value["content"]) + return json.dumps(value, ensure_ascii=False, sort_keys=True) + return str(value) diff --git a/xtuner/v1/rl/gateway/adapters/openai.py b/xtuner/v1/rl/gateway/adapters/openai.py new file mode 100644 index 0000000000..502d26bb99 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/openai.py @@ -0,0 +1,407 @@ +import json +import time +from collections.abc import AsyncIterator +from typing import Any +from uuid import uuid4 + +from fastapi.responses import StreamingResponse +from lmdeploy.serve.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + DeltaMessage, + UsageInfo, +) +from pydantic import BaseModel + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from ..core.models import ( + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalMessage, + CanonicalReasoningBlock, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResult, + CanonicalToolResultBlock, +) +from .base import BaseChatAPIAdapter, coerce_content_to_text, stringify_tool_arguments +from .streaming import build_sse_response, encode_sse_event +from .trace import ChatTraceStore, normalize_trace_payload + + +class OpenAIChatAdapterError(RuntimeError): + def __init__( + self, + message: str, + error_type: str, + code: str, + request_id: str | None = None, + ): + super().__init__(message) + self.message = message + self.error_type = error_type + self.code = code + self.request_id = request_id + + +class OpenAIChatAdapter(BaseChatAPIAdapter[ChatCompletionRequest, ChatCompletionResponse]): + def __init__( + self, + generate_handler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str, + default_model_name: str | None = None, + context_length: int | None = None, + capture_folder: str | None = None, + trace_store: ChatTraceStore | None = None, + trace_store_max_entries: int = 10000, + ): + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + super().__init__( + generate_handler, + tokenizer=tokenizer, + capture_folder=capture_folder, + trace_store=trace_store, + trace_store_max_entries=trace_store_max_entries, + ) + self._default_model_name = default_model_name + self._context_length = context_length + + async def chat( + self, + request: ChatCompletionRequest, + *, + api_key: str | None = None, + ) -> ChatCompletionResponse | StreamingResponse: + if request.stream: + response = await self.handle_request(request, api_key=api_key) + return build_sse_response(self.iter_stream_events(response, request)) + return await self.handle_request(request, api_key=api_key) + + def validate_request(self, request: ChatCompletionRequest) -> None: + if request.n not in (None, 1): + raise OpenAIChatAdapterError( + "n>1 is not supported yet", + "invalid_request_error", + "n_not_supported", + ) + + def request_to_canonical_request(self, request: ChatCompletionRequest) -> CanonicalGenerateRequest: + normalized_messages = normalize_trace_payload(request.messages) + normalized_tools = normalize_trace_payload(request.tools) + normalized_tool_choice = normalize_trace_payload(request.tool_choice) + stop = [] if request.stop is None else [request.stop] if isinstance(request.stop, str) else list(request.stop) + return CanonicalGenerateRequest( + request_id=f"chatcmpl_req_{uuid4().hex}", + model=request.model or self._default_model_name or "rollout-controller", + messages=[self._openai_message_to_canonical_message(message) for message in normalized_messages], + tools=self._openai_tools_to_canonical(normalized_tools), + tool_choice=self._openai_tool_choice_to_canonical(normalized_tool_choice), + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_completion_tokens + if request.max_completion_tokens is not None + else request.max_tokens, + stop=stop, + stream=False, + metadata={ + key: value + for key, value in { + "source_protocol": "openai_chat_completions", + "client_stream": bool(request.stream), + "session_uid": getattr(request, "session_uid", getattr(request, "session_id", None)), + "n": request.n, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + }.items() + if value is not None + }, + ) + + def canonical_response_to_chat_completion_response( + self, + response: CanonicalGenerateResponse, + ) -> ChatCompletionResponse: + message_content = self._render_openai_response_text(response) + reasoning_content = self._render_openai_reasoning_text(response) + tool_calls = self._canonical_tool_calls_to_openai(response) + finish_reason = response.finish_reason or ("tool_calls" if tool_calls else "stop") + return ChatCompletionResponse( + id=response.request_id, + created=int(time.time()), + model=response.model or self._default_model_name or "rollout-controller", + choices=[ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content=None if tool_calls and not message_content else message_content, + reasoning_content=reasoning_content, + tool_calls=tool_calls or None, + ), + finish_reason=finish_reason, + ) + ], + usage=UsageInfo( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + ), + ) + + def canonical_response_to_protocol_response( + self, + canonical_response: CanonicalGenerateResponse, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + return self.canonical_response_to_chat_completion_response(canonical_response) + + def normalize_request(self, request: ChatCompletionRequest) -> dict[str, Any]: + return normalize_trace_payload( + { + "messages": request.messages, + "tools": request.tools, + "tool_choice": request.tool_choice, + } + ) + + def normalize_response(self, response: ChatCompletionResponse) -> dict[str, Any]: + normalized_choices = [] + for choice in response.choices: + normalized_choices.append( + { + "message": getattr(choice.message, "model_dump", lambda **_: choice.message)( + mode="python", + exclude_none=True, + ) + if choice.message is not None + else None, + "finish_reason": choice.finish_reason, + } + ) + return normalize_trace_payload({"choices": normalized_choices}) + + async def iter_stream_events( + self, + response: ChatCompletionResponse, + request: ChatCompletionRequest, + ) -> AsyncIterator[str]: + choice = response.choices[0] + include_usage = bool(getattr(request.stream_options, "include_usage", False)) + + initial_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + ) + ], + ) + yield encode_sse_event(initial_chunk.model_dump(mode="json", exclude_none=True)) + + if choice.message.reasoning_content: + reasoning_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(reasoning_content=choice.message.reasoning_content), + ) + ], + ) + yield encode_sse_event(reasoning_chunk.model_dump(mode="json", exclude_none=True)) + + if choice.message.content: + content_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=choice.message.content), + ) + ], + ) + yield encode_sse_event(content_chunk.model_dump(mode="json", exclude_none=True)) + + for index, tool_call in enumerate(choice.message.tool_calls or []): + tool_call_id = tool_call.get("id") if isinstance(tool_call, dict) else getattr(tool_call, "id", None) + tool_call_type = ( + tool_call.get("type", "function") + if isinstance(tool_call, dict) + else getattr(tool_call, "type", "function") + ) + function_payload = ( + tool_call.get("function") if isinstance(tool_call, dict) else getattr(tool_call, "function", None) + ) + if isinstance(function_payload, BaseModel): + function_payload = function_payload.model_dump(mode="json", exclude_none=True) + tool_call_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage( + tool_calls=[ + { + "index": index, + "id": tool_call_id, + "type": tool_call_type, + "function": function_payload, + } + ] + ), + ) + ], + ) + yield encode_sse_event(tool_call_chunk.model_dump(mode="json", exclude_none=True)) + + final_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason=choice.finish_reason, + ) + ], + usage=response.usage if include_usage else None, + ) + yield encode_sse_event(final_chunk.model_dump(mode="json", exclude_none=True)) + yield encode_sse_event("[DONE]") + + def _openai_message_to_canonical_message(self, message: dict[str, Any]) -> CanonicalMessage: + role = str(message.get("role", "user")) + content_blocks: list[Any] = [] + if role == "tool": + content_blocks.append( + CanonicalToolResultBlock( + tool_result=CanonicalToolResult( + tool_call_id=str(message.get("tool_call_id") or ""), + name=message.get("name"), + output=message.get("content"), + output_text=coerce_content_to_text(message.get("content")), + metadata={"source_protocol": "openai_chat_completions"}, + ) + ) + ) + else: + content_text = coerce_content_to_text(message.get("content")) + if content_text: + content_blocks.append(CanonicalTextBlock(text=content_text)) + for tool_call in message.get("tool_calls") or []: + content_blocks.append(CanonicalToolCallBlock(tool_call=self._openai_tool_call_to_canonical(tool_call))) + return CanonicalMessage( + role=role if role in {"system", "user", "assistant", "tool"} else "user", + content=content_blocks, + name=message.get("name"), + metadata={ + key: value + for key, value in { + "source_protocol": "openai_chat_completions", + "tool_call_id": message.get("tool_call_id"), + }.items() + if value is not None + }, + ) + + def _openai_tools_to_canonical(self, tools: list[dict[str, Any]] | None) -> list[CanonicalToolDefinition]: + if not tools: + return [] + canonical_tools = [] + for tool in tools: + function_spec = tool.get("function", tool) + canonical_tools.append( + CanonicalToolDefinition( + name=str(function_spec.get("name", "")), + description=function_spec.get("description"), + parameters_json_schema=function_spec.get("parameters", {}), + metadata={"source_protocol": "openai_chat_completions"}, + ) + ) + return canonical_tools + + def _openai_tool_choice_to_canonical(self, tool_choice: Any) -> CanonicalToolChoice | None: + if tool_choice is None: + return None + if isinstance(tool_choice, str): + return CanonicalToolChoice(type=tool_choice) + function_spec = tool_choice.get("function") or {} + return CanonicalToolChoice( + type="specific", + tool_name=function_spec.get("name"), + metadata={"source_protocol": "openai_chat_completions"}, + ) + + def _openai_tool_call_to_canonical(self, tool_call: dict[str, Any]) -> CanonicalToolCall: + function_spec = tool_call.get("function") or {} + raw_arguments = function_spec.get("arguments") + parsed_arguments = self._parse_tool_arguments(raw_arguments) + metadata: dict[str, Any] = {"source_protocol": "openai_chat_completions"} + if isinstance(parsed_arguments, dict) and parsed_arguments.pop("__parse_error__", False): + metadata["arguments_parse_error"] = True + return CanonicalToolCall( + id=str(tool_call.get("id") or f"call_{uuid4().hex}"), + name=str(function_spec.get("name", "")), + arguments=parsed_arguments, + raw_arguments_text=raw_arguments if isinstance(raw_arguments, str) else None, + metadata=metadata, + ) + + def _canonical_tool_calls_to_openai(self, response: CanonicalGenerateResponse) -> list[dict[str, Any]]: + tool_calls = [] + for block in response.output.content: + if isinstance(block, CanonicalToolCallBlock): + tool_calls.append( + { + "id": block.tool_call.id, + "type": "function", + "function": { + "name": block.tool_call.name, + "arguments": stringify_tool_arguments(block.tool_call), + }, + } + ) + return tool_calls + + def _render_openai_response_text(self, response: CanonicalGenerateResponse) -> str | None: + text_chunks = [] + for block in response.output.content: + if isinstance(block, CanonicalTextBlock): + text_chunks.append(block.text) + joined = "".join(text_chunks).strip() + return joined or None + + def _render_openai_reasoning_text(self, response: CanonicalGenerateResponse) -> str | None: + reasoning_chunks: list[str] = [] + for block in response.output.content: + if isinstance(block, CanonicalReasoningBlock): + reasoning_chunks.extend(step.text for step in block.reasoning.steps if step.text) + joined = "\n".join(chunk for chunk in reasoning_chunks if chunk).strip() + return joined or None + + def _parse_tool_arguments(self, raw_arguments: Any) -> Any: + if not isinstance(raw_arguments, str): + return raw_arguments + try: + return json.loads(raw_arguments) + except Exception: + return {"__parse_error__": True, "raw": raw_arguments} diff --git a/xtuner/v1/rl/gateway/adapters/responses.py b/xtuner/v1/rl/gateway/adapters/responses.py new file mode 100644 index 0000000000..e0c1e74495 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/responses.py @@ -0,0 +1,587 @@ +from __future__ import annotations + +import json +import re +import time +from collections.abc import AsyncIterator +from typing import Any, Literal +from uuid import uuid4 + +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from ..core.models import ( + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalMessage, + CanonicalReasoning, + CanonicalReasoningBlock, + CanonicalReasoningStep, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResult, + CanonicalToolResultBlock, +) +from .base import BaseChatAPIAdapter, stringify_tool_arguments +from .openai import OpenAIChatAdapterError +from .streaming import build_sse_response, encode_sse_event +from .trace import ChatTraceStore, normalize_trace_payload + + +class ResponsesRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + session_uid: int | None = None + model: str | None = None + instructions: str | None = None + input: str | list[dict[str, Any]] | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + stream: bool = False + store: bool = False + parallel_tool_calls: bool | None = None + include: list[Any] | None = None + reasoning: dict[str, Any] | None = None + max_output_tokens: int | None = None + temperature: float | None = None + top_p: float | None = None + + +class ResponsesUsage(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + output_tokens: int + total_tokens: int + + +class ResponsesResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + object: Literal["response"] = "response" + created_at: int + status: Literal["completed"] = "completed" + model: str + output: list[dict[str, Any]] + output_text: str = "" + parallel_tool_calls: bool = False + store: bool = False + text: dict[str, Any] = {"format": {"type": "text"}} + usage: ResponsesUsage + + +class OpenAIResponsesAdapter(BaseChatAPIAdapter[ResponsesRequest, ResponsesResponse]): + _disabled_tool_names = { + "list_mcp_resources", + "list_mcp_resource_templates", + "read_mcp_resource", + "request_user_input", + } + + def __init__( + self, + generate_handler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None, + default_model_name: str | None = None, + context_length: int | None = None, + capture_folder: str | None = None, + trace_store: ChatTraceStore | None = None, + ): + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + super().__init__(generate_handler, tokenizer=tokenizer, capture_folder=capture_folder, trace_store=trace_store) + self._default_model_name = default_model_name + self._context_length = context_length + + async def responses( + self, + request: ResponsesRequest, + *, + api_key: str | None = None, + ) -> ResponsesResponse | StreamingResponse: + if request.stream: + response = await self.handle_request(request, api_key=api_key) + return build_sse_response(self.iter_stream_events(response)) + return await self.handle_request(request, api_key=api_key) + + def validate_request(self, request: ResponsesRequest) -> None: + return None + + def request_to_canonical_request(self, request: ResponsesRequest) -> CanonicalGenerateRequest: + return CanonicalGenerateRequest( + request_id=f"responses_req_{uuid4().hex}", + model=request.model or self._default_model_name or "rollout-controller", + messages=self._responses_input_to_canonical_messages(request), + tools=self._responses_tools_to_canonical(request.tools), + tool_choice=self._responses_tool_choice_to_canonical(request.tool_choice), + parallel_tool_calls=request.parallel_tool_calls, + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_output_tokens, + stream=False, + metadata={ + key: value + for key, value in { + "source_protocol": "openai_responses", + "client_stream": bool(request.stream), + "session_uid": request.session_uid, + "store": request.store, + "include": request.include, + "reasoning": request.reasoning, + }.items() + if value is not None + }, + ) + + def normalize_request(self, request: ResponsesRequest) -> dict[str, Any]: + return normalize_trace_payload(request.model_dump(mode="python", exclude_none=True)) + + def normalize_response(self, response: ResponsesResponse) -> dict[str, Any]: + return normalize_trace_payload(response.model_dump(mode="python", exclude_none=True)) + + async def iter_stream_events( + self, + response: ResponsesResponse, + ) -> AsyncIterator[str]: + created_response = response.model_dump(mode="json", exclude_none=True) + created_response["status"] = "in_progress" + + yield encode_sse_event( + { + "type": "response.created", + "response": created_response, + }, + event="response.created", + ) + yield encode_sse_event( + { + "type": "response.in_progress", + "response": created_response, + }, + event="response.in_progress", + ) + + for output_index, item in enumerate(response.output): + yield encode_sse_event( + { + "type": "response.output_item.added", + "output_index": output_index, + "item": item, + }, + event="response.output_item.added", + ) + + if item.get("type") == "message": + for content_index, part in enumerate(item.get("content", [])): + yield encode_sse_event( + { + "type": "response.content_part.added", + "output_index": output_index, + "content_index": content_index, + "item_id": item.get("id"), + "part": part, + }, + event="response.content_part.added", + ) + if part.get("type") == "output_text": + yield encode_sse_event( + { + "type": "response.output_text.delta", + "output_index": output_index, + "content_index": content_index, + "item_id": item.get("id"), + "delta": part.get("text", ""), + }, + event="response.output_text.delta", + ) + yield encode_sse_event( + { + "type": "response.output_text.done", + "output_index": output_index, + "content_index": content_index, + "item_id": item.get("id"), + "text": part.get("text", ""), + }, + event="response.output_text.done", + ) + yield encode_sse_event( + { + "type": "response.content_part.done", + "output_index": output_index, + "content_index": content_index, + "item_id": item.get("id"), + "part": part, + }, + event="response.content_part.done", + ) + + if item.get("type") == "function_call": + yield encode_sse_event( + { + "type": "response.function_call_arguments.delta", + "output_index": output_index, + "item_id": item.get("id"), + "delta": item.get("arguments", ""), + }, + event="response.function_call_arguments.delta", + ) + yield encode_sse_event( + { + "type": "response.function_call_arguments.done", + "output_index": output_index, + "item_id": item.get("id"), + "arguments": item.get("arguments", ""), + }, + event="response.function_call_arguments.done", + ) + + yield encode_sse_event( + { + "type": "response.output_item.done", + "output_index": output_index, + "item": item, + }, + event="response.output_item.done", + ) + + yield encode_sse_event( + { + "type": "response.completed", + "response": response.model_dump(mode="json", exclude_none=True), + }, + event="response.completed", + ) + + def canonical_response_to_protocol_response( + self, + canonical_response: CanonicalGenerateResponse, + request: ResponsesRequest, + ) -> ResponsesResponse: + output_items = self._canonical_response_to_responses_output_items(canonical_response) + output_text = "".join( + block.text for block in canonical_response.output.content if isinstance(block, CanonicalTextBlock) + ).strip() + return ResponsesResponse( + id=f"resp_{canonical_response.request_id}", + created_at=int(time.time()), + model=canonical_response.model or self._default_model_name or "rollout-controller", + output=output_items, + output_text=output_text, + parallel_tool_calls=bool( + request.parallel_tool_calls + if request is not None + else canonical_response.metadata.get("parallel_tool_calls") + ), + store=bool(request.store) if request is not None else False, + usage=ResponsesUsage( + input_tokens=canonical_response.usage.prompt_tokens, + output_tokens=canonical_response.usage.completion_tokens, + total_tokens=canonical_response.usage.total_tokens, + ), + ) + + def _normalize_input_role(self, role: Any) -> str: + if role in {"developer", "system"}: + return "system" + if role in {"assistant", "tool"}: + return str(role) + return "user" + + def _extract_message_item_text(self, content: Any) -> str: + if isinstance(content, str): + return content + if not isinstance(content, list): + return str(content) + text_chunks: list[str] = [] + for part in content: + part_type = part.get("type") + if part_type in {"input_text", "output_text", "text", "summary_text", "reasoning_text"}: + text_chunks.append(str(part.get("text", ""))) + return "\n".join(chunk for chunk in text_chunks if chunk) + + def _serialize_tool_output(self, output: Any, tool_name: str | None = None) -> str: + if output is None: + return "" + if isinstance(output, str): + return self._sanitize_tool_output_text(output, tool_name=tool_name) + if isinstance(output, list): + text_chunks = [str(part.get("text", "")) for part in output if isinstance(part, dict) and "text" in part] + if text_chunks: + return self._sanitize_tool_output_text("\n".join(text_chunks), tool_name=tool_name) + return json.dumps(output, ensure_ascii=False) + if isinstance(output, dict): + return json.dumps(output, ensure_ascii=False) + return str(output) + + def _sanitize_tool_output_text(self, text: str, tool_name: str | None = None) -> str: + if tool_name not in {"exec_command", "write_stdin"}: + return text + marker = "\nOutput:\n" + if marker in text: + prefix, body = text.split(marker, 1) + exit_code = self._extract_exec_exit_code(prefix) + body = body.strip() + if exit_code is None: + return body + if body: + return f"[exit_code={exit_code}]\n{body}" + return f"[exit_code={exit_code}]" + return text + + def _extract_exec_exit_code(self, text: str) -> int | None: + match = re.search(r"Process exited with code (\d+)", text) + if match is not None: + return int(match.group(1)) + return None + + def _responses_input_to_canonical_messages(self, request: ResponsesRequest) -> list[CanonicalMessage]: + messages: list[CanonicalMessage] = [] + if request.instructions: + messages.append( + CanonicalMessage( + role="system", + content=[CanonicalTextBlock(text=request.instructions)], + metadata={"source_protocol": "openai_responses"}, + ) + ) + if request.input is None: + return messages + if isinstance(request.input, str): + messages.append( + CanonicalMessage( + role="user", + content=[CanonicalTextBlock(text=request.input)], + metadata={"source_protocol": "openai_responses"}, + ) + ) + return messages + + tool_name_by_call_id: dict[str, str] = {} + for item in request.input: + item_type = item.get("type", "message") + if item_type == "message": + role = self._normalize_input_role(item.get("role")) + content_blocks = self._responses_message_content_to_canonical(item.get("content")) + messages.append( + CanonicalMessage( + role=role if role in {"system", "user", "assistant", "tool"} else "user", + content=content_blocks, + metadata={"source_protocol": "openai_responses"}, + ) + ) + elif item_type == "function_call": + call_id = str(item.get("call_id") or f"call_{uuid4().hex}") + tool_name = str(item.get("name", "")) + tool_name_by_call_id[call_id] = tool_name + messages.append( + CanonicalMessage( + role="assistant", + content=[ + CanonicalToolCallBlock( + tool_call=CanonicalToolCall( + id=call_id, + name=tool_name, + arguments=self._parse_json_string_or_mapping(item.get("arguments")), + raw_arguments_text=item.get("arguments") + if isinstance(item.get("arguments"), str) + else None, + metadata={"source_protocol": "openai_responses"}, + ) + ) + ], + metadata={"source_protocol": "openai_responses"}, + ) + ) + elif item_type == "function_call_output": + call_id = str(item.get("call_id") or "") + output = item.get("output") + messages.append( + CanonicalMessage( + role="tool", + content=[ + CanonicalToolResultBlock( + tool_result=CanonicalToolResult( + tool_call_id=call_id, + name=tool_name_by_call_id.get(call_id), + output=output, + output_text=self._serialize_tool_output( + output, tool_name=tool_name_by_call_id.get(call_id) + ), + metadata={"source_protocol": "openai_responses"}, + ) + ) + ], + metadata={"source_protocol": "openai_responses"}, + ) + ) + elif item_type == "reasoning": + reasoning_text = self._responses_reasoning_item_to_text(item) + messages.append( + CanonicalMessage( + role="assistant", + content=[ + CanonicalReasoningBlock( + reasoning=CanonicalReasoning( + steps=[CanonicalReasoningStep(text=reasoning_text)] if reasoning_text else [], + metadata={"source_protocol": "openai_responses"}, + ) + ) + ], + metadata={"source_protocol": "openai_responses"}, + ) + ) + return messages + + def _responses_message_content_to_canonical(self, content: Any) -> list[Any]: + if isinstance(content, str): + return [CanonicalTextBlock(text=content)] if content else [] + if not isinstance(content, list): + return [CanonicalTextBlock(text=str(content))] + + blocks: list[Any] = [] + unsupported_types: list[str] = [] + for part in content: + part_type = part.get("type") + if part_type in {"input_text", "output_text", "text"}: + text = str(part.get("text", "")) + if text: + blocks.append(CanonicalTextBlock(text=text)) + elif part_type in {"summary_text", "reasoning_text"}: + reasoning_text = str(part.get("text", "")) + if reasoning_text: + blocks.append( + CanonicalReasoningBlock( + reasoning=CanonicalReasoning( + steps=[CanonicalReasoningStep(text=reasoning_text)], + metadata={"source_protocol": "openai_responses"}, + ) + ) + ) + else: + unsupported_types.append(str(part_type)) + if unsupported_types: + unsupported_str = ", ".join(sorted(set(unsupported_types))) + raise OpenAIChatAdapterError( + f"Unsupported Responses content block type(s): {unsupported_str}", + "invalid_request_error", + "unsupported_content_block", + ) + return blocks + + def _responses_reasoning_item_to_text(self, item: dict[str, Any]) -> str: + content = item.get("content") + if isinstance(content, list): + chunks = [] + for part in content: + if isinstance(part, dict) and part.get("type") in {"reasoning_text", "summary_text", "text"}: + chunks.append(str(part.get("text", ""))) + if chunks: + return "\n".join(chunk for chunk in chunks if chunk) + summary = item.get("summary") + if isinstance(summary, list): + chunks = [str(part.get("text", "")) for part in summary if isinstance(part, dict)] + if chunks: + return "\n".join(chunk for chunk in chunks if chunk) + return str(item.get("text", "")) + + def _responses_tools_to_canonical(self, tools: list[dict[str, Any]] | None) -> list[CanonicalToolDefinition]: + if not tools: + return [] + canonical_tools = [] + for tool in tools: + if tool.get("type") != "function": + continue + tool_name = str(tool.get("name", "")) + if tool_name in self._disabled_tool_names: + continue + canonical_tools.append( + CanonicalToolDefinition( + name=tool_name, + description=tool.get("description"), + parameters_json_schema=tool.get("parameters", {}), + metadata={"source_protocol": "openai_responses"}, + ) + ) + return canonical_tools + + def _responses_tool_choice_to_canonical( + self, tool_choice: str | dict[str, Any] | None + ) -> CanonicalToolChoice | None: + if tool_choice is None: + return None + if isinstance(tool_choice, str): + return CanonicalToolChoice(type=tool_choice) + if tool_choice.get("type") == "function": + return CanonicalToolChoice( + type="specific", + tool_name=tool_choice.get("name"), + metadata={"source_protocol": "openai_responses"}, + ) + return CanonicalToolChoice( + type=str(tool_choice.get("type", "auto")), + metadata={"source_protocol": "openai_responses"}, + ) + + def _canonical_response_to_responses_output_items( + self, + response: CanonicalGenerateResponse, + ) -> list[dict[str, Any]]: + output_items: list[dict[str, Any]] = [] + for block in response.output.content: + if isinstance(block, CanonicalTextBlock): + output_items.append( + { + "id": f"msg_{uuid4().hex}", + "type": "message", + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": block.text, "annotations": []}], + } + ) + elif isinstance(block, CanonicalToolCallBlock): + output_items.append( + { + "id": f"fc_{uuid4().hex}", + "type": "function_call", + "status": "completed", + "call_id": block.tool_call.id, + "name": block.tool_call.name, + "arguments": stringify_tool_arguments(block.tool_call), + } + ) + elif isinstance(block, CanonicalToolResultBlock): + output_items.append( + { + "id": f"fco_{uuid4().hex}", + "type": "function_call_output", + "call_id": block.tool_result.tool_call_id, + "output": block.tool_result.output + if block.tool_result.output is not None + else block.tool_result.output_text, + } + ) + elif isinstance(block, CanonicalReasoningBlock): + reasoning_text = "\n".join(step.text for step in block.reasoning.steps if step.text).strip() + if reasoning_text: + output_items.append( + { + "id": f"rs_{uuid4().hex}", + "type": "reasoning", + "summary": [{"type": "summary_text", "text": reasoning_text}], + } + ) + return output_items + + def _parse_json_string_or_mapping(self, value: Any) -> Any: + if isinstance(value, str): + try: + return json.loads(value) + except Exception: + return {"raw": value} + return value or {} diff --git a/xtuner/v1/rl/gateway/adapters/streaming.py b/xtuner/v1/rl/gateway/adapters/streaming.py new file mode 100644 index 0000000000..41fad73cb7 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/streaming.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import json +from typing import Any + +from fastapi.responses import StreamingResponse + + +def encode_sse_event(data: Any, *, event: str | None = None) -> str: + if isinstance(data, str): + payload = data + else: + payload = json.dumps(data, ensure_ascii=False) + + lines: list[str] = [] + if event is not None: + lines.append(f"event: {event}") + if payload: + lines.extend(f"data: {line}" for line in payload.splitlines()) + else: + lines.append("data:") + return "\n".join(lines) + "\n\n" + + +def build_sse_response(event_iterator) -> StreamingResponse: + return StreamingResponse( + event_iterator, + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/xtuner/v1/rl/gateway/adapters/trace.py b/xtuner/v1/rl/gateway/adapters/trace.py new file mode 100644 index 0000000000..dca5ff4e3b --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/trace.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import hashlib +import threading +import time +from collections import OrderedDict +from collections.abc import Sequence +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +from pydantic import BaseModel + +from xtuner.v1.data_proto.rl_data import Status + + +DEFAULT_CHAT_TRACE_KEY = "__default__" + + +def build_api_key_trace_key(api_key: str | None) -> str: + if not api_key: + return DEFAULT_CHAT_TRACE_KEY + api_key_hash = hashlib.sha256(api_key.encode("utf-8")).hexdigest()[:16] + return f"api_key_{api_key_hash}" + + +def normalize_trace_payload(value: Any) -> Any: + if isinstance(value, BaseModel): + return normalize_trace_payload(value.model_dump(mode="python", exclude_none=True)) + if isinstance(value, dict): + return { + str(key): normalize_trace_payload(val) + for key, val in sorted(value.items(), key=lambda item: str(item[0])) + if val is not None + } + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [normalize_trace_payload(item) for item in value] + return value + + +def snapshot_routed_experts(routed_experts: Any) -> Any: + if routed_experts is None: + return None + try: + import ray + + if isinstance(routed_experts, ray.ObjectRef): + return routed_experts + except Exception: + pass + return deepcopy(routed_experts) + + +@dataclass +class ChatTraceRecord: + trace_key: str + request_snapshot: dict[str, Any] + response_snapshot: dict[str, Any] + prompt_ids: list[int] + response_ids: list[int] + input_text: str + output_text: str + logprobs: list[float] | None + routed_experts: Any + finish_reason: str | None + status: Status + sequence: int = -1 + created_at: float = 0.0 + request_id: str | None = None + + +class ChatTraceStore: + def __init__(self, max_entries: int = 10000): + self._max_entries = max_entries + self._records: OrderedDict[str, OrderedDict[int, ChatTraceRecord]] = OrderedDict() + self._record_order: OrderedDict[tuple[str, int], None] = OrderedDict() + self._next_sequence: dict[str, int] = {} + self._lock = threading.RLock() + + def append(self, record: ChatTraceRecord) -> ChatTraceRecord: + with self._lock: + sequence = self._next_sequence.get(record.trace_key, 0) + self._next_sequence[record.trace_key] = sequence + 1 + record.sequence = sequence + record.created_at = time.time() + records = self._records.setdefault(record.trace_key, OrderedDict()) + records[sequence] = record + self._record_order[(record.trace_key, sequence)] = None + self._evict_if_needed() + return record + + def get(self, trace_key: str) -> list[ChatTraceRecord]: + with self._lock: + records = self._records.get(trace_key) + if records is None: + return [] + return list(records.values()) + + def pop(self, trace_key: str) -> list[ChatTraceRecord]: + with self._lock: + records = self._records.pop(trace_key, None) + self._next_sequence.pop(trace_key, None) + if records is None: + return [] + for sequence in records: + self._record_order.pop((trace_key, sequence), None) + return list(records.values()) + + def clear(self, trace_key: str) -> None: + with self._lock: + records = self._records.pop(trace_key, None) + self._next_sequence.pop(trace_key, None) + if records is None: + return + for sequence in records: + self._record_order.pop((trace_key, sequence), None) + + def _evict_if_needed(self) -> None: + while len(self._record_order) > self._max_entries: + (trace_key, sequence), _ = self._record_order.popitem(last=False) + records = self._records.get(trace_key) + if records is None: + continue + records.pop(sequence, None) + if not records: + self._records.pop(trace_key, None) + self._next_sequence.pop(trace_key, None) diff --git a/xtuner/v1/rl/gateway/backend/__init__.py b/xtuner/v1/rl/gateway/backend/__init__.py new file mode 100644 index 0000000000..00a867413f --- /dev/null +++ b/xtuner/v1/rl/gateway/backend/__init__.py @@ -0,0 +1,8 @@ +from .local_backend import LocalRolloutBackend +from .protocol import GatewayBackend + + +__all__ = [ + "GatewayBackend", + "LocalRolloutBackend", +] diff --git a/xtuner/v1/rl/gateway/backend/local_backend.py b/xtuner/v1/rl/gateway/backend/local_backend.py new file mode 100644 index 0000000000..a2bab2c946 --- /dev/null +++ b/xtuner/v1/rl/gateway/backend/local_backend.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import json +from typing import Any +from uuid import uuid4 + +import ray +from ray.actor import ActorHandle + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto import RolloutState, RolloutToolCall, SampleParams, Status +from xtuner.v1.rl.rollout.parser.factory import build_tool_call_parser +from xtuner.v1.rl.rollout.worker import RolloutConfig + +from ..adapters.base import coerce_content_to_text +from ..adapters.trace import normalize_trace_payload +from ..core.exceptions import ContextLengthExceededError, ToolCallParseError +from ..core.models import ( + BackendHealth, + CanonicalAssistantTurn, + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalReasoning, + CanonicalReasoningBlock, + CanonicalReasoningStep, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResultBlock, + CanonicalUsage, + ModelCapabilities, + ModelCard, +) + + +class LocalRolloutBackend: + def __init__( + self, + controller: ActorHandle, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None = None, + ): + self._controller = controller + self._config = self._resolve_rollout_config(controller) + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + resolved_tokenizer = tokenizer + if resolved_tokenizer is None: + resolved_tokenizer = AutoTokenizer.from_pretrained( + self._config.tokenizer_path, + trust_remote_code=True, + ) + self._tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = resolved_tokenizer + self._tool_call_parser = build_tool_call_parser(self._config.tool_call_parser) + + async def generate(self, request: CanonicalGenerateRequest) -> CanonicalGenerateResponse: + rollout_state = self._canonical_request_to_rollout_state(request) + rollout_state = await self._controller.generate.remote(rollout_state) + self._raise_for_failed_rollout(rollout_state, request_id=str(rollout_state.uid)) + return self._rollout_state_to_canonical_response(rollout_state, request) + + async def health(self) -> BackendHealth: + ready, details = await self._controller.get_ready_status.remote() + return BackendHealth( + ready=ready, + status="ready" if ready else "unavailable", + details=details, + ) + + async def list_models(self) -> list[ModelCard]: + return [ + ModelCard( + id=self._model_name, + backend=self._config.rollout_backend, + context_length=self._config.context_length, + ) + ] + + async def get_capabilities(self) -> ModelCapabilities: + return ModelCapabilities( + model=self._model_name, + backend=self._config.rollout_backend, + context_length=self._config.context_length, + supports_stream=True, + supports_tools=True, + supports_cancel=False, + supports_parallel_tool_calls=True, + supports_reasoning=True, + ) + + async def cancel(self, request_id: str) -> dict[str, Any]: + return { + "request_id": request_id, + "cancelled": False, + "status": "not_supported", + } + + @property + def _model_name(self) -> str: + return self._config.model_name or "rollout-controller" + + def _resolve_rollout_config(self, controller: ActorHandle) -> RolloutConfig: + rollout_metadata = ray.get(controller.get_rollout_metadata.remote()) + return rollout_metadata["rollout_config"] + + def _canonical_request_to_rollout_state(self, canonical_request: CanonicalGenerateRequest) -> RolloutState: + internal_messages = self._canonical_messages_to_backend_messages(canonical_request.messages) + rollout_tools = self._canonical_tools_to_backend(canonical_request.tools) + rollout_tool_choice = self._canonical_tool_choice_to_backend(canonical_request.tool_choice) + prompt_ids = self._render_prompt_ids(internal_messages, rollout_tools) + max_tokens = self._fit_max_tokens_to_context(prompt_ids, canonical_request.max_tokens) + return RolloutState( + uid=uuid4().int, + message=internal_messages, + prompt_ids=prompt_ids, + tokens=prompt_ids, + session_uid=canonical_request.metadata.get("session_uid"), + tools=rollout_tools, + tool_choice=rollout_tool_choice, + sample_params=self._build_sample_params(canonical_request, max_tokens=max_tokens), + ) + + def _raise_for_failed_rollout(self, rollout_state: RolloutState, request_id: str) -> None: + if rollout_state.status == Status.FAILED: + raise RuntimeError(rollout_state.error_msg or f"Rollout generation failed for request {request_id}") + + def _rollout_state_to_canonical_response( + self, + rollout_state: RolloutState, + canonical_request: CanonicalGenerateRequest, + ) -> CanonicalGenerateResponse: + request_id = str(rollout_state.uid) + normal_text = rollout_state.response + tool_calls = [ + self._rollout_tool_call_to_canonical(tool_call) for tool_call in (rollout_state.tool_calls or []) + ] + self._raise_for_unparsed_tool_call_markup( + canonical_request=canonical_request, + normal_text=normal_text, + tool_calls=tool_calls, + ) + reasoning_text = None + if isinstance(rollout_state.extra_fields.get("reasoning_text"), str): + reasoning_text = rollout_state.extra_fields.get("reasoning_text") + content_blocks: list[Any] = [] + if reasoning_text: + content_blocks.append( + CanonicalReasoningBlock( + reasoning=CanonicalReasoning( + steps=[CanonicalReasoningStep(text=reasoning_text)], + metadata={"source_backend": "local_rollout"}, + ) + ) + ) + if normal_text: + content_blocks.append(CanonicalTextBlock(text=normal_text)) + for tool_call in tool_calls: + content_blocks.append(CanonicalToolCallBlock(tool_call=tool_call)) + + finish_reason = rollout_state.finish_reason or "stop" + if tool_calls and finish_reason == "stop": + finish_reason = "tool_calls" + + prompt_tokens = len(rollout_state.prompt_ids or []) + completion_tokens = self._count_completion_tokens(rollout_state) + metadata = { + "rollout_trace": self._build_rollout_trace_snapshot(rollout_state), + "parallel_tool_calls": canonical_request.parallel_tool_calls, + "source_backend": "local_rollout", + } + return CanonicalGenerateResponse( + request_id=request_id, + model=canonical_request.model or self._model_name, + output=CanonicalAssistantTurn(content=content_blocks), + finish_reason=finish_reason, + usage=CanonicalUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + metadata=metadata, + ) + + def _raise_for_unparsed_tool_call_markup( + self, + *, + canonical_request: CanonicalGenerateRequest, + normal_text: str | None, + tool_calls: list[CanonicalToolCall], + ) -> None: + if self._tool_call_parser is None: + return + if self._tool_call_parser.should_reject_unparsed_markup( + has_tools=bool(canonical_request.tools), + text=normal_text, + parsed_tool_calls=tool_calls, + ): + raise ToolCallParseError( + "Tool-enabled generation returned tool-call markup that could not be parsed into structured " + "tool calls." + ) + + def _canonical_messages_to_backend_messages(self, messages: list[Any]) -> list[dict[str, Any]]: + backend_messages: list[dict[str, Any]] = [] + for message in messages: + if message.role == "tool": + for block in message.content: + if isinstance(block, CanonicalToolResultBlock): + backend_messages.append( + { + "role": "tool", + "content": block.tool_result.output_text + if block.tool_result.output_text is not None + else coerce_content_to_text(block.tool_result.output), + "tool_call_id": block.tool_result.tool_call_id, + } + ) + continue + + text_chunks: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + for block in message.content: + if isinstance(block, CanonicalTextBlock): + if block.text: + text_chunks.append(block.text) + elif isinstance(block, CanonicalReasoningBlock): + reasoning_text = "\n".join(step.text for step in block.reasoning.steps if step.text).strip() + if reasoning_text: + text_chunks.append(reasoning_text) + elif isinstance(block, CanonicalToolCallBlock): + tool_calls.append( + { + "id": block.tool_call.id, + "type": "function", + "function": { + "name": block.tool_call.name, + "arguments": self._render_tool_arguments_for_template(block.tool_call), + }, + } + ) + payload: dict[str, Any] = {"role": message.role, "content": "\n".join(text_chunks)} + if message.name: + payload["name"] = message.name + if tool_calls: + payload["tool_calls"] = tool_calls + backend_messages.append(self._normalize_backend_message(payload)) + return backend_messages + + def _canonical_tools_to_backend(self, tools: list[CanonicalToolDefinition]) -> list[dict[str, Any]] | None: + if not tools: + return None + return normalize_trace_payload( + [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.parameters_json_schema, + }, + } + for tool in tools + ] + ) + + def _canonical_tool_choice_to_backend(self, tool_choice: CanonicalToolChoice | None) -> Any: + if tool_choice is None: + return None + if tool_choice.type == "specific": + return { + "type": "function", + "function": {"name": tool_choice.tool_name}, + } + return tool_choice.type + + def _render_prompt_ids( + self, + internal_messages: list[dict[str, Any]], + rollout_tools: list[dict[str, Any]] | None, + ) -> list[int] | None: + raw_prompt_ids = self._tokenizer.apply_chat_template( + internal_messages, + tools=rollout_tools, + tokenize=True, + add_generation_prompt=True, + ) + if hasattr(raw_prompt_ids, "get"): + return raw_prompt_ids.get("input_ids") + return list(raw_prompt_ids) + + def _build_sample_params( + self, + canonical_request: CanonicalGenerateRequest, + *, + max_tokens: int | None, + ) -> SampleParams: + kwargs = { + "return_token_ids": True, + "return_logprob": False, + "stream": canonical_request.stream, + "stops": canonical_request.stop, + **{ + key: value + for key, value in { + "n": canonical_request.metadata.get("n"), + "max_tokens": max_tokens if max_tokens is not None else canonical_request.max_tokens, + "temperature": canonical_request.temperature, + "top_p": canonical_request.top_p, + "presence_penalty": canonical_request.metadata.get("presence_penalty"), + "frequency_penalty": canonical_request.metadata.get("frequency_penalty"), + }.items() + if value is not None + }, + } + return SampleParams(**kwargs) + + def _fit_max_tokens_to_context( + self, + prompt_ids: list[int] | None, + requested_max_tokens: int | None, + ) -> int | None: + context_length = self._config.context_length + if context_length is None or prompt_ids is None or requested_max_tokens is None: + return requested_max_tokens + prompt_tokens = len(prompt_ids) + available_completion_tokens = context_length - prompt_tokens + if available_completion_tokens <= 0: + raise ContextLengthExceededError(prompt_tokens=prompt_tokens, context_length=context_length) + return min(requested_max_tokens, available_completion_tokens) + + def _count_completion_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.response_ids is not None: + return len(rollout_state.response_ids) + if rollout_state.response: + return len(self._tokenizer(rollout_state.response, add_special_tokens=False)["input_ids"]) + return 0 + + def _rollout_tool_call_to_canonical(self, tool_call: RolloutToolCall) -> CanonicalToolCall: + return CanonicalToolCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + raw_arguments_text=tool_call.function.raw_arguments_text, + ) + + def _build_rollout_trace_snapshot(self, rollout_state: RolloutState) -> dict[str, Any]: + return { + "session_uid": rollout_state.session_uid, + "status": rollout_state.status.value, + "rollout_finish_reason": rollout_state.finish_reason, + "prompt_ids": list(rollout_state.prompt_ids or []), + "response_ids": list(rollout_state.response_ids or []), + "logprobs": None if rollout_state.logprobs is None else list(rollout_state.logprobs), + "routed_experts": normalize_trace_payload(rollout_state.routed_experts), + "internal_messages": normalize_trace_payload(rollout_state.message), + "rollout_tools": normalize_trace_payload(rollout_state.tools), + "rollout_tool_choice": normalize_trace_payload(rollout_state.tool_choice), + "rollout_sample_params": normalize_trace_payload( + rollout_state.sample_params.model_dump(mode="python", exclude_none=True) + ), + "input_text": self._decode_prompt_ids(rollout_state), + "output_text": self._render_rollout_output_text(rollout_state), + } + + def _render_rollout_output_text(self, rollout_state: RolloutState) -> str: + parts = [] + if rollout_state.response: + parts.append(rollout_state.response) + for rollout_tool_call in rollout_state.tool_calls or []: + tool_call = self._rollout_tool_call_to_canonical(rollout_tool_call) + arguments = self._stringify_tool_arguments(tool_call) + parts.append(f"{arguments}") + return "\n".join(parts) + + def _decode_prompt_ids(self, rollout_state: RolloutState) -> str: + """Decode prompt token IDs to text without re-running the chat + template.""" + try: + return self._tokenizer.decode(rollout_state.prompt_ids or [], skip_special_tokens=False) + except Exception: + return "" + + def _stringify_tool_arguments(self, tool_call: CanonicalToolCall) -> str: + if tool_call.raw_arguments_text is not None: + return tool_call.raw_arguments_text + if isinstance(tool_call.arguments, str): + return tool_call.arguments + return json.dumps(tool_call.arguments if tool_call.arguments is not None else {}, ensure_ascii=False) + + def _render_tool_arguments_for_template(self, tool_call: CanonicalToolCall) -> dict[str, Any]: + arguments = tool_call.arguments + if isinstance(arguments, dict): + return arguments + if tool_call.raw_arguments_text is not None: + try: + decoded = json.loads(tool_call.raw_arguments_text) + except Exception: + return {"raw": tool_call.raw_arguments_text} + if isinstance(decoded, dict): + return decoded + return {"value": decoded} + if arguments is None: + return {} + if isinstance(arguments, str): + try: + decoded = json.loads(arguments) + except Exception: + return {"raw": arguments} + if isinstance(decoded, dict): + return decoded + return {"value": decoded} + return {"value": arguments} + + def _normalize_backend_message(self, payload: dict[str, Any]) -> dict[str, Any]: + """Normalize a backend message dict: remove None values and sort keys.""" + return { + str(key): val for key, val in sorted(payload.items(), key=lambda item: str(item[0])) if val is not None + } diff --git a/xtuner/v1/rl/gateway/backend/protocol.py b/xtuner/v1/rl/gateway/backend/protocol.py new file mode 100644 index 0000000000..5fbb2f1446 --- /dev/null +++ b/xtuner/v1/rl/gateway/backend/protocol.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from ..core.models import ( + BackendHealth, + CanonicalGenerateRequest, + CanonicalGenerateResponse, + ModelCapabilities, + ModelCard, +) + + +class GatewayBackend(Protocol): + async def generate(self, request: CanonicalGenerateRequest) -> CanonicalGenerateResponse: ... + + async def health(self) -> BackendHealth: ... + + async def list_models(self) -> list[ModelCard]: ... + + async def get_capabilities(self) -> ModelCapabilities: ... + + async def cancel(self, request_id: str) -> dict[str, Any]: ... diff --git a/xtuner/v1/rl/gateway/config.py b/xtuner/v1/rl/gateway/config.py new file mode 100644 index 0000000000..235e56b8d2 --- /dev/null +++ b/xtuner/v1/rl/gateway/config.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class GatewayConfig: + _CAPTURE_PATH_FOLDER = "gateway_captures" + """Configuration for the XTuner gateway HTTP server. + + Examples:: + + # Auto-start with RolloutController: + cfg = GatewayConfig(port=8080) + + # Opt-out of auto-start (start manually later): + cfg = GatewayConfig(port=8080, auto_start=False) + + # With request capture (writes one JSONL file per API key): + cfg = GatewayConfig(port=8080, capture_folder="/tmp/gateway_captures") + """ + + port: int + """TCP port to bind the server on.""" + + host: str = "0.0.0.0" + """Interface to bind the server on.""" + + auto_start: bool = True + """Whether to start the gateway automatically when the RolloutController + initialises. + + Set to False if you want to start the gateway manually via + :func:`~xtuner.v1.rl.gateway.serve_gateway` or + :meth:`~xtuner.v1.rl.rollout.controller.RolloutController.start_gateway`. + """ + + capture_folder: str | None = None + """Optional folder for writing per-request trace records. + + The gateway writes one JSONL file per API key inside this folder. If + omitted, this resolves to ``./worker_dirs/gateway_captures``; when started + by :class:`~xtuner.v1.rl.rollout.controller.RolloutController`, an omitted + value resolves relative to ``RolloutConfig.worker_log_dir`` instead. + """ + title: str = "XTuner Gateway" + """FastAPI application title shown in /docs.""" + + version: str = "0.1.0" + """FastAPI application version string.""" + + log_level: str = "warning" + """Uvicorn log level (debug/info/warning/error/critical).""" + + def __post_init__(self) -> None: + if self.capture_folder is None: + self.capture_folder = str(Path.cwd() / "worker_dirs" / self._CAPTURE_PATH_FOLDER) + print(f"GatewayConfig.capture_folder is not specified, use default capture_folder: {self.capture_folder}") diff --git a/xtuner/v1/rl/gateway/core/__init__.py b/xtuner/v1/rl/gateway/core/__init__.py new file mode 100644 index 0000000000..7e0210c672 --- /dev/null +++ b/xtuner/v1/rl/gateway/core/__init__.py @@ -0,0 +1,49 @@ +from .exceptions import ContextLengthExceededError, GatewayError, GatewayStateError, ModelNotFoundError +from .models import ( + BackendHealth, + CanonicalAssistantTurn, + CanonicalContentBlock, + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalMessage, + CanonicalReasoning, + CanonicalReasoningBlock, + CanonicalReasoningStep, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResult, + CanonicalToolResultBlock, + CanonicalUsage, + ModelCapabilities, + ModelCard, +) + + +__all__ = [ + "BackendHealth", + "CanonicalAssistantTurn", + "CanonicalContentBlock", + "CanonicalGenerateRequest", + "CanonicalGenerateResponse", + "CanonicalMessage", + "CanonicalReasoning", + "CanonicalReasoningBlock", + "CanonicalReasoningStep", + "CanonicalTextBlock", + "CanonicalToolCall", + "CanonicalToolCallBlock", + "CanonicalToolChoice", + "CanonicalToolDefinition", + "CanonicalToolResult", + "CanonicalToolResultBlock", + "CanonicalUsage", + "ContextLengthExceededError", + "GatewayError", + "GatewayStateError", + "ModelCapabilities", + "ModelCard", + "ModelNotFoundError", +] diff --git a/xtuner/v1/rl/gateway/core/exceptions.py b/xtuner/v1/rl/gateway/core/exceptions.py new file mode 100644 index 0000000000..7d52f2e626 --- /dev/null +++ b/xtuner/v1/rl/gateway/core/exceptions.py @@ -0,0 +1,28 @@ +class GatewayError(RuntimeError): + """Base exception for gateway failures.""" + + +class GatewayStateError(GatewayError): + """Raised when the gateway app is missing required runtime state.""" + + +class ModelNotFoundError(GatewayError): + """Raised when a requested model is not exposed by the backend.""" + + def __init__(self, model: str): + super().__init__(f"Model '{model}' is not available.") + self.model = model + + +class ContextLengthExceededError(GatewayError): + """Raised when the prompt is too long for the model's context window.""" + + def __init__(self, prompt_tokens: int, context_length: int): + super().__init__(f"Input is too long: prompt_tokens={prompt_tokens}, context_length={context_length}.") + self.prompt_tokens = prompt_tokens + self.context_length = context_length + + +class ToolCallParseError(GatewayError): + """Raised when a tool-enabled response contains tool-call markup that could + not be parsed into structured tool calls.""" diff --git a/xtuner/v1/rl/gateway/core/models.py b/xtuner/v1/rl/gateway/core/models.py new file mode 100644 index 0000000000..976011752a --- /dev/null +++ b/xtuner/v1/rl/gateway/core/models.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from typing import Annotated, Any, Literal, TypeAlias + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class GatewayCoreModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class CanonicalToolDefinition(GatewayCoreModel): + name: str + description: str | None = None + parameters_json_schema: dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalToolChoice(GatewayCoreModel): + type: Literal["auto", "none", "required", "specific"] = "auto" + tool_name: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def validate_specific_choice(self) -> CanonicalToolChoice: + if self.type == "specific" and not self.tool_name: + raise ValueError("tool_name is required when tool choice type is 'specific'.") + return self + + +class CanonicalToolCall(GatewayCoreModel): + id: str + name: str + arguments: Any = None + raw_arguments_text: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalToolResult(GatewayCoreModel): + tool_call_id: str + name: str | None = None + output: Any = None + output_text: str | None = None + is_error: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalReasoningStep(GatewayCoreModel): + text: str + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalReasoning(GatewayCoreModel): + steps: list[CanonicalReasoningStep] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalTextBlock(GatewayCoreModel): + type: Literal["text"] = "text" + text: str + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalToolCallBlock(GatewayCoreModel): + type: Literal["tool_call"] = "tool_call" + tool_call: CanonicalToolCall + + +class CanonicalToolResultBlock(GatewayCoreModel): + type: Literal["tool_result"] = "tool_result" + tool_result: CanonicalToolResult + + +class CanonicalReasoningBlock(GatewayCoreModel): + type: Literal["reasoning"] = "reasoning" + reasoning: CanonicalReasoning + + +CanonicalContentBlock: TypeAlias = Annotated[ + CanonicalTextBlock | CanonicalToolCallBlock | CanonicalToolResultBlock | CanonicalReasoningBlock, + Field(discriminator="type"), +] + + +class CanonicalMessage(GatewayCoreModel): + role: Literal["system", "user", "assistant", "tool"] + content: list[CanonicalContentBlock] = Field(default_factory=list) + name: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalGenerateRequest(GatewayCoreModel): + request_id: str + model: str + messages: list[CanonicalMessage] = Field(default_factory=list) + tools: list[CanonicalToolDefinition] = Field(default_factory=list) + tool_choice: CanonicalToolChoice | None = None + parallel_tool_calls: bool | None = None + temperature: float | None = None + top_p: float | None = None + max_tokens: int | None = None + stop: list[str] = Field(default_factory=list) + stream: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalUsage(GatewayCoreModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class CanonicalAssistantTurn(GatewayCoreModel): + role: Literal["assistant"] = "assistant" + content: list[CanonicalContentBlock] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalGenerateResponse(GatewayCoreModel): + request_id: str + model: str + output: CanonicalAssistantTurn + finish_reason: str = "stop" + usage: CanonicalUsage = Field(default_factory=CanonicalUsage) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ModelCard(GatewayCoreModel): + id: str + backend: str + context_length: int | None = None + owned_by: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ModelCapabilities(GatewayCoreModel): + model: str + backend: str + context_length: int | None = None + supports_stream: bool = False + supports_tools: bool = False + supports_cancel: bool = False + supports_parallel_tool_calls: bool = False + supports_reasoning: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class BackendHealth(GatewayCoreModel): + ready: bool + status: Literal["ready", "degraded", "unavailable"] + details: dict[str, Any] = Field(default_factory=dict) + reason: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) diff --git a/xtuner/v1/rl/gateway/server/__init__.py b/xtuner/v1/rl/gateway/server/__init__.py new file mode 100644 index 0000000000..f4f09a5377 --- /dev/null +++ b/xtuner/v1/rl/gateway/server/__init__.py @@ -0,0 +1,4 @@ +from .app import build_gateway_app, build_local_gateway_app, serve_gateway, serve_gateway_in_thread + + +__all__ = ["build_gateway_app", "build_local_gateway_app", "serve_gateway", "serve_gateway_in_thread"] diff --git a/xtuner/v1/rl/gateway/server/app.py b/xtuner/v1/rl/gateway/server/app.py new file mode 100644 index 0000000000..56cb3ea1ad --- /dev/null +++ b/xtuner/v1/rl/gateway/server/app.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import socket +import threading +from typing import Union + +import ray +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from ray.actor import ActorHandle + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.rl.rollout.worker import RolloutConfig + +from ..adapters import AnthropicChatAdapter, ChatTraceStore, OpenAIChatAdapter +from ..adapters.responses import OpenAIResponsesAdapter +from ..backend.local_backend import LocalRolloutBackend +from ..backend.protocol import GatewayBackend +from ..config import GatewayConfig +from ..core.exceptions import ContextLengthExceededError, GatewayError, ToolCallParseError +from .routes import ( + build_anthropic_router, + build_openai_router, + build_responses_router, + build_runtime_router, + build_trace_store_router, +) + + +TokenizerArg = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str] + + +# --------------------------------------------------------------------------- +# Internal base builder +# --------------------------------------------------------------------------- + + +def _create_base_gateway_app( + backend: GatewayBackend, + *, + title: str = "XTuner Gateway", + version: str = "0.1.0", +) -> FastAPI: + """Create the base FastAPI app with runtime routes and global error + handlers. + + This is an internal builder used by higher-level factory functions. The returned app exposes /livez, /readyz, and + /capabilities but no protocol-specific endpoints. + """ + app = FastAPI(title=title, version=version) + app.state.gateway_backend = backend + app.include_router(build_runtime_router()) + + @app.exception_handler(ContextLengthExceededError) + async def context_length_error_handler(request: Request, exc: ContextLengthExceededError) -> JSONResponse: + return JSONResponse( + status_code=400, + content={"error": {"message": str(exc), "type": "context_length_exceeded", "code": "context_too_long"}}, + ) + + @app.exception_handler(GatewayError) + async def gateway_error_handler(request: Request, exc: GatewayError) -> JSONResponse: + return JSONResponse( + status_code=500, + content={"error": {"message": str(exc), "type": type(exc).__name__, "code": "gateway_error"}}, + ) + + @app.exception_handler(ToolCallParseError) + async def tool_call_parse_error_handler(request: Request, exc: ToolCallParseError) -> JSONResponse: + return JSONResponse( + status_code=400, + content={"error": {"message": str(exc), "type": "tool_call_parse_error", "code": "tool_call_parse_error"}}, + ) + + @app.exception_handler(Exception) + async def generic_error_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=500, + content={"error": {"message": str(exc), "type": "internal_error", "code": "internal_server_error"}}, + ) + + return app + + +# --------------------------------------------------------------------------- +# Generic public factory (works with any GatewayBackend) +# --------------------------------------------------------------------------- + + +def build_gateway_app( + backend: GatewayBackend, + *, + tokenizer: TokenizerArg, + model_name: str, + context_length: int, + config: GatewayConfig | None = None, +) -> FastAPI: + """Build a gateway FastAPI app wired to *any* :class:`GatewayBackend`. + + This is the lowest-level public factory. Use this when you have a custom + backend (e.g. a future ``RemoteRolloutBackend``) and want to wire it into + the full gateway stack (OpenAI / Anthropic / Responses endpoints). + + Args: + backend: An object that satisfies the :class:`~xtuner.v1.rl.gateway.backend.protocol.GatewayBackend` protocol. + tokenizer: Tokenizer used for prompt encoding and token-count helpers. + Accepts a :class:`~transformers.PreTrainedTokenizer`, + :class:`~transformers.PreTrainedTokenizerFast`, or a **string** + path/identifier which will be loaded via + :func:`~transformers.AutoTokenizer.from_pretrained`. + model_name: Default model name reported by the ``/capabilities`` endpoint. + context_length: Maximum context length enforced by the gateway. + config: Gateway configuration (title, version, capture_folder, ...). + Defaults to a bare :class:`~xtuner.v1.rl.gateway.config.GatewayConfig` + with ``port=8080`` when not provided. + + Returns: + A fully-configured :class:`fastapi.FastAPI` instance ready to serve. + """ + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + cfg = config or GatewayConfig(port=8080) + app = _create_base_gateway_app(backend, title=cfg.title, version=cfg.version) + app.state.gateway_trace_store = ChatTraceStore() + adapter_kwargs = { + "generate_handler": backend.generate, + "tokenizer": tokenizer, + "default_model_name": model_name, + "context_length": context_length, + "capture_folder": cfg.capture_folder, + "trace_store": app.state.gateway_trace_store, + } + app.state.gateway_openai_adapter = OpenAIChatAdapter(**adapter_kwargs) + app.state.gateway_anthropic_adapter = AnthropicChatAdapter(**adapter_kwargs) + app.state.gateway_responses_adapter = OpenAIResponsesAdapter(**adapter_kwargs) + app.include_router(build_openai_router()) + app.include_router(build_anthropic_router()) + app.include_router(build_responses_router()) + app.include_router(build_trace_store_router()) + return app + + +# --------------------------------------------------------------------------- +# LocalRolloutBackend convenience factory +# --------------------------------------------------------------------------- + + +def build_local_gateway_app( + controller: ActorHandle, + config: GatewayConfig | None = None, +) -> FastAPI: + """Build a gateway app backed by a Ray-actor RolloutController.""" + cfg = config or GatewayConfig(port=8080) + rollout_metadata = ray.get(controller.get_rollout_metadata.remote()) + rollout_config: RolloutConfig = rollout_metadata["rollout_config"] + tokenizer = AutoTokenizer.from_pretrained(rollout_config.tokenizer_path, trust_remote_code=True) + + model_name = rollout_config.model_name + if model_name is None: + raise ValueError("controller.config.model_name must be set when building a local gateway app") + context_length = rollout_config.context_length + if context_length is None: + raise ValueError("controller.config.context_length must be set when building a local gateway app") + + backend = LocalRolloutBackend(controller, tokenizer=tokenizer) + return build_gateway_app( + backend, + tokenizer=tokenizer, + model_name=model_name, + context_length=context_length, + config=cfg, + ) + + +# --------------------------------------------------------------------------- +# Standalone serve helpers +# --------------------------------------------------------------------------- + + +def serve_gateway(app: FastAPI, config: GatewayConfig) -> None: + """Start the gateway server in the **current thread** (blocking). + + Use this for a fully standalone gateway process:: + + from xtuner.v1.rl.gateway import ( + GatewayConfig, build_local_gateway_app, serve_gateway + ) + + config = GatewayConfig(port=8080, auto_start=False) + app = build_local_gateway_app(controller, config) + serve_gateway(app, config) # blocks until interrupted + + For a custom backend:: + + from xtuner.v1.rl.gateway import ( + GatewayConfig, build_gateway_app, serve_gateway + ) + + config = GatewayConfig(port=8080) + app = build_gateway_app( + my_backend, + tokenizer=tokenizer, + model_name="my-model", + context_length=32768, + config=config, + ) + serve_gateway(app, config) + + Args: + app: A FastAPI application previously built by :func:`build_gateway_app` + or :func:`build_local_gateway_app`. + config: Gateway configuration supplying ``host``, ``port``, and + ``log_level``. + """ + _ensure_gateway_port_available(config) + uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level) + + +def serve_gateway_in_thread(app: FastAPI, config: GatewayConfig) -> threading.Thread: + """Start the gateway server in a **daemon thread** (non-blocking). + + Returns the :class:`threading.Thread` that is running uvicorn so callers + can monitor it if needed. The thread is daemonised so it will not prevent + the process from exiting. + + Args: + app: A FastAPI application previously built by :func:`build_gateway_app` + or :func:`build_local_gateway_app`. + config: Gateway configuration supplying ``host``, ``port``, and + ``log_level``. + + Returns: + The started daemon thread. + """ + thread = threading.Thread( + target=serve_gateway, + args=(app, config), + daemon=True, + name="gateway-server", + ) + thread.start() + return thread + + +def _ensure_gateway_port_available(config: GatewayConfig) -> None: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind((config.host, config.port)) + return + except OSError: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind((config.host, 0)) + config.port = int(sock.getsockname()[1]) diff --git a/xtuner/v1/rl/gateway/server/routes.py b/xtuner/v1/rl/gateway/server/routes.py new file mode 100644 index 0000000000..b80217170f --- /dev/null +++ b/xtuner/v1/rl/gateway/server/routes.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, cast + +from fastapi import APIRouter, Depends, Query, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from ..adapters import ( + AnthropicChatAdapter, + AnthropicChatAdapterError, + AnthropicMessagesRequest, + ChatCompletionRequest, + ChatTraceRecord, + ChatTraceStore, + OpenAIChatAdapter, + OpenAIChatAdapterError, + ResponsesRequest, + build_api_key_trace_key, +) +from ..adapters.responses import OpenAIResponsesAdapter +from ..backend.protocol import GatewayBackend +from ..core.exceptions import GatewayStateError + + +def get_openai_adapter(request: Request) -> OpenAIChatAdapter: + adapter = getattr(request.app.state, "gateway_openai_adapter", None) + if adapter is None: + raise GatewayStateError("Gateway OpenAI adapter is not configured.") + return cast(OpenAIChatAdapter, adapter) + + +def get_anthropic_adapter(request: Request) -> AnthropicChatAdapter: + adapter = getattr(request.app.state, "gateway_anthropic_adapter", None) + if adapter is None: + raise GatewayStateError("Gateway Anthropic adapter is not configured.") + return cast(AnthropicChatAdapter, adapter) + + +def get_responses_adapter(request: Request) -> OpenAIResponsesAdapter: + adapter = getattr(request.app.state, "gateway_responses_adapter", None) + if adapter is None: + raise GatewayStateError("Gateway Responses adapter is not configured.") + return cast(OpenAIResponsesAdapter, adapter) + + +def extract_api_key(request: Request) -> str | None: + authorization = request.headers.get("authorization") + if authorization: + scheme, _, credentials = authorization.partition(" ") + if scheme.lower() == "bearer" and credentials.strip(): + return credentials.strip() + if authorization.strip(): + return authorization.strip() + + api_key = request.headers.get("x-api-key") or request.headers.get("api-key") + if api_key and api_key.strip(): + return api_key.strip() + return None + + +# --------------------------------------------------------------------------- +# Runtime router (/livez, /readyz, /capabilities) +# --------------------------------------------------------------------------- + + +def build_runtime_router() -> APIRouter: + router = APIRouter() + + @router.get("/livez") + async def livez() -> dict[str, str]: + return {"status": "ok"} + + @router.get("/readyz") + async def readyz(request: Request): + backend = _get_backend(request) + health = await backend.health() + payload = health.model_dump(mode="json") + if health.ready: + return payload + return JSONResponse(status_code=503, content=payload) + + @router.get("/capabilities") + async def get_capabilities(request: Request): + backend = _get_backend(request) + capabilities = await backend.get_capabilities() + return capabilities.model_dump(mode="json") + + return router + + +def _get_backend(request: Request) -> GatewayBackend: + backend = getattr(request.app.state, "gateway_backend", None) + if backend is None: + raise GatewayStateError("Gateway backend is not configured.") + return cast(GatewayBackend, backend) + + +# --------------------------------------------------------------------------- +# Trace store router (/trace_store) +# --------------------------------------------------------------------------- + + +def build_trace_store_router() -> APIRouter: + router = APIRouter() + + @router.get("/trace_store") + async def get_trace_records( + request: Request, + trace_key: str | None = Query(default=None), + ) -> dict: + trace_store = _get_trace_store(request) + resolved_trace_key = _resolve_trace_key(request, trace_key) + records = trace_store.get(resolved_trace_key) + return _build_trace_store_response(resolved_trace_key, records) + + @router.post("/trace_store/pop") + async def pop_trace_records( + request: Request, + trace_key: str | None = Query(default=None), + ) -> dict: + trace_store = _get_trace_store(request) + resolved_trace_key = _resolve_trace_key(request, trace_key) + records = trace_store.pop(resolved_trace_key) + return _build_trace_store_response(resolved_trace_key, records) + + @router.post("/trace_store/clear") + async def clear_trace_records( + request: Request, + trace_key: str | None = Query(default=None), + ) -> dict: + trace_store = _get_trace_store(request) + resolved_trace_key = _resolve_trace_key(request, trace_key) + trace_store.clear(resolved_trace_key) + return { + "trace_key": resolved_trace_key, + "cleared": True, + } + + return router + + +def _get_trace_store(request: Request) -> ChatTraceStore: + trace_store = getattr(request.app.state, "gateway_trace_store", None) + if trace_store is None: + raise GatewayStateError("Gateway trace store is not configured.") + return cast(ChatTraceStore, trace_store) + + +def _resolve_trace_key(request: Request, trace_key: str | None) -> str: + if trace_key: + return trace_key + return build_api_key_trace_key(extract_api_key(request)) + + +def _build_trace_store_response(trace_key: str, records: list[ChatTraceRecord]) -> dict[str, Any]: + return { + "trace_key": trace_key, + "count": len(records), + "records": [_serialize_trace_record(record) for record in records], + } + + +def _serialize_trace_record(record: ChatTraceRecord) -> dict[str, Any]: + return { + "trace_key": record.trace_key, + "request_snapshot": _serialize_trace_value(record.request_snapshot), + "response_snapshot": _serialize_trace_value(record.response_snapshot), + "prompt_ids": list(record.prompt_ids), + "response_ids": list(record.response_ids), + "input_text": record.input_text, + "output_text": record.output_text, + "logprobs": _serialize_trace_value(record.logprobs), + "routed_experts": _serialize_trace_value(record.routed_experts), + "finish_reason": record.finish_reason, + "status": _serialize_trace_value(record.status), + "sequence": record.sequence, + "created_at": record.created_at, + "request_id": record.request_id, + } + + +def _serialize_trace_value(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, BaseModel): + try: + return _serialize_trace_value(value.model_dump(mode="json", exclude_none=True)) + except Exception: + return _serialize_trace_value(value.model_dump(mode="python", exclude_none=True)) + if isinstance(value, Enum): + return value.value + if isinstance(value, dict): + return {str(key): _serialize_trace_value(val) for key, val in value.items() if val is not None} + if isinstance(value, (list, tuple, set)): + return [_serialize_trace_value(item) for item in value] + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + try: + import ray + + if isinstance(value, ray.ObjectRef): + return str(value) + except Exception: + pass + if hasattr(value, "tolist"): + try: + return _serialize_trace_value(value.tolist()) + except Exception: + pass + return str(value) + + +# --------------------------------------------------------------------------- +# OpenAI Chat Completions router (/v1/chat/completions) +# --------------------------------------------------------------------------- + + +def build_openai_router() -> APIRouter: + router = APIRouter() + + @router.post("/v1/chat/completions") + async def chat_completions( + request_body: ChatCompletionRequest, + request: Request, + adapter: OpenAIChatAdapter = Depends(get_openai_adapter), + ): + try: + return await adapter.chat(request_body, api_key=extract_api_key(request)) + except OpenAIChatAdapterError as exc: + return JSONResponse( + status_code=400 if exc.error_type == "invalid_request_error" else 500, + content={"error": {"message": exc.message, "type": exc.error_type, "code": exc.code}}, + ) + + return router + + +# --------------------------------------------------------------------------- +# Anthropic Messages router (/v1/messages) +# --------------------------------------------------------------------------- + + +def build_anthropic_router() -> APIRouter: + router = APIRouter() + + @router.post("/v1/messages") + async def messages( + request_body: AnthropicMessagesRequest, + request: Request, + adapter: AnthropicChatAdapter = Depends(get_anthropic_adapter), + ): + try: + return await adapter.messages(request_body, api_key=extract_api_key(request)) + except AnthropicChatAdapterError as exc: + return JSONResponse( + status_code=400 if exc.error_type == "invalid_request_error" else 500, + content={"type": "error", "error": {"type": exc.error_type, "message": exc.message}}, + ) + + return router + + +# --------------------------------------------------------------------------- +# OpenAI Responses router (/v1/responses) +# --------------------------------------------------------------------------- + + +def build_responses_router() -> APIRouter: + router = APIRouter() + + @router.post("/v1/responses") + async def responses( + request_body: ResponsesRequest, + request: Request, + adapter: OpenAIResponsesAdapter = Depends(get_responses_adapter), + ): + try: + return await adapter.responses(request_body, api_key=extract_api_key(request)) + except OpenAIChatAdapterError as exc: + return JSONResponse( + status_code=400 if exc.error_type == "invalid_request_error" else 500, + content={"error": {"message": exc.message, "type": exc.error_type, "code": exc.code}}, + ) + + return router diff --git a/xtuner/v1/rl/grpo/__init__.py b/xtuner/v1/rl/grpo/__init__.py deleted file mode 100644 index cdcff8f252..0000000000 --- a/xtuner/v1/rl/grpo/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .loss import GRPOLossConfig, GRPOLossContext - - -__all__ = [ - "GRPOLossConfig", - "GRPOLossContext", -] diff --git a/xtuner/v1/rl/judger/__init__.py b/xtuner/v1/rl/judger/__init__.py new file mode 100644 index 0000000000..1f260f132c --- /dev/null +++ b/xtuner/v1/rl/judger/__init__.py @@ -0,0 +1,21 @@ +from .composed import ( + ComposedJudger, + ComposedJudgerConfig, + default_merge_fn, + default_select_fn, +) +from .dapo_math import DapoMathJudgerConfig +from .factory import ( + build_judger, +) +from .geo3k import GEO3KJudgerConfig +from .gsm8k import GSM8KJudgerConfig +from .native import ( + Judger, + JudgerConfig, + JudgerPool, + NativeJudger, + RayJudger, + RayJudgerProxy, + RemoteJudger, +) diff --git a/xtuner/v1/rl/judger/composed.py b/xtuner/v1/rl/judger/composed.py new file mode 100644 index 0000000000..027afc8ddc --- /dev/null +++ b/xtuner/v1/rl/judger/composed.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import Callable, TypeAlias + +from pydantic import BaseModel, ConfigDict, Field +from ray.util.placement_group import PlacementGroup + +from xtuner.v1.data_proto import RolloutState + +from .native import Judger, JudgerConfig + + +SelectedJudgerKeys: TypeAlias = str | list[str] | None +JudgerSelectFn: TypeAlias = Callable[[RolloutState, dict[str, Judger]], SelectedJudgerKeys] +JudgerMergeFn: TypeAlias = Callable[[RolloutState, dict[str, RolloutState]], RolloutState] + + +def default_select_fn(rollout_state: RolloutState, branches: dict[str, Judger]) -> SelectedJudgerKeys: + """Default branch selector for ``ComposedJudgerConfig``. + + Selection order is intentionally simple: + 1. If ``rollout_state.data_source`` is a string and matches a branch key, use it. + 2. Otherwise return ``None`` and let ``default_key`` or the single-branch fallback decide. + + Users with task-specific routing logic should pass a custom ``select_fn`` instead of extending + this default heuristic. + """ + data_source = rollout_state.data_source + if isinstance(data_source, str) and data_source in branches: + return data_source + + return None + + +def default_merge_fn(original: RolloutState, judged: dict[str, RolloutState]) -> RolloutState: + """Default merger for ``ComposedJudgerConfig``. + + This merger intentionally does not combine multiple judger scores into a single aggregated value. + It writes the merged reward as ``{branch_name: score}``, where ``branch_name`` is the selected + key from ``ComposedJudgerConfig.branches`` and ``score`` is taken from each child judger's + ``reward["score"]``. + + Users who need weighted sums, richer reward payloads, or custom post-processing should provide + their own ``merge_fn``. + """ + merged = original.model_copy(deep=True) + merged.reward = {} + + for name, state in judged.items(): + if state.reward is None or "score" not in state.reward: + raise KeyError(f"Default merge_fn requires reward['score'] for branch {name!r}.") + merged.reward[name] = state.reward["score"] + + return merged + + +class ComposedJudger(Judger): + def __init__( + self, + branches: dict[str, Judger], + select_fn: JudgerSelectFn = default_select_fn, + merge_fn: JudgerMergeFn = default_merge_fn, + default_key: str | None = "default", + ): + if not branches: + raise ValueError("ComposedJudger requires at least one branch.") + self.branches = branches + self.select_fn = select_fn + self.merge_fn = merge_fn + self.default_key = default_key + + def _resolve_selected_keys(self, rollout_state: RolloutState) -> list[str]: + selected = self.select_fn(rollout_state, self.branches) + + if selected is None: + selected_keys: list[str] = [] + elif isinstance(selected, str): + selected_keys = [selected] + else: + selected_keys = list(dict.fromkeys(selected)) + + if not selected_keys: + if self.default_key is not None and self.default_key in self.branches: + return [self.default_key] + if len(self.branches) == 1: + return [next(iter(self.branches))] + raise KeyError( + f"ComposedJudger could not select a branch for task_name={rollout_state.task_name!r}, " + f"data_source={rollout_state.data_source!r}, available={sorted(self.branches)}" + ) + return selected_keys + + async def judge(self, rollout_state: RolloutState) -> RolloutState: + selected_keys = self._resolve_selected_keys(rollout_state) + + judged: dict[str, RolloutState] = {} + for key in selected_keys: + if key not in self.branches: + raise KeyError(f"Unknown judger branch: {key}, available={sorted(self.branches)}") + judged[key] = await self.branches[key].judge(rollout_state.model_copy(deep=True)) + return self.merge_fn(rollout_state, judged) + + +class ComposedJudgerConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + branches: dict[str, JudgerConfigLike] + # ``select_fn`` chooses which branch keys should be executed for one sample. + # Return a single string for single-judger routing, a list of strings for multi-judger execution, + # or ``None`` to fall back to ``default_key`` / single-branch implicit fallback. + select_fn: JudgerSelectFn = Field(default=default_select_fn, exclude=True) + # ``merge_fn`` merges the judged rollout states back into one rollout state. + # The default implementation does not aggregate scores; it writes ``{branch_name: score}``. + merge_fn: JudgerMergeFn | None = Field(default=None, exclude=True) + default_key: str | None = "default" + + def get_num_placement_group_bundles(self) -> int: + return sum(branch.get_num_placement_group_bundles() for branch in self.branches.values()) + + def build(self, pg: PlacementGroup | None = None, start_bundle_idx: int = 0) -> Judger: + from .factory import build_judger + + return build_judger(self, pg=pg, start_bundle_idx=start_bundle_idx) + + +JudgerConfigLike: TypeAlias = JudgerConfig | ComposedJudgerConfig + +ComposedJudgerConfig.model_rebuild() diff --git a/xtuner/v1/ray/judger/dapo_math.py b/xtuner/v1/rl/judger/dapo_math.py similarity index 81% rename from xtuner/v1/ray/judger/dapo_math.py rename to xtuner/v1/rl/judger/dapo_math.py index e08af59bdd..827dd6758e 100644 --- a/xtuner/v1/ray/judger/dapo_math.py +++ b/xtuner/v1/rl/judger/dapo_math.py @@ -1,9 +1,9 @@ import re from typing import Any, Callable, List, Optional, Tuple -from pydantic import ConfigDict, Field +from pydantic import Field, model_validator -from .native import NativeJudgerConfig +from .native import JudgerConfig # Adapted from https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math_dapo.py @@ -291,8 +291,7 @@ def compute_reward(response, label, extra_info): return {"score": reward, "acc": out["acc"]} -class DapoMathJudgerConfig(NativeJudgerConfig): - model_config = ConfigDict(extra="forbid") +class DapoMathJudgerConfig(JudgerConfig): eos_token: List[str] | str enable_overlong_buffer: bool score: int = 1 @@ -301,62 +300,42 @@ class DapoMathJudgerConfig(NativeJudgerConfig): overlong_buffer_len: Optional[int] = None overlong_penalty_factor: Optional[float] = None tokenizer: Any = Field(default=None, exclude=True) - reward_func: Callable = Field(default=compute_reward, exclude=True) - - def __init__( - self, - judger_name: str, - eos_token: List[str] | str, - enable_overlong_buffer: bool, - max_response_len: Optional[int], - overlong_buffer_len: Optional[int], - overlong_penalty_factor: Optional[float], - tokenizer: Any, - score: int = 1, - format_score: int = 0, - ): - if isinstance(eos_token, str): - assert eos_token.strip() != "", "eos_token string must not be empty" - elif isinstance(eos_token, list): - assert all(isinstance(e, str) and e.strip() != "" for e in eos_token), ( + reward_handler: Callable | str = Field(default=compute_reward, exclude=True) + extra_info: dict = Field(default_factory=dict, exclude=True) + + @model_validator(mode="after") + def _pack_extra_info(self) -> "DapoMathJudgerConfig": + if isinstance(self.eos_token, str): + assert self.eos_token.strip() != "", "eos_token string must not be empty" + elif isinstance(self.eos_token, list): + assert all(isinstance(e, str) and e.strip() != "" for e in self.eos_token), ( "All eos_token list elements must be non-empty strings" ) - assert len(eos_token) > 0, "eos_token list must not be empty" + assert len(self.eos_token) > 0, "eos_token list must not be empty" else: raise TypeError("eos_token must be a non-empty string or a non-empty list of strings") - # 初始化基类 - super().__init__( - judger_name=judger_name, - eos_token=eos_token, - enable_overlong_buffer=enable_overlong_buffer, - score=score, - format_score=format_score, - max_response_len=max_response_len, - overlong_buffer_len=overlong_buffer_len, - overlong_penalty_factor=overlong_penalty_factor, - tokenizer=tokenizer, - ) - - self.extra_info.update( + self.extra_info.update( # type: ignore[attr-defined] { - "eos_token": eos_token, - "score": score, - "format_score": format_score, + "eos_token": self.eos_token, + "score": self.score, + "format_score": self.format_score, } ) - if enable_overlong_buffer: - assert max_response_len is not None - assert overlong_buffer_len is not None - assert overlong_penalty_factor is not None - assert tokenizer is not None - self.extra_info.update( + if self.enable_overlong_buffer: + assert self.max_response_len is not None, "max_response_len is required." + assert self.overlong_buffer_len is not None, "overlong_buffer_len is required." + assert self.overlong_penalty_factor is not None, "overlong_penalty_factor is required." + assert self.tokenizer is not None, "tokenizer is required." + self.extra_info.update( # type: ignore[attr-defined] { - "enable_overlong_buffer": enable_overlong_buffer, - "max_response_len": max_response_len, - "overlong_buffer_len": overlong_buffer_len, - "overlong_penalty_factor": overlong_penalty_factor, - "tokenizer": tokenizer, + "enable_overlong_buffer": self.enable_overlong_buffer, + "max_response_len": self.max_response_len, + "overlong_buffer_len": self.overlong_buffer_len, + "overlong_penalty_factor": self.overlong_penalty_factor, + "tokenizer": self.tokenizer, } ) + + return self diff --git a/xtuner/v1/rl/judger/factory.py b/xtuner/v1/rl/judger/factory.py new file mode 100644 index 0000000000..0cb0cdd997 --- /dev/null +++ b/xtuner/v1/rl/judger/factory.py @@ -0,0 +1,47 @@ +from ray.util.placement_group import PlacementGroup + +from .composed import ComposedJudger, ComposedJudgerConfig, JudgerConfigLike, default_merge_fn +from .native import Judger, JudgerConfig, JudgerPool + + +# +# Use ``JudgerConfig`` when one sample only needs one concrete judger implementation: +# one reward handler, one judger_name, and one execution mode (local or Ray actors). +# +# Use ``ComposedJudgerConfig`` when one sample may need to be routed to different child +# judgers by ``select_fn``, or when you want to run multiple child judgers and merge their +# outputs with ``merge_fn``. +# +def build_judger(config: JudgerConfigLike, pg: PlacementGroup | None = None, start_bundle_idx: int = 0) -> Judger: + if isinstance(config, ComposedJudgerConfig): + return _build_composite_judger(config, pg=pg, start_bundle_idx=start_bundle_idx) + return _build_replicated_judger(config, pg=pg, start_bundle_idx=start_bundle_idx) + + +def _build_replicated_judger(config: JudgerConfig, pg: PlacementGroup | None, start_bundle_idx: int) -> Judger: + if config.num_ray_actors == 0: + return config.build_local() + if config.num_ray_actors == 1: + return config._build_remote_judger(pg=pg, bundle_idx=start_bundle_idx) + return JudgerPool( + replicas=config._build_remote_judgers(pg=pg, start_bundle_idx=start_bundle_idx), + judger_name=config.judger_name, + ) + + +def _build_composite_judger( + config: ComposedJudgerConfig, + pg: PlacementGroup | None, + start_bundle_idx: int, +) -> Judger: + branches: dict[str, Judger] = {} + bundle_idx = start_bundle_idx + for key, branch_config in config.branches.items(): + branches[key] = build_judger(branch_config, pg=pg, start_bundle_idx=bundle_idx) + bundle_idx += branch_config.get_num_placement_group_bundles() + return ComposedJudger( + branches=branches, + select_fn=config.select_fn, + merge_fn=config.merge_fn or default_merge_fn, + default_key=config.default_key, + ) diff --git a/xtuner/v1/ray/judger/geo3k.py b/xtuner/v1/rl/judger/geo3k.py similarity index 90% rename from xtuner/v1/ray/judger/geo3k.py rename to xtuner/v1/rl/judger/geo3k.py index 71e5dd2592..3449c8852b 100644 --- a/xtuner/v1/ray/judger/geo3k.py +++ b/xtuner/v1/rl/judger/geo3k.py @@ -8,7 +8,7 @@ extract_boxed_content = None grade_answer = None -from .native import NativeJudgerConfig +from .native import JudgerConfig def format_reward(predict_str: str) -> float: @@ -35,9 +35,9 @@ def compute_reward(response, label, extra_info) -> dict: return {"score": score, "acc": acc} -class GEO3KJudgerConfig(NativeJudgerConfig): +class GEO3KJudgerConfig(JudgerConfig): """Configuration for the GEO3K judger.""" judger_name: str = "hiyouga/geometry3k" extra_info: dict = {"format_score": 0.1, "use_boxed": True} - reward_func: Callable = compute_reward + reward_handler: Callable | str = compute_reward diff --git a/xtuner/v1/ray/judger/gsm8k.py b/xtuner/v1/rl/judger/gsm8k.py similarity index 95% rename from xtuner/v1/ray/judger/gsm8k.py rename to xtuner/v1/rl/judger/gsm8k.py index 3a22d83783..c125c3f610 100644 --- a/xtuner/v1/ray/judger/gsm8k.py +++ b/xtuner/v1/rl/judger/gsm8k.py @@ -1,7 +1,7 @@ import re from typing import Callable -from .native import NativeJudgerConfig +from .native import JudgerConfig _SOLUTION_CLIP_CHARS = 300 @@ -77,9 +77,9 @@ def compute_reward(response, label, extra_info): return {"score": extra_info["format_score"]} -class GSM8KJudgerConfig(NativeJudgerConfig): +class GSM8KJudgerConfig(JudgerConfig): """Configuration for the GSM8K judger.""" judger_name: str = "openai/gsm8k" extra_info: dict = {"score": 1, "format_score": 0} - reward_func: Callable = compute_reward + reward_handler: Callable | str = compute_reward diff --git a/xtuner/v1/rl/judger/native.py b/xtuner/v1/rl/judger/native.py new file mode 100644 index 0000000000..d0b3699333 --- /dev/null +++ b/xtuner/v1/rl/judger/native.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import asyncio +import inspect +from abc import ABC, abstractmethod +from typing import Callable, TypeAlias, cast + +import httpx +from pydantic import BaseModel, ConfigDict, Field, model_validator +from ray.actor import ActorClass, ActorProxy +from ray.util.placement_group import PlacementGroup + +from xtuner.v1.data_proto.rl_data import RolloutState +from xtuner.v1.rl.utils import CPUActorLauncher +from xtuner.v1.utils.logger import get_logger +from xtuner.v1.utils.type_helper import ray_method + + +logger = get_logger() + + +class Judger(ABC): + @abstractmethod + async def judge(self, rollout_state: RolloutState) -> RolloutState: ... + + +class NativeJudger(Judger): + """Local judger implementation backed by a Python callable or HTTP + endpoint.""" + + def __init__( + self, + judger_name: str = "native_judger", + reward_handler: Callable | str | None = None, + extra_info: dict | None = None, + request_timeout: float = 30.0, + ): + self._judger_name = judger_name + self.extra_info = extra_info or {} + self.reward_handler = reward_handler + self.request_timeout = request_timeout + + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: # type: ignore[override] + assert rollout_state.response is not None, ( + "RolloutState must have a response for judging. You should detokenize the response_ids in AgentLoop" + ) + assert rollout_state.reward_model is not None and "ground_truth" in rollout_state.reward_model, ( + "RolloutState must have reward_model with 'ground_truth' for judging. You should set reward_model in AgentLoop" + ) + + input_kwargs = { + "response": rollout_state.response, + "label": rollout_state.reward_model["ground_truth"], + "extra_info": {**self.extra_info}, + } + + judger_response = None + if isinstance(self.reward_handler, str): + async with httpx.AsyncClient(timeout=self.request_timeout) as client: + response = await client.post(self.reward_handler, json=input_kwargs) + response.raise_for_status() + judger_response = response.json() + elif callable(self.reward_handler): + if inspect.iscoroutinefunction(self.reward_handler): + judger_response = await self.reward_handler(**input_kwargs) + else: + judger_response = self.reward_handler(**input_kwargs) + + assert judger_response is not None, "Reward handler did not return a response." + assert isinstance(judger_response, dict), ( + f"Reward handler must return a dict, but got {type(judger_response)}." + ) + rollout_state.reward = judger_response + return rollout_state + + def get_judger_name(self) -> str: + return self._judger_name + + +class RemoteJudger(Judger): + def __init__(self, actor: RayJudgerProxy, judger_name: str): + self.actor = actor + self._judger_name = judger_name + + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: # type: ignore[override] + return await self.actor.judge.remote(rollout_state) + + def get_judger_name(self) -> str: + return self._judger_name + + +class JudgerPool(Judger): + """Round-robin dispatch across replicas of the same judger type.""" + + def __init__(self, replicas: list[Judger], judger_name: str): + if not replicas: + raise ValueError("JudgerPool requires at least one replica.") + self.replicas = replicas + self._judger_name = judger_name + self._rr_index = 0 + self._lock = asyncio.Lock() + self._worker_loads = dict.fromkeys(range(len(replicas)), 0) + + async def _pick_replica(self) -> tuple[int, Judger]: + async with self._lock: + replica_idx = self._rr_index % len(self.replicas) + self._rr_index = (self._rr_index + 1) % len(self.replicas) + self._worker_loads[replica_idx] += 1 + return replica_idx, self.replicas[replica_idx] + + async def _release_replica(self, replica_idx: int) -> None: + async with self._lock: + self._worker_loads[replica_idx] -= 1 + + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: # type: ignore[override] + replica_idx, replica = await self._pick_replica() + try: + return await replica.judge(rollout_state) + finally: + await self._release_replica(replica_idx) + + def get_worker_status(self) -> dict[str, int]: + return {f"{self._judger_name}[{idx}]": load for idx, load in self._worker_loads.items()} + + def get_judger_name(self) -> str: + return self._judger_name + + +class JudgerConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + judger_name: str + reward_handler: Callable | str | None = Field(default=None, exclude=True) + request_timeout: float = 30.0 + extra_info: dict = Field(default_factory=dict, exclude=True) + num_ray_actors: int = Field(default=0, ge=0, description="0 means local mode, >0 means remote Ray actors.") + num_cpus_per_actor: int = Field(default=1, gt=0, description="CPU cores per remote judger actor.") + cpu_memory_per_actor: int = Field( + default=1024**3, gt=0, description="CPU memory in bytes per remote judger actor." + ) + + @model_validator(mode="after") + def _validate_ray_actor_config(self) -> JudgerConfig: + if self.num_ray_actors == 0: + if self.num_cpus_per_actor != 1 or self.cpu_memory_per_actor != 1024**3: + logger.warning( + "num_cpus_per_actor and cpu_memory_per_actor are ignored when Judger runs in local mode." + ) + return self + + def get_num_placement_group_bundles(self) -> int: + return self.num_ray_actors + + def get_cpu_bundles(self) -> list[dict[str, float | int]]: + return [ + { + "CPU": self.num_cpus_per_actor, + "memory": self.cpu_memory_per_actor, + } + for _ in range(self.get_num_placement_group_bundles()) + ] + + def build_local(self) -> Judger: + return NativeJudger( + judger_name=self.judger_name, + reward_handler=self.reward_handler, + request_timeout=self.request_timeout, + extra_info=self.extra_info, + ) + + def _build_remote_actor(self, pg: PlacementGroup | None = None, bundle_idx: int = 0) -> RayJudgerProxy: + return CPUActorLauncher.build_actor( + JudgerActor, + self, + pg=pg, + bundle_idx=bundle_idx, + actor_num_cpus=self.num_cpus_per_actor, + actor_memory=self.cpu_memory_per_actor, + ) + + def _build_remote_actors( + self, + pg: PlacementGroup | None = None, + start_bundle_idx: int = 0, + num_ray_actors: int | None = None, + ) -> list[RayJudgerProxy]: + actor_count = self.num_ray_actors if num_ray_actors is None else num_ray_actors + return CPUActorLauncher.build_actors( + JudgerActor, + self, + pg=pg, + start_bundle_idx=start_bundle_idx, + num_workers=actor_count, + actor_num_cpus_per_worker=self.num_cpus_per_actor, + actor_memory_per_worker=self.cpu_memory_per_actor, + ) + + def _build_remote_judger(self, pg: PlacementGroup | None = None, bundle_idx: int = 0) -> Judger: + return RemoteJudger(self._build_remote_actor(pg=pg, bundle_idx=bundle_idx), judger_name=self.judger_name) + + def _build_remote_judgers( + self, + pg: PlacementGroup | None = None, + start_bundle_idx: int = 0, + num_ray_actors: int | None = None, + ) -> list[Judger]: + return [ + RemoteJudger(actor, judger_name=self.judger_name) + for actor in self._build_remote_actors( + pg=pg, + start_bundle_idx=start_bundle_idx, + num_ray_actors=num_ray_actors, + ) + ] + + def build(self, pg: PlacementGroup | None = None, start_bundle_idx: int = 0) -> Judger: + from .factory import build_judger + + return build_judger(self, pg=pg, start_bundle_idx=start_bundle_idx) + + +class JudgerActor: + def __init__(self, judger_config: JudgerConfig): + self.judger = judger_config.build_local() + + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: + return await self.judger.judge(rollout_state) + + +RayJudger = cast(ActorClass[JudgerActor], CPUActorLauncher.to_actor_class(JudgerActor)) +RayJudgerProxy: TypeAlias = ActorProxy[JudgerActor] diff --git a/xtuner/v1/rl/loss/__init__.py b/xtuner/v1/rl/loss/__init__.py new file mode 100644 index 0000000000..550338e9e0 --- /dev/null +++ b/xtuner/v1/rl/loss/__init__.py @@ -0,0 +1,4 @@ +from .base_loss import BaseRLLossConfig, BaseRLLossContext, BaseRLLossKwargs, compute_kl_loss_weight +from .grpo_loss import GRPOLossConfig, GRPOLossContext, GRPOLossKwargs +from .loss_fn import check_config, get_policy_loss_fn, kl_penalty, pg_loss_fn, register_policy_loss, sft_loss_fn +from .oreal_loss import OrealLossConfig, OrealLossContext, OrealLossKwargs diff --git a/xtuner/v1/rl/base/loss.py b/xtuner/v1/rl/loss/base_loss.py similarity index 99% rename from xtuner/v1/rl/base/loss.py rename to xtuner/v1/rl/loss/base_loss.py index 8164538b9d..16849ce1d5 100644 --- a/xtuner/v1/rl/base/loss.py +++ b/xtuner/v1/rl/loss/base_loss.py @@ -7,10 +7,10 @@ from xtuner.v1.loss import BaseLossConfig, BaseLossKwargs from xtuner.v1.loss.base_loss_ctx import BaseLossContext from xtuner.v1.loss.utils import sp_gather, sp_split -from xtuner.v1.utils.device import get_device # from ..utils import sp_split -from .rollout_is import RolloutImportanceSampling +from xtuner.v1.rl.trainer.rollout_is import RolloutImportanceSampling +from xtuner.v1.utils.device import get_device DEVICE = get_device() diff --git a/xtuner/v1/rl/grpo/loss.py b/xtuner/v1/rl/loss/grpo_loss.py similarity index 98% rename from xtuner/v1/rl/grpo/loss.py rename to xtuner/v1/rl/loss/grpo_loss.py index ec34c0b156..e9aa607b55 100644 --- a/xtuner/v1/rl/grpo/loss.py +++ b/xtuner/v1/rl/loss/grpo_loss.py @@ -7,14 +7,14 @@ from xtuner.v1.utils import get_logger -from ..base import ( +from ..utils import gather_logprobs +from .base_loss import ( BaseRLLossConfig, BaseRLLossContext, BaseRLLossKwargs, compute_kl_loss_weight, ) -from ..loss_fn import get_policy_loss_fn, kl_penalty -from ..utils import gather_logprobs +from .loss_fn import get_policy_loss_fn, kl_penalty logger = get_logger() diff --git a/xtuner/v1/rl/loss_fn.py b/xtuner/v1/rl/loss/loss_fn.py similarity index 100% rename from xtuner/v1/rl/loss_fn.py rename to xtuner/v1/rl/loss/loss_fn.py diff --git a/xtuner/v1/rl/oreal/loss.py b/xtuner/v1/rl/loss/oreal_loss.py similarity index 98% rename from xtuner/v1/rl/oreal/loss.py rename to xtuner/v1/rl/loss/oreal_loss.py index 4d39ccf0cf..f115e7762b 100644 --- a/xtuner/v1/rl/oreal/loss.py +++ b/xtuner/v1/rl/loss/oreal_loss.py @@ -5,14 +5,14 @@ import torch.distributed as dist import torch.nn.functional as F -from ..base import ( +from ..utils import gather_logprobs +from .base_loss import ( BaseRLLossConfig, BaseRLLossContext, BaseRLLossKwargs, compute_kl_loss_weight, ) -from ..loss_fn import get_policy_loss_fn, kl_penalty, sft_loss_fn -from ..utils import gather_logprobs +from .loss_fn import get_policy_loss_fn, kl_penalty, sft_loss_fn class OrealLossConfig(BaseRLLossConfig): diff --git a/xtuner/v1/rl/oreal/__init__.py b/xtuner/v1/rl/oreal/__init__.py deleted file mode 100644 index 1a15ec3944..0000000000 --- a/xtuner/v1/rl/oreal/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .loss import OrealLossConfig, OrealLossContext - - -__all__ = [ - "OrealLossConfig", - "OrealLossContext", -] diff --git a/xtuner/v1/rl/replay_buffer.py b/xtuner/v1/rl/replay_buffer.py new file mode 100644 index 0000000000..6dabce99dd --- /dev/null +++ b/xtuner/v1/rl/replay_buffer.py @@ -0,0 +1,424 @@ +import asyncio +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields, replace +from itertools import count +from pathlib import Path +from typing import Any, List, TypeAlias, Union + +import pandas as pd +import torch +from pydantic import BaseModel, ConfigDict + +from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_group_status +from xtuner.v1.rl.utils import ( + BetweenNode, + ConditionNode, + LogicNode, + LogicOperator, + Operators, + QueryNode, + ScalarNode, + SetNode, + parse_query, +) +from xtuner.v1.utils import get_logger + + +logger = get_logger(__name__) + + +@dataclass +class StorageItem: + # 存储类型 + item: List[RolloutState] + uid: int + timestamp_id: int + task_name: str + status: Status + staleness: int + + +QUERY_KEYS = [f.name for f in fields(StorageItem)] +QueryKey = Union[str, LogicOperator] # str 是 StorageItem 的字段名,LogicOperator 是 "$and", "$or" 等逻辑操作符 + +# 查询类型: +QueryDict: TypeAlias = dict[ + QueryKey, + Union[ + Any, # 直接匹配值,例如: {"task_name": "math"} + dict[Operators, Any], # 操作符匹配,例如: {"uid": {"$gt": 10}} + List["QueryDict"], # 逻辑组合,例如: {"$and": [{"a": 1}, {"b": 2}]} + ], +] +QueryType = Union[QueryDict, QueryNode] + + +class StorageBackend(ABC): + @abstractmethod + async def put(self, item: StorageItem) -> int: ... + + @abstractmethod + async def get(self, query: QueryType) -> List[StorageItem]: ... + + @abstractmethod + async def count(self, query: QueryType) -> int: ... + + @abstractmethod + async def delete(self, uids: list[int]) -> None: ... + + @abstractmethod + def __len__(self) -> int: ... + + @abstractmethod + def state_dict(self) -> dict[str, Any]: ... + + @abstractmethod + def load_state_dict(self, state: dict[str, Any]) -> None: ... + + +class ReplayPolicy(ABC): + @abstractmethod + async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None: ... + + @abstractmethod + async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]: ... + + async def count(self, query: QueryType, storage_backend: StorageBackend) -> int: + return await storage_backend.count(query) + + +class NaiveStorage(StorageBackend): + def __init__(self): + self._uid_gen = count(1) + self._timestamp_id_gen = count(1) + self._items: dict[int, StorageItem] = {} + + async def put(self, item: StorageItem) -> int: + uid = next(self._uid_gen) + stored = replace(item, uid=uid, timestamp_id=next(self._timestamp_id_gen)) + self._items[uid] = stored + return uid + + def _evaluate(self, item: StorageItem, query_node: QueryNode) -> bool: + """NaiveStorage 实现的原生 Python 对象过滤树遍历.""" + if isinstance(query_node, LogicNode): + if not query_node.conditions: + return query_node.relation == "$and" + + if query_node.relation == "$and": + return all(self._evaluate(item, child) for child in query_node.conditions) + else: + return any(self._evaluate(item, child) for child in query_node.conditions) + + elif isinstance(query_node, ConditionNode): + if query_node.field not in QUERY_KEYS: + raise ValueError(f"查询字段错误: 找不到属性 '{query_node.field}'。可用属性为: {QUERY_KEYS}") + val = getattr(item, query_node.field) + + if isinstance(query_node, ScalarNode): + if query_node.op == "$eq": + return val == query_node.value + if query_node.op == "$ne": + return val != query_node.value + if query_node.op == "$gt": + return val > query_node.value + if query_node.op == "$gte": + return val >= query_node.value + if query_node.op == "$lt": + return val < query_node.value + if query_node.op == "$lte": + return val <= query_node.value + + elif isinstance(query_node, SetNode): + if query_node.op == "$in": + return val in query_node.value + if query_node.op == "$not_in": + return val not in query_node.value + + elif isinstance(query_node, BetweenNode): + return query_node.lower <= val <= query_node.upper + + return False + + async def get(self, query: QueryType) -> list[StorageItem]: + ast = parse_query(query) + return [item for item in self._items.values() if self._evaluate(item, ast)] + + async def count(self, query: QueryType) -> int: + ast = parse_query(query) + return sum(1 for item in self._items.values() if self._evaluate(item, ast)) + + async def delete(self, uids: list[int]) -> None: + if not uids: + return + for uid in uids: + self._items.pop(uid, None) + + def __len__(self) -> int: + return len(self._items) + + def state_dict(self) -> dict[str, Any]: + max_uid = max(self._items, default=0) + max_timestamp_id = max((item.timestamp_id for item in self._items.values()), default=0) + return { + "items": list(self._items.values()), + "next_uid": max_uid + 1, + "next_timestamp_id": max_timestamp_id + 1, + } + + def load_state_dict(self, state: dict[str, Any]) -> None: + items: list[StorageItem] = state["items"] + self._items = {item.uid: item for item in items} + self._uid_gen = count(state["next_uid"]) + self._timestamp_id_gen = count(state["next_timestamp_id"]) + + +class PandasStorage(StorageBackend): + def __init__(self): + self._uid_gen = count(1) + self._timestamp_id_gen = count(1) + self._df = pd.DataFrame(columns=["uid", "timestamp_id", "task_name", "status", "staleness", "item"]) + self._buffer: list[dict] = [] + + def _flush_buffer(self): + if self._buffer: + new_df = pd.DataFrame(self._buffer) + self._df = new_df if self._df.empty else pd.concat([self._df, new_df], ignore_index=True) + self._buffer.clear() + + async def put(self, item: StorageItem) -> int: + uid = next(self._uid_gen) + row = { + "uid": uid, + "timestamp_id": next(self._timestamp_id_gen), + "task_name": item.task_name, + "status": item.status, + "staleness": item.staleness, + "item": item.item, + } + self._buffer.append(row) + return uid + + def _evaluate(self, query_node: QueryNode, df: pd.DataFrame) -> pd.Series: + """PandasStorage 实现的向量化 DataFrame 过滤树遍历.""" + if isinstance(query_node, LogicNode): + if not query_node.conditions: + return ( + pd.Series(True, index=df.index) + if query_node.relation == "$and" + else pd.Series(False, index=df.index) + ) + + mask = self._evaluate(query_node.conditions[0], df) + for child in query_node.conditions[1:]: + child_mask = self._evaluate(child, df) + if query_node.relation == "$and": + mask = mask & child_mask + else: + mask = mask | child_mask + return mask + + elif isinstance(query_node, ConditionNode): + field = query_node.field + if field not in QUERY_KEYS: + raise ValueError(f"查询字段错误: 找不到属性 '{query_node.field}'。可用属性为: {QUERY_KEYS}") + series = df[query_node.field] + + if isinstance(query_node, ScalarNode): + if query_node.op == "$eq": + return series == query_node.value + if query_node.op == "$ne": + return series != query_node.value + if query_node.op == "$gt": + return series > query_node.value + if query_node.op == "$gte": + return series >= query_node.value + if query_node.op == "$lt": + return series < query_node.value + if query_node.op == "$lte": + return series <= query_node.value + + elif isinstance(query_node, SetNode): + if query_node.op == "$in": + return series.isin(query_node.value) + if query_node.op == "$not_in": + return ~series.isin(query_node.value) + + elif isinstance(query_node, BetweenNode): + return series.between(query_node.lower, query_node.upper) + else: + raise ValueError(f"不支持的查询节点类型: {type(query_node)}") + + async def get(self, query: QueryType) -> list[StorageItem]: + self._flush_buffer() + if self._df.empty: + return [] + + ast = parse_query(query) + filtered_df = self._df[self._evaluate(ast, self._df)] + return [ + StorageItem( + item=row["item"], + uid=row["uid"], + timestamp_id=row["timestamp_id"], + task_name=row["task_name"], + status=row["status"], + staleness=row["staleness"], + ) + for _, row in filtered_df.iterrows() + ] + + async def count(self, query: QueryType) -> int: + self._flush_buffer() + if self._df.empty: + return 0 + ast = parse_query(query) + return int(self._evaluate(ast, self._df).sum()) + + async def delete(self, uids: list[int]) -> None: + self._flush_buffer() + if not uids or self._df.empty: + return + self._df = self._df[~self._df["uid"].isin(uids)] + + def __len__(self) -> int: + return len(self._df) + len(self._buffer) + + def state_dict(self) -> dict[str, Any]: + self._flush_buffer() + max_uid = int(self._df["uid"].max()) if not self._df.empty else 0 + max_timestamp_id = int(self._df["timestamp_id"].max()) if not self._df.empty else 0 + return { + "df": self._df.copy(deep=True), + "next_uid": max_uid + 1, + "next_timestamp_id": max_timestamp_id + 1, + } + + def load_state_dict(self, state: dict[str, Any]) -> None: + self._df = state["df"].copy(deep=True) + self._buffer = [] + self._uid_gen = count(state["next_uid"]) + self._timestamp_id_gen = count(state["next_timestamp_id"]) + + +class FIFOReplayPolicy(ReplayPolicy): + async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None: + if not item.item: + return + await storage_backend.put(item) + + async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]: + if count <= 0: + return [] + records = await storage_backend.get(query) + records.sort(key=lambda r: r.timestamp_id) + selected = records[:count] + if selected: + await storage_backend.delete([record.uid for record in selected]) + return [record.item for record in selected] + + +class StalenessReplayPolicy(ReplayPolicy): + async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None: + if not item.item: + return + await storage_backend.put(item) + + async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]: + if count <= 0: + return [] + + records = await storage_backend.get(query) + records.sort(key=lambda r: (-r.staleness, r.timestamp_id)) + selected = records[:count] + if selected: + await storage_backend.delete([record.uid for record in selected]) + return [record.item for record in selected] + + async def count(self, query: QueryType, storage_backend: StorageBackend) -> int: + return await storage_backend.count(query) + + +class ReplayBuffer: + _SAVE_PATH = "replay_buffer.pth" + + def __init__(self, policy: ReplayPolicy, storage_backend: StorageBackend): + self._policy = policy + self._storage = storage_backend + self._lock = asyncio.Lock() + + async def put(self, items: list[RolloutState], task_name: str) -> None: + if not items: + return + storage_item = StorageItem( + item=items, + uid=0, + timestamp_id=0, + task_name=task_name, + status=update_group_status(items), + staleness=max(item.seq_staleness for item in items), + ) + async with self._lock: + await self._policy.put(storage_item, self._storage) + + async def get(self, batch_size: int, task_name: str, group_status: Status) -> list[list[RolloutState]]: + # 使用 DSL 字典进行查询 + query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]} + async with self._lock: + return await self._policy.get(batch_size, query_dsl, self._storage) + + async def count(self, task_name: str, group_status: Status) -> int: + # 使用 DSL 字典进行查询 + query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]} + async with self._lock: + return await self._policy.count(query_dsl, self._storage) + + def __len__(self) -> int: + return len(self._storage) + + async def save(self, path: str | Path) -> None: + file_path = Path(path) + file_path.parent.mkdir(parents=True, exist_ok=True) + replay_buffer_path = file_path / self._SAVE_PATH + async with self._lock: + state = { + "policy": type(self._policy).__name__, + "storage": type(self._storage).__name__, + "storage_state": self._storage.state_dict(), + } + await asyncio.to_thread(torch.save, state, replay_buffer_path) + logger.info(f"Replay buffer saved to {replay_buffer_path}") + + async def resume(self, path: str | Path) -> None: + if len(self._storage) > 0: + raise RuntimeError("Cannot resume into a non-empty buffer") + + file_path = Path(path) + replay_buffer_path = file_path / self._SAVE_PATH + state = await asyncio.to_thread(torch.load, replay_buffer_path, map_location="cpu", weights_only=False) + if state["policy"] != type(self._policy).__name__: + raise ValueError(f"Replay policy mismatch: expected {type(self._policy).__name__}, got {state['policy']}") + + if state["storage"] != type(self._storage).__name__: + raise ValueError( + f"Storage backend mismatch: expected {type(self._storage).__name__}, got {state['storage']}" + ) + + async with self._lock: + self._storage.load_state_dict(state["storage_state"]) + logger.info(f"Replay buffer resumed from {replay_buffer_path}") + + +class SyncReplayBufferConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + def build(self): + return ReplayBuffer(policy=FIFOReplayPolicy(), storage_backend=NaiveStorage()) + + +class AsyncReplayBufferConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + def build(self): + policy = StalenessReplayPolicy() + return ReplayBuffer(policy=policy, storage_backend=NaiveStorage()) diff --git a/xtuner/v1/ray/rollout/__init__.py b/xtuner/v1/rl/rollout/__init__.py similarity index 75% rename from xtuner/v1/ray/rollout/__init__.py rename to xtuner/v1/rl/rollout/__init__.py index f09429134a..349cd2fad7 100644 --- a/xtuner/v1/ray/rollout/__init__.py +++ b/xtuner/v1/rl/rollout/__init__.py @@ -1,6 +1,6 @@ import os -from .controller import RolloutController, SampleParams +from .controller import RolloutController from .worker import RolloutWorker @@ -10,3 +10,5 @@ from .vllm import vLLMWorker if os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": from .lmdeploy import LMDeployWorker + +from .utils import continue_generation, pause_generation diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py new file mode 100644 index 0000000000..df00323449 --- /dev/null +++ b/xtuner/v1/rl/rollout/controller.py @@ -0,0 +1,468 @@ +import asyncio +import os +import threading +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, TypedDict +from uuid import uuid4 + +import ray +from ray.actor import ActorProxy +from ray.util.placement_group import PlacementGroup + +from transformers import AutoTokenizer +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.utils import AutoAcceleratorWorkers +from xtuner.v1.utils import get_logger + +from .parser.factory import build_reasoning_parser, build_tool_call_parser +from .parser.reasoning_parser import ReasoningParser +from .parser.tool_parser import ToolCallParser +from .utils import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthChecker, SessionRouter +from .worker import RolloutConfig, RolloutWorker + + +if TYPE_CHECKING: + from xtuner.v1.rl.gateway.config import GatewayConfig + + +@dataclass +class WorkerInfo: + """A data class to hold all state information for a single worker.""" + + actor: RolloutWorker + url: str + is_active: bool = True + + +class RolloutWorkerMetadata(TypedDict): + """Metadata for rollout workers and their configuration. + + This data structure encapsulates all necessary information about the rollout worker infrastructure, including + engine topology, server addresses, and worker status. Used for communication between training processes and rollout + workers. + """ + + # 推理引擎的拓扑结构,每个子列表代表一个推理引擎包含的所有 worker ranks + # 例如:[[0, 1, 2, 3], [4, 5, 6, 7]] 表示有 2 个推理引擎,每个引擎包含 4 个 workers + # 用于确定分布式推理的并行组划分 + engine_rank_mesh_array: List[List[int]] + + # worker rank 到服务器 URL 的映射字典,用于训练进程与 rollout workers 通信 + # 键:worker 的 rank ID(字符串形式的整数) + # 值:对应的服务器地址列表(通常每个 rank 对应一个 URL) + server_url_dict: Dict[str, List[str]] + + # Rollout 配置对象,包含推理引擎的所有配置参数 + # 包括:并行策略(TP/EP)、超时设置、后端类型(LMDeploy/vLLM/SGLang)等 + rollout_config: RolloutConfig + + # 每个 worker 服务器 URL 的当前活跃状态 + # 键:服务器 URL 字符串 + # 值:布尔值,True 表示该 worker 处于活跃状态,False 表示已失效或停用 + worker_server_urls_status: Dict[str, bool] + + # Gateway HTTP server URL (e.g. "http://1.2.3.4:8080"). + # Set after start_gateway() is called; None if the gateway has not been started. + api_server_url: Optional[str] + + +class RolloutController: + """Controller for managing and coordinating multiple RolloutWorker + actors.""" + + def __init__( + self, + infer_config: RolloutConfig, + placement_group: PlacementGroup, + ): + """Initialize the RolloutController. + + Args: + infer_config (RolloutConfig): The configuration for the rollout. + placement_group (PlacementGroup): The placement group for the + RolloutWorker actors. + """ + self.config = infer_config + self.num_gpus_per_engine = ( + self.config.expert_parallel_size + if self.config.expert_parallel_size > 1 + else self.config.tensor_parallel_size + ) + self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController") + self.engine_rank_mesh_array: List[List[int]] = [] + self.worker_server_urls_map: dict[str, List[str]] = {} + self.rank2info: dict[int, WorkerInfo] = {} + self.engine_rank_mesh_array, self.worker_server_urls_map, self.rank2info = self._init_workers(placement_group) + self.num_active_workers = len(self.rank2info) + self.worker_info_lock = threading.RLock() + # The timeout for the environment to wait for the rollout controller's response. + # This should be longer than the controller's internal timeout (`rollout_timeout`) + # to account for potential queuing delays and other overheads. + self.timeout_multiplier = 2.0 + self.router = SessionRouter(self.rank2info, worker_infos_lock=self.worker_info_lock) + self.health_checker = RolloutHealthChecker( + config=self.config, + workers_info=self.rank2info, + worker_infos_lock=self.worker_info_lock, + ) + self.health_checker.start() + self._tool_call_parser, self._reasoning_parser = self._build_output_parsers() + self._gateway_url: str | None = None + + def start_gateway(self, config: "GatewayConfig") -> str: + """Start the gateway HTTP server in a daemon thread and return its URL. + + The gateway exposes OpenAI-compatible endpoints that forward requests to + this controller via :class:`~xtuner.v1.rl.gateway.backend.local_backend.LocalRolloutBackend`. + Agent loops (e.g. CamelAgentLoop) discover the URL via :meth:`get_rollout_metadata`. + + Args: + config: Gateway configuration. ``port`` and ``host`` control where + the server binds; ``capture_folder`` enables per-request trace files. + + Returns: + The base URL of the gateway, e.g. ``"http://1.2.3.4:8080"``. + """ + from xtuner.v1.rl.gateway import build_local_gateway_app, serve_gateway_in_thread + + config.capture_folder = str(Path(self.config.worker_log_dir) / config._CAPTURE_PATH_FOLDER) + app = build_local_gateway_app(ray.get_runtime_context().current_actor, config=config) + serve_gateway_in_thread(app, config) + node_ip = ray.util.get_node_ip_address() + url = f"http://{node_ip}:{config.port}" + self._gateway_url = url + self.logger.info(f"Gateway server started at {url}, capture_folder: {config.capture_folder}") + return url + + def get_rollout_metadata(self) -> RolloutWorkerMetadata: + """Get information about the current rollout setup. + + Returns: + dict: A dictionary containing the engine mesh list, server URL + dictionary, and the rollout configuration. + """ + with self.worker_info_lock: + worker_server_urls_status = {info.url: info.is_active for info in self.rank2info.values()} + rollout_metadata: RolloutWorkerMetadata = { + "engine_rank_mesh_array": self.engine_rank_mesh_array, + "server_url_dict": self.worker_server_urls_map, + "rollout_config": self.config, + "worker_server_urls_status": worker_server_urls_status, + "api_server_url": self._gateway_url, + } + return rollout_metadata + + def _build_output_parsers(self) -> tuple[ToolCallParser | None, ReasoningParser | None]: + tool_call_parser = None + reasoning_parser = None + + if self.config.tool_call_parser != "none": + tool_call_parser = build_tool_call_parser(self.config.tool_call_parser) + + if self.config.reasoning_parser != "none": + tokenizer_path = self.config.tokenizer_path or self.config.model_path + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + reasoning_parser = build_reasoning_parser(self.config.reasoning_parser, tokenizer) + + return tool_call_parser, reasoning_parser + + def get_ready_status(self) -> tuple[bool, dict[str, Any]]: + with self.worker_info_lock: + active_workers = sum(1 for info in self.rank2info.values() if info.is_active) + total_workers = len(self.rank2info) + return active_workers > 0, { + "active_workers": active_workers, + "total_workers": total_workers, + } + + async def generate(self, rollout_state: RolloutState) -> RolloutState: + session_id = rollout_state.session_uid if rollout_state.session_uid else uuid4().int + worker = await self.router.get_worker(session_id) + if worker is None: + rollout_state.status = Status.FAILED + rollout_state.error_msg = "No active rollout worker available." + return rollout_state + + response_ref = worker.generate.remote(rollout_state=rollout_state) # type: ignore[attr-defined] + try: + response_rollout_state = await asyncio.wait_for( + response_ref, timeout=self.config.rollout_timeout * self.timeout_multiplier + ) + self._apply_output_parsers(response_rollout_state) + return response_rollout_state + except asyncio.TimeoutError: + self.logger.error(f"Rollout timeout for worker {worker}. Skipping sample.") + rollout_state.status = Status.FAILED + rollout_state.error_msg = ( + f"Rollout request timed out after {self.config.rollout_timeout * self.timeout_multiplier} seconds." + ) + return rollout_state + + def _apply_output_parsers(self, rollout_state: RolloutState) -> None: + """Apply tool-call and reasoning parsers to the rollout state in- + place.""" + if self._tool_call_parser is not None: + parsed = self._tool_call_parser.parse(rollout_state) + rollout_state.tool_calls = parsed.tool_calls + rollout_state.response = parsed.remaining_text or None + if self._reasoning_parser is not None: + parsed_reasoning = self._reasoning_parser.parse(rollout_state) + rollout_state.response = parsed_reasoning.remaining_text + if parsed_reasoning.reasoning_text: + rollout_state.extra_fields["reasoning_text"] = parsed_reasoning.reasoning_text + else: + rollout_state.extra_fields.pop("reasoning_text", None) + + def pause_generation(self): + self.health_checker.pause() + + def continue_generation(self): + self.health_checker.resume() + self._broadcast_to_active_workers("continue_generation") + + def offload(self): + self._broadcast_to_active_workers("offload") + + def onload(self): + self._broadcast_to_active_workers("onload_weights") + self._broadcast_to_active_workers("onload_kvcache") + + def onload_weights(self): + self._broadcast_to_active_workers("onload_weights") + + def onload_kvcache(self): + self._broadcast_to_active_workers("onload_kvcache") + + def shutdown(self): + """Shuts down all active rollout workers. + + Args: + block (bool): Whether to block until the operation completes. + """ + self.health_checker.stop() + self._broadcast_to_active_workers("shutdown") + + def recover_failed_workers(self): + """Recovers from worker failures by restarting failed workers and + reinitializing the rollout setup.""" + self.health_checker.pause() + with self.worker_info_lock: + failed_workers = [info for info in self.rank2info.values() if not info.is_active] + if not failed_workers: + self.logger.info("No failed workers detected during recovery.") + return + + self.logger.warning(f"Detected {len(failed_workers)} failed workers. Initiating recovery process.") + for worker in failed_workers: + if self._restart_failed_workers(worker.actor): + with self.worker_info_lock: + rank = self._get_rank_by_actor(worker.actor) + if rank is not None: + self.rank2info[rank].is_active = True + self.health_checker.resume() + + def _restart_failed_workers(self, worker: RolloutWorker) -> bool: + try: + dist_init_addr = ray.get(worker.init_dist_port.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + _, url = ray.get(worker.init.remote(dist_init_addr), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + is_healthy = ray.get(worker.check_health.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + if is_healthy: + self.logger.info(f"Successfully restarted worker {worker} with URL {url}.") + return True + else: + self.logger.error(f"Worker {worker} is still unhealthy after restart.") + return False + except Exception as e: + self.logger.error(f"Failed to restart worker: {e}") + return False + + def _update_dist_init_addr(self, nodes_per_engine, server_urls_per_engine, dist_init_addrs, tp_size): + """Update the distributed initialization addresses for workers. + + This is used to group workers that belong to the same inference engine. + + Args: + nodes_per_engine (int): The number of nodes per inference engine. + server_urls_per_engine (int): The number of server urls per inference engine. + dist_init_addrs (list): The list of initial addresses. + tp_size (int): The tensor parallel size. + + Returns: + list: The updated list of distributed initialization addresses. + """ + # lmdeploy pytorch ep: server_urls_per_engine > 1 + # sglang cross node engine: nodes_per_engine > 1 + assert server_urls_per_engine == 1 or nodes_per_engine == 1 + if nodes_per_engine > 1: + index = list(range(0, self.num_active_workers + 1, tp_size)) + [self.num_active_workers] + for i in range(1, len(index)): + dist_init_addrs[index[i - 1] : index[i]] = [dist_init_addrs[index[i - 1]]] * (index[i] - index[i - 1]) + if server_urls_per_engine > 1: + activate_servers = len(dist_init_addrs) + for i in range(0, activate_servers, server_urls_per_engine): + dist_init_addrs[i : i + server_urls_per_engine] = [dist_init_addrs[i]] * server_urls_per_engine + return dist_init_addrs + + def _get_active_servers_count(self, infer_config: RolloutConfig, gpu_nums: int): + """Calculate the number of active servers and nodes per engine. + + This calculation depends on the inference backend and parallelism settings. + + Args: + infer_config (RolloutConfig): The rollout configuration. + gpu_nums (int): The total number of GPUs available. + + Returns: + Tuple[int, int]: A tuple containing the number of active servers + and the number of nodes per engine. + """ + # NOTE:Since different inference engines have different launch methods, + # the number of nodes contained in each engine is not consistent. + # For example: sglang requires starting an inference engine for each node, + # while lmdeploy and vllm does not. Therefore, we calculate the number + # of active servers based on the configuration. + support_cross_node_comm = infer_config.rollout_cross_node_comm + gpus_per_node = infer_config.gpus_per_node + nodes_per_engine = ( + 1 + if support_cross_node_comm or self.num_gpus_per_engine < gpus_per_node + else self.num_gpus_per_engine // gpus_per_node + ) + + active_servers_count = int( + (gpu_nums // self.num_gpus_per_engine) * nodes_per_engine * infer_config.server_urls_per_engine + ) + return active_servers_count, nodes_per_engine + + def _broadcast_to_active_workers(self, method_name: str): + """Helper function to call a method on all active workers. + + Args: + method_name (str): The name of the method to call. + block (bool): Whether to block until the call completes. + + Returns: + A list of futures if `block` is False, otherwise a list of results. + """ + futures = [] + with self.worker_info_lock: + active_actors = [info.actor for info in self.rank2info.values() if info.is_active] + futures = [getattr(actor, method_name).remote() for actor in active_actors] + results = ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT) + return results + + def _get_worker_cls(self): + if os.environ.get("XTUNER_USE_LMDEPLOY") == "1": + from .lmdeploy import LMDeployWorker + + return ray.remote(LMDeployWorker) + elif os.environ.get("XTUNER_USE_VLLM") == "1": + from .vllm import vLLMWorker + + return ray.remote(vLLMWorker) + elif os.environ.get("XTUNER_USE_SGLANG") == "1": + from .sglang import SGLangWorker + + return ray.remote(SGLangWorker) + else: + raise NotImplementedError( + "Rollout backend is not supported." + "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM" + " or XTUNER_USE_SGLANG environment variable." + ) + + def _get_rank_by_actor(self, actor: RolloutWorker) -> Optional[int]: + """Get rank by actor object. + + Args: + actor: The RolloutWorker actor object. + + Returns: + The rank of the worker, or None if not found. + """ + for rank, info in self.rank2info.items(): + if info.actor == actor: + return rank + return None + + def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map): + """Update the list of active rollout workers and their server URLs. + + When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with + tp_rank=0 in each engine is responsible for receiving input data. Other tp_ranks do not accept input. + Therefore, this function updates active_rollout_workers and worker_server_urls_map to keep only the tp_rank=0 + workers and their corresponding URLs. + """ + if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node: + return active_rollout_workers, worker_server_urls_map + else: + active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node + active_rank = list(worker_server_urls_map.keys())[::active_worker_interval] + active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] + return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) + + def _init_workers(self, placement_group: PlacementGroup): + """Initializes and configures the pool of RolloutWorker actors. + + This method creates workers from the placement group, configures distributed + inference engines by grouping workers, where each group forms a tensor-parallel + inference engine. It determines the `active_workers` to act as the head of each + engine, constructs the `engine_rank_mesh_array` to define engine topology, + acquires necessary distributed communication ports, and finally launches servers + on the `active_workers` to get their addresses. + + Returns: + Tuple[List, Dict]: A tuple where the first element is + `engine_rank_mesh_array`, a list of lists containing the ranks of workers + in each engine, and the second element is `worker_server_urls_map`, + a dictionary mapping the rank of each active worker to its + corresponding server URL. + """ + # Create workers from placement group + workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( + self._get_worker_cls(), self.config, placement_group + ) + active_servers_count, nodes_per_engine = self._get_active_servers_count(self.config, len(workers)) + interval = len(workers) // active_servers_count + active_rollout_workers = workers[::interval] + server_urls_per_engine = self.config.server_urls_per_engine + + set_bundle_idxs_objectref = [] + engine_rank_mesh_array = [] + activate_worker_idx = 0 + for active_worker in active_rollout_workers: + head_rank, _ = rank_bundle_idx_list[activate_worker_idx] + engine_workers_meta = rank_bundle_idx_list[head_rank : head_rank + interval] + engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] # meta: (rank, bundle_idx) + set_bundle_idxs_objectref.append(active_worker._set_engine_bundle_idxs.remote(engine_bundle_idxs)) # type: ignore[attr-defined] + engine_rank_mesh_array.append([meta[0] for meta in engine_workers_meta]) + activate_worker_idx += interval + ray.get(set_bundle_idxs_objectref) + # set engine mesh list for each worker + ray.get( + [worker._set_engine_rank_mesh_array.remote(engine_rank_mesh_array) for worker in active_rollout_workers] + ) # type: ignore[attr-defined] + # init dist_init_addr for each worker according to parallel settings + init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined] + dist_init_addrs = self._update_dist_init_addr( + nodes_per_engine, server_urls_per_engine, init_dist_init_addrs, self.num_gpus_per_engine + ) + # launch rollout servers + worker_server_urls_map = dict( # rank -> url + ray.get([worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)]) + ) + active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map( + active_rollout_workers, worker_server_urls_map + ) + workers_info = {} + for i in range(len(active_rollout_workers)): + rank = list(worker_server_urls_map.keys())[i] + url = worker_server_urls_map[rank] + workers_info[rank] = WorkerInfo(actor=active_rollout_workers[i], url=url) + self.logger.info(f"Rollout worker server URLs: {[info.url for info in workers_info.values()]}") + return engine_rank_mesh_array, worker_server_urls_map, workers_info + + +RayRolloutController = ray.remote(RolloutController) +RolloutControllerProxy: TypeAlias = ActorProxy[RayRolloutController] diff --git a/xtuner/v1/ray/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py similarity index 75% rename from xtuner/v1/ray/rollout/lmdeploy.py rename to xtuner/v1/rl/rollout/lmdeploy.py index 27bfb8282f..3d8acd3115 100644 --- a/xtuner/v1/ray/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -1,17 +1,16 @@ -import copy import os from argparse import Namespace from itertools import chain -from typing import Any, Dict, List, Union +from typing import List import ray import requests from ray.util.placement_group import placement_group_table from transformers import AutoTokenizer -from xtuner.v1.ray.config import RolloutConfig +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams -from .worker import RolloutWorker +from .worker import RolloutConfig, RolloutWorker def run_lmdeploy_server_wrapper(lmdeploy_config_namespace: Namespace): @@ -75,81 +74,69 @@ def __init__( self.model_name = self.config.model_name self.enable_return_routed_experts = self.config.enable_return_routed_experts - async def _create_request( - self, - url: str, - prompt: Union[str, List[Dict[str, Any]]] | None, - input_ids: List[int] | None, - tools: List, # reserved for agent tool use - tool_choice: str, # reserved for agent tool use - sample_params: dict, - extra_params: dict, - extra_info: dict, - ): - """Create and send a streaming generation request to the server. + def offload(self): + """Offloads the model weights and KV cache.""" + return self._sleep(level=2) - Args: - url (str): The URL of the generation endpoint. - prompt (List[Dict[str, str]]): The input prompt for generation, - formatted as a list of messages. - tools (List): A list of tools the model can call. - tool_choice (str): The tool choice strategy. - sample_params (dict): Parameters for sampling. Defaults to {}. - extra_params (dict): Extra parameters for the request. - Defaults to {}. + def onload_weights(self): + """Onloads the model weights by waking up the model.""" + return self._wake_up(tags=["weights"]) - Returns: - An httpx.Response object for streaming the response. - """ - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_keys}", # 如果需要鉴权 - } - payload = { - "model": self.model_name, - "tools": tools if len(tools) > 0 else None, - "tool_choice": tool_choice if tool_choice else None, - } - if "return_token_ids" in extra_params and extra_params["return_token_ids"]: - if "image_data" in extra_info: - assert input_ids is not None, "input_ids is required when image_data is provided." - - if input_ids is not None: - payload["input_ids"] = input_ids - if "image_data" in extra_info: - payload["image_data"] = extra_info["image_data"] + def onload_kvcache(self): + """Onloads the KV cache by waking up the model.""" + return self._wake_up(tags=["kv_cache"]) + + def _get_request_payload(self, rollout_state: RolloutState) -> dict: + tools = rollout_state.tools + tool_choice = rollout_state.tool_choice + sample_params = rollout_state.sample_params + message = rollout_state.message + input_tokens = rollout_state.tokens + + optional_fields: dict[str, object] = {} + if tools is not None: + optional_fields["tools"] = tools + if tool_choice is not None: + optional_fields["tool_choice"] = tool_choice + + if sample_params.return_token_ids: + payload = {"model": self.model_name, **optional_fields} + + if "image_data" in rollout_state.extra_fields: + assert input_tokens is not None, "input_tokens is required when image_data is provided." + payload["image_data"] = rollout_state.extra_fields["image_data"] + + if input_tokens is not None: + payload["input_ids"] = input_tokens else: - text_prompt = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + text_prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) prompt_token_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] payload["input_ids"] = prompt_token_ids + sample_params.return_routed_experts = True if self.enable_return_routed_experts else False + lmdeploy_sample_params = self._transform_sample_params(sample_params) + payload.update(lmdeploy_sample_params) else: - payload["messages"] = prompt - - if "partial_rollout_input_ids" in extra_info: - assert "return_token_ids" in extra_params and extra_params["return_token_ids"], ( - "concat response_ids and input_ids is only compatible with return_token_ids=True." - ) - payload["input_ids"] = extra_info["partial_rollout_input_ids"] - assert len(payload["input_ids"]) <= self.config.context_length, ( - f"Total input length {len(payload['input_ids'])} exceeds context length {self.config.context_length}." - ) - - if self.enable_return_routed_experts: - extra_params["return_routed_experts"] = True - - lmdeploy_sample_params = self._transform_sample_params(sample_params, extra_params) - payload.update(lmdeploy_sample_params) - return await self._safe_post_request(url, headers, payload) - - def get_logprobs(self, input_ids, sampling_params): - """This method will be implemented for the LMDeploy worker in the - future.""" - pass - - def generate(self, input_ids, sampling_params): - """This method will be implemented for the LMDeploy worker in the - future.""" - pass + payload = { + "model": self.model_name, + "messages": rollout_state.message, + **optional_fields, + } + lmdeploy_sample_params = { + "temperature": sample_params.temperature, + "top_p": sample_params.top_p, + "n": sample_params.n, + "stream": sample_params.stream, + "max_tokens": sample_params.max_tokens, + "repetition_penalty": sample_params.repetition_penalty, + "top_k": sample_params.top_k, + "skip_special_tokens": sample_params.skip_special_tokens, + } + if sample_params.stops: + lmdeploy_sample_params["stop"] = sample_params.stops + if sample_params.min_tokens > 0: + lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens + payload.update(lmdeploy_sample_params) + return payload def _sleep(self, level: int = 1): """Put the model into a sleep state to save resources. @@ -167,11 +154,7 @@ def _sleep(self, level: int = 1): assert response.status_code == 200, response.status_code return response.text - def offload(self): - """Offloads the model weights and KV cache.""" - return self._sleep(level=2) - - def wake_up(self, tags: List[str] | None = None): + def _wake_up(self, tags: List[str] | None = None): """Wakes up the model from a sleep state. Args: @@ -188,26 +171,6 @@ def wake_up(self, tags: List[str] | None = None): assert response.status_code == 200, response.status_code return response.text - def onload_weights(self): - """Onloads the model weights by waking up the model.""" - return self.wake_up(tags=["weights"]) - - def onload_kvcache(self): - """Onloads the KV cache by waking up the model.""" - return self.wake_up(tags=["kv_cache"]) - - def pause_generation(self): - """It will implemented for LMDeploy worker in the future.""" - pass - - def continue_generation(self): - """It will implemented for LMDeploy worker in the future.""" - pass - - def reset_prefix_cache(self): - """It will implemented for LMDeploy worker in the future.""" - pass - def _transform_rollout_config_to_server_configs(self) -> Namespace: """Transform the RolloutConfig into a Namespace suitable for the LMDeploy server. @@ -238,7 +201,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: # Therefore, each server only needs to handle 1 / dp_size of the total requests max_batch_size = self.config.rollout_max_batch_size_per_instance // dp_size distributed_executor_backend = lmdeploy_config_kwargs.get("distributed_executor_backend", "ray") - lmdeploy_config_kwargs["log_level"] = lmdeploy_config_kwargs.pop("log_level", "WARNING") + lmdeploy_config_kwargs["log_level"] = lmdeploy_config_kwargs.pop("log_level", "ERROR") lmdeploy_config_kwargs["uvicorn_log_level"] = lmdeploy_config_kwargs.pop("uvicorn_log_level", "ERROR") lmdeploy_config_kwargs["tm_log_level"] = lmdeploy_config_kwargs.pop("tm_log_level", "ERROR") @@ -368,8 +331,5 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: **lmdeploy_config_kwargs, ) - def _transform_sample_params(self, sample_params: Dict, extra_params: Dict = {}): - lmdeploy_sample_params = copy.deepcopy(sample_params) - if extra_params: - lmdeploy_sample_params.update(extra_params) - return lmdeploy_sample_params + def _transform_sample_params(self, sample_params: SampleParams) -> dict: + return sample_params.model_dump(exclude_none=True) diff --git a/xtuner/v1/rl/rollout/parser/__init__.py b/xtuner/v1/rl/rollout/parser/__init__.py new file mode 100644 index 0000000000..a7c3cdf594 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/__init__.py @@ -0,0 +1,19 @@ +from .factory import build_reasoning_parser, build_tool_call_parser +from .qwen3_reasoning_parser import Qwen3ReasoningParser +from .qwen3_tool_parser import Qwen3ToolCallParser +from .qwen3p5_tool_parser import Qwen3p5ToolCallParser +from .reasoning_parser import ParsedReasoningResult, ReasoningParser +from .tool_parser import ParsedToolCallResult, ToolCallParser + + +__all__ = [ + "ParsedReasoningResult", + "ParsedToolCallResult", + "Qwen3ReasoningParser", + "Qwen3p5ToolCallParser", + "Qwen3ToolCallParser", + "ReasoningParser", + "ToolCallParser", + "build_reasoning_parser", + "build_tool_call_parser", +] diff --git a/xtuner/v1/rl/rollout/parser/factory.py b/xtuner/v1/rl/rollout/parser/factory.py new file mode 100644 index 0000000000..86cf37e4ce --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/factory.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Literal + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from .qwen3_reasoning_parser import Qwen3ReasoningParser, extract_qwen3_reasoning_strip_tokens +from .qwen3_tool_parser import Qwen3ToolCallParser +from .qwen3p5_tool_parser import Qwen3p5ToolCallParser +from .reasoning_parser import ReasoningParser +from .tool_parser import ToolCallParser + + +ToolCallParserName = Literal["none", "qwen3", "qwen3p5"] +ReasoningParserName = Literal["none", "qwen3"] + + +def build_tool_call_parser(parser_name: ToolCallParserName) -> ToolCallParser | None: + if parser_name == "none": + return None + if parser_name == "qwen3": + return Qwen3ToolCallParser() + if parser_name == "qwen3p5": + return Qwen3p5ToolCallParser() + raise ValueError(f"Unsupported tool_call_parser: {parser_name}") + + +def build_reasoning_parser( + parser_name: ReasoningParserName, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, +) -> ReasoningParser | None: + if parser_name == "none": + return None + if parser_name == "qwen3": + return Qwen3ReasoningParser(strip_tokens=extract_qwen3_reasoning_strip_tokens(tokenizer)) + raise ValueError(f"Unsupported reasoning_parser: {parser_name}") diff --git a/xtuner/v1/rl/rollout/parser/qwen3_reasoning_parser.py b/xtuner/v1/rl/rollout/parser/qwen3_reasoning_parser.py new file mode 100644 index 0000000000..bfa4b61472 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/qwen3_reasoning_parser.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import re + +from xtuner.v1.data_proto import RolloutState + +from .reasoning_parser import ParsedReasoningResult, ReasoningParser + + +class Qwen3ReasoningParser(ReasoningParser): + _reasoning_pattern = re.compile(r"\s*(.*?)\s*", re.DOTALL) + + def __init__(self, strip_tokens: list[str] | None = None): + self._strip_tokens = strip_tokens or [] + + def parse(self, rollout_state: RolloutState) -> ParsedReasoningResult: + text = rollout_state.response or "" + if not text: + return ParsedReasoningResult() + cleaned = text + for token in self._strip_tokens: + cleaned = cleaned.replace(token, "") + reasoning_chunks = [ + match.group(1).strip() for match in self._reasoning_pattern.finditer(cleaned) if match.group(1).strip() + ] + content = self._reasoning_pattern.sub("", cleaned).strip() + if not reasoning_chunks and "" in cleaned: + prefix, suffix = cleaned.split("", 1) + content = prefix.strip() + truncated_reasoning = suffix.replace("", "").strip() + if truncated_reasoning: + reasoning_chunks.append(truncated_reasoning) + elif not reasoning_chunks and "" in cleaned: + reasoning_text, content = cleaned.split("", 1) + reasoning_text = reasoning_text.strip() + if reasoning_text: + reasoning_chunks.append(reasoning_text) + content = content.strip() + reasoning = "\n".join(reasoning_chunks).strip() or None + return ParsedReasoningResult(reasoning_text=reasoning, remaining_text=content or None) + + +def extract_qwen3_reasoning_strip_tokens( + tokenizer, +) -> list[str]: + strip_tokens: list[str] = [] + + eos_token = getattr(tokenizer, "eos_token", None) + if isinstance(eos_token, str) and eos_token: + strip_tokens.append(eos_token) + + for token in getattr(tokenizer, "additional_special_tokens", []) or []: + if not isinstance(token, str): + continue + lowered = token.lower() + if any(marker in lowered for marker in ("im_end", "eot", "end_of_turn", "turn_end")): + strip_tokens.append(token) + + return list(dict.fromkeys(strip_tokens)) diff --git a/xtuner/v1/rl/rollout/parser/qwen3_tool_parser.py b/xtuner/v1/rl/rollout/parser/qwen3_tool_parser.py new file mode 100644 index 0000000000..1e0d4663a0 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/qwen3_tool_parser.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import re +from typing import Any +from uuid import uuid4 + +from xtuner.v1.data_proto import RolloutToolCall + +from .tool_parser import ( + ParsedToolCallResult, + ToolCallParser, + build_rollout_tool_call, + coerce_parameter_value, + parse_json_or_python_mapping, +) + + +class Qwen3ToolCallParser(ToolCallParser): + _tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL) + _qwen_function_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + _qwen_parameter_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + _xml_tag_pattern = re.compile(r"<([a-zA-Z_][^>\n/]*)>(.*?)", re.DOTALL) + + def parse_text(self, text: str) -> ParsedToolCallResult: + if not text: + return ParsedToolCallResult() + cleaned_text, tool_calls = self._extract_tool_call_tags(text) + cleaned_text, qwen_tool_calls = self._extract_qwen_function_calls(cleaned_text) + tool_calls.extend(qwen_tool_calls) + return ParsedToolCallResult(remaining_text=cleaned_text.strip(), tool_calls=tool_calls) + + def should_reject_unparsed_markup( + self, + *, + has_tools: bool, + text: str | None, + parsed_tool_calls: list[Any] | None, + ) -> bool: + if not has_tools: + return False + if parsed_tool_calls: + return False + if not text: + return False + return any(marker in text for marker in ("", "", " tuple[str, list[RolloutToolCall]]: + tool_calls: list[RolloutToolCall] = [] + text_parts: list[str] = [] + last_end = 0 + for match in self._qwen_function_pattern.finditer(text): + if match.start() > last_end: + text_parts.append(text[last_end : match.start()]) + parsed_tool_call = self._parse_qwen_function_call(match.group(1).strip(), match.group(2)) + if parsed_tool_call is None: + text_parts.append(match.group(0)) + else: + tool_calls.append(parsed_tool_call) + last_end = match.end() + if last_end < len(text): + text_parts.append(text[last_end:]) + return "".join(text_parts), tool_calls + + def _parse_single_textual_tool_call(self, raw_payload: str) -> RolloutToolCall | None: + payload = parse_json_or_python_mapping(raw_payload) + if isinstance(payload, dict) and payload.get("name"): + arguments = payload.get("arguments", payload.get("parameters", {})) + return build_rollout_tool_call( + name=str(payload["name"]), + arguments=arguments, + call_id=str(payload.get("id") or f"call_{uuid4().hex}"), + ) + function_match = self._qwen_function_pattern.search(raw_payload) + if function_match is None: + return None + return self._parse_qwen_function_call(function_match.group(1).strip(), function_match.group(2)) + + def _parse_qwen_function_call(self, function_name: str, function_body: str) -> RolloutToolCall | None: + arguments: dict[str, Any] = {} + for parameter_match in self._qwen_parameter_pattern.finditer(function_body): + param_name = parameter_match.group(1).strip() + param_value = parameter_match.group(2).strip() + arguments[param_name] = coerce_parameter_value(param_value) + if not arguments: + for tag_match in self._xml_tag_pattern.finditer(function_body): + tag_name = tag_match.group(1).strip() + if tag_name.startswith("function="): + continue + tag_value = tag_match.group(2).strip() + if tag_name in {"path", "file_path"}: + arguments[tag_name] = tag_value + else: + arguments[tag_name] = coerce_parameter_value(tag_value) + return build_rollout_tool_call( + name=function_name, + arguments=arguments, + call_id=f"call_{uuid4().hex}", + ) diff --git a/xtuner/v1/rl/rollout/parser/qwen3p5_tool_parser.py b/xtuner/v1/rl/rollout/parser/qwen3p5_tool_parser.py new file mode 100644 index 0000000000..098077cbe8 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/qwen3p5_tool_parser.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import re +from typing import Any +from uuid import uuid4 + +from xtuner.v1.data_proto import RolloutToolCall + +from .tool_parser import ParsedToolCallResult, ToolCallParser, build_rollout_tool_call, coerce_parameter_value + + +class Qwen3p5ToolCallParser(ToolCallParser): + _tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL) + _parameter_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + + def parse_text(self, text: str) -> ParsedToolCallResult: + if not text: + return ParsedToolCallResult() + cleaned_text, tool_calls = self._extract_tool_call_tags(text) + return ParsedToolCallResult(remaining_text=cleaned_text.strip(), tool_calls=tool_calls) + + def should_reject_unparsed_markup( + self, + *, + has_tools: bool, + text: str | None, + parsed_tool_calls: list[Any] | None, + ) -> bool: + if not has_tools: + return False + if parsed_tool_calls: + return False + if not text: + return False + return any(marker in text for marker in ("", "", " RolloutToolCall | None: + function_name = self._extract_function_name(raw_payload) + if not function_name: + return None + + arguments: dict[str, Any] = {} + for parameter_match in self._parameter_pattern.finditer(raw_payload): + parameter_name = parameter_match.group(1).strip() + parameter_value = parameter_match.group(2).strip() + arguments[parameter_name] = coerce_parameter_value(parameter_value) + + return build_rollout_tool_call( + name=function_name, + arguments=arguments, + call_id=f"call_{uuid4().hex}", + ) + + def _extract_function_name(self, raw_payload: str) -> str | None: + function_start = raw_payload.find("", name_start), raw_payload.find("\n", name_start)) if index != -1 + ] + if not terminators: + return None + + function_name = raw_payload[name_start : min(terminators)].strip() + return function_name or None diff --git a/xtuner/v1/rl/rollout/parser/reasoning_parser.py b/xtuner/v1/rl/rollout/parser/reasoning_parser.py new file mode 100644 index 0000000000..51916e8ac7 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/reasoning_parser.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from pydantic import BaseModel, ConfigDict + +from xtuner.v1.data_proto import RolloutState + + +class ParsedReasoningResult(BaseModel): + model_config = ConfigDict(extra="forbid") + + reasoning_text: str | None = None + remaining_text: str | None = None + + +class ReasoningParser(ABC): + @abstractmethod + def parse(self, rollout_state: RolloutState) -> ParsedReasoningResult: + """Return parsed reasoning and remaining text for a rollout + response.""" diff --git a/xtuner/v1/rl/rollout/parser/tool_parser.py b/xtuner/v1/rl/rollout/parser/tool_parser.py new file mode 100644 index 0000000000..31510617c9 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/tool_parser.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import ast +import json +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto import RolloutFunctionCall, RolloutState, RolloutToolCall + + +class ParsedToolCallResult(BaseModel): + model_config = ConfigDict(extra="forbid") + + tool_calls: list[RolloutToolCall] = Field(default_factory=list) + remaining_text: str = "" + + +class ToolCallParser(ABC): + def parse(self, rollout_state: RolloutState) -> ParsedToolCallResult: + return self.parse_text(rollout_state.response or "") + + def should_reject_unparsed_markup( + self, + *, + has_tools: bool, + text: str | None, + parsed_tool_calls: list[Any] | None, + ) -> bool: + """Whether the remaining assistant text should be rejected as a + malformed tool call. + + Most parsers do not use textual tool-call markup, so the default behavior is to accept the text. Parsers with + format-specific markup can override this and reject outputs that still contain unparsed tool-call fragments. + """ + return False + + @abstractmethod + def parse_text(self, text: str) -> ParsedToolCallResult: + raise NotImplementedError + + +def extract_tokenizer_token_contents( + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | Any, +) -> set[str]: + token_contents: set[str] = set() + + for token in getattr(tokenizer, "additional_special_tokens", []) or []: + if isinstance(token, str): + token_contents.add(token) + + added_tokens_decoder = getattr(tokenizer, "added_tokens_decoder", None) + if isinstance(added_tokens_decoder, dict): + for token_info in added_tokens_decoder.values(): + if isinstance(token_info, str): + token_contents.add(token_info) + elif isinstance(token_info, dict): + content = token_info.get("content") + if isinstance(content, str): + token_contents.add(content) + else: + content = getattr(token_info, "content", None) + if isinstance(content, str): + token_contents.add(content) + + get_vocab = getattr(tokenizer, "get_vocab", None) + if callable(get_vocab): + try: + vocab = get_vocab() + except Exception: + vocab = None + if isinstance(vocab, dict): + token_contents.update(token for token in vocab if isinstance(token, str)) + + return token_contents + + +def parse_json_or_python_mapping(raw_payload: str) -> Any: + try: + return json.loads(raw_payload) + except Exception: + try: + return ast.literal_eval(raw_payload) + except Exception: + return None + + +def coerce_parameter_value(value: str) -> Any: + stripped = value.strip() + if not stripped: + return "" + try: + return json.loads(stripped) + except Exception: + try: + return ast.literal_eval(stripped) + except Exception: + return stripped + + +def build_rollout_tool_call( + *, + name: str, + arguments: Any, + call_id: str, +) -> RolloutToolCall: + raw_arguments_text = arguments if isinstance(arguments, str) else None + parsed_arguments = arguments + if isinstance(arguments, str): + decoded = parse_json_or_python_mapping(arguments) + parsed_arguments = decoded if decoded is not None else {"raw": arguments} + return RolloutToolCall( + id=call_id, + function=RolloutFunctionCall( + name=name, + arguments=parsed_arguments, + raw_arguments_text=raw_arguments_text, + ), + ) diff --git a/xtuner/v1/ray/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py similarity index 99% rename from xtuner/v1/ray/rollout/sglang.py rename to xtuner/v1/rl/rollout/sglang.py index 9d0b7db980..9d5c919cd1 100644 --- a/xtuner/v1/ray/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -5,9 +5,8 @@ from urllib3.exceptions import NewConnectionError from transformers import AutoTokenizer -from xtuner.v1.ray.config import RolloutConfig -from .worker import RolloutWorker +from .worker import RolloutConfig, RolloutWorker class SGLangWorker(RolloutWorker): diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py new file mode 100644 index 0000000000..e19bb28649 --- /dev/null +++ b/xtuner/v1/rl/rollout/utils.py @@ -0,0 +1,279 @@ +import asyncio +import os +import threading +import time +from collections import OrderedDict +from itertools import cycle +from typing import TYPE_CHECKING, Any, Optional + +import httpx +import ray + +from xtuner.v1.rl.utils import asyncio_run +from xtuner.v1.utils import get_logger + + +if TYPE_CHECKING: + from .controller import RolloutControllerProxy, WorkerInfo + from .worker import RolloutConfig, RolloutWorker + +ROLLOUT_RAY_GET_TIMEOUT = int(os.getenv("XTUNER_ROLLOUT_RAY_GET_TIMEOUT", str(5 * 3600))) # default 5 hours +logger = get_logger() + + +class SessionRouter: + def __init__( + self, + worker_infos: dict[int, "WorkerInfo"], # worker: worker_status + worker_infos_lock: Optional[threading.RLock] = None, + max_sessions: int = 10000, + max_idle_seconds: Optional[float] = 3600.0, + ): + self._worker_infos = worker_infos + self._worker_infos_lock = worker_infos_lock + self._max_sessions = max_sessions + self._max_idle = max_idle_seconds + + # OrderedDict: key=session_id -> value=(worker_rank, last_used_ts) + self._map: OrderedDict[int, tuple[int, float]] = OrderedDict() + + self._worker_cycler = cycle(worker_infos.keys()) + self._lock = asyncio.Lock() + self.logger = get_logger() + + def _now(self) -> float: + return time.time() + + def _evict_expired(self): + if self._max_idle is None: + return + now = self._now() + + to_delete = [] + for sid, (_, last_used) in self._map.items(): + if now - last_used > self._max_idle: + to_delete.append(sid) + else: + break + for sid in to_delete: + self._map.pop(sid, None) + + def _evict_lru_to_capacity(self): + while len(self._map) > self._max_sessions: + self._map.popitem(last=False) + + def _choose_next_active_worker(self) -> tuple[int, Any]: + n = len(self._worker_infos) + for _ in range(n): + rank = next(self._worker_cycler) + if self._worker_infos_lock is None: + info = self._worker_infos[rank] + if info and info.is_active: + return rank, info.actor + else: + with self._worker_infos_lock: + info = self._worker_infos[rank] + if info and info.is_active: + return rank, info.actor + return -1, None + + async def get_worker(self, session_id: int) -> Optional[Any]: + async with self._lock: + self._evict_expired() + + if session_id in self._map: + worker_rank, _ = self._map.pop(session_id) + if self._worker_infos_lock is None: + info = self._worker_infos.get(worker_rank) + else: + with self._worker_infos_lock: + info = self._worker_infos.get(worker_rank) + if info and info.is_active: + self._map[session_id] = (worker_rank, self._now()) + return info.actor + + rank, worker = self._choose_next_active_worker() + if rank == -1: + return None + self._map[session_id] = (rank, self._now()) + self._evict_lru_to_capacity() + return worker + + +class RolloutHealthChecker: + def __init__( + self, + config: "RolloutConfig", + workers_info: dict[int, "WorkerInfo"], + worker_infos_lock: Optional[threading.RLock] = None, + ): + self._workers_info = workers_info + self._worker_infos_lock = worker_infos_lock + self._check_interval = config.health_check_interval_seconds + self._check_failure_threshold = config.health_check_failure_threshold + self._stop_event: Optional[threading.Event] = None + self._pause_event: Optional[threading.Event] = None + self._thread: Optional[threading.Thread] = None + + def start(self) -> None: + if self._thread and self._thread.is_alive(): + return + + self._stop_event = threading.Event() + self._pause_event = threading.Event() + self._pause_event.set() # 启动时设置为暂停状态,开始generation后再调用restart方法恢复 + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + logger.info("RolloutHealthChecker started.") + + def stop(self) -> None: + if not self._thread: + return + + assert self._stop_event is not None + self._stop_event.set() + if self._pause_event: + self._pause_event.clear() + self._thread.join(timeout=5) + self._thread = None + self._stop_event = None + logger.info("RolloutHealthChecker stopped.") + + def pause(self) -> None: + if self._pause_event is None: + return + self._pause_event.set() + logger.info("RolloutHealthChecker paused.") + + def resume(self) -> None: + if self._pause_event is None: + return + self._pause_event.clear() + logger.info("RolloutHealthChecker restarted.") + + def run_once(self) -> None: + logger.info("RolloutHealthChecker running health checks for all workers.") + if self._worker_infos_lock is None: + workers_snapshot = { + rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items() + } + else: + with self._worker_infos_lock: + workers_snapshot = { + rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items() + } + + tasks = [ + check_worker_health( + actor, + rank, + url, + is_active, + self._check_failure_threshold, + ) + for rank, (actor, url, is_active) in workers_snapshot.items() + ] + + async def _run_checks() -> list[bool]: + return await asyncio.gather(*tasks) + + check_results = asyncio_run(_run_checks()) + inactive_workers = [] + for rank, is_healthy in zip(workers_snapshot.keys(), check_results): + if not is_healthy: + logger.warning(f"Worker {rank} failed health check. Marking as inactive.") + if self._worker_infos_lock is None: + self._workers_info[rank].is_active = False + inactive_worker = self._workers_info[rank].actor + else: + with self._worker_infos_lock: + self._workers_info[rank].is_active = False + inactive_worker = self._workers_info[rank].actor + if inactive_worker is None: + logger.error(f"[RolloutHealthChecker] Worker {rank} has no actor reference. Skipping shutdown.") + continue + inactive_workers.append((rank, inactive_worker)) + else: + logger.info(f"[RolloutHealthChecker] Worker {rank} passed health check.") + + for rank, inactive_worker in inactive_workers: + try: + ray.get(inactive_worker.offload.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + ray.get(inactive_worker.shutdown.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + except Exception as e: + logger.error(f"Exception while shutting down worker {rank}: {e}") + + def _run_loop(self) -> None: + assert self._stop_event is not None and self._pause_event is not None + logger.info("RolloutHealthChecker loop started.") + + while not self._stop_event.is_set(): + while self._pause_event.is_set() and not self._stop_event.is_set(): + self._stop_event.wait(timeout=0.5) + + if self._stop_event.is_set(): + break + + if not self._pause_event.is_set() and not self._stop_event.is_set(): + self.run_once() + + if self._stop_event.wait(self._check_interval): + break + + +async def send_abort_request(client: httpx.AsyncClient, url: str, timeout: float = 60.0) -> tuple[str, bool]: + worker_url = f"{url}/abort_request" + try: + response = await client.post(worker_url, json={"abort_all": True}, timeout=timeout) + response.raise_for_status() + logger.debug(f"Successfully sent abort request to {url}") + return url, True + except Exception as e: + logger.error(f"Failed to send abort request to {url}: {e}") + return url, False + + +async def pause_generation(rollout_ctl: "RolloutControllerProxy", pause_time_out: float = 60.0) -> None: + await rollout_ctl.pause_generation.remote() # type: ignore[attr-defined] + rollout_ctl_metadata = await rollout_ctl.get_rollout_metadata.remote() # type: ignore[attr-defined] + infer_server_url = list(rollout_ctl_metadata["server_url_dict"].values()) + async with httpx.AsyncClient() as client: + tasks = [send_abort_request(client, url, timeout=pause_time_out) for url in infer_server_url] + results = await asyncio.gather(*tasks) + + failed_workers = [url for url, success in results if not success] + succeeded_count = len(infer_server_url) - len(failed_workers) + + if failed_workers: + logger.warning( + f"Abort requests completed. Succeeded: {succeeded_count}, " + f"Failed: {len(failed_workers)}. Failed workers: {failed_workers}" + ) + else: + logger.info(f"All {succeeded_count} abort requests sent successfully.") + + +async def continue_generation(rollout_ctl: "RolloutControllerProxy") -> None: + return await rollout_ctl.continue_generation.remote() # type: ignore[attr-defined] + + +async def check_worker_health( + worker: "RolloutWorker", rank: int, url: str, is_active: bool, failure_threshold: int = 3 +) -> bool: + if worker is None or not is_active: + logger.warning("Worker has no actor reference or is marked inactive.") + return False + failing_count = 0 + while failing_count < failure_threshold: + try: + health_status = await worker.check_health.remote() # type: ignore[attr-defined] + if health_status: + return True + failing_count += 1 + logger.warning(f"Health check failed for worker {rank} at {url}. Failure count: {failing_count}") + except Exception as e: + failing_count += 1 + logger.error( + f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}" + ) + return False diff --git a/xtuner/v1/ray/rollout/vllm.py b/xtuner/v1/rl/rollout/vllm.py similarity index 98% rename from xtuner/v1/ray/rollout/vllm.py rename to xtuner/v1/rl/rollout/vllm.py index 400db51ae2..39d01dd0cc 100644 --- a/xtuner/v1/ray/rollout/vllm.py +++ b/xtuner/v1/rl/rollout/vllm.py @@ -6,9 +6,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils import FlexibleArgumentParser -from xtuner.v1.ray.config import RolloutConfig - -from .worker import RolloutWorker +from .worker import RolloutConfig, RolloutWorker def run_vllm_server_wrapper(server_args): diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py new file mode 100644 index 0000000000..ad86c36333 --- /dev/null +++ b/xtuner/v1/rl/rollout/worker.py @@ -0,0 +1,1050 @@ +import asyncio +import copy +import json +import multiprocessing +import os +import socket +import time +import traceback +from abc import abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union + +import httpx +import ray +import requests # type: ignore[import-untyped] +import torch +from cyclopts import Group, Parameter +from packaging.version import Version +from pydantic import BaseModel, ConfigDict, PrivateAttr +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from typing_extensions import Annotated + +from transformers import AutoTokenizer +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status, update_status_from_finish_reason +from xtuner.v1.rl.utils import ( + AutoAcceleratorWorkers, + SingleAcceleratorWorker, + find_master_addr_and_port, + get_eos_token, +) +from xtuner.v1.utils import get_logger +from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult + + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + + +infer_group = Group("inference", help="Inference worker configuration.") + + +class RolloutConfig(BaseModel): + """Rollout worker configuration for XTuner. + + This class defines comprehensive configuration parameters for rollout workers in XTuner, + supporting multiple inference backends with distributed computing and optimization features. + + Args: + env (str): Environment variables for the rollout worker. Defaults to "". + backend (str): Backend framework ('vllm', 'lmdeploy', etc.). Defaults to "lmdeploy". + model_path (str | Path): Path to the inference model. + model_name (str): Model name for the backend engine. + tokenizer_path (str): Path to the model tokenizer. Defaults to "". + api_key (Optional[Union[List[str], str]]): API keys for rollout service. Supports single key or + list of keys. Defaults to None. + api_port (Optional[int]): Port number for the rollout API server. If not set, it will find an + available port starting from 8000. Defaults to 8000. + gpus_per_node (int): Number of GPUs per node. Defaults to 8. + dtype (str): Model data type ('bfloat16', 'float16', 'int8'). Defaults to "bfloat16". + gpu_memory_utilization (float): GPU memory utilization ratio. Defaults to 0.85. + random_seed (int): Random seed for reproducible generation. Defaults to 1024. + rollout_cross_node_comm (bool): Enable cross-node communication. Defaults to False. + rollout_max_batch_size_per_instance (int): Maximum batch size for the rollout worker. If not set, it + will be determined automatically based on `context_length`. Defaults to 512. + allow_over_concurrency_ratio (float): Factor to allow over-concurrency in HTTP requests for the + rollout worker to improve GPU utilization. Defaults to 1.2. + tensor_parallel_size (int): GPUs per inference engine (tensor parallelism). Defaults to 1. + expert_parallel_size (int): Experts per inference engine (expert parallelism). Defaults to 1. + enable_chunked_prefill (bool): Enable chunked prefill for memory efficiency. Defaults to False. + chunked_prefill_size (int): Chunk size for prefill operations. Defaults to 128. + skip_load_weights (bool): Skip weight loading for rollout worker. Defaults to False. + rollout_timeout (float): Timeout duration in seconds for rollout requests. Defaults to 3600.0. + context_length (int): Context length for the rollout worker. + launch_server_method (Literal["ray", "multiprocessing"]): Server launch method. Defaults to "ray". + system_prompt (Optional[str]): System prompt to guide generation behavior. Defaults to None. + extra_rollout_config (Optional[dict]): Backend-specific configurations using engine prefixes + (e.g., 'vllm_enable_chunked_prefill', 'lmdeploy_max_batch_size'). Defaults to empty dict. + + **Examples:** + + Example configuration with LMDeploy backend:: + + config = RolloutConfig( + env="test_env", + model_path="Qwen/Qwen3-8B", + model_name="Qwen3-8B", + tensor_parallel_size=2, + gpu_memory_utilization=0.6, + gpus_per_node=8, + backend="lmdeploy", + ) + """ + + model_config = ConfigDict(extra="forbid") + + # base config + env: Annotated[ + str, + Parameter(group=infer_group, help="Environment variables to set for the rollout."), + ] = "" + device: Annotated[str, Parameter(group=infer_group, help="Device to be used for the rollout worker.")] = "GPU" + model_path: Annotated[str | Path, Parameter(group=infer_group, help="Path to the SGLang model.")] + model_name: Annotated[ + str | None, Parameter(group=infer_group, help="Name of the model to be used in the LMDeploy.") + ] = None + tokenizer_path: Annotated[ + str | None, Parameter(group=infer_group, help="Path to the tokenizer for the model.") + ] = None + api_key: Annotated[ + Optional[Union[List[str], str]], + Parameter( + group=infer_group, + help="API keys for the rollout service. Can be a single key or a list of keys.", + ), + ] = None + api_port: Annotated[ + int, + Parameter(group=infer_group, help="Port number for the rollout API server. If not set, 8000 will be used."), + ] = 8000 + api_host: Annotated[ + str, + Parameter(group=infer_group, help="Host for the rollout API server."), + ] = "0.0.0.0" + gpus_per_node: Annotated[int, Parameter(group=infer_group, help="Number of GPUs allocated per node.")] = 8 + dtype: Annotated[ + str, + Parameter(group=infer_group, help="Data type for the model, e.g., 'bfloat16', 'float16', 'int8'."), + ] = "bfloat16" + gpu_memory_utilization: Annotated[ + float, Parameter(group=infer_group, help="GPU memory utilization for the rollout worker.") + ] = 0.85 + random_seed: Annotated[int, Parameter(group=infer_group, help="Random seed for the rollout worker.")] = 1024 + # distributed config + rollout_cross_node_comm: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to enable cross-node communication for the rollout worker.", + ), + ] = False + dist_port_base: Annotated[ + int, + Parameter( + group=infer_group, + help="Base port number for distributed communication among rollout workers.", + ), + ] = 35000 + rollout_max_batch_size_per_instance: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="Maximum batch size for the rollout worker. If not set, it will be determined automatically based on the model and GPU memory.", + ), + ] = None + allow_over_concurrency_ratio: Annotated[ + float, + Parameter( + group=infer_group, + help="Factor to allow over concurrency in the http request for rollout worker to improve GPU utilization.", + ), + ] = 1.2 + tensor_parallel_size: Annotated[ + int, + Parameter( + group=infer_group, + help="Number of GPUs allocated for each inference engine in the rollout worker.", + ), + ] = 1 + expert_parallel_size: Annotated[ + int, + Parameter( + group=infer_group, + help="Number of experts allocated for each inference engine in the rollout worker.", + ), + ] = 1 + # optimization config + enable_chunked_prefill: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to enable chunked prefill for the rollout worker.", + ), + ] = False + chunked_prefill_size: Annotated[ + int, + Parameter( + group=infer_group, + help="Chunked prefill size for the rollout worker.", + ), + ] = 128 + skip_load_weights: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to skip loading weights for the rollout worker.", + ), + ] = False + enable_return_routed_experts: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to enable returning routed experts for the rollout worker.", + ), + ] = False + launch_server_method: Annotated[ + Literal["ray", "multiprocessing"], + Parameter( + group=infer_group, + help="Method to launch the rollout server, either 'ray' or 'multiprocessing'.", + ), + ] = "ray" + rollout_timeout: Annotated[ + float, + Parameter( + group=infer_group, + help="Timeout duration (in seconds) for rollout requests.", + ), + ] = 1200.0 + context_length: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="Context length for the rollout worker.", + ), + ] = None + tool_call_parser: Annotated[ + Literal["none", "qwen3", "qwen3p5"], + Parameter( + group=infer_group, + help='Structured tool-call parser to apply to rollout output. Use "none" to disable parsing, "qwen3" to enable Qwen3 tool-call parsing, or "qwen3p5" to enable Qwen3.5 coder-style tool-call parsing.', + ), + ] = "none" + reasoning_parser: Annotated[ + Literal["none", "qwen3"], + Parameter( + group=infer_group, + help='Reasoning parser to apply to rollout output. Use "none" to disable parsing or "qwen3" to enable Qwen3 parsing.', + ), + ] = "none" + enable_float8: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to enable float8 quantization for the rollout worker.", + ), + ] = False + extra_rollout_config: Annotated[ + dict, + Parameter( + group=infer_group, + help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.', + ), + ] = {} + max_retry_per_worker: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="Maximum number of retries per rollout worker before deactivation.", + ), + ] = None + max_retry_per_sample: Annotated[ + int, + Parameter( + group=infer_group, + help="Maximum number of retries per sample before marking it as failed.", + ), + ] = 1 + worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" + health_check_interval_seconds: Annotated[ + float, + Parameter( + group=infer_group, + help="Interval in seconds between rollout worker health checks.", + ), + ] = 30.0 + health_check_failure_threshold: Annotated[ + int, + Parameter( + group=infer_group, + help="Number of consecutive health check failures required before marking a worker inactive.", + ), + ] = 3 + _logged_server_urls_per_engine: bool = PrivateAttr(default=False) + + @property + def rollout_backend(self) -> str: + backend = "" + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + backend = "sglang" + elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": + backend = "vllm" + elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": + backend = "lmdeploy" + + assert backend in ["sglang", "vllm", "lmdeploy"], ( + f"Unsupported rollout backend: {backend}. Please set XTUNER_USE_SGLANG, XTUNER_USE_VLLM, or XTUNER_USE_LMDEPLOY to 1." + ) + return backend + + @property + def server_urls_per_engine(self) -> int: + # server_urls_per_engine is introduced for lmdeploy ep settings + # for now only lmdeploy pytorch backend with ep > 1 requires multiple server urls per engine + if self.rollout_backend == "lmdeploy" and self.expert_parallel_size > 1: + # when expert parallelism is used, lmdeploy requires `expert_parallel_size` server instances per engine + if not self._logged_server_urls_per_engine: + self._logged_server_urls_per_engine = True + get_logger().info( + f"Setting server_urls_per_engine={self.expert_parallel_size} due to expert parallelism in LMDeploy." + ) + return self.expert_parallel_size + else: + return 1 + + def model_post_init(self, __context: Any) -> None: + if self.model_name is None: + model_name_from_config = None + config_json_path = Path(self.model_path) / "config.json" + try: + with open(config_json_path, encoding="utf-8") as f: + config_data = json.load(f) + model_name_from_config = config_data.get("model_type") + except (json.JSONDecodeError, OSError): + pass + self.model_name = model_name_from_config or Path(self.model_path).name + + if self.tokenizer_path is None: + self.tokenizer_path = str(self.model_path) + + port = self.api_port + while True: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((self.api_host if self.api_host != "0.0.0.0" else "localhost", port)) + break + except OSError: + port += 1 + self.api_port = port + + if self.device == "NPU": + self.gpus_per_node = 16 + + if self.rollout_backend == "sglang": + self.launch_server_method = "multiprocessing" + self.rollout_cross_node_comm = False + else: + self.launch_server_method = "ray" + self.rollout_cross_node_comm = True + + if self.rollout_max_batch_size_per_instance is None: + assert self.context_length is not None, ( + "context_length must be set if rollout_max_batch_size_per_instance is not provided." + ) + # TODO(@duanyanhui): Provide better suggestions for different models/input-output lengths + if self.context_length <= 4096: + self.rollout_max_batch_size_per_instance = 1024 + elif self.context_length <= 8192: + self.rollout_max_batch_size_per_instance = 512 + else: + self.rollout_max_batch_size_per_instance = 128 + + if self.max_retry_per_worker is None: + self.max_retry_per_worker = self.rollout_max_batch_size_per_instance + + self.worker_log_dir.mkdir(parents=True, exist_ok=True) + + def build(self, placement_group: "PlacementGroup"): + """Build and return a Ray remote RolloutController from this config. + + Args: + placement_group: The placement group for scheduling RolloutWorker actors. + + Returns: + A Ray actor handle (proxy) of RolloutController. + """ + import ray + + from xtuner.v1.rl.rollout.controller import RolloutController + + return ( + ray.remote(RolloutController) + .options(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000))) + .remote(self, placement_group) + ) + + +class RolloutWorker(SingleAcceleratorWorker): + """Base class for a rollout worker that runs an inference server. + + This class manages the lifecycle of a distributed inference server, including initialization, launching, and + handling generation requests. It is designed to be subclassed for specific inference backends like LMDeploy, vLLM + or SGLang. + """ + + def __init__( + self, + config: RolloutConfig, + rank: int, + master_addr: str, + master_port: int, + world_size: int, + accelerator: str = "GPU", + ): + """Initialize the RolloutWorker. + + Args: + config (RolloutConfig): The configuration for the rollout. + rank (int): The rank of this worker in the distributed setup. + master_addr (str): The address of the Ray master node. + master_port (int): The port of the Ray master node. + world_size (int): The total number of workers. + accelerator (str): The type of accelerator to use. + Defaults to "GPU". + """ + self.config = config + self.rank = rank + self.master_addr = master_addr # ray master + self.master_port = master_port + self.world_size = world_size + self.accelerator = accelerator + self.server_func: Callable + self.endpoints: dict[str, str] = dict() + self.engine_rank_mesh_array: list[list[int]] + # http_concurrency is calculated based on the max batch size per engine and the total number of engines + assert config.rollout_max_batch_size_per_instance, ( + "rollout_max_batch_size_per_instance must be set in RolloutConfig" + ) + http_concurrency = config.rollout_max_batch_size_per_instance * config.allow_over_concurrency_ratio + limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) + self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout) + self.paused = False + self.server_task = None + self.engine_bundle_idxs: list[int] = [] + self.server_process: Optional[multiprocessing.Process] = None + self.logger = get_logger(log_dir=config.worker_log_dir, tag="RolloutWorker") + self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True) + self.check_flag = True # only print once + self.enable_return_routed_experts = self.config.enable_return_routed_experts + if self.rank == 0: + self.logger.info(f"RolloutConfig:\n{self.config.model_dump_json(indent=2)}") + eos_token = get_eos_token(self.config.model_path) + self.logger.info(f"Using eos_token: {eos_token} for model at {self.config.model_path}") + self.eos_token: List[int] = [eos_token] if isinstance(eos_token, int) else eos_token + self.receive_abort_request = asyncio.Event() + self.abort_timeout = 5.0 + self.dist_init_addr: str = "" + self.serverl_url: str = "" + + def init(self, dist_init_addr: str) -> tuple[int, str]: + """Initialize the worker and launch the server. + + Args: + dist_init_addr (str): The distributed initialization address. + If not provided, the one generated by `init_dist_port` is used. + + Returns: + Tuple[int, str]: A tuple containing the worker's rank and its + server URL. + """ + self.dist_init_addr = dist_init_addr if dist_init_addr else self.dist_init_addr + self.receive_abort_request.clear() + self._launch_server() + return (self.rank, self.server_url) + + def init_dist_port(self) -> str: + """Initialize distributed communication ports. + + This method acquires three free ports for the distributed setup: + one for the inference server, one for NCCL, and one for Ray's + distributed communication. + + Returns: + str: The distributed initialization address (host:port). + """ + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=ray.util.get_current_placement_group(), + placement_group_capture_child_tasks=True, + placement_group_bundle_index=self.engine_bundle_idxs[0], + ) + + local_rank = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) + interval = 1024 + start_port = self.config.dist_port_base + local_rank * interval + end_port = start_port + interval + self.host, self.ports = ray.get( + find_master_addr_and_port.options(scheduling_strategy=scheduling_strategy).remote( + nums=3, + start_port=start_port, + end_port=end_port, + ) + ) + + self.dist_port = self.ports[0] + self.server_port = self.ports[1] + self.nccl_port = self.ports[2] + self.dist_init_addr = f"{self.host}:{self.dist_port}" + self.server_url = f"http://{self.host}:{self.server_port}" + return self.dist_init_addr + + def shutdown(self): + """Shut down the worker, its server task, and any child processes.""" + if self.server_task is not None: + ray.cancel(self.server_task, force=True) + return + + if self.server_process is not None: + import psutil + + parent = psutil.Process(self.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.terminate() + gone, alive = psutil.wait_procs(children, timeout=5) + for child in alive: + child.kill() + parent.terminate() + parent.wait(timeout=5) + self.logger.debug(f"Worker {self.rank} server process and its children terminated.") + return + + def pause_generation(self): + """Pause the worker's generation process.""" + self.paused = True + + def continue_generation(self): + """Resume the worker's generation process.""" + self.receive_abort_request.clear() + + def check_health(self) -> bool: + """Check the health of the worker's server. + + Returns: + bool: True if the server is healthy, False otherwise. + """ + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {self.config.api_key}", + } + response = requests.get( + f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers, timeout=5.0 + ) + return response.status_code == 200 + except requests.RequestException as e: + self.logger.error(f"Health check failed for server {self.server_url}: {e}") + return False + + async def generate(self, rollout_state: RolloutState) -> RolloutState: + # TODO(@duanyanhui): + # 1. support claude format input + # 2. 需要看下新的输入输出(RolloutState)怎么适配PartialRollout的逻辑,先跑起来 + # 3. 对于流式返回的response先删掉,目前还用不上,等需要的时候再加上 + + uid = rollout_state.uid + sample_params: SampleParams = rollout_state.sample_params + + if sample_params.return_token_ids: + endpoint_url = f"{self.server_url}/{self.endpoints['generate']}" + else: + endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}" + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.config.api_key}", + } + + max_retries = self.config.max_retry_per_sample + payload = self._get_request_payload(rollout_state) + + # 早退逻辑 1:检查是否已被标记为完成 + if rollout_state.status == Status.COMPLETED: + self.logger.debug(f"Request {uid} is already marked as COMPLETED, skipping generation.") + return rollout_state + + # 早退逻辑 2:检测输入是否还需要 generation (安全获取变量) + input_ids = payload.get("input_ids", []) + max_tokens = payload.get("max_tokens") + + last_id = input_ids[-1] if len(input_ids) > 0 else "None" + is_max_tokens_zero = max_tokens is not None and max_tokens <= 0 + is_eos_reached = len(input_ids) > 0 and input_ids[-1] in self.eos_token + + if is_max_tokens_zero or is_eos_reached: + self.logger.debug( + f"No generation needed for request {uid}: max_tokens={max_tokens} or last input_id={last_id} is in eos_token." + ) + rollout_state.status = Status.COMPLETED + rollout_state.response_ids = [] + rollout_state.response = "" + rollout_state.logprobs = [] + rollout_state.response_mask = [] + rollout_state.response_rollout_steps = [] + rollout_state.finish_reason = "stop" if is_eos_reached else "length" + return rollout_state + + for attempt in range(max_retries + 1): + is_last_attempt = attempt == max_retries + http_result = await self._safe_post_request(endpoint_url, headers=headers, payload=payload) + + # Case 1: HTTP Request is Successful + if http_result.response: + # Case 1.1: Valid rollout response + rollout_state = await self._safe_handle_response(rollout_state, http_result.response) + if rollout_state.status in [Status.COMPLETED, Status.ABORTED]: + return rollout_state + + if is_last_attempt: + # Case 1.2: Invalid rollout response and no retries left, so we return FAILED + self.logger.warning( + f"Invalid rollout response for request {uid} after {max_retries} attempts, marking as FAILED." + ) + rollout_state.status = Status.FAILED + rollout_state.error_msg = f"Invalid rollout response after {max_retries} attempts." + return rollout_state + + # Case 1.3: Invalid rollout response but we have retries left + self.logger.warning( + f"Invalid rollout response for request {uid}, retrying {attempt + 1}/{max_retries}." + ) + await asyncio.sleep(0.1) + continue + + # Case 2: Error occurred during HTTP Request + if http_result.error_type == HttpRequestErrorType.REQUEST_ABORTED: + # Case 2.1: The request was aborted due to an signal set by `receive_abort_request` + rollout_state.finish_reason = "abort" + rollout_state.status = update_status_from_finish_reason("abort") + return rollout_state + + if http_result.is_client_error: + # Case 2.2: A non-retryable client error occurred (such as 4xx HTTP status) + self.logger.warning( + f"rollout request {uid} to {http_result.url} was skipped due to client error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = ( + f"Client error {http_result.error_type} with message: {http_result.error_msg}" + ) + rollout_state.status = Status.FAILED + return rollout_state + + if http_result.is_server_error: + # Case 2.3: A non-retryable server error occurred (such as 5xx HTTP status) + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed due to server error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = ( + f"Server error {http_result.error_type} with message: {http_result.error_msg}" + ) + rollout_state.status = Status.FAILED + return rollout_state + + # Case 3: Retryable error occurred during HTTP Request + if http_result.is_retryable: + if is_last_attempt: + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed after {max_retries} attempts due to retryable error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = f"Request failed after {max_retries} attempts due to retryable error {http_result.error_type} with message: {http_result.error_msg}" + rollout_state.status = Status.FAILED + return rollout_state + + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed due to retryable error {http_result.error_type} with {http_result.error_msg}, retrying {attempt + 1}/{max_retries}." + ) + await asyncio.sleep(0.1) + continue + + # Case 4: Unknown error occurred during HTTP Request and stop the rollout + if http_result.is_unknown_error: + raise RuntimeError( + f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}" + ) + return rollout_state + + def _launch_server(self): + """Launch the inference server as a separate process or Ray task. + + It waits for the server to become healthy before returning. + + Raises: + TimeoutError: If the server fails to start within the specified + timeout. + Exception: If the server task terminates unexpectedly. + """ + server_configs = self._transform_rollout_config_to_server_configs() + timeout = 3600.0 # Increased timeout to 5 minutes for downloading large models + start_time = time.perf_counter() + last_log_time = start_time + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {server_configs.api_key}", + } + + self.logger.info(f"Launch server task on server_url: {self.server_url}") + + # note(@duanyanhui): launch server as multiprocessing for sglang temporarily + if self.config.launch_server_method == "multiprocessing": + mp_ctx = multiprocessing.get_context("spawn") + process = mp_ctx.Process(target=self.server_func, args=(server_configs,)) + process.start() + self.server_process = process + time.sleep(60) # Wait for the server to start + with requests.Session() as session: + while time.perf_counter() - start_time < timeout: + try: + response = session.get( + f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers + ) + if response.status_code == 200: + return + except requests.RequestException as e: + self.logger.error( + f"can't connect to server url {self.server_url}/{self.endpoints['health_generate']} because {e}" + ) + + current_time = time.perf_counter() + if current_time - last_log_time >= 15: + self.logger.info( + f"Waiting for server to start, Elapsed time: {current_time - start_time:.2f}s" + ) + last_log_time = current_time + + time.sleep(5) + process.terminate() + raise TimeoutError("Server failed to start within the timeout period.") + else: + # launch the server as ray task + # so that the lmdeploy backend could get externl pg + current_pg = ray.util.get_current_placement_group() + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=current_pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=self.engine_bundle_idxs[0], + ) + assert ray.is_initialized() + ray_kwargs = ( + {"runtime_env": server_configs.ray_runtime_env} if hasattr(server_configs, "ray_runtime_env") else {} + ) + self.server_task = ( + ray.remote(self.server_func) + .options( + scheduling_strategy=scheduling_strategy, + **AutoAcceleratorWorkers.get_pg_options(current_pg), + **ray_kwargs, + ) + .remote(server_configs) + ) + + with requests.Session() as session: + while time.perf_counter() - start_time < timeout: + try: + response = session.get( + f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers + ) + if response.status_code == 200: + return + except requests.RequestException: + pass + + try: + ray.get(self.server_task, timeout=0.1) + raise Exception("Server task terminated unexpectedly.") + except ray.exceptions.GetTimeoutError: + pass + except Exception as e: + raise e + + current_time = time.perf_counter() + if current_time - last_log_time >= 15: + self.logger.info( + f"Waiting for server to start... Elapsed time: {current_time - start_time:.2f}s" + ) + last_log_time = current_time + + ray.cancel(self.server_task) + raise TimeoutError("Server failed to start within the timeout period.") + + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + if self.receive_abort_request.is_set(): + self.logger.debug(f"Request to {url} was cancelled before sending due to an abort signal.") + return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) + req = self.client.build_request( + "POST", + url, + headers=headers, + json=payload, + ) + r = await self.client.send(req) + r.raise_for_status() + return HttpRequestResult(response=r) + + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + return result + + async def _safe_handle_response(self, rollout_state: RolloutState, http_response: httpx.Response) -> RolloutState: + uid = rollout_state.message_uid + sample_params = rollout_state.sample_params + is_token_out = sample_params.return_token_ids + response = http_response.json() + if is_token_out: + response_ids: list[int] = [] + logprobs: list[float] = [] + routed_experts = None + returned_response = "" + finish_reason = response["meta_info"]["finish_reason"]["type"] + if finish_reason == "abort" and self.receive_abort_request.is_set() is False: + self.receive_abort_request.set() + self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}") + try: + returned_response = response.get("text", "") + # 获取response_ids && respoonse_ids + if ( + "output_token_logprobs" in response["meta_info"] + and response["meta_info"]["output_token_logprobs"] is not None + ): + response_ids = [item[1] for item in response["meta_info"]["output_token_logprobs"]] + logprobs = [item[0] for item in response["meta_info"]["output_token_logprobs"]] + else: + num_return_tokens = response["meta_info"].get("completion_tokens", 0) + response_ids = response["output_ids"][-num_return_tokens:] if num_return_tokens > 0 else [] + + # 获取 routed_experts + if self.enable_return_routed_experts: + assert "routed_experts" in response["meta_info"], ( + "enable_return_routed_experts is True, but routed_experts is not in meta_info" + ) + routed_experts = response["meta_info"]["routed_experts"] # token[layer[expert]] + if routed_experts is not None: + if isinstance(routed_experts, str): + import base64 + + data = base64.b64decode(routed_experts) + routed_experts = ray.cloudpickle.loads(data) + else: + routed_experts = torch.tensor(routed_experts) # n,layer,expert + routed_experts = ray.put(routed_experts) + + # 获取 status + rollout_status = update_status_from_finish_reason(finish_reason) + + # 检查输出结果 + if rollout_status == Status.COMPLETED: + validation_errors = [] + + if not response_ids: + validation_errors.append("empty response_ids") + + if not response: + validation_errors.append("empty response text") + + if sample_params.return_logprob and not logprobs: + validation_errors.append("missing logprobs") + + if self.enable_return_routed_experts and routed_experts is None: + validation_errors.append("missing routed_experts") + + if validation_errors: + error_msg = f"Incomplete rollout data for msg {uid}: {', '.join(validation_errors)}" + self.logger.error(error_msg) + rollout_state.status = Status.FAILED + rollout_state.error_msg = error_msg + return rollout_state + elif rollout_status == Status.FAILED: + error_msg = f"Rollout failed for msg {uid} with finish_reason {finish_reason}" + self.logger.error(error_msg) + rollout_state.status = Status.FAILED + rollout_state.error_msg = error_msg + return rollout_state + + rollout_state.response = returned_response + rollout_state.response_ids = response_ids + rollout_state.logprobs = logprobs + rollout_state.routed_experts = routed_experts + rollout_state.finish_reason = finish_reason + rollout_state.status = rollout_status + return rollout_state + except KeyError as e: + error_msg = f"Missing expected key {e} in response {response} for {uid}" + raise RuntimeError(error_msg) + except IndexError as e: + error_msg = f"Index error {e} while processing response {response} for {uid}" + raise RuntimeError(error_msg) + except AssertionError as e: + error_msg = f"AssertionError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except json.JSONDecodeError as e: + error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except TypeError as e: + error_msg = f"TypeError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except Exception as e: + error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" + raise RuntimeError(error_msg) + else: + # v1/chat/completions API response + try: + returned_response = response["choices"][0]["message"]["content"] + finish_reason = response["choices"][0]["finish_reason"] + rollout_status = update_status_from_finish_reason(finish_reason) + if rollout_status == Status.COMPLETED and not returned_response: + self.logger.error(f"Empty response text for msg {uid} with finish_reason {finish_reason}") + rollout_state.status = Status.FAILED + rollout_state.error_msg = "Empty response text" + return rollout_state + + rollout_state.response = returned_response + rollout_state.finish_reason = finish_reason + rollout_state.status = rollout_status + return rollout_state + except KeyError as e: + error_msg = f"Missing expected key {e} in response {response} for {uid}" + raise RuntimeError(error_msg) + except IndexError as e: + error_msg = f"Index error {e} while processing response {response} for {uid}" + raise RuntimeError(error_msg) + except AssertionError as e: + error_msg = f"AssertionError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except json.JSONDecodeError as e: + error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except TypeError as e: + error_msg = f"TypeError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except Exception as e: + error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" + raise RuntimeError(error_msg) + + def _adapt_input_to_openai_spec(self, prompts, tools, tool_choice): + openai_prompts = [] + openai_tools = [] + # transform claude spec to openai spec + # 1. transform system prompt: concat provided system_prompt to input prompt + system_prompt = self.config.system_prompt + if system_prompt: + system_prompt_json = {"role": "system", "content": f"{system_prompt}"} + prompts.insert(0, system_prompt_json) + # 2. transform multi-modal usage + for prompt in prompts: + content = prompt["content"] + openai_content = [] + for item in content: + if item["type"] == "image": + if item["source"]["type"] == "base64": + openai_url = f"data:{item['source']['media_type']};base64,{item['source']['data']}" + if item["source"]["type"] == "url": + openai_url = item["source"]["url"] + new_prompt = {"type": "image_url", "image_url": {"url": openai_url}} + openai_content.append(new_prompt) + elif item["type"] == "text": + openai_content.append(item) + new_prompt = copy.deepcopy(prompt) + new_prompt["content"] = openai_content + openai_prompts.append(new_prompt) + # 3. transform tool use + for tool in tools: + openai_tool = { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool["input_schema"], + }, + } + openai_tools.append(openai_tool) + return openai_prompts, openai_tools + + def _check_infer_engine_version(self, return_token_ids: bool): + # TODO(@duanyanhui): remove this check when all backends support return_token_ids + if self.check_flag: + if os.environ.get("XTUNER_USE_VLLM", "0") == "1": + if return_token_ids: + self.logger.error( + "VLLM backend does not support return_token_ids or generate with input_ids as input in Xtuner now" + ) + elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": + import lmdeploy + + lmdeploy_version = lmdeploy.__version__ + if return_token_ids and Version(lmdeploy_version) < Version("0.10.2"): + self.logger.error( + f"You should use lmdeploy >= v0.10.2 to support return_token_ids, but current version is {lmdeploy_version}" + ) + self.check_flag = False + + def _set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]): + self.engine_rank_mesh_array = engine_rank_mesh_array + + def _set_engine_bundle_idxs(self, engine_bundle_idxs: list[int]): + """Set the bundle indices for the inference engine. + + This is used by some backends (like LMDeploy with Ray executor) to + know which bundles in the placement group belong to this engine. + + Args: + engine_bundle_idxs (list[int]): A list of bundle indices. + """ + self.engine_bundle_idxs = engine_bundle_idxs + + @abstractmethod + def _get_request_payload(self, rollout_state: RolloutState) -> dict: + """Abstract method to create a generation request. + + Must be implemented by subclasses. + """ + raise NotImplementedError("_create_request must be implemented in subclass") + + @abstractmethod + def _transform_rollout_config_to_server_configs(self): + """Abstract method to transform rollout config to server configs. + + Must be implemented by subclasses. + """ + raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") + + @abstractmethod + def _transform_sample_params(self, sample_params: SampleParams) -> dict: + """Abstract method to transform rollout config to server configs. + + Must be implemented by subclasses. + """ + raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") + + @abstractmethod + def offload(self): + """Abstract method to offload the model and KVcache. + + Must be implemented by subclasses. + """ + raise NotImplementedError("reset_prefix_cache must be implemented in subclass") + + @abstractmethod + def onload_weights(self): + """Abstract method to onload the model weights. + + Must be implemented by subclasses. + """ + pass + + @abstractmethod + def onload_kvcache(self): + """Abstract method to onload the KV cache. + + Must be implemented by subclasses. + """ + pass diff --git a/xtuner/v1/rl/trainer/__init__.py b/xtuner/v1/rl/trainer/__init__.py new file mode 100644 index 0000000000..6d40472947 --- /dev/null +++ b/xtuner/v1/rl/trainer/__init__.py @@ -0,0 +1,28 @@ +from .controller import ColateItem, RawTrainingController, TrainingController, TrainingControllerProxy +from .rollout_is import ( + RolloutImportanceSampling, + compute_is_metrics, + compute_mismatch_metrics, + compute_rollout_importance_weights, + merge_rollout_is_metrics, +) +from .worker import RLOtherLog, TrainingWorker, WorkerConfig, WorkerInputItem, WorkerLogItem, WorkerTrainLogItem + + +__all__ = [ + "ColateItem", + "RawTrainingController", + "TrainingController", + "TrainingControllerProxy", + "RolloutImportanceSampling", + "compute_rollout_importance_weights", + "compute_is_metrics", + "compute_mismatch_metrics", + "merge_rollout_is_metrics", + "WorkerConfig", + "WorkerInputItem", + "RLOtherLog", + "WorkerTrainLogItem", + "WorkerLogItem", + "TrainingWorker", +] diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/trainer/controller.py similarity index 96% rename from xtuner/v1/rl/base/controller.py rename to xtuner/v1/rl/trainer/controller.py index b500b53e46..55643af2a4 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -9,7 +9,7 @@ from xtuner.v1.data_proto.sequence_context import SequenceContext from xtuner.v1.model.compose.base import BaseComposeConfig from xtuner.v1.train.trainer import LoadCheckpointConfig -from xtuner.v1.utils import ray_method +from xtuner.v1.utils import get_logger, ray_method from .worker import TrainingWorker, WorkerLogItem @@ -27,6 +27,7 @@ class ColateItem(TypedDict): class RawTrainingController: def __init__(self, workers: list[TrainingWorker]) -> None: self.workers = workers + self.logger = get_logger() # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack def _get_pack_infos(self, dataset, num_tokens, target, random=None): @@ -115,6 +116,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg): dtype=data_batches[0]["shifted_labels"].dtype, device=data_batches[0]["shifted_labels"].device, ) + pad_advantages = [-100] * pad_len if is_qwen3_vl: _position_ids_list = [] for pad_token in pad_tokens: @@ -128,10 +130,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg): seq_ctx_list.append(pad_seq_ctx) label_list.append(pad_labels) - advantage_list.extend( - [-100] * math.ceil(pad_len / 1024) - ) # can be any number, pad tokens are excluded from the calculation of the loss function. - + advantage_list.append(pad_advantages) if rollout_logprobs_list is not None: pad_rollout_logprobs = torch.zeros( 1, @@ -143,10 +142,8 @@ def _packing(self, data_batches, pack_max_length, language_cfg): seq_ctx = SequenceContext.cat(seq_ctx_list) shifted_labels = torch.cat(label_list, dim=1) # (1, max_len) - advantages = torch.tensor(advantage_list).float().unsqueeze(0) # (1, num_samples) - cu_seq_lens_q = seq_ctx.cu_seq_lens_q - num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] - advantages = torch.repeat_interleave(advantages, num_tokens, dim=1) # (1, max_len) + advantage_flat = [item for sublist in advantage_list for item in sublist] + advantages = torch.tensor(advantage_flat, dtype=torch.float32).unsqueeze(0) rollout_logprobs = None if rollout_logprobs_list is not None: diff --git a/xtuner/v1/rl/base/rollout_is.py b/xtuner/v1/rl/trainer/rollout_is.py similarity index 100% rename from xtuner/v1/rl/base/rollout_is.py rename to xtuner/v1/rl/trainer/rollout_is.py diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/trainer/worker.py similarity index 97% rename from xtuner/v1/rl/base/worker.py rename to xtuner/v1/rl/trainer/worker.py index 7ba83e21be..581c88bcf9 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/trainer/worker.py @@ -4,8 +4,13 @@ import time from itertools import chain from pathlib import Path -from typing import Dict, Iterable, List, Sequence, TypeAlias, TypedDict, cast +from typing import TYPE_CHECKING, Dict, Iterable, List, Sequence, TypeAlias, TypedDict, cast + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +import numpy as np import ray import requests import torch @@ -35,10 +40,9 @@ from xtuner.v1.model.base import ModelItem, TransformerConfig from xtuner.v1.model.compose.base import BaseComposeConfig, BaseComposeModel from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration -from xtuner.v1.ray.base import SingleAcceleratorWorker -from xtuner.v1.ray.config import RolloutConfig -from xtuner.v1.rl.base.loss import BaseRLLossContext -from xtuner.v1.rl.utils import gather_logprobs +from xtuner.v1.rl.loss import BaseRLLossConfig, BaseRLLossContext, kl_penalty +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import SingleAcceleratorWorker, gather_logprobs from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ( XTUNER_DETERMINISTIC, @@ -51,8 +55,6 @@ ) from xtuner.v1.utils.load_spec import LoadEnum -from ..loss_fn import kl_penalty -from .loss import BaseRLLossConfig from .rollout_is import merge_rollout_is_metrics @@ -169,6 +171,26 @@ class WorkerConfig(BaseModel): rollout_steps_per_sft: int = 1 sft_loss_cfg: CELossConfig = CELossConfig() + def build(self, placement_group: "PlacementGroup"): + """Build training workers and controller from this config and placement + group.""" + # import here to avoid circular import + from xtuner.v1.rl.trainer.controller import TrainingController + from xtuner.v1.rl.utils import AutoAcceleratorWorkers + + TrainingWorkerCls = ray.remote( + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + "HCCL_NPU_SOCKET_PORT_RANGE": "auto", + } + } + )(TrainingWorker) + train_workers, _ = AutoAcceleratorWorkers.from_placement_group(TrainingWorkerCls, self, placement_group) + ray.wait([w.ready.remote() for w in train_workers]) + return TrainingController.remote(workers=train_workers) + class WorkerInputItem(TypedDict): seq_ctx: SequenceContext @@ -517,13 +539,16 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo seq_ctx = data["seq_ctx"] pixel_values = seq_ctx.pixel_values if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, np.ndarray): assert isinstance(pixel_values, list), ( f"pixel_values should be list of tensor, got {type(pixel_values)}" ) pixel_values = [ray.get(pixel_obf) for pixel_obf in pixel_values] + pixel_values = [torch.as_tensor(pixel_value) for pixel_value in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) seq_ctx.pixel_values = pixel_values + else: + raise NotImplementedError("The case where pixel_values is a numpy array is not implemented yet.") rollout_routed_experts = seq_ctx.rollout_routed_experts if rollout_routed_experts is not None: diff --git a/xtuner/v1/rl/utils.py b/xtuner/v1/rl/utils.py deleted file mode 100644 index 8da313a360..0000000000 --- a/xtuner/v1/rl/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -import atexit -import signal -import subprocess - -import torch.nn.functional as F - -from xtuner.v1.utils.logger import get_logger - - -def gather_logprobs(logits, shifted_labels): - logprobs = F.log_softmax(logits, dim=-1) - logprobs = logprobs.gather(dim=-1, index=shifted_labels.clip(min=0).unsqueeze(-1)).squeeze(-1) - return logprobs - - -logger = get_logger() - - -def close_ray(): - """Clean up the ray resource.""" - import ray - - # 1. Shutdown ray if initialized - try: - if ray.is_initialized(): - ray.shutdown() - logger.info("Ray shutdown successfully") - except Exception as e: - logger.warning(f"Error during ray.shutdown(): {e}") - - # 2. Stop ray launched by CLI - try: - result = subprocess.run(["ray", "stop", "--force"], capture_output=True, text=True, timeout=10) - if result.returncode != 0: - logger.warning(f"Ray stop failed: {result.stderr}") - except Exception as e: - logger.warning(f"Error stopping ray cluster: {e}") - - -def register_cleanup(): - """Register cleanup handlers for Ray on exit and signals.""" - _cleaned = False - - def cleanup_once(): - nonlocal _cleaned - if not _cleaned: - _cleaned = True - close_ray() - - def signal_handler(signum, frame): - logger.info(f"Received signal {signum}, cleaning up...") - cleanup_once() - import sys - - sys.exit(128 + signum) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - atexit.register(cleanup_once) diff --git a/xtuner/v1/rl/utils/__init__.py b/xtuner/v1/rl/utils/__init__.py new file mode 100644 index 0000000000..d87ef41407 --- /dev/null +++ b/xtuner/v1/rl/utils/__init__.py @@ -0,0 +1,69 @@ +from .async_utils import asyncio_run, create_task, handle_task_exception +from .misc import ( + BetweenNode, + BetweenOperator, + ConditionNode, + LogicNode, + LogicOperator, + Operators, + QueryNode, + ScalarNode, + ScalarOperator, + SetNode, + SetOperator, + gather_logprobs, + get_eos_token, + load_function, + parse_query, +) +from .ray_utils import ( + bind_train_rollout, + close_ray, + find_master_addr_and_port, + get_accelerator_ids, + get_ray_accelerator, + register_cleanup, +) +from .ray_worker import ( + AcceleratorResourcesConfig, + AutoAcceleratorWorkers, + AutoCPUWorkers, + BaseCPUWorker, + CPUActorLauncher, + CPUResourcesConfig, + SingleAcceleratorWorker, +) + + +__all__ = [ + "AcceleratorResourcesConfig", + "SingleAcceleratorWorker", + "AutoAcceleratorWorkers", + "CPUResourcesConfig", + "CPUActorLauncher", + "BaseCPUWorker", + "AutoCPUWorkers", + "get_ray_accelerator", + "load_function", + "find_master_addr_and_port", + "get_accelerator_ids", + "bind_train_rollout", + "handle_task_exception", + "create_task", + "QueryNode", + "ConditionNode", + "ScalarNode", + "SetNode", + "BetweenNode", + "LogicNode", + "parse_query", + "gather_logprobs", + "close_ray", + "register_cleanup", + "ScalarOperator", + "SetOperator", + "BetweenOperator", + "LogicOperator", + "Operators", + "get_eos_token", +] diff --git a/xtuner/v1/rl/utils/async_utils.py b/xtuner/v1/rl/utils/async_utils.py new file mode 100644 index 0000000000..ca6b25d4ee --- /dev/null +++ b/xtuner/v1/rl/utils/async_utils.py @@ -0,0 +1,109 @@ +import asyncio +from asyncio import AbstractEventLoop, Task +from typing import Any, Callable, Coroutine, List, Optional + + +_ASYNCIO_RUN_LOOP: AbstractEventLoop | None = None + + +def handle_task_exception(task: Task): + """Handles exceptions from an asyncio Task. + + This function checks if a task has raised an exception and, if so, + re-raises it. It ignores `asyncio.CancelledError`. + + Args: + task (Task): The asyncio task to check for exceptions. + + Raises: + Exception: The exception raised by the task. + """ + try: + exc = task.exception() + if exc is not None: + raise exc + except asyncio.CancelledError: + pass # Task was cancelled, ignore + + +def create_task( + coro: Coroutine, + loop: Optional[AbstractEventLoop] = None, + done_callbacks: Optional[List[Callable[[Task], object]]] = None, +) -> asyncio.tasks.Task: + """Creates and configures an asyncio Task. + + This function creates a task from a coroutine and attaches specified + done callbacks. By default, it includes a callback to handle exceptions. + + Args: + coro (Coroutine): The coroutine to wrap in a task. + loop (Optional[AbstractEventLoop], optional): The event loop to run + the task in. If None, the current event loop is used. + Defaults to None. + done_callbacks (Optional[List[Callable[[Task], object]]], optional): + A list of callbacks to add to the task. If None, a default + exception handler is used. Defaults to None. + + Returns: + asyncio.tasks.Task: The created asyncio task. + """ + if loop is None: + loop = asyncio.get_event_loop() + if done_callbacks is None: + done_callbacks = [handle_task_exception] + task = loop.create_task(coro) + for callback in done_callbacks: + task.add_done_callback(callback) + return task + + +def _get_default_asyncio_loop() -> AbstractEventLoop: + """Get a module-level event loop reused by ``asyncio_run``.""" + global _ASYNCIO_RUN_LOOP + if _ASYNCIO_RUN_LOOP is not None and not _ASYNCIO_RUN_LOOP.is_closed(): + return _ASYNCIO_RUN_LOOP + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + _ASYNCIO_RUN_LOOP = loop + return _ASYNCIO_RUN_LOOP + + +def asyncio_run(coro: Coroutine, loop: Optional[AbstractEventLoop] = None) -> Any: + """Synchronously run a coroutine on a shared/explicit event loop. + + This helper is used by `RLColocateTrainer.fit` for rollout collection: + 1) Trainer runs in sync code and repeatedly calls: + - self.eval_agent_loop_manager.produce_batch(...) + - self.agent_loop_manager.produce_batch(...) + 2) `produce_batch` is async, and internally runs `ProduceStrategy.produce_batch`, + which launches many nested async tasks (`create_task`) and ultimately calls + `AgentLoop.generate_group -> generate_sample`. + 3) In `VerlToolAgentLoop`, `generate_sample` awaits `self.verl_tool_agent_loop.run()`, + where the tool loop stays on the same loop. + + In this pattern, if sync code uses `asyncio.run` every call, each invocation + creates/closes a fresh loop, but `VerlToolAgentLoop` keeps internal work + on one loop, the wrapped `generate_sample -> run -> Ray futures` chain can see + mismatched loop ownership and trigger: + ``Future attached to a different loop``. + + `asyncio_run` keeps calls bound to a stable loop instance so nested task/future + chains stay compatible across repeated rollout phases. + + This helper is for sync-to-async boundaries only and should not be used from + within an already running event loop. + """ + if loop is None: + loop = _get_default_asyncio_loop() + if loop.is_running(): + raise RuntimeError("asyncio_run does not support being called from a running event loop.") + return loop.run_until_complete(coro) diff --git a/xtuner/v1/rl/utils/misc.py b/xtuner/v1/rl/utils/misc.py new file mode 100644 index 0000000000..7868fc913f --- /dev/null +++ b/xtuner/v1/rl/utils/misc.py @@ -0,0 +1,148 @@ +import importlib +import json +import socket +import typing +from abc import ABC +from pathlib import Path +from typing import Any, List, Literal, Union + +import torch.nn.functional as F + +from xtuner.v1.utils.logger import get_logger + + +logger = get_logger() +ScalarOperator = Literal["$eq", "$ne", "$gt", "$gte", "$lt", "$lte"] +SetOperator = Literal["$in", "$not_in"] +BetweenOperator = Literal["$between"] +Operators = Union[ScalarOperator, SetOperator, BetweenOperator] +LogicOperator = Literal["$and", "$or"] + + +class QueryNode(ABC): + """查询语法树的基类,仅作数据结构标记.""" + + pass + + +class ConditionNode(QueryNode): + """代表一个具体的查询条件.""" + + field: str + + +class ScalarNode(ConditionNode): + def __init__(self, field: str, op: ScalarOperator, value: Any): + self.field = field + self.op = op + self.value = value + + +class SetNode(ConditionNode): + def __init__(self, field: str, op: SetOperator, value: list[Any] | tuple[Any]): + self.field = field + self.op = op + self.value = value + + +class BetweenNode(ConditionNode): + def __init__(self, field: str, lower: Any, upper: Any): + if lower > upper: + raise ValueError("lower bound must be less than or equal to upper bound") + self.field = field + self.op = "$between" + self.lower = lower + self.upper = upper + + +class LogicNode(QueryNode): + """复合逻辑组.""" + + def __init__(self, relation: LogicOperator, conditions: List[QueryNode]): + self.relation = relation + self.conditions = conditions + + +def parse_query(expr: Union[dict, QueryNode]) -> QueryNode: + """将基于字典的 DSL 解析为纯粹的 AST 节点树 (ConditionNode, LogicNode)""" + if isinstance(expr, QueryNode): + return expr + + if isinstance(expr, dict): + conditions: list[QueryNode] = [] + for key, value in expr.items(): + if key in ("$and", "$or"): + if isinstance(value, list): + sub_asts = [parse_query(sub_expr) for sub_expr in value] + conditions.append(LogicNode(key, sub_asts)) # type: ignore + else: + raise ValueError(f"逻辑操作符 {key} 的值必须是一个列表") + else: + if isinstance(value, dict): + # 例如: {"staleness": {"$lt": 5, "$gt": 0}} + for op, op_val in value.items(): + if op in typing.get_args(ScalarOperator): + conditions.append(ScalarNode(field=key, op=op, value=op_val)) + elif op in typing.get_args(SetOperator): + if not isinstance(op_val, (list, tuple)): + raise ValueError(f"操作符 '{op}' 需要传入一个列表或元组") + conditions.append(SetNode(field=key, op=op, value=op_val)) + elif op == "$between": + if not isinstance(op_val, (list, tuple)) or len(op_val) != 2: + raise ValueError("操作符 '$between' 需要传入包含2个元素的列表或元组") + conditions.append(BetweenNode(field=key, lower=op_val[0], upper=op_val[1])) + else: + raise ValueError(f"不支持的操作符: {op}") + else: + # 隐式等值,例如: {"task_name": "math"} -> "$eq" + conditions.append(ScalarNode(field=key, op="$eq", value=value)) + + if len(conditions) > 1: + # 默认多个条件之间是 AND 关系,例如: {"uid": "123", "status": {"$in": ["pending", "running]}}} + return LogicNode("$and", conditions) # type: ignore + return conditions[0] if conditions else LogicNode("$and", []) + + raise ValueError(f"不支持的查询表达式格式: {expr}") + + +def gather_logprobs(logits, shifted_labels): + logprobs = F.log_softmax(logits, dim=-1) + logprobs = logprobs.gather(dim=-1, index=shifted_labels.clip(min=0).unsqueeze(-1)).squeeze(-1) + return logprobs + + +def load_function(path): + """Load a function from a module. + + :param path: The path to the function, e.g. "module.submodule.function". + :return: The function object. + """ + module_path, _, attr = path.rpartition(".") + module = importlib.import_module(module_path) + return getattr(module, attr) + + +def _is_port_available(check_socket: socket.socket, port: int) -> bool: + try: + check_socket.bind(("", port)) + check_socket.listen(1) + return True + except OSError: + return False + + +def get_eos_token(model_path: str) -> int | List[int]: + generation_config_path = Path(model_path) / "generation_config.json" + if not generation_config_path.exists(): + logger.warning( + f"Config {generation_config_path} does not exist and thus cannot get eos_token. You must provide eos_token manually." + ) + return [] + with open(generation_config_path) as f: + generation_config = json.load(f) + eos_token_id = generation_config.get("eos_token_id") + if eos_token_id is None: + raise ValueError( + f"eos_token_id is not found in {generation_config_path}. You must provide eos_token manually." + ) + return eos_token_id diff --git a/xtuner/v1/ray/utils.py b/xtuner/v1/rl/utils/ray_utils.py similarity index 64% rename from xtuner/v1/ray/utils.py rename to xtuner/v1/rl/utils/ray_utils.py index 4eea3b2280..14d94323d7 100644 --- a/xtuner/v1/ray/utils.py +++ b/xtuner/v1/rl/utils/ray_utils.py @@ -1,72 +1,21 @@ -import asyncio -import importlib +import atexit +import signal import socket -from asyncio import AbstractEventLoop, Task -from typing import TYPE_CHECKING, Callable, Coroutine, List, Optional, cast +import subprocess +from typing import TYPE_CHECKING, cast import ray +from xtuner.v1.utils.logger import get_logger -if TYPE_CHECKING: - import ray.actor - - from xtuner.v1.ray.base.accelerator import AcceleratorType - - -def get_ray_accelerator() -> "AcceleratorType": - from xtuner.v1.utils.device import get_device - - """Get the type of accelerator available in the Ray environment. - - This function checks for the availability of CUDA and NPU devices and - returns the corresponding accelerator type. - - Returns: - AcceleratorType: The type of accelerator ("GPU" or "NPU"). - - Raises: - NotImplementedError: If neither CUDA nor NPU is available. - """ - accelerator = None - if get_device() == "cuda": - accelerator = "GPU" - return "GPU" - else: - try: - import torch_npu # noqa: F401 - - accelerator = "NPU" - except ImportError: - pass - - if accelerator is None: - raise NotImplementedError( - "Supports only CUDA or NPU. If your device is CUDA or NPU, " - "please make sure that your environmental settings are " - "configured correctly." - ) - - return cast("AcceleratorType", accelerator) +from .misc import _is_port_available -def load_function(path): - """Load a function from a module. +if TYPE_CHECKING: + from .ray_worker import AcceleratorType - :param path: The path to the function, e.g. "module.submodule.function". - :return: The function object. - """ - module_path, _, attr = path.rpartition(".") - module = importlib.import_module(module_path) - return getattr(module, attr) - -def _is_port_available(check_socket: socket.socket, port: int) -> bool: - try: - check_socket.bind(("", port)) - check_socket.listen(1) - return True - except OSError: - return False +logger = get_logger() @ray.remote @@ -139,6 +88,84 @@ def get_accelerator_ids(accelerator: str) -> list: return ray.get_runtime_context().get_accelerator_ids()[accelerator] +def get_ray_accelerator() -> "AcceleratorType": + from xtuner.v1.utils.device import get_device + + """Get the type of accelerator available in the Ray environment. + + This function checks for the availability of CUDA and NPU devices and + returns the corresponding accelerator type. + + Returns: + AcceleratorType: The type of accelerator ("GPU" or "NPU"). + + Raises: + NotImplementedError: If neither CUDA nor NPU is available. + """ + accelerator = None + if get_device() == "cuda": + accelerator = "GPU" + return "GPU" + else: + try: + import torch_npu # noqa: F401 + + accelerator = "NPU" + except ImportError: + pass + + if accelerator is None: + raise NotImplementedError( + "Supports only CUDA or NPU. If your device is CUDA or NPU, " + "please make sure that your environmental settings are " + "configured correctly." + ) + + return cast("AcceleratorType", accelerator) + + +def close_ray(): + """Clean up the ray resource.""" + # 1. Shutdown ray if initialized + try: + if ray.is_initialized(): + ray.shutdown() + logger.info("Ray shutdown successfully") + except Exception as e: + logger.warning(f"Error during ray.shutdown(): {e}") + + # 2. Stop ray launched by CLI + try: + result = subprocess.run(["ray", "stop", "--force"], capture_output=True, text=True, timeout=10) + if result.returncode != 0: + logger.warning(f"Ray stop failed: {result.stderr}") + except Exception as e: + logger.warning(f"Error stopping ray cluster: {e}") + + +def register_cleanup(): + """Register cleanup handlers for Ray on exit and signals.""" + _cleaned = False + + def cleanup_once(): + nonlocal _cleaned + if not _cleaned: + _cleaned = True + close_ray() + + def signal_handler(signum, frame): + logger.info(f"Received signal {signum}, cleaning up...") + cleanup_once() + import sys + + sys.exit(128 + signum) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + atexit.register(cleanup_once) + + def bind_train_rollout( train_workers, rollout_controller, @@ -153,58 +180,6 @@ def bind_train_rollout( train_workers: A list of training worker actors. rollout_controller: The rollout controller actor. """ - info_dict = ray.get(rollout_controller.get_rollout_info.remote()) # type: ignore[attr-defined] + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) # type: ignore[attr-defined] ray.get([worker.update_rollout_info.remote(**info_dict) for worker in train_workers]) # type: ignore[attr-defined] return - - -def handle_task_exception(task: Task): - """Handles exceptions from an asyncio Task. - - This function checks if a task has raised an exception and, if so, - re-raises it. It ignores `asyncio.CancelledError`. - - Args: - task (Task): The asyncio task to check for exceptions. - - Raises: - Exception: The exception raised by the task. - """ - try: - exc = task.exception() - if exc is not None: - raise exc - except asyncio.CancelledError: - pass # Task was cancelled, ignore - - -def create_task( - coro: Coroutine, - loop: Optional[AbstractEventLoop] = None, - done_callbacks: Optional[List[Callable[[Task], object]]] = None, -) -> asyncio.tasks.Task: - """Creates and configures an asyncio Task. - - This function creates a task from a coroutine and attaches specified - done callbacks. By default, it includes a callback to handle exceptions. - - Args: - coro (Coroutine): The coroutine to wrap in a task. - loop (Optional[AbstractEventLoop], optional): The event loop to run - the task in. If None, the current event loop is used. - Defaults to None. - done_callbacks (Optional[List[Callable[[Task], object]]], optional): - A list of callbacks to add to the task. If None, a default - exception handler is used. Defaults to None. - - Returns: - asyncio.tasks.Task: The created asyncio task. - """ - if loop is None: - loop = asyncio.get_event_loop() - if done_callbacks is None: - done_callbacks = [handle_task_exception] - task = loop.create_task(coro) - for callback in done_callbacks: - task.add_done_callback(callback) - return task diff --git a/xtuner/v1/ray/base/accelerator.py b/xtuner/v1/rl/utils/ray_worker.py similarity index 53% rename from xtuner/v1/ray/base/accelerator.py rename to xtuner/v1/rl/utils/ray_worker.py index df41feebf0..859d6cc649 100644 --- a/xtuner/v1/ray/base/accelerator.py +++ b/xtuner/v1/rl/utils/ray_worker.py @@ -1,4 +1,5 @@ import os +import threading from typing import Any, Dict, List, Literal, Tuple, TypeVar import ray @@ -13,9 +14,10 @@ placement_group, placement_group_table, ) +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from typing_extensions import Annotated -from ..utils import find_master_addr_and_port, get_accelerator_ids +from .ray_utils import find_master_addr_and_port, get_accelerator_ids PG_READY_TIMEOUT = os.getenv("XTUNER_PG_READY_TIMEOUT", 30) # default 30 seconds @@ -23,6 +25,84 @@ T = TypeVar("T") +class CPUResourcesConfig(BaseModel): + """Configuration for CPU resources in a placement group for XTuner. + + This class provide specific configuration options for CPU-based workers in Ray placement groups. + + Args: + num_cpus_per_worker (float): Number of CPUs to allocate per worker in the + placement group. Defaults to 8. + cpu_memory_per_worker (int): Amount of CPU memory (in bytes) to allocate + for each worker in the placement group. + num_workers (int): Total number of workers in the placement group. + """ + + model_config = ConfigDict(extra="forbid") + num_workers: Annotated[int, Parameter(help="Number of workers in the placement group.")] = 1 + num_cpus_per_worker: Annotated[float, Parameter(help="Number of CPUs to allocate for the placement group.")] = 1 + cpu_memory_per_worker: Annotated[ + int, Parameter(help="Amount of memory (in bytes) to allocate for the placement group.") + ] = 1024**3 # 1 GB + pg_pack_strategy: Annotated[ + str, + Parameter(help="Placement group packing strategy, options: " + ", ".join(VALID_PLACEMENT_GROUP_STRATEGIES)), + ] = "SPREAD" + + @field_validator("pg_pack_strategy") + @classmethod + def check_pg_pack_strategy(cls, v): + if v not in VALID_PLACEMENT_GROUP_STRATEGIES: + raise ValueError(f"pg_pack_strategy must be one of {VALID_PLACEMENT_GROUP_STRATEGIES}") + return v + + def model_post_init(self, __context: Any) -> None: + assert ray.is_initialized(), "Ray must be initialized before creating CPUResourcesConfig." + available_resources = ray.available_resources() + available_cpus = available_resources.get("CPU", 0) + available_memory = available_resources.get("memory", 0) + # TODO: manage single controller's cpu resource to replace "10" here + needed_cpus = (self.num_cpus_per_worker * self.num_workers) + 10 + assert needed_cpus <= available_cpus, ( + f"Not enough available CPUs in Ray cluster, available_cpus is {available_cpus} but xtuner needs {needed_cpus}." + ) + needed_memory = self.cpu_memory_per_worker * self.num_workers + 10 * 1024**3 + assert needed_memory <= available_memory, ( + f"Not enough available memory in Ray cluster, available_memory is {available_memory} but xtuner needs {needed_memory}." + ) + # TODO: check all resources sum in cluster to avoid over allocation + + @classmethod + def from_total( + cls, total_cpus: float | int, total_memory: int, num_workers: int, pg_pack_strategy: str = "SPREAD" + ): + """Create a CPUResourcesConfig from total CPU and memory resources. + + Args: + total_cpus (float | int): Total number of CPUs to allocate across all workers. + total_memory (int): Total amount of memory (in bytes) to allocate across all workers. + num_workers (int): Number of workers in the placement group. + + Returns: + CPUResourcesConfig: The created CPUResourcesConfig object. + """ + assert num_workers > 0, "Number of workers must be positive." + return cls( + num_workers=num_workers, + num_cpus_per_worker=total_cpus / num_workers, + cpu_memory_per_worker=total_memory / num_workers, + pg_pack_strategy=pg_pack_strategy, + ) + + def build_placement_group(self) -> PlacementGroup: + """Build a Ray PlacementGroup based on this resource configuration. + + Returns: + PlacementGroup: The created Ray PlacementGroup. + """ + return CPUActorLauncher.build_placement_group(self) + + class AcceleratorResourcesConfig(BaseModel): """Configuration for accelerator resources in a placement group for XTuner. @@ -196,6 +276,25 @@ def device_visible_env_name(self): else: raise ValueError(f"Unsupported accelerator type: {self.accelerator}") + def get_logical_local_rank(self) -> int: + """Resolve the assigned accelerator id to the logical local rank. + + Ray reports accelerator ids in the physical numbering space. Torch selects devices from the current visible- + device list, which is indexed logically from zero after applying visibility masks. + """ + accelerator_id = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) + visible_devices = os.environ.get(self.device_visible_env_name) + if visible_devices is None: + return int(accelerator_id) + + visible_device_ids = [device_id.strip() for device_id in visible_devices.split(",") if device_id.strip()] + if accelerator_id not in visible_device_ids: + raise ValueError( + f"Assigned accelerator id {accelerator_id} is not present in " + f"{self.device_visible_env_name}={visible_devices}." + ) + return visible_device_ids.index(accelerator_id) + def setup_distributed(self, rank: int, master_addr: str, master_port: int, world_size: int): """Set up the distributed environment for the worker. @@ -215,7 +314,7 @@ def setup_distributed(self, rank: int, master_addr: str, master_port: int, world os.environ["MASTER_PORT"] = str(master_port) os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) - os.environ["LOCAL_RANK"] = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) + os.environ["LOCAL_RANK"] = str(self.get_logical_local_rank()) # backend 参数是指定通信后端,不是从环境变量获取 # - 'nccl': NVIDIA GPU 间通信(推荐用于 GPU) @@ -444,3 +543,291 @@ def from_placement_group( rank_bundle_idx_list.append((rank, bundle_idx)) return workers_list, rank_bundle_idx_list + + +class BaseCPUWorker: + """The BaseCPUWorker class serves as a foundational structure for CPU-based + workers within the XTuner framework. + + This class is designed to be extended by specific CPU worker implementations. + It provides a constructor that accepts a configuration object, allowing + subclasses to initialize with custom settings. + + Args: + config: The configuration object for the CPU worker. + num_cpus (float | int): The number of CPUs allocated to this worker. + Defaults to 1. + """ + + def __init__(self, config, num_cpus: float | int = 1): + self.config = config + self.num_cpus = num_cpus + + +class CPUActorLauncher: + """Infrastructure for launching CPU Ray actors from plain Python classes. + + This class owns the generic actorization flow for CPU-only components: + building homogeneous CPU placement groups, converting plain classes into + Ray actor classes, validating bundle resources, and launching one or more + actors on specific bundles. + """ + + _ACTOR_CLASS_CACHE: dict[type, ActorClass] = {} + + @staticmethod + def build_placement_group(resources_config: CPUResourcesConfig): + """Build a Ray PlacementGroup based on the provided resource + configuration. + + Args: + resources_config (CPUResourcesConfig): The configuration + specifying the resources for each worker bundle. + + Returns: + PlacementGroup: The created Ray PlacementGroup. + """ + bundles = [ + { + "CPU": resources_config.num_cpus_per_worker, + "memory": resources_config.cpu_memory_per_worker, + } + ] * resources_config.num_workers + + pg = placement_group(bundles=bundles, strategy=resources_config.pg_pack_strategy) + + ray.get(pg.ready(), timeout=PG_READY_TIMEOUT) + return pg + + @staticmethod + def get_pg_options(pg: PlacementGroup, num_cpus: int | float = -1) -> Dict: + """Provide a dictionary of resource requests for Ray tasks or actors + with specific cpu requirements. + + Args: + pg (PlacementGroup): The placement group to get options for. + num_cpus (float): The number of CPUs to request. If set to -1, + the default CPU allocation from the placement group bundle + will be used. Defaults to -1. + + Returns: + Dict: A dictionary of Ray resource options for `task.options()`. + """ + assert len(pg.bundle_specs) > 0, "Placement group has no bundles defined." + default_cpu = pg.bundle_specs[0].get("CPU", 1) + return {"num_cpus": num_cpus if num_cpus >= 0 else default_cpu} + + @classmethod + def to_actor_class(cls, worker_cls): + """Convert a plain Python class into a Ray actor class. + + If ``worker_cls`` is already a Ray actor class, it is returned as-is. + """ + if hasattr(worker_cls, "remote") and hasattr(worker_cls, "options"): + return worker_cls + + if worker_cls not in cls._ACTOR_CLASS_CACHE: + cls._ACTOR_CLASS_CACHE[worker_cls] = ray.remote(worker_cls) + return cls._ACTOR_CLASS_CACHE[worker_cls] + + @staticmethod + def _get_bundle_resources(pg: PlacementGroup, bundle_idx: int) -> dict[str, float | int]: + assert len(pg.bundle_specs) > bundle_idx, f"Placement group does not have bundle index {bundle_idx}." + return pg.bundle_specs[bundle_idx] + + @classmethod + def _resolve_actor_resources( + cls, + pg: PlacementGroup, + bundle_idx: int, + actor_num_cpus: int | float | None = None, + actor_memory: int | None = None, + ) -> tuple[float | int, int]: + bundle = cls._get_bundle_resources(pg, bundle_idx) + resolved_num_cpus = actor_num_cpus if actor_num_cpus is not None else bundle.get("CPU", 1) + resolved_memory = actor_memory if actor_memory is not None else int(bundle.get("memory", 0)) + assert bundle.get("CPU", 1) >= resolved_num_cpus, ( + f"Placement group bundle {bundle_idx} does not have enough CPU resources." + ) + assert bundle.get("memory", 0) >= resolved_memory, ( + f"Placement group bundle {bundle_idx} does not have enough memory resources." + ) + return resolved_num_cpus, resolved_memory + + @classmethod + def build_actor( + cls, + worker_cls, + *init_args, + pg: PlacementGroup | None = None, + bundle_idx: int = 0, + actor_num_cpus: int | float | None = None, + actor_memory: int | None = None, + capture_child_tasks: bool = False, + **init_kwargs, + ): + """Build a single CPU actor from a plain class or Ray actor class.""" + resolved_num_cpus = 1 if actor_num_cpus is None else actor_num_cpus + resolved_memory = actor_memory + + actor_cls = cls.to_actor_class(worker_cls) + actor_options = { + "num_cpus": resolved_num_cpus, + } + if resolved_memory is not None and resolved_memory > 0: + actor_options["memory"] = resolved_memory + + if pg is None: + return actor_cls.options(**actor_options).remote(*init_args, **init_kwargs) + + resolved_num_cpus, resolved_memory = cls._resolve_actor_resources( + pg=pg, + bundle_idx=bundle_idx, + actor_num_cpus=actor_num_cpus, + actor_memory=actor_memory, + ) + actor_options["num_cpus"] = resolved_num_cpus + if resolved_memory > 0: + actor_options["memory"] = resolved_memory + actor_options["scheduling_strategy"] = PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundle_idx, + placement_group_capture_child_tasks=capture_child_tasks, + ) + return actor_cls.options(**actor_options).remote(*init_args, **init_kwargs) + + @classmethod + def build_actors( + cls, + worker_cls, + *init_args, + pg: PlacementGroup | None = None, + start_bundle_idx: int = 0, + num_workers: int = 1, + actor_num_cpus_per_worker: int | float | None = None, + actor_memory_per_worker: int | None = None, + capture_child_tasks: bool = False, + **init_kwargs, + ): + """Build multiple homogeneous CPU actors from a plain class or Ray + actor class.""" + workers_list = [] + for idx in range(num_workers): + workers_list.append( + cls.build_actor( + worker_cls, + *init_args, + pg=pg, + bundle_idx=start_bundle_idx + idx, + actor_num_cpus=actor_num_cpus_per_worker, + actor_memory=actor_memory_per_worker, + capture_child_tasks=capture_child_tasks, + **init_kwargs, + ) + ) + return workers_list + + +class AutoCPUWorkers(CPUActorLauncher): + """Convenience wrapper for BaseCPUWorker-style homogeneous worker pools. + + `CPUActorLauncher` is the generic actorization layer. `AutoCPUWorkers` + keeps the legacy worker-centric API that instantiates one worker per bundle + using the conventional `(worker_config, num_cpus=...)` constructor shape. + """ + + _PG_NEXT_BUNDLE_INDEX: dict[str, int] = {} + _PG_NEXT_BUNDLE_INDEX_LOCK = threading.Lock() + + @staticmethod + def _get_pg_key(pg: PlacementGroup) -> str: + """Build a stable placement-group identifier for local bundle + tracking.""" + return str(pg.id) + + @classmethod + def _reserve_bundle_range( + cls, + pg: PlacementGroup, + num_workers: int, + start_bundle_idx: int | None, + ) -> tuple[int, int]: + """Reserve a contiguous bundle range for worker creation. + + When ``start_bundle_idx`` is omitted, the next unconsumed bundle range + in this process is used. Explicit bundle reservations still advance the + local cursor so later auto-allocation does not reuse the same bundles. + """ + pg_key = cls._get_pg_key(pg) + + with cls._PG_NEXT_BUNDLE_INDEX_LOCK: + current_cursor = cls._PG_NEXT_BUNDLE_INDEX.get(pg_key, 0) + resolved_start_bundle_idx = current_cursor if start_bundle_idx is None else start_bundle_idx + resolved_num_workers = num_workers if num_workers > 0 else pg.bundle_count - resolved_start_bundle_idx + + assert resolved_num_workers > 0, "At least one worker must be created from the placement group." + assert resolved_start_bundle_idx >= 0, "start_bundle_idx must be non-negative." + assert resolved_start_bundle_idx + resolved_num_workers <= pg.bundle_count, ( + "Placement group does not have enough remaining bundles for the requested CPU workers." + ) + + cls._PG_NEXT_BUNDLE_INDEX[pg_key] = max(current_cursor, resolved_start_bundle_idx + resolved_num_workers) + + return resolved_start_bundle_idx, resolved_num_workers + + @classmethod + def from_config(cls, worker_cls, worker_config, cpu_config: CPUResourcesConfig): + """Create workers and a placement group from configuration objects. + + Args: + worker_cls: The class of the worker to instantiate. + worker_config: The configuration for each worker instance. + cpu_config (CPUResourcesConfig): The configuration + for the cpu resources. + + Returns: + List[T]: List of created worker instances. + """ + pg = cls.build_placement_group(cpu_config) + workers_list = cls.from_placement_group(worker_cls, worker_config, pg) + + return workers_list, pg + + @classmethod + def from_placement_group( + cls, + worker_cls, + worker_config, + pg: PlacementGroup, + num_workers: int = -1, + start_bundle_idx: int | None = None, + ): + """Create workers from an existing placement group. + + Args: + worker_cls: The class of the worker to instantiate. + worker_config: The configuration for each worker instance. + pg (PlacementGroup): The existing placement group to use. + num_workers (int): The number of workers to create. Defaults to -1, + the remaining bundles in the placement group will be used. + start_bundle_idx (int | None): Bundle index to start from. If + omitted, the next unconsumed local bundle range for this + placement group will be used. + + Returns: + List[T]: List of created worker instances. + """ + start_bundle_idx, num_workers = cls._reserve_bundle_range( + pg=pg, num_workers=num_workers, start_bundle_idx=start_bundle_idx + ) + default_cpu = cls._get_bundle_resources(pg, start_bundle_idx).get("CPU", 1) + return cls.build_actors( + worker_cls, + worker_config, + num_cpus=default_cpu, + pg=pg, + start_bundle_idx=start_bundle_idx, + num_workers=num_workers, + actor_num_cpus_per_worker=default_cpu, + actor_memory_per_worker=None, + ) diff --git a/xtuner/v1/train/cli/rl.py b/xtuner/v1/train/cli/rl.py index 0a91ee1edb..cb5efd5a83 100644 --- a/xtuner/v1/train/cli/rl.py +++ b/xtuner/v1/train/cli/rl.py @@ -10,7 +10,6 @@ from cyclopts.group import Group from xtuner.v1.rl.utils import register_cleanup -from xtuner.v1.train.rl_trainer import RLTrainer from xtuner.v1.utils import Config from xtuner.v1.utils.track_rl_mem import monitor_actor_memory @@ -56,7 +55,8 @@ def main( track_thread.start() trainer_cfg = Config.fromfile(config)["trainer"] - trainer = RLTrainer.from_config(trainer_cfg) + # trainer = RLTrainer.from_config(trainer_cfg) + trainer = trainer_cfg.build() trainer.fit() if dist.is_initialized(): diff --git a/xtuner/v1/train/rl_colocate_trainer.py b/xtuner/v1/train/rl_colocate_trainer.py new file mode 100644 index 0000000000..a01bfced7c --- /dev/null +++ b/xtuner/v1/train/rl_colocate_trainer.py @@ -0,0 +1,880 @@ +import json +import os +import random +from pathlib import Path +from shutil import rmtree +from typing import Any, List, Union, cast + +import ray +import torch +from mmengine.dist import get_rank +from mmengine.runner import set_random_seed +from pydantic import BaseModel, ConfigDict +from typing_extensions import Literal, TypedDict + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1._writer import get_writer +from xtuner.v1.data_proto import RolloutState, Status +from xtuner.v1.data_proto.sequence_context import SequenceContext +from xtuner.v1.patch import patch_default_save_plan +from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, ProduceBatchResult +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.gateway.config import GatewayConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig, SyncReplayBufferConfig +from xtuner.v1.rl.rollout.controller import RolloutControllerProxy +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer.controller import TrainingControllerProxy +from xtuner.v1.rl.trainer.worker import WorkerConfig, WorkerLogItem +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers, asyncio_run +from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta +from xtuner.v1.utils import get_logger, is_hf_model_path, set_deterministic, timer +from xtuner.v1.utils.device import get_device, get_torch_device_module + + +# TODO: Move DEVICE to `xtuner.utils.device` +PG_READY_TIMEOUT = 30 +TRAINER_RAY_GET_TIMEOUT = 5 * 3600 # 5 hour +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() + + +def check_fa3(): + if os.environ.get("XTUNER_USE_FA3", "0") != "1": + return + + try: + from xtuner.v1.ops.flash_attn import get_flash_attn_varlen + + get_flash_attn_varlen() + except RuntimeError as e: + raise RuntimeError(f"Flash attention v3 runtime error {e}, Please install it first or set XTUNER_USE_FA3=0.") + + +def force_set_tokenize_workers(logger): + # To avoid segmentation faults when setting num_workers for the dataloader + # The root cause is the incompatibility between fork start method and ray's grpc. + # The most fundamental solution is that all processes started in ray should + # use spawn start method. + tokenize_workers = os.environ.get("XTUNER_TOKENIZE_WORKERS", None) + os.environ["XTUNER_TOKENIZE_WORKERS"] = "1" + if tokenize_workers is not None and int(tokenize_workers) > 1: + logger.warning( + f"XTUNER_TOKENIZE_WORKERS is set to {tokenize_workers}, which may cause segmentation faults. Force set XTUNER_TOKENIZE_WORKERS to 1 to avoid this." + ) + else: + logger.info( + f"Set XTUNER_TOKENIZE_WORKERS to {os.environ['XTUNER_TOKENIZE_WORKERS']} for safe tokenization in dataloader workers." + ) + + +def bind_train_rollout( + train_controller: TrainingControllerProxy, + rollout_controller: RolloutControllerProxy, +) -> None: + """Bind the training and rollout workers for update weights.""" + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) # type: ignore[attr-defined] + ray.get(train_controller.update_rollout_info.remote(info_dict)) + return + + +class TrainInfo(TypedDict, total=False): + data_info: dict[str, float] + workers_log_item: list[WorkerLogItem] + + +def get_train_seq_ctx( + input_ids: torch.LongTensor, + position_ids: torch.Tensor | None = None, + multimodal_train_info: dict | None = None, + len_response_ids: int = 0, +): + seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu") + if position_ids is not None and len(position_ids.shape) == 3: + # qwen3vl 需要特殊处理,其余的不需要额外处理 + max_value = position_ids.max(dim=-1).values # (3,1) + response_position_ids = max_value.unsqueeze(-1).expand(-1, -1, len_response_ids) + torch.arange( + 1, len_response_ids + 1, device=max_value.device + ) + position_ids = torch.cat([position_ids, response_position_ids], dim=-1) + seq_ctx.position_ids = position_ids # type: ignore[assignment] + assert position_ids.size(-1) == input_ids.size(-1) + + if multimodal_train_info: + seq_ctx.pixel_values = multimodal_train_info.get("pixel_values") + seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw") + return seq_ctx + + +def is_valid_for_training(group_data_items: list[RolloutState], logger) -> bool: + """Checks if a group of rollout states is valid for a training step. + + Args: + group_data_items: A list of RolloutState objects. + + Returns: + True if the group is valid, False otherwise. + + NOTE: Why this check is needed: + - For system fault tolerance, this check is performed at rollout / dataflow + time, but we still do it here to ensure training data integrity. + - 'filtered'/'failed': These items are fundamentally broken or incomplete and + should not be used for training. + - 'aborted': These items represent rollouts that were stopped + prematurely. Using such partial data could lead the model to learn + undesirable behaviors (e.g., stopping generation too early). + - Empty response/response_ids: The model's generated response is the core + of the training data for RL algorithms like PPO. If the response is + missing, there is nothing to compute rewards on or to train the model with. + """ + is_abort = any(item.status == Status.ABORTED for item in group_data_items) + is_filtered = any(item.status == Status.FILTERED for item in group_data_items) + is_failed = any(item.status == Status.FAILED for item in group_data_items) + if is_filtered or is_failed or is_abort: + logger.warning( + f"Invalid dataflow group found during training, rollout state filtered: {is_filtered}, failed: {is_failed}, aborted: {is_abort}." + ) + return False + for item in group_data_items: + response_valid = item.response is not None and len(item.response) > 0 + ids_valid = item.response_ids is not None and len(item.response_ids) > 0 + if not ids_valid: + # NOTE: `response_ids` is the critical field for token-in-token-out mode, so we ensure it's not empty. + logger.warning( + "Invalid dataflow item found during training: no response or response_ids and skip this item." + ) + return False + if not response_valid: + # NOTE: check valid response string for judger inputs + logger.warning("Invalid dataflow item found during training: empty response string and skip this item.") + return False + return True + + +class RLColocateTrainerConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + resources: AcceleratorResourcesConfig + train_worker_cfg: WorkerConfig + rollout_config: RolloutConfig + tokenizer_path: Union[str, Path] + replay_buffer_config: SyncReplayBufferConfig | AsyncReplayBufferConfig = SyncReplayBufferConfig() + agent_loop_manager_cfg: AgentLoopManagerConfig + eval_agent_loop_manager_cfg: AgentLoopManagerConfig + evaluator_config: EvaluatorConfig + load_from: Union[str, Path] + rollout_steps: int + global_batch_size: int + + enable_evaluate: bool = True + enable_initial_evaluate: bool = False + evaluate_step: int = 1 + work_dir: Union[Path, str, None] = None + auto_resume: bool = False + load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig() + checkpoint_interval: int | None = -1 + checkpoint_maxkeep: int | None = -1 + hf_interval: int | None = -1 + hf_max_keep: int | None = -1 + checkpoint_no_save_optimizer: bool = False + log_dir: Union[Path, str, None] = None + seed: int = 66 + debug_rollout: bool = False + skip_checkpoint_validation: bool = False + exp_tracker: Literal["tensorboard", "jsonl"] = "tensorboard" + + def build(self) -> "RLColocateTrainer": + return RLColocateTrainer( + resources=self.resources, + train_worker_cfg=self.train_worker_cfg, + rollout_config=self.rollout_config, + tokenizer_path=self.tokenizer_path, + replay_buffer_config=self.replay_buffer_config, + agent_loop_manager_cfg=self.agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=self.eval_agent_loop_manager_cfg, + evaluator_config=self.evaluator_config, + enable_evaluate=self.enable_evaluate, + enable_initial_evaluate=self.enable_initial_evaluate, + evaluate_step=self.evaluate_step, + work_dir=self.work_dir, + auto_resume=self.auto_resume, + load_checkpoint_cfg=self.load_checkpoint_cfg, + checkpoint_interval=self.checkpoint_interval, + checkpoint_maxkeep=self.checkpoint_maxkeep, + checkpoint_no_save_optimizer=self.checkpoint_no_save_optimizer, + hf_interval=self.hf_interval, + hf_max_keep=self.hf_max_keep, + load_from=self.load_from, + log_dir=self.log_dir, + seed=self.seed, + debug_rollout=self.debug_rollout, + skip_checkpoint_validation=self.skip_checkpoint_validation, + rollout_steps=self.rollout_steps, + global_batch_size=self.global_batch_size, + exp_tracker=self.exp_tracker, + ) + + +class RLColocateTrainer: + _META_PATH = ".xtuner_rl_colocate_trainer" + _EXP_TRACKING_PATH = "exp_tracking" + _CHECKPOINT_DIR = "checkpoints" + _HF_DIR = "hf" + _SAVE_TRAIN_STATE_PATH = "train_state.json" + + # 弱化Trainer:Trainer中代码尽量少,尽量用componet来组织代码。 + # 目标是像torch一样,让用户自己写init 和 train loop,我们只提供组件。 + def __init__( + self, + *, + resources: AcceleratorResourcesConfig, + train_worker_cfg: WorkerConfig, + rollout_config: RolloutConfig, + # Sampler config + # sampler_config: SamplerConfig, + tokenizer_path: str | Path, + replay_buffer_config: SyncReplayBufferConfig | AsyncReplayBufferConfig, + # agent loop config + # agent_loop_config: AgentLoopConfig, + # agent loop manager config + # produce_strategy_config: ProduceStrategyConfig, + agent_loop_manager_cfg: AgentLoopManagerConfig, + # eval configs + eval_agent_loop_manager_cfg: AgentLoopManagerConfig, + evaluator_config: EvaluatorConfig, + enable_evaluate: bool = True, + enable_initial_evaluate: bool = False, + evaluate_step: int = 1, + # work_dir and resume + work_dir: Path | str | None = None, + auto_resume: bool = False, + load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig(), + checkpoint_interval: int | None = -1, + checkpoint_maxkeep: int | None = -1, + checkpoint_no_save_optimizer: bool = False, + hf_interval: int | None = None, + hf_max_keep: int | None = None, + # others + load_from: str | Path, + log_dir: Path | str | None = None, + seed: int = 66, + debug_rollout: bool = False, + skip_checkpoint_validation: bool = False, # Suggest enabled if fsdp_size is larger than 512 + # steps + rollout_steps: int, + global_batch_size: int, + # exp tracker + exp_tracker: Literal["tensorboard", "jsonl"] = "tensorboard", + # gateway + gateway_config: GatewayConfig | None = None, + ): + check_fa3() + + # work_dir + work_dir = Path(work_dir) if work_dir else Path.cwd() / "work_dirs" + if get_rank() == 0: + work_dir.mkdir(parents=True, exist_ok=True) + self._meta = XTunerMeta.build(work_dir, self._META_PATH, auto_resume) + + # hf checkpoint config + self._load_from = Path(load_from) if isinstance(load_from, str) else load_from + is_hf_path, error_info = is_hf_model_path(load_from) if load_from is not None else (False, "") + self._load_from_hf = is_hf_path + + if not self._load_from_hf: + raise NotImplementedError(error_info) + self._hf_max_keep = hf_max_keep + self._hf_interval = hf_interval + + # checkpoint config + self._checkpoint_interval = checkpoint_interval + self._checkpoint_maxkeep = checkpoint_maxkeep + self._checkpoint_no_save_optimizer = checkpoint_no_save_optimizer + self._load_checkpoint_cfg = self._resolve_load_checkpoint_cfg(auto_resume, load_checkpoint_cfg) + + # log + log_dir = self.exp_dir / "logs" + self.logger = get_logger(log_dir=log_dir, tag="RLTrainer") + + force_set_tokenize_workers(self.logger) + + if skip_checkpoint_validation: + patch_default_save_plan() + + # steps + self._rollout_steps = rollout_steps + # self._total_epochs = total_epochs # TODO + self._cur_step = 0 + self._global_train_step = 0 + self._seed = seed + set_deterministic() + set_random_seed(seed) + self.global_batch_size = global_batch_size + + # main components + self._pg = AutoAcceleratorWorkers.build_placement_group(resources) + + # override train worker config + if train_worker_cfg.seed is None: + self.logger.warning(f"RLTrainer seed {seed} is used as train worker seed.") + train_worker_cfg.seed = seed + train_worker_cfg.load_from = load_from + train_worker_cfg.log_dir = log_dir + self._train_worker_cfg = train_worker_cfg + + # override rollout config + rollout_config.worker_log_dir = log_dir + + # If resuming from checkpoint, skip loading weights in rollout workers + if self._load_checkpoint_cfg.checkpoint_path is not None: + rollout_config.skip_load_weights = True + self.logger.info( + f"Skip load rollout weights due to resume from checkpoint {self._load_checkpoint_cfg.checkpoint_path}" + ) + + # build train controller and rollout controller + self.train_controller = train_worker_cfg.build(self._pg) + + # Resume train worker if checkpoint exists + if self._load_checkpoint_cfg.checkpoint_path is not None: + ray.get(self.train_controller.resume.remote(self._load_checkpoint_cfg)) + train_state_path = Path(self._load_checkpoint_cfg.checkpoint_path) / self._SAVE_TRAIN_STATE_PATH + with train_state_path.open("r") as f: + train_state = json.load(f) + self._cur_step = train_state["cur_step"] + + self.rollout_controller = rollout_config.build(self._pg) + + if gateway_config is not None and gateway_config.auto_start: + ray.get(self.rollout_controller.start_gateway.remote(gateway_config)) + + replay_buffer = replay_buffer_config.build() + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + # build agnet_loop_manager + self.agent_loop_manager = agent_loop_manager_cfg.build( + rollout_controller=self.rollout_controller, + tokenizer=self.tokenizer, + replay_buffer=replay_buffer, + logger=self.logger, + ) + + # build eval agent loop manager + self.eval_agent_loop_manager = eval_agent_loop_manager_cfg.build( + rollout_controller=self.rollout_controller, + tokenizer=self.tokenizer, + replay_buffer=replay_buffer, + logger=self.logger, + ) + + self._enable_evaluate = enable_evaluate + self._enable_initial_evaluate = enable_initial_evaluate + self._evaluate_step = evaluate_step + + # build evaluator + total_eval_samples = len(self.eval_agent_loop_manager.data_sampler) + self.evaluator = evaluator_config.build(total_eval_samples=total_eval_samples) + + # Resume sampler and sync weights if checkpoint exists + if self._load_checkpoint_cfg.checkpoint_path is not None: + self.logger.info(f"Resume sampler from {self._load_checkpoint_cfg.checkpoint_path}") + self.agent_loop_manager.resume(self._load_checkpoint_cfg.checkpoint_path) + + bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) + self.logger.info("Rollout workers skip load weights, update weights from train workers.") + ray.get(self.train_controller.offload.remote(target="optimizer")) + ray.get(self.rollout_controller.offload.remote()) + ray.get(self.rollout_controller.onload_weights.remote()) + ray.get(self.train_controller.update_weights.remote()) + ray.get(self.train_controller.offload.remote(target="model")) + ray.get(self.rollout_controller.onload_kvcache.remote()) + self.logger.info("Rollout workers updated weights from train workers.") + else: + ray.get(self.train_controller.offload.remote(target="all")) + + # others + if debug_rollout: + self.logger.warning("Debug rollout mode is enabled, rollout will not be offloaded.") + self._debug_rollout = debug_rollout + self._exp_tracker = get_writer(writer_type=exp_tracker, log_dir=log_dir / self._EXP_TRACKING_PATH) + self._display_all_workers_log = False + + @property + def exp_dir(self) -> Path: + return Path(self._meta.latest_exp.exp_dir) + + def _resolve_load_checkpoint_cfg( + self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig + ) -> LoadCheckpointConfig: + """Resolve checkpoint path for auto-resume.""" + latest_checkpoint = self._meta.latest_exp.latest_checkpoint + if latest_checkpoint is not None and auto_resume: + load_checkpoint_cfg.checkpoint_path = Path(latest_checkpoint) + return load_checkpoint_cfg + + def _maybe_save_checkpoint(self, cur_step: int) -> None: + """Save checkpoint if interval condition is met.""" + ckp_interval = self._checkpoint_interval + if ckp_interval is None or ckp_interval == -1: + return + if cur_step % ckp_interval != 0: + return + + checkpoint_path = self.exp_dir / self._CHECKPOINT_DIR / f"ckpt-step-{cur_step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + # 1. Save sampler (dataloader) state + self.logger.info(f"Saving sampler state to {checkpoint_path}") + self.agent_loop_manager.save(checkpoint_path) + + # 2. Save DCP checkpoint (model + optimizer) + self.logger.info(f"Saving DCP checkpoint to {checkpoint_path}") + ray.get(self.train_controller.save.remote(str(checkpoint_path), self._checkpoint_no_save_optimizer)) + + # 3. Save train state JSON + train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH + with train_state_path.open("w") as f: + json.dump({"cur_step": cur_step}, f) + + # 4. Update meta + current_exp = self._meta.latest_exp + current_exp.checkpoint_list.append(str(checkpoint_path)) + + # 5. Prune old checkpoints + ckp_maxkeep = self._checkpoint_maxkeep + ckp_list = current_exp.checkpoint_list + if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep: + for deleted in ckp_list[:-ckp_maxkeep]: + if Path(deleted).exists(): + rmtree(deleted, ignore_errors=True) + current_exp.checkpoint_list = ckp_list[-ckp_maxkeep:] + + # 6. Persist meta to disk + meta_path = self.exp_dir.parent / self._META_PATH + with meta_path.open("w") as f: + f.write(self._meta.model_dump_json(indent=2)) + + def _maybe_save_hf(self, cur_step: int): + if self._hf_interval is None or self._hf_interval == -1: + return + + if not self._load_from_hf: + raise RuntimeError( + "Only support saving to Huggingface format when loading from Huggingface! " + "You meet this error means `load_from` of trainer is not a Huggingface model path." + ) + + if cur_step % self._hf_interval != 0 and cur_step != self._rollout_steps: + return + + save_hf_path = self.exp_dir / self._HF_DIR / f"hf-step-{cur_step}" + save_hf_path.mkdir(parents=True, exist_ok=True) + + # update meta + current_exp = self._meta.latest_exp + current_exp.hf_checkpoint_list.append(str(save_hf_path)) + + # save hf + self.logger.info(f"Saving Huggingface checkpoint to {save_hf_path}") + hf_list = self._meta.latest_exp.hf_checkpoint_list + if self._hf_max_keep is not None and self._hf_max_keep > 0 and len(hf_list) > self._hf_max_keep: + for deleted in hf_list[: -self._hf_max_keep]: + if Path(deleted).exists(): + rmtree(deleted, ignore_errors=True) + current_exp.hf_checkpoint_list = hf_list[-self._hf_max_keep :] + ray.get(self.train_controller.save_hf.remote(str(save_hf_path)), timeout=TRAINER_RAY_GET_TIMEOUT) + + # save tokenizer + if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + self.tokenizer.save_pretrained(str(save_hf_path)) + + def fit(self) -> None: + self.logger.info("Start RL training") + if self._cur_step >= self._rollout_steps: + self.logger.info(f"Rollout steps {self._rollout_steps} reached, stop training") + return + + if self._enable_initial_evaluate and not self._debug_rollout: + eval_produce_result = asyncio_run( + self.eval_agent_loop_manager.produce_batch(self.evaluator.eval_batch_size, rollout_step=0) + ) + eval_metrics = self.evaluator.run(eval_produce_result.rollout_states) + self.logger.info(f"Initial rollout evaluate scores {eval_metrics} and start training") + + tb_scores = {f"eval/{k}": v for k, v in eval_metrics.items()} + self._exp_tracker.add_scalars( + tag_scalar_dict=tb_scores, + global_step=0, + ) + + for rollout_idx in range(self._cur_step + 1, self._rollout_steps + 1): + self.logger.info(f"Rollout {rollout_idx}/{self._rollout_steps} start") + step_timer_dict: dict[str, float] = {} + with timer("step", step_timer_dict): + # 1. Rollout to generate experience + self.logger.info("start to generate rollout experience for training") + produce_result: ProduceBatchResult = asyncio_run( + self.agent_loop_manager.produce_batch(self.global_batch_size, rollout_step=rollout_idx) + ) + train_batch = produce_result.rollout_states + self.logger.info(f"generate {len(train_batch) * len(train_batch[0])} samples for training") + train_trajectory_dir = self.exp_dir / "train_rollout" + train_trajectory_dir.mkdir(parents=True, exist_ok=True) + train_trajectory_path = train_trajectory_dir / f"train_rollout_{rollout_idx}.jsonl" + self._save_trajectories(train_batch, train_trajectory_path) + self.logger.info(f"Rollout {rollout_idx} train trajectories saved to {train_trajectory_path}") + if not self._debug_rollout: + ray.get(self.rollout_controller.offload.remote()) + + if not self._debug_rollout: + with timer("onload", step_timer_dict): + ray.get(self.train_controller.onload.remote(target="all")) + self.logger.info("Training controller loaded") + + # 2. Train on the generated experience + # TODO: simplify with Packer.pack_pad_dispatch() + # train_batch = Packer.pack_pad_dispatch(train_batch) + with timer("prepare_data", step_timer_dict): + data_batches, data_info = self._prepare_train_data( + train_batch, self._train_worker_cfg.pack_max_length + ) + self.logger.info(f"Prepared {len(data_batches)} training data batches") + + with timer("training", step_timer_dict): + workers_log_item: list[WorkerLogItem] = ray.get( + self.train_controller.fit.remote( + data_batches, + pack_max_length=self._train_worker_cfg.pack_max_length, + rollout_idx=rollout_idx, + ) + ) + train_log_info: TrainInfo = { + "data_info": data_info, + "workers_log_item": workers_log_item, + } + + # 3. Synchronize weights and save checkpoints + self._sync_weights_and_save(rollout_idx, step_timer_dict) + + # 4. Evaluate model performance + eval_log_info = {} + if self._enable_evaluate and rollout_idx % self._evaluate_step == 0: + with timer("evaluation", step_timer_dict): + eval_produce_result = asyncio_run( + self.eval_agent_loop_manager.produce_batch( + self.evaluator.eval_batch_size, rollout_step=rollout_idx + ) + ) + eval_batch = eval_produce_result.rollout_states + eval_metrics = self.evaluator.run(eval_batch) + eval_trajectory_dir = self.exp_dir / "eval_rollout" + eval_trajectory_dir.mkdir(parents=True, exist_ok=True) + eval_trajectory_path = eval_trajectory_dir / f"eval_rollout_{rollout_idx}.jsonl" + self._save_trajectories(eval_batch, eval_trajectory_path) + self.logger.info( + f"Rollout {rollout_idx} eval trajectories saved to {eval_trajectory_path}" + ) + eval_log_info.update(eval_metrics) + else: + train_log_info = {} + eval_log_info = {} + + self._log_step(rollout_idx, step_timer_dict, produce_result, train_log_info, eval_log_info) + self._cur_step = rollout_idx + + # TODO: simplify with Packer.pack_pad_dispatch() + def _prepare_train_data(self, data_groups: list[list[RolloutState]], pack_max_length: int): + rewards_list = [] + advantages_list = [] + prompt_len_list = [] + response_len_list = [] + + data_batches = [] + + for j, group in enumerate(data_groups): + if not is_valid_for_training(group, self.logger): + self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") + continue + + is_vlm_model = "train_prompt_ids" in group[0].extra_fields + if is_vlm_model: + # TODO(hha): VLM, 不好的设计,后续要去掉 + prompt_ids = group[0].extra_fields["train_prompt_ids"] + else: + prompt_ids = group[0].prompt_ids + assert prompt_ids is not None and len(prompt_ids) > 0, ( + f"Prompt ids cannot be None or empty in data: {group[0]}" + ) + rewards = [] + for data in group: + assert data.reward is not None and "score" in data.reward, ( + f"Reward is missing or does not contain 'score' key in data: {data}" + ) + rewards.append(data.reward["score"]) + + rewards_list.extend(rewards) + rewards_tensor = torch.tensor(rewards, dtype=torch.float32) + advantages = (rewards_tensor - rewards_tensor.mean(0)) / (rewards_tensor.std(0) + 1e-8) + + prompt_repeat_k = len(group) + for i in range(prompt_repeat_k): + item = group[i].response + logprobs: list[float] | None = None + + response_ids: List[int] = [] + if group[i].response_ids is not None: + resp_ids_raw = group[i].response_ids + if isinstance(resp_ids_raw, torch.Tensor): + response_ids = resp_ids_raw.flatten().tolist() + else: + response_ids = cast(List[int], resp_ids_raw) + + logprobs = group[i].logprobs + if logprobs is not None: + assert len(logprobs) == len(response_ids), ( + f"{len(logprobs)} vs {len(response_ids)}, data: {group[i]}" + ) + # 只有 response 部分有 logprobs, 需要前面追加 + logprobs = [0.0] * (len(prompt_ids) - 1) + logprobs # type: ignore[arg-type] + else: + assert item is not None, "response item cannot be None" + response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() + + # 返回的 routed_experts 不包括 eos 的值,实际上也不需要,需要减一 + # TODO: verl tool agent loop 是否需要? + input_ids = prompt_ids + response_ids[:-1] + + prompt_len_list.append(len(prompt_ids)) + response_len_list.append(len(response_ids)) + + # 根据 response_mask 计算 response_ids 对应的shifted_labels + if not group[i].response_mask: + response_mask = [1] * len(response_ids) + response_labels = response_ids + else: + assert len(group[i].response_mask) == len(response_ids), ( # type: ignore[arg-type] + f"{len(group[i].response_mask)} vs {len(response_ids)}" # type: ignore[arg-type] + ) + response_mask = cast(list[int], group[i].response_mask) + response_labels = [ + response_id if mask_id != 0 else -100 + for response_id, mask_id in zip(response_ids, response_mask) + ] + shifted_labels = [-100] * (len(prompt_ids) - 1) + response_labels + shifted_labels_t = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) + + # 根据 response_mask 计算新的 advantages + advatnages_val = advantages[i].item() + actual_advantages = [advatnages_val] * len(prompt_ids) + [ + 0.0 if mask == 0 else advatnages_val for mask in response_mask + ] + advantages_list.extend(actual_advantages[:-1]) + + assert len(input_ids) <= pack_max_length, f"{len(input_ids)} vs {pack_max_length}" + input_ids_t = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) + + if logprobs is not None: + rollout_logprobs = torch.tensor(logprobs, dtype=torch.float32).unsqueeze(0) + assert rollout_logprobs.size() == shifted_labels_t.size(), ( + f"{rollout_logprobs.size()} vs {shifted_labels_t.size()}" + ) + else: + rollout_logprobs = None + + position_ids = group[i].position_ids + multimodal_train_info = group[i].mm_info + multi_info_cast = cast(dict | None, multimodal_train_info) + seq_ctx = get_train_seq_ctx(input_ids_t, position_ids, multi_info_cast, len(response_ids) - 1) # type: ignore[arg-type] + + data_dict = { + "seq_ctx": seq_ctx, + "shifted_labels": shifted_labels_t, + "advantage": actual_advantages, + "rollout_logprobs": rollout_logprobs, + } + + seq_ctx.rollout_routed_experts = group[i].routed_experts # n,layer*expert + + data_batches.append(data_dict) + random.shuffle(data_batches) + + rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float() + advantages_t = torch.tensor(advantages_list).float() if advantages_list else torch.tensor([0.0]).float() + prompt_len_t = torch.tensor(prompt_len_list).float() if prompt_len_list else torch.tensor([0.0]).float() + response_len_t = torch.tensor(response_len_list).float() if response_len_list else torch.tensor([0.0]).float() + + info_dict = { + "batch_size": len(rewards_list), + "rewards/mean": rewards_t.mean().item(), + "rewards/min": rewards_t.min().item(), + "rewards/max": rewards_t.max().item(), + "advantages/mean": advantages_t.mean().item(), + "advantages/min": advantages_t.min().item(), + "advantages/max": advantages_t.max().item(), + "response_len/mean": response_len_t.mean().item(), + "response_len/min": response_len_t.min().item(), + "response_len/max": response_len_t.max().item(), + "response_len/std": response_len_t.std().item(), + "prompt_len/mean": prompt_len_t.mean().item(), + "prompt_len/min": prompt_len_t.min().item(), + "prompt_len/max": prompt_len_t.max().item(), + } + return data_batches, info_dict + + def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict): + """Synchronizes weights and saves checkpoints.""" + with timer("save_ckpt", step_timer_dict): + ray.get(self.train_controller.offload.remote(target="optimizer")) + self._maybe_save_checkpoint(rollout_idx) + self._maybe_save_hf(rollout_idx) + + ray.get(self.rollout_controller.recover_failed_workers.remote()) + with timer("sync_weight", step_timer_dict): + bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) + ray.get(self.rollout_controller.onload_weights.remote()) + ray.get(self.train_controller.update_weights.remote()) + self.logger.info("Model weights synchronized successfully.") + ray.get(self.train_controller.offload.remote(target="model")) + ray.get(self.rollout_controller.onload_kvcache.remote()) + + def _log_step( + self, + rollout_idx: int, + step_timer_dict: dict, + produce_result: ProduceBatchResult, + train_info: TrainInfo, + eval_info: dict[str, float], + ): + all_scalars = {} + log_time_str = "" + trajectory_str = "" + eval_str = "" + if step_timer_dict: + all_scalars.update({f"time/{k}": v for k, v in step_timer_dict.items()}) + log_time_str = f"\nRollout {rollout_idx} finished and timing listed:\n" + log_time_str += "\n".join([f" - {k:<25}: {v:.2f}s" for k, v in step_timer_dict.items()]) + + if produce_result.group_gen_count is not None: + all_scalars["timing/task_n"] = produce_result.group_gen_count + all_scalars["timing/task_mean_s"] = produce_result.group_gen_mean_s + all_scalars["timing/task_p50_s"] = produce_result.group_gen_p50_s + all_scalars["timing/task_p99_s"] = produce_result.group_gen_p99_s + all_scalars["timing/task_p99_p50_ratio"] = produce_result.group_gen_p99_p50_ratio + all_scalars["timing/pause_s"] = produce_result.group_gen_pause_time_s + all_scalars["async/completed_samples"] = produce_result.leftover_completed + all_scalars["async/aborted_samples"] = produce_result.leftover_aborted + all_scalars["async/expired_samples"] = produce_result.leftover_expired + + if train_info: + all_scalars.update({f"response/{k}": v for k, v in train_info.get("data_info", {}).items()}) + trajectory_str = f"\nRollout {rollout_idx} data statistics:\n" + trajectory_str += "\n".join([f"- {k:<25}: {v:.4f}" for k, v in train_info.get("data_info", {}).items()]) + rank0_log_item = train_info["workers_log_item"][0] + rank0_rollout_is_metrics = rank0_log_item.get("rollout_is_metrics", {}) + rank0_mismatch_metrics = rank0_log_item.get("mismatch_metrics", {}) + rank0_rollout_entropy = rank0_log_item.get("rollout_entropy", 0.0) + all_scalars.update({f"rollout_is/{k}": v for k, v in rank0_rollout_is_metrics.items()}) + all_scalars.update({f"{k}": v for k, v in rank0_mismatch_metrics.items()}) + all_scalars.update({"entropy/rollout": rank0_rollout_entropy}) + all_scalars.update({"entropy/train": rank0_log_item["train_entropy"]}) + for worker_idx, log_item in enumerate(train_info["workers_log_item"]): + if not self._display_all_workers_log and worker_idx > 0: + break + mini_batch_metrics: dict[str, List[float]] = {} + for mini_batch_log in log_item["train_metrics"]: + rl_worker_log = mini_batch_log["loss_log"] | mini_batch_log["rl_other_log"] + for k, v in rl_worker_log.items(): + mini_batch_metrics.setdefault(k, []).append(cast(float, v)) + + for key, value in mini_batch_metrics.items(): + avg_value = sum(value) / len(value) + all_scalars.update({f"train_metrics/worker_{worker_idx}/step_avg_{key}": avg_value}) + + rank_sft_log = log_item["sft_train_metrics"] + for k, v in rank_sft_log.items(): + all_scalars.update({f"sft_train_metrics/worker_{worker_idx}/{k}": v}) + + self._log_mini_batch_metrics(train_info["workers_log_item"]) + + if eval_info: + all_scalars.update({f"eval/{k}": v for k, v in eval_info.items()}) + eval_str = " ".join([f"{k}: {v:.4f}" for k, v in eval_info.items()]) + + self.logger.info(f"Rollout {rollout_idx}/{self._rollout_steps}{log_time_str} {trajectory_str} ") + if eval_str: + self.logger.info(f"Eval: {eval_str}") + self._exp_tracker.add_scalars(tag_scalar_dict=all_scalars, global_step=rollout_idx) + + def _save_trajectories(self, data_groups: list[list[RolloutState]], save_path: Path) -> None: + rewards = [] + response_len_list = [] + + for group in data_groups: + if not is_valid_for_training(group, self.logger): + continue + for data in group: + assert data.reward is not None + rewards.append(data.reward["score"]) + if data.response_ids is not None: + if isinstance(data.response_ids, torch.Tensor): + response_ids = data.response_ids.flatten().tolist() + else: + response_ids = data.response_ids + response_len_list.append(len(response_ids)) + elif data.response is not None: + response_ids = self.tokenizer.encode(data.response, add_special_tokens=False) + response_len_list.append(len(response_ids)) + + rewards_tensor = torch.tensor(rewards).float() if rewards else torch.tensor([0.0]).float() + response_lens = torch.tensor(response_len_list).float() if response_len_list else torch.tensor([0.0]).float() + + _count = 0 + with open(save_path, "w", encoding="utf-8") as f: + summary = { + "reward_mean": rewards_tensor.mean().item(), + "reward_std": rewards_tensor.std().item(), + "reward_max": rewards_tensor.max().item(), + "reward_min": rewards_tensor.min().item(), + "response_len_mean": response_lens.mean().item(), + "response_len_std": response_lens.std().item(), + "response_len_max": response_lens.max().item(), + "response_len_min": response_lens.min().item(), + "total_len": len(rewards), + } + json.dump(summary, f, ensure_ascii=False, indent=2) + f.write("\n") + for group in data_groups: + if not is_valid_for_training(group, self.logger): + continue + for data in group: + assert data.reward is not None + ground_truth = None + if data.reward_model is not None: + ground_truth = data.reward_model.get("ground_truth") + item = { + "prompt": data.message, + "raw_prompt": data.extra_fields.get("raw_prompt", None), + "response": data.response, + "response_len": response_len_list[_count], + "label": ground_truth, + "reward": data.reward["score"], + "finish_reason": data.finish_reason, + } + json.dump(item, f, ensure_ascii=False, indent=2) + f.write("\n") + _count += 1 + + def _log_mini_batch_metrics(self, workers_log_item: List[WorkerLogItem]): + train_start_step = self._global_train_step + 1 + for worker_idx, log_item in enumerate(workers_log_item): + for step_idx, mini_batch_log in enumerate(log_item["train_metrics"]): + if not self._display_all_workers_log and worker_idx > 0: + break + current_global_step = train_start_step + step_idx + + metrics: dict[str, Any] = dict(mini_batch_log["loss_log"]) + metrics.update(mini_batch_log["rl_other_log"]) + + self._exp_tracker.add_scalars( + tag_scalar_dict={f"train_metrics/worker_{worker_idx}/{k}": float(v) for k, v in metrics.items()}, + global_step=current_global_step, + ) + self._global_train_step += len(workers_log_item[0]["train_metrics"]) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py deleted file mode 100644 index c02b02a494..0000000000 --- a/xtuner/v1/train/rl_trainer.py +++ /dev/null @@ -1,1113 +0,0 @@ -import json -import os -import random -from datetime import datetime -from pathlib import Path -from shutil import rmtree -from typing import List, cast - -import ray -import torch -from mmengine import load -from mmengine.dist import get_rank -from mmengine.runner import set_random_seed -from pydantic import BaseModel, ConfigDict, model_validator -from ray.util.placement_group import placement_group -from typing_extensions import Literal, Self, TypedDict - -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast -from xtuner.v1._writer import get_writer -from xtuner.v1.data_proto.rl_data import MultimodalTrainInfo, RLDataFlowItem, is_valid_for_training -from xtuner.v1.data_proto.sequence_context import SequenceContext -from xtuner.v1.patch import patch_default_save_plan -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers, CPUResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, DataFlowProxy, ReplayBufferConfig -from xtuner.v1.ray.environment import SingleTurnEnvironment, SingleTurnEnvironmentProxy -from xtuner.v1.ray.evaluator import Evaluator, EvaluatorConfig -from xtuner.v1.ray.judger import JudgerConfig -from xtuner.v1.rl.base import ( - TrainingController, - TrainingControllerProxy, - TrainingWorkerClass, - TrainingWorkerProxy, - WorkerConfig, - WorkerLogItem, -) -from xtuner.v1.rl.base import TrainingWorker as BaseTrainingWorker -from xtuner.v1.train import ResumeConfig -from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, record_git_info, timer -from xtuner.v1.utils.device import get_device, get_torch_device_module -from xtuner.v1.utils.env_check import get_rollout_engine_version - -from .trainer import ExpHistory, ExpInfo, GitInfo, LoadCheckpointConfig, XTunerMeta - - -# TODO: Move DEVICE to `xtuner.utils.device` -PG_READY_TIMEOUT = 30 -TRAINER_RAY_GET_TIMEOUT = 5 * 3600 # 5 hour -DEVICE = get_device() -DEVICE_MODULE = get_torch_device_module() - - -def bind_train_rollout( - train_controller, - env_controller, -) -> None: - """Bind the training and rollout workers for update weights.""" - info_dict = ray.get(env_controller.get_rollout_info.remote()) # type: ignore[attr-defined] - ray.get(train_controller.update_rollout_info.remote(info_dict)) - return - - -class RolloutInfo(TypedDict): - data_groups: list[list[RLDataFlowItem]] - multimodal_train_infos: list[MultimodalTrainInfo] - task_time: dict[str, float] - replay_buffer_info: dict[str, float] - - -class TrainInfo(TypedDict): - data_info: dict[str, float] - workers_log_item: list[WorkerLogItem] - - -class RLTrainerConfig(BaseModel): - model_config = ConfigDict(extra="forbid") - load_from: str | Path - resources: AcceleratorResourcesConfig - cpu_resources: CPUResourcesConfig | None = None - rollout_config: RolloutConfig - dataflow_config: DataFlowConfig - judger_config: JudgerConfig - replay_buffer_config: ReplayBufferConfig - train_worker_config: WorkerConfig - evaluator_config: EvaluatorConfig | None = None - tokenizer_path: str | Path - work_dir: Path | str | None = None - log_dir: Path | str | None = None - total_epochs: int - resume_config: ResumeConfig | None = None - auto_resume: bool = False - load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig() - strict_load: bool = True - checkpoint_interval: int | None = -1 - checkpoint_maxkeep: int | None = -1 - checkpoint_no_save_optimizer: bool = False - skip_checkpoint_validation: bool = False # Suggest enabled if fsdp_size is larger than 512 - hf_interval: int | None = None - hf_max_keep: int | None = None - seed: int = 42 - debug: bool = False - debug_rollout: bool = False - rollout_steps: int | None = None - display_all_workers_log: bool = False - exp_tracker: Literal["tensorboard", "jsonl"] = "tensorboard" - - @model_validator(mode="after") - def _convert_work_dir(self): - if isinstance(self.work_dir, str): - self.work_dir = Path(self.work_dir) - elif self.work_dir is None: - self.work_dir = Path.cwd() - return self - - -def get_train_seq_ctx( - input_ids: torch.LongTensor, multimodal_train_info: dict | None = None, len_response_ids: int = 0 -): - seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu") - if multimodal_train_info and len(multimodal_train_info) > 0: - position_ids = multimodal_train_info.get("position_ids") # (1,n) or (3,1,n) - if position_ids is not None and len(position_ids.shape) == 3: - # qwen3vl 需要特殊处理,其余的不需要额外处理 - max_value = position_ids.max(dim=-1).values # (3,1) - response_position_ids = max_value.unsqueeze(-1).expand(-1, -1, len_response_ids) + torch.arange( - 1, len_response_ids + 1, device=max_value.device - ) - position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - seq_ctx.position_ids = position_ids # type: ignore[assignment] - assert position_ids.size(-1) == input_ids.size(-1) - seq_ctx.pixel_values = multimodal_train_info.get("pixel_values") - seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw") - return seq_ctx - - -class RLTrainer: - """Universal Reinforcement Learning Trainer for XTuner. - - A flexible RL training orchestrator that supports multiple RL algorithms - through pluggable training workers and controllers. Manages the complete - RL training workflow including rollout generation, policy updates, - evaluation, and checkpoint management. - - **Training Workflow:** - 1. Initialize distributed workers and rollout environment - 2. Generate experiences using current policy - 3. Update policy using algorithm-specific training logic - 4. Synchronize weights between training and rollout workers - 5. Evaluate model performance and save checkpoints - - Args: - load_from (str | Path): Path to the base model to load. Should be a HuggingFace - model path (e.g., "meta-llama/Llama-2-7b-hf") or local model directory. - resources (AcceleratorResourcesConfig): Configuration for distributed computing - resources including number of workers, GPU allocation, and placement groups. - rollout_config (RolloutConfig): Configuration for rollout workers that generate - experiences by interacting with the environment. - dataflow_config (DataFlowConfig): Data orchestration configuration controlling - experience collection, batch formation, and data distribution across workers. - judger_config (JudgerConfig): Configuration for the reward model or scoring system - that evaluates generated responses and provides training signals. - replay_buffer_config (ReplayBufferConfig): Settings for experience replay buffer - including capacity, sampling strategy, and data retention policies. - evaluator_config (EvaluatorConfig | None): Evaluation configuration specifying metrics, - evaluation datasets, and assessment frequency for monitoring training progress. Defaults to None. - train_worker_cfg (WorkerConfig): Configuration for distributed training workers - including model architecture, optimizer settings, loss functions, and parallelism. - tokenizer_path (str | Path): Path to the tokenizer for text preprocessing. - Should be compatible with the base model specified in load_from. - work_dir (Path | str | None): Working directory for experiment outputs, - checkpoints, and logs. Defaults to None. - log_dir (Path | str | None): Directory for training logs and monitoring outputs. - Defaults to None. - total_epochs (int): Total number of training epochs to execute. - enable_evaluate (bool): Whether to perform periodic evaluation during training. - resume_config (ResumeConfig | None): Configuration for resuming training from - a previous checkpoint. Defaults to None. - auto_resume (bool): Whether to automatically resume training. Defaults to False. - load_checkpoint_cfg (LoadCheckpointConfig): Configuration for loading checkpoints. - strict_load (bool): Whether to strictly enforce checkpoint loading compatibility. - Defaults to True. - hf_interval (int | None): Interval (in epochs) for saving HuggingFace format - checkpoints. Defaults to None. - hf_max_keep (int | None): Maximum number of HuggingFace checkpoints to retain. - Defaults to None. - seed (int): Random seed for reproducible training. Defaults to 42. - debug (bool): Enable debug mode with additional logging. Defaults to False. - debug_rollout (bool): Enable debug mode for rollout workers. Defaults to False. - rollout_steps (int | None): Total number of rollout steps to perform. - If specified, overrides total_epochs. Defaults to None. - display_all_workers_log (bool): Whether to display logs from all workers. Defaults to False. - exp_tracker (Literal["tensorboard", "jsonl"]): Type of experiment tracker to use. - Options are "tensorboard" or "jsonl". Defaults to "tensorboard". - - **Examples:** - - Example configuration for GRPO RL training setup:: - - trainer = RLTrainer( - load_from="Qwen3-8B", - resources=resources_config, - rollout_config=rollout_cfg, - dataflow_config=dataflow_cfg, - judger_config=judger_cfg, - replay_buffer_config=buffer_cfg, - evaluator_config=eval_cfg, - train_worker_cfg=worker_cfg, - tokenizer_path="Qwen3-8B", - total_epochs=10, - enable_evaluate=True - ) - trainer.fit() - """ - - META_PATH = ".xtuner_grpo" - _EXP_TRACKING_PATH = "exp_tracking" - _CHECKPOINT_DIR = "checkpoints" - _SAVE_TRAIN_STATE_PATH = "train_state.json" - - def __init__( - self, - *, - load_from: str | Path, # Huggingface model path or saved trainer_path - resources: AcceleratorResourcesConfig, - cpu_resources: CPUResourcesConfig | None = None, - rollout_config: RolloutConfig, - dataflow_config: DataFlowConfig, - judger_config: JudgerConfig, - replay_buffer_config: ReplayBufferConfig, - train_worker_cfg: WorkerConfig, - evaluator_config: EvaluatorConfig | None = None, - tokenizer_path: str | Path, - work_dir: Path | str | None = None, - log_dir: Path | str | None = None, - total_epochs: int, - auto_resume: bool = False, - load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig(), - strict_load: bool = True, - checkpoint_interval: int | None = -1, - checkpoint_maxkeep: int | None = -1, - checkpoint_no_save_optimizer: bool = False, - skip_checkpoint_validation: bool = False, # Suggest enabled if fsdp_size is larger than 512 - hf_interval: int | None = None, - hf_max_keep: int | None = None, - seed: int = 42, - debug: bool = False, - debug_rollout: bool = False, - rollout_steps: int | None = None, - exp_tracker: Literal["tensorboard", "jsonl"] = "tensorboard", - display_all_workers_log: bool = False, - trainer_cfg: RLTrainerConfig | None = None, - ): - """Initialize the RL training system.""" - if os.environ.get("XTUNER_USE_FA3", "0") == "1": - try: - from xtuner.v1.ops.flash_attn import get_flash_attn_varlen - - get_flash_attn_varlen() - except RuntimeError as e: - raise RuntimeError( - f"Flash attention v3 runtime error {e}, Please install it first or set XTUNER_USE_FA3=0." - ) - train_worker_cfg.load_from = load_from - - self._total_epochs = total_epochs - self._cur_step = 0 - self._global_train_step = 1 - - if skip_checkpoint_validation: - patch_default_save_plan() - - self._rl_trainer_cfg = trainer_cfg - self._load_from = Path(load_from) if isinstance(load_from, str) else load_from - - is_hf_path, error_info = is_hf_model_path(load_from) if load_from is not None else False, "" - self._load_from_hf = is_hf_path - - if not self._load_from_hf: - raise NotImplementedError(error_info) - - self._hf_max_keep = hf_max_keep - self._hf_interval = hf_interval - self._checkpoint_interval = checkpoint_interval - self._checkpoint_maxkeep = checkpoint_maxkeep - self._checkpoint_no_save_optimizer = checkpoint_no_save_optimizer - - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) - - self._debug = debug - self._debug_rollout = debug_rollout - self._seed = seed - self._set_deterministic() - self._set_random_seed(seed) - - if work_dir is None: - work_dir = Path.cwd() / "work_dir" - - if isinstance(work_dir, str): - work_dir = Path(work_dir) - - if get_rank() == 0: - work_dir.mkdir(parents=True, exist_ok=True) - - self._work_dir = work_dir - self._auto_resume = auto_resume - self._meta = self._init_xtuner_meta(work_dir, self._auto_resume) - - if log_dir is None: - log_dir = self.exp_dir - if isinstance(log_dir, str): - log_dir = Path(log_dir) - - self.logger = self._init_logger(log_dir) - - self._load_checkpoint_cfg = self._resolve_load_checkpoint_cfg(self._auto_resume, load_checkpoint_cfg) - - if train_worker_cfg.seed is None: - self.logger.warning(f"RLTrainer seed {seed} is used as train worker seed.") - train_worker_cfg.seed = seed - - train_worker_cfg.log_dir = log_dir - dataflow_config.worker_log_dir = log_dir - rollout_config.worker_log_dir = log_dir - self._enable_evaluate = False - self._enable_initial_evaluate = False - if evaluator_config: - evaluator_config.worker_log_dir = log_dir - self._enable_evaluate = evaluator_config.enable_evaluate - self._enable_initial_evaluate = evaluator_config.enable_initial_evaluate - self._pg = AutoAcceleratorWorkers.build_placement_group(resources) - - if cpu_resources is not None: - # NOTE: Here we only check CPU and memory for judger actors because only judger actors use CPU resources currently. - assert judger_config.total_cpus_needed <= cpu_resources.num_cpus_per_worker * cpu_resources.num_workers, ( - f"Not enough CPU resources for judger actors, " - f"required {judger_config.total_cpus_needed}, but got {cpu_resources.num_cpus_per_worker * cpu_resources.num_workers}." - ) - assert ( - judger_config.total_memory_needed <= cpu_resources.cpu_memory_per_worker * cpu_resources.num_workers - ), ( - f"Not enough memory resources for judger actors, " - f"required {judger_config.total_memory_needed}, but got {cpu_resources.cpu_memory_per_worker * cpu_resources.num_workers}." - ) - - self._judger_cpu_pg = placement_group(bundles=judger_config.total_bundles_needed, strategy="SPREAD") - ray.get(self._judger_cpu_pg.ready(), timeout=PG_READY_TIMEOUT) - - # We need to build train controller first, and then build rollout dataflow to make - # inference engines know how much memory they can utilize. - self._train_controller = self._build_train_controller(train_worker_cfg) - - if self._load_checkpoint_cfg.checkpoint_path is not None: - rollout_config.skip_load_weights = True - self.logger.info( - f"Skip load rollout weights due to resume from checkpoint {self._load_checkpoint_cfg.checkpoint_path}" - ) - - # resume train worker - ray.get(self._train_controller.resume.remote(self._load_checkpoint_cfg)) - - train_state_path = Path(self._load_checkpoint_cfg.checkpoint_path) / self._SAVE_TRAIN_STATE_PATH - with train_state_path.open("r") as f: - train_state = json.load(f) - self._cur_step = train_state["cur_step"] - - self._rollout_env_controller, self._rollout_dataflow = self._build_rollout_dataflow( - dataflow_cfg=dataflow_config, - rollout_cfg=rollout_config, - judger_cfg=judger_config, - replay_buffer_config=replay_buffer_config, - ) - self._dataflow_partial_rollout_step = dataflow_config.tail_batch_candidate_steps - - if self._load_checkpoint_cfg.checkpoint_path is not None: - # resume rollout dataflow - self.logger.info(f"Resume rollout dataflow from checkpoint {self._load_checkpoint_cfg.checkpoint_path}") - ray.get(self._rollout_dataflow.resume.remote(self._load_checkpoint_cfg.checkpoint_path)) - - if self._enable_evaluate and evaluator_config: - self._evaluator = Evaluator.remote(evaluator_config, self._rollout_env_controller) # type: ignore[attr-defined] - self._eval_step = evaluator_config.evaluate_step - else: - pass - - self._global_batch_size = dataflow_config.global_batch_size - self._rollout_steps = ( - ray.get(self._rollout_dataflow.get_train_dataset_length.remote()) # type: ignore[attr-defined] - // dataflow_config.global_batch_size - * total_epochs - ) - if rollout_steps is not None: - self._rollout_steps = rollout_steps - self.logger.info(f"Set rollout steps to {self._rollout_steps} according to rollout_steps arg") - - bind_train_rollout(train_controller=self._train_controller, env_controller=self._rollout_env_controller) - # update weights if rollout_config.skip_load_weights == True - if rollout_config.skip_load_weights: - self.logger.info("Rollout workers skip load weights, update weights from train workers.") - ray.get(self._train_controller.offload.remote(target="optimizer")) - ray.get(self._rollout_env_controller.offload.remote()) - ray.get(self._rollout_env_controller.onload_weights.remote()) - ray.get(self._train_controller.update_weights.remote()) - ray.get(self._train_controller.offload.remote(target="model")) - ray.get(self._rollout_env_controller.onload_kvcache.remote()) - self.logger.info("Rollout workers has updated weights from train workers.") - else: - ray.get(self._train_controller.offload.remote(target="all")) - - self._train_worker_cfg = train_worker_cfg - - if self._rl_trainer_cfg is not None and get_rank() == 0: - config_path = log_dir / "rl_trainer_config.json" - with config_path.open("w") as f: - f.write(self._rl_trainer_cfg.model_dump_json(indent=2)) - - env_path = log_dir / "env.json" - environment_variables = dict(os.environ) - infer_engine_version = get_rollout_engine_version() - environment_variables.update(infer_engine_version) - with env_path.open("w") as f: - json.dump(environment_variables, f, indent=2) - - self._ray_get_timeout = max( - TRAINER_RAY_GET_TIMEOUT, rollout_config.rollout_timeout, judger_config.judger_timeout - ) - self._exp_tracker = self._init_tracker(exp_tracker, log_dir / self._EXP_TRACKING_PATH) - self._display_all_workers_log = display_all_workers_log - - def _resolve_load_checkpoint_cfg( - self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig - ) -> LoadCheckpointConfig: - # auto_resume优先级高,如果有latest ckp,则说明走auto_resume逻辑 - # 此时,覆盖load checkpoint path - latest_checkpoint = self.meta.latest_exp.latest_checkpoint - if latest_checkpoint is not None and auto_resume: - load_checkpoint_cfg.checkpoint_path = Path(latest_checkpoint) - return load_checkpoint_cfg - - def _init_tracker(self, exp_tracker: Literal["tensorboard", "jsonl"], work_dir: Path): - writer = get_writer(writer_type=exp_tracker, log_dir=work_dir) - return writer - - @classmethod - def from_config(cls, config: RLTrainerConfig) -> Self: - """Create a Trainer instance from a TrainerConfig. - - Args: - config (TrainerConfig): TrainerConfig instance containing all configuration parameters. - - Returns: - Self: Trainer instance initialized with the provided config. - """ - self = cls( - load_from=config.load_from, - resources=config.resources, - cpu_resources=config.cpu_resources, - rollout_config=config.rollout_config, - dataflow_config=config.dataflow_config, - judger_config=config.judger_config, - replay_buffer_config=config.replay_buffer_config, - train_worker_cfg=config.train_worker_config, - evaluator_config=config.evaluator_config, - tokenizer_path=config.tokenizer_path, - work_dir=config.work_dir, - log_dir=config.log_dir, - total_epochs=config.total_epochs, - auto_resume=config.auto_resume, - load_checkpoint_cfg=config.load_checkpoint_cfg, - strict_load=config.strict_load, - checkpoint_interval=config.checkpoint_interval, - checkpoint_maxkeep=config.checkpoint_maxkeep, - checkpoint_no_save_optimizer=config.checkpoint_no_save_optimizer, - hf_interval=config.hf_interval, - hf_max_keep=config.hf_max_keep, - skip_checkpoint_validation=config.skip_checkpoint_validation, - seed=config.seed, - debug=config.debug, - debug_rollout=config.debug_rollout, - rollout_steps=config.rollout_steps, - exp_tracker=config.exp_tracker, - trainer_cfg=config, - ) - return self - - def _build_rollout_dataflow( - self, - dataflow_cfg: DataFlowConfig, - rollout_cfg: RolloutConfig, - judger_cfg: JudgerConfig, - replay_buffer_config: ReplayBufferConfig, - ) -> tuple[SingleTurnEnvironmentProxy, DataFlowProxy]: - env = SingleTurnEnvironment.remote("grpo", self._pg, rollout_cfg, self._judger_cpu_pg, judger_cfg) - flow = DataFlow.remote("grpo", dataflow_cfg, replay_buffer_config, env) - return env, flow - - def _build_train_controller(self, train_worker_cfg: WorkerConfig) -> TrainingControllerProxy: - TrainingWorker = cast( - TrainingWorkerClass, - ray.remote( - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", - "HCCL_NPU_SOCKET_PORT_RANGE": "auto", - } - }, - )(BaseTrainingWorker), - ) - train_workers: list[TrainingWorkerProxy] - train_workers, _ = AutoAcceleratorWorkers.from_placement_group(TrainingWorker, train_worker_cfg, self._pg) - ray.wait([worker.ready.remote() for worker in train_workers]) - train_controller = TrainingController.remote(workers=train_workers) - return train_controller - - def _initial_evaluate(self): - """Performs an initial evaluation before the training loop starts.""" - if self._debug_rollout: - return - if self._enable_initial_evaluate and self._enable_evaluate and self._evaluator: - ray.get(self._rollout_env_controller.update_active_workers.remote()) - scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) - trajectory_save_path = self.exp_dir / "eval_0_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path) - self.logger.info(f"Initial rollout evaluate scores {scores} and start training") - tb_scores = {f"eval/{k}": v for k, v in scores.items()} - self._exp_tracker.add_scalars( - tag_scalar_dict=tb_scores, - global_step=0, - ) - - def _rollout_step(self, rollout_idx: int, step_timer_dict: dict) -> RolloutInfo: - """Performs a single rollout step to generate experience.""" - with timer("generation", step_timer_dict): - ray.get(self._rollout_env_controller.update_active_workers.remote()) - dataflow_result = ray.get(self._rollout_dataflow.run.remote()) - - with timer("save_trajectory", step_timer_dict): - trajectory_save_path = self.exp_dir / f"rollout_idx_{rollout_idx}_trajectory.jsonl" - self._save_trajectories(dataflow_result["data_groups"], trajectory_save_path) - self.logger.info(f"Rollout_idx {rollout_idx} finished, saved trajectories to {trajectory_save_path}") - - if not self._debug_rollout: - with timer("rollout_offload", step_timer_dict): - ray.get(self._rollout_dataflow.pause.remote()) - ray.get(self._rollout_env_controller.offload.remote()) - - rollout_info: RolloutInfo = { - "data_groups": dataflow_result["data_groups"], - "multimodal_train_infos": dataflow_result.get("multimodal_train_infos", None), - "task_time": dataflow_result.get("metrics", {}), - "replay_buffer_info": ray.get(self._rollout_dataflow.get_replaybuffer_status.remote()), - } - return rollout_info - - def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, step_timer_dict: dict) -> TrainInfo: - """Performs a single training step on the generated experience.""" - with timer("onload", step_timer_dict): - ray.get(self._train_controller.onload.remote(target="all")) - self.logger.info("Training controller loaded") - - with timer("prepare_data", step_timer_dict): - data_batches, data_info = self._prepare_train_data( - data_groups, self._train_worker_cfg.pack_max_length, multimodal_train_infos - ) - self.logger.info(f"Prepared {len(data_batches)} training data batches") - - with timer("training", step_timer_dict): - workers_log_item: List[WorkerLogItem] = ray.get( - self._train_controller.fit.remote( - data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx - ) - ) - train_log_info: TrainInfo = { - "data_info": data_info, - "workers_log_item": workers_log_item, - } - return train_log_info - - def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict): - """Synchronizes weights and saves checkpoints.""" - with timer("save_ckpt", step_timer_dict): - ray.get(self._train_controller.offload.remote(target="optimizer")) - self._maybe_save_hf() - self._maybe_save_checkpoint() - - with timer("sync_weight", step_timer_dict): - bind_train_rollout(train_controller=self._train_controller, env_controller=self._rollout_env_controller) - ray.get(self._rollout_env_controller.onload_weights.remote()) - ray.get(self._train_controller.update_weights.remote()) - self.logger.info("Model weights synchronized successfully.") - ray.get(self._train_controller.offload.remote(target="model")) - ray.get(self._rollout_env_controller.onload_kvcache.remote()) - - def _evaluate_step(self, rollout_idx: int, step_timer_dict: dict) -> dict[str, float]: - """Performs an evaluation step.""" - eval_log_info = {} - if self._enable_evaluate and self._evaluator and rollout_idx % self._eval_step == 0: - with timer("evaluation", step_timer_dict): - scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) - trajectory_save_path = self.exp_dir / f"eval_{rollout_idx}_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path) - eval_log_info.update(scores) - return eval_log_info - - def fit(self): - """Run the RL training loop. - - This method executes the main rl training loop, iterating generating through the dataset and performing - training steps. It handles rollout, prepare training data, update policy , synchronize model weights, and - evaluation. - """ - self.logger.info("Start RL training") - if self._cur_step >= self._rollout_steps: - self.logger.info(f"Rollout steps {self._rollout_steps} reached, stop training") - return - - self._initial_evaluate() - - for rollout_idx in range(self._cur_step + 1, self._rollout_steps + 1): - self.logger.info(f"Rollout {rollout_idx}/{self._rollout_steps} start") - step_timer_dict = {} - with timer("step", step_timer_dict): - # 1. Rollout to generate experience - rollout_info = self._rollout_step(rollout_idx, step_timer_dict) - - if not self._debug_rollout: - # 2. Train on the generated experience - train_log_info = self._train_step( - rollout_idx, - rollout_info["data_groups"], - rollout_info["multimodal_train_infos"], - step_timer_dict, - ) - - # 3. Synchronize weights and save checkpoints - self._sync_weights_and_save(rollout_idx, step_timer_dict) - - # 4. Evaluate model performance - eval_log_info = self._evaluate_step(rollout_idx, step_timer_dict) - - self._log_step(rollout_idx, step_timer_dict, rollout_info, train_log_info, eval_log_info) - self._cur_step = rollout_idx - - self._exp_tracker.close() - - def _log_step( - self, - rollout_idx: int, - step_timer_dict: dict, - rollout_info: RolloutInfo, - train_info: TrainInfo, - eval_info: dict[str, float], - ): - all_scalars = {} - log_time_str = "" - trajectory_str = "" - eval_str = "" - if step_timer_dict: - all_scalars.update({f"time/{k}": v for k, v in step_timer_dict.items()}) - log_time_str = f"\nRollout {rollout_idx} finished and timing listed:\n" - log_time_str += "\n".join([f" - {k:<25}: {v:.2f}s" for k, v in step_timer_dict.items()]) - - if rollout_info: - all_scalars.update(rollout_info.get("task_time", {})) - all_scalars.update({f"async/{k}": v for k, v in rollout_info.get("replay_buffer_info", {}).items()}) - - if train_info: - all_scalars.update({f"response/{k}": v for k, v in train_info.get("data_info", {}).items()}) - trajectory_str = f"\nRollout {rollout_idx} data statistics:\n" - trajectory_str += "\n".join([f"- {k:<25}: {v:.4f}" for k, v in train_info.get("data_info", {}).items()]) - rank0_log_item = train_info["workers_log_item"][0] - rank0_rollout_is_metrics = rank0_log_item.get("rollout_is_metrics", {}) - rank0_mismatch_metrics = rank0_log_item.get("mismatch_metrics", {}) - rank0_rollout_entropy = rank0_log_item.get("rollout_entropy", 0.0) - all_scalars.update({f"rollout_is/{k}": v for k, v in rank0_rollout_is_metrics.items()}) - all_scalars.update({f"{k}": v for k, v in rank0_mismatch_metrics.items()}) - all_scalars.update({"entropy/rollout": rank0_rollout_entropy}) - all_scalars.update({"entropy/train": rank0_log_item["train_entropy"]}) - for worker_idx, log_item in enumerate(train_info["workers_log_item"]): - if not self._display_all_workers_log and worker_idx > 0: - break - mini_batch_metrics: dict[str, List[float]] = {} - for mini_batch_log in log_item["train_metrics"]: - rl_worker_log = mini_batch_log["loss_log"] | mini_batch_log["rl_other_log"] - for k, v in rl_worker_log.items(): - mini_batch_metrics.setdefault(k, []).append(cast(float, v)) - - for key, value in mini_batch_metrics.items(): - avg_value = sum(value) / len(value) - all_scalars.update({f"train_metrics/worker_{worker_idx}/step_avg_{key}": avg_value}) - - rank_sft_log = log_item["sft_train_metrics"] - for k, v in rank_sft_log.items(): - all_scalars.update({f"sft_train_metrics/worker_{worker_idx}/{k}": v}) - - if eval_info: - all_scalars.update({f"eval/{k}": v for k, v in eval_info.items()}) - eval_str = " ".join([f"{k}: {v:.4f}" for k, v in eval_info.items()]) - - self.logger.info(f"Rollout {rollout_idx}/{self._rollout_steps}{log_time_str} {trajectory_str} ") - if eval_str: - self.logger.info(f"Eval: {eval_str}") - self._exp_tracker.add_scalars(tag_scalar_dict=all_scalars, global_step=rollout_idx) - - def _log_mini_batch_metrics(self, workers_log_item: List[WorkerLogItem]): - train_start_step = self._global_train_step - for worker_idx, log_item in enumerate(workers_log_item): - for step_idx, mini_batch_log in enumerate(log_item["train_metrics"]): - if not self._display_all_workers_log and worker_idx > 0: - break - current_global_step = train_start_step + step_idx - metrics = mini_batch_log["loss_log"] | mini_batch_log["rl_other_log"] - - self._exp_tracker.add_scalars( - tag_scalar_dict={f"train_metrics/worker_{worker_idx}/{k}": v for k, v in metrics.items()}, - global_step=current_global_step, - ) - self._global_train_step += len(workers_log_item[0]["train_metrics"]) - - # TODO: advantage 是在 DataFlow 里算好,还是在 train controller 里算? - # 因为可能有根据 advantage 来判断数据能否进 rl 训练的需求。暂时先放在这 - def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_infos=None): - rewards_list = [] - advantages_list = [] - prompt_len_list = [] - response_len_list = [] - - data_batches = [] - is_multimodal = False - if multimodal_train_infos and len(multimodal_train_infos) > 0: - assert len(multimodal_train_infos) == len(data_groups), ( - f"{len(multimodal_train_infos)} vs {len(data_groups)}" - ) - is_multimodal = True - - for j, group in enumerate(data_groups): - if not is_valid_for_training(group): - self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") - continue - if is_multimodal: - multimodal_train_info = multimodal_train_infos[j] - else: - multimodal_train_info = None - - prompt_ids = group[0].data.extra_info["train_prompt_ids"] - rewards = [data.env.judger.reward["score"] for data in group] - rewards_list.extend(rewards) - rewards = torch.tensor(rewards, dtype=torch.float32) - advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8) - - prompt_repeat_k = len(group) - for i in range(prompt_repeat_k): - item = group[i].env.rollout.response - logprobs = None - if group[i].env.rollout.response_ids is not None: - response_ids = group[i].env.rollout.response_ids - if isinstance(response_ids, torch.Tensor): - response_ids = response_ids.flatten().tolist() - logprobs = group[i].env.rollout.logprobs - assert len(logprobs) == len(response_ids), f"{len(logprobs)} vs {len(response_ids)}" - # 只有 response 部分有 logprobs, 需要前面追加 - logprobs = [0] * (len(prompt_ids) - 1) + logprobs - else: - response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() - # 返回的 routed_experts 不包括 eos 的值,实际上也不需要,需要减一 - input_ids = prompt_ids + response_ids[:-1] - - prompt_len_list.append(len(prompt_ids)) - response_len_list.append(len(response_ids)) - advantages_list.extend([advantages[i]] * len(response_ids)) - - shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids - assert len(input_ids) <= pack_max_length, f"{len(input_ids)} vs {pack_max_length}" - input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) - shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) - - if logprobs is not None: - rollout_logprobs = torch.tensor(logprobs, dtype=torch.float32).unsqueeze(0) - assert rollout_logprobs.size() == shifted_labels.size(), ( - f"{rollout_logprobs.size()} vs {shifted_labels.size()}" - ) - else: - rollout_logprobs = None - - seq_ctx = get_train_seq_ctx(input_ids, multimodal_train_info, len(response_ids) - 1) - data_dict = { - "seq_ctx": seq_ctx, - "shifted_labels": shifted_labels, - "advantage": advantages[i].item(), - "rollout_logprobs": rollout_logprobs, - } - - if "routed_experts" in group[i].env.rollout.extra_info: - routed_experts = group[i].env.rollout.extra_info["routed_experts"] # n,layer*expert - seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert - - data_batches.append(data_dict) - random.shuffle(data_batches) - - rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float() - advantages_t = torch.tensor(advantages_list).float() if advantages_list else torch.tensor([0.0]).float() - prompt_len_t = torch.tensor(prompt_len_list).float() if prompt_len_list else torch.tensor([0.0]).float() - response_len_t = torch.tensor(response_len_list).float() if response_len_list else torch.tensor([0.0]).float() - - info_dict = { - "batch_size": len(rewards_list), - "rewards/mean": rewards_t.mean().item(), - "rewards/min": rewards_t.min().item(), - "rewards/max": rewards_t.max().item(), - "advantages/mean": advantages_t.mean().item(), - "advantages/min": advantages_t.min().item(), - "advantages/max": advantages_t.max().item(), - "response_len/mean": response_len_t.mean().item(), - "response_len/min": response_len_t.min().item(), - "response_len/max": response_len_t.max().item(), - "response_len/std": response_len_t.std().item(), - "prompt_len/mean": prompt_len_t.mean().item(), - "prompt_len/min": prompt_len_t.min().item(), - "prompt_len/max": prompt_len_t.max().item(), - } - return data_batches, info_dict - - def _save_trajectories(self, data_groups, save_path): - rewards = [] - - rollout_response_len_list = [] - version_dict = {i: 0 for i in range(self._dataflow_partial_rollout_step + 1)} - - # NOTE: Since we currently default to token-in token-out, the code for checking whether response_ids have Retokenization Drift is commented out. - # If you need to debug, you can uncomment it. - # mismatch_token_ids_count = 0 - # response_len_list = [] - for group in data_groups: - if not is_valid_for_training(group): - self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") - continue - for data in group: - rewards.append(data.env.judger.reward["score"]) - if data.env.rollout.response_ids is not None: - if isinstance(data.env.rollout.response_ids, torch.Tensor): - response_ids = data.env.rollout.response_ids.flatten().tolist() - else: - response_ids = data.env.rollout.response_ids - rollout_response_len_list.append(len(response_ids)) - # response_str = self.tokenizer.decode(response_ids, skip_special_tokens=False) - # revert_encode_response_ids = self.tokenizer.encode(response_str, add_special_tokens=False) - - # response_str_to_ids = self.tokenizer.encode(data.env.rollout.response, add_special_tokens=False) - # response_len_list.append(len(response_str_to_ids)) - - # if response_ids != revert_encode_response_ids or response_ids != response_str_to_ids: - # mismatch_token_ids_count += 1 - else: - response_ids = self.tokenizer.encode(data.env.rollout.response, add_special_tokens=False) - rollout_response_len_list.append(len(response_ids)) - - version = data.uid.version - if version not in version_dict: - version_dict[version] = 0 - version_dict[version] += 1 - - rewards_tensor = torch.tensor(rewards).float() - rollout_response_lens: torch.Tensor = torch.tensor([0.0]).float() - if len(rollout_response_len_list) > 0: - rollout_response_lens = torch.tensor(rollout_response_len_list).float() - - _count = 0 - with open(save_path, "w", encoding="utf-8") as f: - item = { - "reward_mean": rewards_tensor.mean().item(), - "reward_std": rewards_tensor.std().item(), - "reward_max": rewards_tensor.max().item(), - "reward_min": rewards_tensor.min().item(), - "response_len_mean": rollout_response_lens.mean().item(), - "response_len_std": rollout_response_lens.std().item(), - "response_len_max": rollout_response_lens.max().item(), - "response_len_min": rollout_response_lens.min().item(), - "total_len": len(rewards), - "versions": version_dict, - # "mismatch_token_ids_count": mismatch_token_ids_count, - } - json.dump(item, f, ensure_ascii=False, indent=2) - f.write("\n") - for group in data_groups: - if not is_valid_for_training(group): - self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") - continue - for data in group: - item = { - "action_id": data.uid.action_id, - "prompt": data.data.extra_info["raw_prompt"], - "response": data.env.rollout.response, - "versioned_response": data.env.rollout.versioned_response, - # "response_ids": str(data.env.rollout.response_ids), - # "versioned_response_ids": str(data.env.rollout.versioned_response_ids), - "response_len": rollout_response_len_list[_count], - "versioned_response_len": data.env.rollout.versioned_num_return_tokens, - "label": data.data.reward_model["ground_truth"], - "reward": data.env.judger.reward["score"], - "version": data.uid.version, - "finish_reason": data.env.rollout.finish_reason, - } - json.dump(item, f, ensure_ascii=False, indent=2) - f.write("\n") - _count += 1 - - def _load_trajectories(self, save_path): - data_groups = [] - with open(save_path) as f: - for line in f: - item = json.loads(line) - messages = item["messages"] - responses = item["response"] - rewards = item["reward"] - group = [] - for response, reward in zip(responses, rewards): - group.append( - { - "messages": messages, - "response_str": response, - "reward": reward, - } - ) - data_groups.append(group) - return data_groups - - def _compute_metrics(self, data_groups): - correctness = [1 if data[0]["reward"] > 0 else 0 for data in data_groups] - acc = sum(correctness) / len(correctness) - return acc - - def _maybe_save_hf(self): - if self._hf_interval is None: - return - - assert self._load_from_hf, ( - "Only support saving to Huggingface format when loading from Huggingface! " - "You meet this error means `load_from` of trainer is not a Huggingface model path." - ) - - if (self.cur_step + 1) % self._hf_interval != 0 and (self.cur_step + 1) != self._rollout_steps: - return - - save_hf_path = self.exp_dir / f"hf-{self.cur_step + 1}" - self.logger.info(f"Saving step {self.cur_step + 1} hf checkpoints to: {save_hf_path}") - self.meta.latest_exp.hf_checkpoint_list.append(str(save_hf_path)) - - if self._hf_max_keep is not None and len(self.meta.latest_exp.hf_checkpoint_list) > self._hf_max_keep: - deleted_hf_checkpoints = self.meta.latest_exp.hf_checkpoint_list[: -self._hf_max_keep] - self.meta.latest_exp.hf_checkpoint_list = self.meta.latest_exp.hf_checkpoint_list[-self._hf_max_keep :] - for hf_dir in deleted_hf_checkpoints: - rmtree(hf_dir) - - ray.get(self._train_controller.save_hf.remote(str(save_hf_path)), timeout=self._ray_get_timeout) - if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): - self.tokenizer.save_pretrained(str(save_hf_path)) - - def _maybe_save_checkpoint(self): - ckp_interval = self._checkpoint_interval - if ckp_interval is None: - return - - if ckp_interval == -1: - return - else: - if (self.cur_step + 1) % ckp_interval != 0 or (self.cur_step + 1) == self._rollout_steps: - return - - checkpoint_path = self.exp_dir / self._CHECKPOINT_DIR / f"ckpt-step-{self.cur_step + 1}" - checkpoint_path.mkdir(parents=True, exist_ok=True) - - self.logger.info(f"Saving step {self.cur_step + 1} rollout dataflow to: {checkpoint_path}") - ray.get(self._rollout_dataflow.save.remote(str(checkpoint_path)), timeout=self._ray_get_timeout) - self.logger.info(f"Saving step {self.cur_step + 1} dcp checkpoints to: {checkpoint_path}") - ray.get( - self._train_controller.save.remote(str(checkpoint_path), self._checkpoint_no_save_optimizer), - timeout=self._ray_get_timeout, - ) - - # Update meta - current_exp = self.meta.latest_exp - ckp_list = current_exp.checkpoint_list - ckp_list.append(str(checkpoint_path)) - current_exp.cur_step = self.cur_step + 1 - current_exp.history[-1]["end"] = self.cur_step + 1 - - train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH - with train_state_path.open("w") as f: - f.write( - json.dumps( - { - "cur_step": self.cur_step + 1, - } - ) - ) - - # Delete checkpoints and update meta's checkpoint_list - ckp_maxkeep = self._checkpoint_maxkeep - if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep: - ckp_pop_num = len(ckp_list) - ckp_maxkeep - for _ in range(ckp_pop_num): - deleted_ckp = ckp_list.pop(0) - if Path(deleted_ckp).exists(): - rmtree(deleted_ckp, ignore_errors=True) - - meta_path = self.work_dir / self.META_PATH - with meta_path.open("w") as f: - f.write(self.meta.model_dump_json(indent=2)) - - def _init_logger(self, work_dir: Path): - # Logging system maybe need better design - logger = get_logger(log_dir=work_dir, tag="RLTrainer") - return logger - - def _set_deterministic(self): - if XTUNER_DETERMINISTIC: - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - torch.use_deterministic_algorithms(True, warn_only=True) - - def _set_random_seed(self, seed: int): - set_random_seed(seed) - - def _init_xtuner_meta(self, work_dir: Path, resume: bool) -> XTunerMeta: - if not work_dir.exists(): - work_dir.mkdir(parents=True, exist_ok=True) - - meta_path = work_dir / self.META_PATH - if not meta_path.exists(): - meta = XTunerMeta(exps=[]) - with open(meta_path, "w") as f: - f.write(meta.model_dump_json(indent=2)) - - meta = cast(XTunerMeta, XTunerMeta.model_validate(load(meta_path, file_format="json"))) - - resume = resume and bool(meta.exps) - - if resume and meta.exps: - latest_exp = meta.exps[-1] - latest_exp_history = latest_exp.history[-1] - - begin = cast(int, latest_exp_history.get("end") or latest_exp_history["begin"]) - exp_dir = Path(latest_exp.exp_dir) - git_dir = exp_dir / f"git-info-begin-{begin}" - - if not git_dir: - git_dir.mkdir(parents=True, exist_ok=True) - - staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" - - if not git_dir.exists(): - git_dir.mkdir(parents=True, exist_ok=True) - commit = record_git_info(staged_path, unstaged_path) - git_info = GitInfo( - commit=commit, - staged=str(staged_path), - unstaged=str(unstaged_path), - ) - - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - new_exp_history = ExpHistory( - begin=begin, - timestamp=timestamp, - git_info=git_info, - ) - latest_exp.history.append(new_exp_history) - else: - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - exp_dir = work_dir / timestamp - git_dir = Path(f"{exp_dir}/git-info-begin-{0}") - - if not git_dir.exists(): - git_dir.mkdir(parents=True, exist_ok=True) - - staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" - commit = record_git_info(staged_path, unstaged_path) - git_info = GitInfo( - commit=commit, - staged=str(staged_path), - unstaged=str(unstaged_path), - ) - - new_history = ExpHistory( - begin=0, - timestamp=timestamp, - git_info=git_info, - ) - new_exp = ExpInfo(history=[new_history], exp_dir=str(exp_dir)) - meta.exps.append(new_exp) - return meta - - @property - def work_dir(self) -> Path: - return self._work_dir - - @property - def exp_dir(self) -> Path: - return Path(self._meta.latest_exp.exp_dir) - - @property - def meta(self) -> XTunerMeta: - return self._meta - - @property - def cur_step(self): - return self._cur_step - - @property - def total_epoch(self): - return self._total_epochs - - @property - def rollout_steps(self): - return self._rollout_steps diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 8f55f8c523..a30ab71032 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -162,6 +162,70 @@ def get_exp_by_checkpoint(self, checkpoint: str) -> ExpInfo | None: return exp return None + @classmethod + def build(cls, work_dir: Path, meta_filename: str, resume: bool) -> "XTunerMeta": + """Create or load meta from work_dir and optionally start a new exp or + resume. + + Single-process helper (e.g. for rl_colocate_trainer). For distributed training use the trainer's + _init_xtuner_meta. + """ + if not work_dir.exists(): + work_dir.mkdir(parents=True, exist_ok=True) + + meta_path = work_dir / meta_filename + if not meta_path.exists(): + meta = cls(exps=[]) + with open(meta_path, "w") as f: + f.write(meta.model_dump_json(indent=2)) + + meta = cast(XTunerMeta, cls.model_validate(load(meta_path, file_format="json"))) + resume = resume and bool(meta.exps) + + if resume and meta.exps: + latest_exp = meta.exps[-1] + latest_exp_history = latest_exp.history[-1] + begin = cast(int, latest_exp_history.get("end") or latest_exp_history["begin"]) + exp_dir = Path(latest_exp.exp_dir) + git_dir = exp_dir / f"git-info-begin-{begin}" + if not git_dir.exists(): + git_dir.mkdir(parents=True, exist_ok=True) + staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" + commit = record_git_info(staged_path, unstaged_path) + git_info = GitInfo( + commit=commit, + staged=str(staged_path), + unstaged=str(unstaged_path), + ) + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + new_exp_history = ExpHistory( + begin=begin, + timestamp=timestamp, + git_info=git_info, + ) + latest_exp.history.append(new_exp_history) + else: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + exp_dir = work_dir / timestamp + git_dir = Path(f"{exp_dir}/git-info-begin-0") + if not git_dir.exists(): + git_dir.mkdir(parents=True, exist_ok=True) + staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" + commit = record_git_info(staged_path, unstaged_path) + git_info = GitInfo( + commit=commit, + staged=str(staged_path), + unstaged=str(unstaged_path), + ) + new_history = ExpHistory( + begin=0, + timestamp=timestamp, + git_info=git_info, + ) + new_exp = ExpInfo(history=[new_history], exp_dir=str(exp_dir)) + meta.exps.append(new_exp) + return meta + class ResumeConfig(BaseModel): model_config = ConfigDict(extra="forbid") @@ -1332,6 +1396,7 @@ def _init_dist(self, backend: str | None = None): dist.all_reduce(warmup_tensor) def _init_xtuner_meta(self, work_dir: Path, auto_resume: bool) -> XTunerMeta: + # TODO: simplify with XTunerMeta.build() of dist version if not work_dir.exists(): if self.rank == 0: work_dir.mkdir(parents=True, exist_ok=True) diff --git a/xtuner/v1/utils/__init__.py b/xtuner/v1/utils/__init__.py index 6918937839..1704607a2b 100644 --- a/xtuner/v1/utils/__init__.py +++ b/xtuner/v1/utils/__init__.py @@ -1,3 +1,4 @@ +from .cache import CacheDict, CacheObj from .compile import maybe_compile from .config import Config from .device import get_device, get_torch_device_module @@ -16,6 +17,7 @@ get_padding_length, is_hf_model_path, record_git_info, + set_deterministic, ) from .pad import pad_to_max_length, pad_to_multiple_of from .profile import profile_time, profile_time_and_memory, timer, timer_logger @@ -59,4 +61,7 @@ "ray_method", "profile_time", "clean_param_name", + "CacheDict", + "CacheObj", + "set_deterministic", ] diff --git a/xtuner/v1/utils/cache.py b/xtuner/v1/utils/cache.py new file mode 100644 index 0000000000..2106d33172 --- /dev/null +++ b/xtuner/v1/utils/cache.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Minimal cache types shared by data_proto and datasets to avoid circular +imports.""" + +from typing_extensions import TypedDict + + +class CacheDict(TypedDict, total=False): + num_tokens: int + + +class CacheObj: + num_tokens: int | None = None diff --git a/xtuner/v1/utils/convert_gsm8k_with_tool.py b/xtuner/v1/utils/convert_gsm8k_with_tool.py new file mode 100644 index 0000000000..d5ba33f403 --- /dev/null +++ b/xtuner/v1/utils/convert_gsm8k_with_tool.py @@ -0,0 +1,87 @@ +"""Preprocess the GSM8k dataset to parquet format.""" + +import argparse +import os +import re + +import datasets + + +def extract_solution(solution_str): + solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) + assert solution is not None + final_solution = solution.group(0) + final_solution = final_solution.split("#### ")[1].replace(",", "") + return final_solution + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-dir", default="openai/gsm8k") + parser.add_argument("--out-dir") + + args = parser.parse_args() + + dataset = datasets.load_dataset(args.input_dir, "default") + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + instruction_following = 'Let\'s think step by step and output the final answer after "####".' + + # add a row to each data item that represents a unique id + # Adapted from https://github.com/verl-project/verl/blob/c37d4d53850906aced4c071666340ec26966d707/examples/data_preprocess/gsm8k_tool_agent_loop.py#L62 + def make_map_fn(split): + def process_fn(example, idx): + question_raw = example.pop("question") + + question = question_raw + " " + instruction_following + + answer_raw = example.pop("answer") + solution = extract_solution(answer_raw) + data = { + "data_source": "openai/gsm8k", + "agent_name": "tool_agent", + "prompt": [ + { + "role": "system", + "content": ( + "You are a math expert. You are given a question and you need to solve it step by step. " + "Reasoning step by step before any tool call. " + "You should use the `calc_gsm8k_reward` tool after step by step solving the question, " + "before generate final answer at least once and refine your answer if necessary. " + "Put your final answer in the format of `#### `." + ), + }, + { + "role": "user", + "content": question, + }, + ], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": { + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + "need_tools_kwargs": True, + "tools_kwargs": { + "calc_gsm8k_reward": { + "create_kwargs": {"ground_truth": solution}, + }, + }, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + out_dir = args.out_dir + + os.makedirs(out_dir, exist_ok=True) + train_dataset.to_json(os.path.join(out_dir, "train.jsonl"), orient="records", lines=True) + test_dataset.to_json(os.path.join(out_dir, "test.jsonl"), orient="records", lines=True) diff --git a/xtuner/v1/utils/misc.py b/xtuner/v1/utils/misc.py index e9aaf82bb5..44f7339a2a 100644 --- a/xtuner/v1/utils/misc.py +++ b/xtuner/v1/utils/misc.py @@ -9,6 +9,7 @@ from types import FunctionType from typing import Annotated +import torch from huggingface_hub import constants from mmengine import is_installed @@ -24,6 +25,13 @@ logger = get_logger() XTUNER_DETERMINISTIC = os.getenv("XTUNER_DETERMINISTIC") == "true" + +def set_deterministic(): + if XTUNER_DETERMINISTIC: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True, warn_only=True) + + # https://github.com/python/cpython/issues/82300#issuecomment-2169035092 if sys.version_info >= (3, 13): SharedMemory = _mpshm.SharedMemory diff --git a/xtuner/v1/utils/processing_utils.py b/xtuner/v1/utils/processing_utils.py new file mode 100644 index 0000000000..4e378be9ca --- /dev/null +++ b/xtuner/v1/utils/processing_utils.py @@ -0,0 +1,23 @@ +from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin + +from .logger import get_logger + + +logger = get_logger() + + +def load_tokenizer(name_or_path: str, **kwargs): + return AutoTokenizer.from_pretrained(name_or_path, **kwargs) + + +def load_processor(name_or_path: str, **kwargs): + try: + proc = AutoProcessor.from_pretrained(name_or_path, **kwargs) + except (OSError, ValueError) as e: + logger.warning(f"Failed to load processor from {name_or_path}: {e}") + proc = None + + if isinstance(proc, PreTrainedTokenizerBase) or not isinstance(proc, ProcessorMixin): + proc = None + + return proc diff --git a/xtuner/v1/utils/rl_test_utils.py b/xtuner/v1/utils/rl_test_utils.py index b4677b5638..d6fec9c82a 100644 --- a/xtuner/v1/utils/rl_test_utils.py +++ b/xtuner/v1/utils/rl_test_utils.py @@ -2,82 +2,14 @@ import multiprocessing import os import time -from typing import Any, Callable, Dict, List +from typing import Any, Dict, List -import httpx import requests import uvicorn from fastapi import FastAPI from pydantic import BaseModel, ConfigDict, Field -from xtuner.v1.ray.judger.native import NativeJudgerConfig - -# try: -from xtuner.v1.ray.rollout.lmdeploy import LMDeployWorker -from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult - - -# except ImportError: -# LMDeployWorker = object -class MockTimeoutRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - raise httpx.TimeoutException("Mocked timeout error") - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked timeout exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - -class MockRequestErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - raise httpx.RequestError("Mocked httpx request error", request=req) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked request error exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - -class MockClientErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - res = httpx.Response(400, request=req) - raise httpx.HTTPStatusError("Mocked client error", request=req, response=res) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked client exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - -class MockServerErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - res = httpx.Response(500, request=req) - raise httpx.HTTPStatusError("Mocked server error", request=req, response=res) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked server exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override +from xtuner.v1.rl.judger.native import JudgerConfig app = FastAPI() @@ -113,7 +45,7 @@ class JudgeResponse(BaseModel): @app.post("/judge", response_model=JudgeResponse) async def judge(request: JudgeRequest): - from xtuner.v1.ray.judger.gsm8k import compute_reward + from xtuner.v1.rl.judger.gsm8k import compute_reward """Endpoint to compute reward for a given response and label.""" # The compute_reward function returns a float, we wrap it in a dict @@ -158,17 +90,7 @@ def stop(self): print("Server stopped.") -def custom_postprocessor_for_gsm8k(result): - from xtuner.v1.data_proto.rl_data import RLJudgerResponseItem - - if not isinstance(result, list): - result = [result] - judger_response_item = [RLJudgerResponseItem(uid=result[i]["uid"], reward=result[i]) for i in range(len(result))] - return judger_response_item - - -class GSM8KRemoteJudgerConfig(NativeJudgerConfig): +class GSM8KRemoteJudgerConfig(JudgerConfig): judger_name: str - remote_url: str - extra_info: dict = {"score": 1, "format_score": 0} - postprocess_func: Callable = custom_postprocessor_for_gsm8k + reward_handler: str + extra_info: dict = Field(default_factory=lambda: {"score": 1, "format_score": 0}) diff --git a/xtuner/v1/utils/type_helper.py b/xtuner/v1/utils/type_helper.py index cca7fcf7de..3b86ebc3fd 100644 --- a/xtuner/v1/utils/type_helper.py +++ b/xtuner/v1/utils/type_helper.py @@ -40,7 +40,21 @@ def ray_method(f: Callable[Concatenate[C, P], Awaitable[T]]) -> RemoteMethod[P, def ray_method(f: Callable[Concatenate[C, P], T]) -> RemoteMethod[P, T]: ... -def ray_method(f): +def ray_method(f=None, *, num_returns=1, concurrency_group=None): + """Decorator for Ray actor methods. + + Compatible with Ray versions that require at least one of num_returns or concurrency_group. Ray.method() must be + called with keyword args only, then applied to the function: ray.method(num_returns=1)(f). + """ import ray - return ray.method(f) # type: ignore[ret-type] + kwargs = {"num_returns": num_returns} + if concurrency_group is not None: + kwargs["concurrency_group"] = concurrency_group + + if f is None: + # Called as @ray_method(num_returns=...) or @ray_method(concurrency_group=...) + return lambda fn: ray.method(**kwargs)(fn) + + # Called as @ray_method + return ray.method(**kwargs)(f) # type: ignore[ret-type]