diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index b0fefdb7f..357b0b0d6 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -973,7 +973,7 @@ index 00bd68755..5a3ca8a67 100644 def get_routed_experts( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py -index 4cbfed6f9..cd6c825f6 100644 +index 4cbfed6f9..88b452744 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -499,7 +499,7 @@ class CompressedTensorsConfig(QuantizationConfig): @@ -985,6 +985,16 @@ index 4cbfed6f9..cd6c825f6 100644 def _is_mxint4a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: input_quant_none = input_quant is None +@@ -968,6 +968,9 @@ class CompressedTensorsFusedMoEMethod(FusedMoEMethodBase): + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + ++ def restore_weights_before_loading(self, layer: torch.nn.Module) -> None: ++ layer.scheme.restore_weights_before_loading(layer) ++ + def create_weights( + self, + layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py index 6264f36d0..bef31a374 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py @@ -2438,3 +2448,17 @@ index 4636128fa..a9b61df39 100644 } DENY_CLASSES = { +diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py +index 3be16446e..1b2371c83 100644 +--- a/python/sglang/srt/utils/weight_checker.py ++++ b/python/sglang/srt/utils/weight_checker.py +@@ -69,6 +69,9 @@ def _check_tensors( + actual_should_compare, + actual, + ) in zip(expect_tensors, actual_tensors, strict=True): ++ if ".cos_sin_cache" in expect_name: ++ # skip cos/sin cache which is deterministic from shape and dtype and may have different shapes due to different implementations. ++ continue + assert expect_name == actual_name, f"{expect_name=} {actual_name=}" + assert ( + expect_should_compare == actual_should_compare diff --git a/docs/en/developer_guide/debug.md b/docs/en/developer_guide/debug.md index 535a36d0c..4f8218a17 100644 --- a/docs/en/developer_guide/debug.md +++ b/docs/en/developer_guide/debug.md @@ -50,6 +50,48 @@ Specifically, slime currently provides the following parameters for separate deb When enabled, data will be loaded from `args.load_debug_rollout_data.format(rollout_id=rollout_id)`, and SGLang will not be initialized (automatically setting `debug_train_only=True`). This method allows you to fix the input for the training part to tune it, for example, by switching between different parallelization strategies. +## INT4 / Compressed-Tensors Quantization Checkpoint Issues + +When using INT4-quantized models (e.g., `compressed-tensors` with `W4A16`), the checkpoint's `config.json` contains a `quantization_config.ignore` list that specifies which parameters should **not** be quantized. During online weight updates (Megatron → SGLang), slime also reads this ignore list to decide which parameters to INT4-quantize. An incorrect ignore list can cause silent errors: + +1. **MoE router weights (`mlp.gate.weight`) become all zeros** + + The MoE router weight (`mlp.gate.weight`, shape `[num_experts, hidden_size]`) is a plain 2D weight tensor, but it is **not** a Linear layer weight. If it is not in the ignore list, the online quantizer will INT4-quantize it into `weight_packed`, `weight_scale`, `weight_zero_point`, etc. However, SGLang does not expect quantized names for the router, so these parameters are silently skipped during `load_weights`, resulting in all-zero gate weights. + + **Fix**: Ensure `config.json` contains `"re:.*mlp\\.gate\\..*"` in the ignore list. + +2. **Other non-Linear 2D weights** + + Similar issues can occur with any 2D `.weight` tensor that is not a true Linear layer, such as `model.embed_tokens.weight`. Always verify the ignore list covers all non-Linear weights. + + **Recommended ignore patterns** (for GLM-style MoE models): + ```json + "ignore": [ + "lm_head", + "model.embed_tokens.weight", + "re:.*self_attn.*", + "re:.*mlp\\.shared_experts.*", + "re:.*mlp\\.gate_up_proj.*", + "re:.*mlp\\.gate_proj.*", + "re:.*mlp\\.up_proj.*", + "re:.*mlp\\.down_proj.*", + "re:.*eh_proj.*", + "re:.*mlp\\.gate\\..*" + ] + ``` + +3. **Missing safetensors shards** + + Conversion tools may occasionally produce an incomplete checkpoint (e.g., a missing `model-00010-of-00093.safetensors`). After conversion, always verify: + - The number of `.safetensors` files matches the expected count. + - The `model.safetensors.index.json` contains entries for every layer. + - Spot-check that critical layers (e.g., the first MoE layer) have the expected number of keys. + +4. **How to diagnose** + + - Use `--check-weight-update-equal` to verify that weights after a Megatron → SGLang sync match the expected values. If a parameter shows all zeros on the SGLang side, it was likely incorrectly quantized or missing from the checkpoint. + - Use `--debug-rollout-only` with a small number of GPUs to quickly test whether SGLang can generate coherent text from the quantized checkpoint alone. + ## Debug sglang illegal memory access (IMA) When running large scale RL, we will occationally meet the IMA in SGLang, there are some debug suggestions based on our experience: diff --git a/docs/zh/developer_guide/debug.md b/docs/zh/developer_guide/debug.md index f243d054b..d964a4e82 100644 --- a/docs/zh/developer_guide/debug.md +++ b/docs/zh/developer_guide/debug.md @@ -48,6 +48,48 @@ slime 支持将训练部分和推理部分分开进行调试,从而实现: 开启后,会从 `args.load_debug_rollout_data.format(rollout_id=rollout_id)` 来加载数据,并且不会初始化 sglang(自动设置 `debug_train_only=True`)。可以以这种方式来固定训练部分的输入,对训练部分进行调优,例如切换各种并行。 +## INT4 / Compressed-Tensors 量化 Checkpoint 问题 + +使用 INT4 量化模型(如 `compressed-tensors` 的 `W4A16`)时,checkpoint 的 `config.json` 中有一个 `quantization_config.ignore` 列表,指定哪些参数**不**做量化。在线权重更新(Megatron → SGLang)时,slime 也会读取这个 ignore list 来决定哪些参数需要 INT4 量化。ignore list 不正确会导致静默错误: + +1. **MoE 路由权重(`mlp.gate.weight`)变成全零** + + MoE 的路由权重(`mlp.gate.weight`,shape `[num_experts, hidden_size]`)是一个普通的 2D weight tensor,但它**不是** Linear 层的权重。如果它不在 ignore list 中,在线量化器会把它 INT4 量化为 `weight_packed`、`weight_scale`、`weight_zero_point` 等。然而 SGLang 不会以量化名称来加载路由权重,因此这些参数在 `load_weights` 时被静默跳过,导致 gate 权重全零。 + + **修复方法**:确保 `config.json` 的 ignore list 中包含 `"re:.*mlp\\.gate\\..*"`。 + +2. **其他非 Linear 的 2D 权重** + + 类似问题可能出现在任何不是真正 Linear 层的 2D `.weight` tensor 上,例如 `model.embed_tokens.weight`。务必检查 ignore list 覆盖了所有非 Linear 权重。 + + **推荐的 ignore 配置**(以 GLM 系 MoE 模型为例): + ```json + "ignore": [ + "lm_head", + "model.embed_tokens.weight", + "re:.*self_attn.*", + "re:.*mlp\\.shared_experts.*", + "re:.*mlp\\.gate_up_proj.*", + "re:.*mlp\\.gate_proj.*", + "re:.*mlp\\.up_proj.*", + "re:.*mlp\\.down_proj.*", + "re:.*eh_proj.*", + "re:.*mlp\\.gate\\..*" + ] + ``` + +3. **safetensors 分片缺失** + + 转换工具偶尔可能产出不完整的 checkpoint(例如缺少 `model-00010-of-00093.safetensors`)。转换完成后,务必检查: + - `.safetensors` 文件数量是否与预期一致。 + - `model.safetensors.index.json` 中是否包含所有 layer 的条目。 + - 抽查关键 layer(如第一个 MoE layer)的 key 数量是否正确。 + +4. **如何排查** + + - 使用 `--check-weight-update-equal` 验证 Megatron → SGLang 权重同步后的值是否正确。如果某个参数在 SGLang 侧全为零,说明它可能被错误量化或在 checkpoint 中缺失。 + - 使用 `--debug-rollout-only` 配合少量 GPU,快速测试 SGLang 能否从量化 checkpoint 正常生成文本。 + ## Debug sglang illegal memory access (IMA) 在进行大规模 RL 时,不时会遇到 SGLang IMA 的问题,以下是我们的一些 debug 建议: diff --git a/slime/utils/debug_utils/__init__.py b/slime/utils/debug_utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/slime/utils/debug_utils/display_debug_rollout_data.py b/slime/utils/debug_utils/display_debug_rollout_data.py deleted file mode 100644 index a954e679a..000000000 --- a/slime/utils/debug_utils/display_debug_rollout_data.py +++ /dev/null @@ -1,73 +0,0 @@ -import json -from pathlib import Path -from types import SimpleNamespace -from typing import Annotated - -import torch -import typer - -from slime.ray.rollout import compute_perf_metrics_from_samples -from slime.utils.types import Sample - -_WHITELIST_KEYS = [ - "group_index", - "index", - "prompt", - "response", - "response_length", - "label", - "reward", - "status", - "metadata", -] - - -def main( - # Deliberately make this name consistent with main training arguments - load_debug_rollout_data: Annotated[str, typer.Option()], - show_metrics: bool = True, - show_samples: bool = True, - category: list[str] = None, -): - if category is None: - category = ["train", "eval"] - for rollout_id, path in _get_rollout_dump_paths(load_debug_rollout_data, category): - print("-" * 80) - print(f"{rollout_id=} {path=}") - print("-" * 80) - - pack = torch.load(path) - sample_dicts = pack["samples"] - - if show_metrics: - # TODO read these configs from dumps - args = SimpleNamespace( - advantage_estimator="grpo", - reward_key=None, - log_reward_category=None, - ) - sample_objects = [Sample.from_dict(s) for s in sample_dicts] - metrics = compute_perf_metrics_from_samples(args, sample_objects) - print("metrics", metrics) - - if show_samples: - for sample in sample_dicts: - print(json.dumps({k: v for k, v in sample.items() if k in _WHITELIST_KEYS})) - - -def _get_rollout_dump_paths(load_debug_rollout_data: str, categories: list[str]): - # may improve later - for rollout_id in range(1000): - for category in categories: - prefix = { - "train": "", - "eval": "eval_", - }[category] - path = Path(load_debug_rollout_data.format(rollout_id=f"{prefix}{rollout_id}")) - if path.exists(): - yield rollout_id, path - - -if __name__ == "__main__": - """python -m slime.utils.debug_utils.display_debug_rollout_data --load-debug-rollout-data ...""" - typer.run(main) diff --git a/slime/utils/debug_utils/replay_reward_fn.py b/slime/utils/debug_utils/replay_reward_fn.py deleted file mode 100644 index 4fd2a860e..000000000 --- a/slime/utils/debug_utils/replay_reward_fn.py +++ /dev/null @@ -1,50 +0,0 @@ -import asyncio -from typing import Annotated - -import ray -import torch -import typer - -from slime.utils.misc import load_function -from slime.utils.types import Sample - - -def _truncate(text, max_len=200): - """Truncate text and add ellipsis if too long.""" - if text is None: - return None - text = str(text).replace("\n", "\\n") - if len(text) > max_len: - return text[:max_len] + "..." - return text - - -def main( - rollout_data_path: Annotated[str, typer.Option()], - custom_rm_path: Annotated[str, typer.Option()], -): - if not ray.is_initialized(): - ray.init() - - pack = torch.load(rollout_data_path) - samples = [Sample.from_dict(s) for s in pack["samples"]] - asyncio.run(_main_async(samples=samples, custom_rm_path=custom_rm_path)) - - -async def _main_async(samples, custom_rm_path): - rm_function = load_function(custom_rm_path) - rewards = await asyncio.gather(*[rm_function(None, sample) for sample in samples]) - - for i, (sample, reward) in enumerate(zip(samples, rewards, strict=True)): - print("-" * 60) - print(f"Sample {i + 1}/{len(samples)}") - print(f" Index: {sample.index}") - print(f" Status: {sample.status}") - print(f" Reward: {reward}") - print(f" Prompt: {_truncate(sample.prompt, 200)}") - print(f" Response: {_truncate(sample.response, 200)}") - print("-" * 60) - - -if __name__ == "__main__": - typer.run(main) diff --git a/slime/utils/debug_utils/send_to_sglang.py b/slime/utils/debug_utils/send_to_sglang.py deleted file mode 100644 index 454e03ef1..000000000 --- a/slime/utils/debug_utils/send_to_sglang.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -import json -from typing import Annotated - -import typer -from openai import AsyncOpenAI - -from slime.utils.data import read_file - - -# can unify w/ sglang_rollout.py later, e.g. add RM, if needed -def main( - prompt_data: Annotated[str, typer.Option()], - url: Annotated[str, typer.Option()] = "http://localhost:30000/v1", - input_key: Annotated[str, typer.Option()] = "input", - n_samples_per_prompt: Annotated[int, typer.Option()] = 1, - rollout_max_response_len: Annotated[int, typer.Option()] = 1024, - rollout_temperature: Annotated[float, typer.Option()] = 1.0, - rollout_top_p: Annotated[float, typer.Option()] = 1.0, -): - """ - Minimally send prompts to SGLang using OpenAI endpoints with arguments in the same format as main Slime. - - Example usage: - python -m slime.utils.debug_utils.send_to_sglang --prompt-data /root/datasets/aime-2024/aime-2024.jsonl --input-key prompt --n-samples-per-prompt 16 --rollout-max-response-len 32768 --rollout-temperature 1 --rollout-top-p 1 - """ - - async def _main_async(): - tasks = [ - asyncio.create_task(_run_one(row, row_index=row_index, repeat_index=repeat_index)) - for row_index, row in enumerate(read_file(prompt_data)) - for repeat_index in range(n_samples_per_prompt) - ] - outputs = await asyncio.gather(*tasks) - for output in outputs: - print(json.dumps(output)) - - async def _run_one(row, row_index: int, repeat_index: int): - resp = await client.chat.completions.create( - messages=row[input_key], - model="dummy_model", - max_tokens=rollout_max_response_len, - temperature=rollout_temperature, - top_p=rollout_top_p, - ) - return dict( - row_index=row_index, - repeat_index=repeat_index, - **row, - response=resp.choices[0].message.content, - ) - - client = AsyncOpenAI(api_key="dummy_key", base_url=url) - asyncio.run(_main_async()) - - -if __name__ == "__main__": - typer.run(main) diff --git a/tools/convert_hf_to_int4_direct.py b/tools/convert_hf_to_int4_direct.py index e741b802d..613f65595 100644 --- a/tools/convert_hf_to_int4_direct.py +++ b/tools/convert_hf_to_int4_direct.py @@ -283,7 +283,7 @@ def parse_args(): parser.add_argument("--model-dir", type=str, required=True, help="local BF16 path") parser.add_argument("--save-dir", type=str, required=True) parser.add_argument("--group-size", type=int, default=32, help="Group Size") - parser.add_argument("--is-symmetric", type=bool, default=True, help="Is Symmetric") + parser.add_argument("--is-symmetric", action="store_true", help="Whether to use symmetric quantization") parser.add_argument( "--ignore-rules", nargs="+",