Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion docker/patch/latest/sglang.patch
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ index 00bd68755..5a3ca8a67 100644

def get_routed_experts(
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
index 4cbfed6f9..cd6c825f6 100644
index 4cbfed6f9..88b452744 100644
--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -499,7 +499,7 @@ class CompressedTensorsConfig(QuantizationConfig):
Expand All @@ -985,6 +985,16 @@ index 4cbfed6f9..cd6c825f6 100644

def _is_mxint4a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
@@ -968,6 +968,9 @@ class CompressedTensorsFusedMoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer)

+ def restore_weights_before_loading(self, layer: torch.nn.Module) -> None:
+ layer.scheme.restore_weights_before_loading(layer)
+
def create_weights(
self,
layer: torch.nn.Module,
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py
index 6264f36d0..bef31a374 100644
--- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py
Expand Down Expand Up @@ -2438,3 +2448,17 @@ index 4636128fa..a9b61df39 100644
}

DENY_CLASSES = {
diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py
index 3be16446e..1b2371c83 100644
--- a/python/sglang/srt/utils/weight_checker.py
+++ b/python/sglang/srt/utils/weight_checker.py
@@ -69,6 +69,9 @@ def _check_tensors(
actual_should_compare,
actual,
) in zip(expect_tensors, actual_tensors, strict=True):
+ if ".cos_sin_cache" in expect_name:
+ # skip cos/sin cache which is deterministic from shape and dtype and may have different shapes due to different implementations.
+ continue
assert expect_name == actual_name, f"{expect_name=} {actual_name=}"
assert (
expect_should_compare == actual_should_compare
42 changes: 42 additions & 0 deletions docs/en/developer_guide/debug.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,48 @@ Specifically, slime currently provides the following parameters for separate deb

When enabled, data will be loaded from `args.load_debug_rollout_data.format(rollout_id=rollout_id)`, and SGLang will not be initialized (automatically setting `debug_train_only=True`). This method allows you to fix the input for the training part to tune it, for example, by switching between different parallelization strategies.

## INT4 / Compressed-Tensors Quantization Checkpoint Issues

When using INT4-quantized models (e.g., `compressed-tensors` with `W4A16`), the checkpoint's `config.json` contains a `quantization_config.ignore` list that specifies which parameters should **not** be quantized. During online weight updates (Megatron → SGLang), slime also reads this ignore list to decide which parameters to INT4-quantize. An incorrect ignore list can cause silent errors:

1. **MoE router weights (`mlp.gate.weight`) become all zeros**

The MoE router weight (`mlp.gate.weight`, shape `[num_experts, hidden_size]`) is a plain 2D weight tensor, but it is **not** a Linear layer weight. If it is not in the ignore list, the online quantizer will INT4-quantize it into `weight_packed`, `weight_scale`, `weight_zero_point`, etc. However, SGLang does not expect quantized names for the router, so these parameters are silently skipped during `load_weights`, resulting in all-zero gate weights.

**Fix**: Ensure `config.json` contains `"re:.*mlp\\.gate\\..*"` in the ignore list.

2. **Other non-Linear 2D weights**

Similar issues can occur with any 2D `.weight` tensor that is not a true Linear layer, such as `model.embed_tokens.weight`. Always verify the ignore list covers all non-Linear weights.

**Recommended ignore patterns** (for GLM-style MoE models):
```json
"ignore": [
"lm_head",
"model.embed_tokens.weight",
"re:.*self_attn.*",
"re:.*mlp\\.shared_experts.*",
"re:.*mlp\\.gate_up_proj.*",
"re:.*mlp\\.gate_proj.*",
"re:.*mlp\\.up_proj.*",
"re:.*mlp\\.down_proj.*",
"re:.*eh_proj.*",
"re:.*mlp\\.gate\\..*"
]
```

3. **Missing safetensors shards**

Conversion tools may occasionally produce an incomplete checkpoint (e.g., a missing `model-00010-of-00093.safetensors`). After conversion, always verify:
- The number of `.safetensors` files matches the expected count.
- The `model.safetensors.index.json` contains entries for every layer.
- Spot-check that critical layers (e.g., the first MoE layer) have the expected number of keys.

4. **How to diagnose**

- Use `--check-weight-update-equal` to verify that weights after a Megatron → SGLang sync match the expected values. If a parameter shows all zeros on the SGLang side, it was likely incorrectly quantized or missing from the checkpoint.
- Use `--debug-rollout-only` with a small number of GPUs to quickly test whether SGLang can generate coherent text from the quantized checkpoint alone.

## Debug sglang illegal memory access (IMA)

When running large scale RL, we will occationally meet the IMA in SGLang, there are some debug suggestions based on our experience:
Expand Down
42 changes: 42 additions & 0 deletions docs/zh/developer_guide/debug.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,48 @@ slime 支持将训练部分和推理部分分开进行调试,从而实现:

开启后,会从 `args.load_debug_rollout_data.format(rollout_id=rollout_id)` 来加载数据,并且不会初始化 sglang(自动设置 `debug_train_only=True`)。可以以这种方式来固定训练部分的输入,对训练部分进行调优,例如切换各种并行。

## INT4 / Compressed-Tensors 量化 Checkpoint 问题

使用 INT4 量化模型(如 `compressed-tensors` 的 `W4A16`)时,checkpoint 的 `config.json` 中有一个 `quantization_config.ignore` 列表,指定哪些参数**不**做量化。在线权重更新(Megatron → SGLang)时,slime 也会读取这个 ignore list 来决定哪些参数需要 INT4 量化。ignore list 不正确会导致静默错误:

1. **MoE 路由权重(`mlp.gate.weight`)变成全零**

MoE 的路由权重(`mlp.gate.weight`,shape `[num_experts, hidden_size]`)是一个普通的 2D weight tensor,但它**不是** Linear 层的权重。如果它不在 ignore list 中,在线量化器会把它 INT4 量化为 `weight_packed`、`weight_scale`、`weight_zero_point` 等。然而 SGLang 不会以量化名称来加载路由权重,因此这些参数在 `load_weights` 时被静默跳过,导致 gate 权重全零。

**修复方法**:确保 `config.json` 的 ignore list 中包含 `"re:.*mlp\\.gate\\..*"`。

2. **其他非 Linear 的 2D 权重**

类似问题可能出现在任何不是真正 Linear 层的 2D `.weight` tensor 上,例如 `model.embed_tokens.weight`。务必检查 ignore list 覆盖了所有非 Linear 权重。

**推荐的 ignore 配置**(以 GLM 系 MoE 模型为例):
```json
"ignore": [
"lm_head",
"model.embed_tokens.weight",
"re:.*self_attn.*",
"re:.*mlp\\.shared_experts.*",
"re:.*mlp\\.gate_up_proj.*",
"re:.*mlp\\.gate_proj.*",
"re:.*mlp\\.up_proj.*",
"re:.*mlp\\.down_proj.*",
"re:.*eh_proj.*",
"re:.*mlp\\.gate\\..*"
]
```

3. **safetensors 分片缺失**

转换工具偶尔可能产出不完整的 checkpoint(例如缺少 `model-00010-of-00093.safetensors`)。转换完成后,务必检查:
- `.safetensors` 文件数量是否与预期一致。
- `model.safetensors.index.json` 中是否包含所有 layer 的条目。
- 抽查关键 layer(如第一个 MoE layer)的 key 数量是否正确。

4. **如何排查**

- 使用 `--check-weight-update-equal` 验证 Megatron → SGLang 权重同步后的值是否正确。如果某个参数在 SGLang 侧全为零,说明它可能被错误量化或在 checkpoint 中缺失。
- 使用 `--debug-rollout-only` 配合少量 GPU,快速测试 SGLang 能否从量化 checkpoint 正常生成文本。

## Debug sglang illegal memory access (IMA)

在进行大规模 RL 时,不时会遇到 SGLang IMA 的问题,以下是我们的一些 debug 建议:
Expand Down
Empty file.
73 changes: 0 additions & 73 deletions slime/utils/debug_utils/display_debug_rollout_data.py

This file was deleted.

50 changes: 0 additions & 50 deletions slime/utils/debug_utils/replay_reward_fn.py

This file was deleted.

58 changes: 0 additions & 58 deletions slime/utils/debug_utils/send_to_sglang.py

This file was deleted.

2 changes: 1 addition & 1 deletion tools/convert_hf_to_int4_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def parse_args():
parser.add_argument("--model-dir", type=str, required=True, help="local BF16 path")
parser.add_argument("--save-dir", type=str, required=True)
parser.add_argument("--group-size", type=int, default=32, help="Group Size")
parser.add_argument("--is-symmetric", type=bool, default=True, help="Is Symmetric")
parser.add_argument("--is-symmetric", action="store_true", help="Whether to use symmetric quantization")
parser.add_argument(
"--ignore-rules",
nargs="+",
Expand Down