You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/en/developer_guide/debug.md
+42Lines changed: 42 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -50,6 +50,48 @@ Specifically, slime currently provides the following parameters for separate deb
50
50
51
51
When enabled, data will be loaded from `args.load_debug_rollout_data.format(rollout_id=rollout_id)`, and SGLang will not be initialized (automatically setting `debug_train_only=True`). This method allows you to fix the input for the training part to tune it, for example, by switching between different parallelization strategies.
When using INT4-quantized models (e.g., `compressed-tensors` with `W4A16`), the checkpoint's `config.json` contains a `quantization_config.ignore` list that specifies which parameters should **not** be quantized. During online weight updates (Megatron → SGLang), slime also reads this ignore list to decide which parameters to INT4-quantize. An incorrect ignore list can cause silent errors:
56
+
57
+
1.**MoE router weights (`mlp.gate.weight`) become all zeros**
58
+
59
+
The MoE router weight (`mlp.gate.weight`, shape `[num_experts, hidden_size]`) is a plain 2D weight tensor, but it is **not** a Linear layer weight. If it is not in the ignore list, the online quantizer will INT4-quantize it into `weight_packed`, `weight_scale`, `weight_zero_point`, etc. However, SGLang does not expect quantized names for the router, so these parameters are silently skipped during `load_weights`, resulting in all-zero gate weights.
60
+
61
+
**Fix**: Ensure `config.json` contains `"re:.*mlp\\.gate\\..*"` in the ignore list.
62
+
63
+
2.**Other non-Linear 2D weights**
64
+
65
+
Similar issues can occur with any 2D `.weight` tensor that is not a true Linear layer, such as `model.embed_tokens.weight`. Always verify the ignore list covers all non-Linear weights.
Conversion tools may occasionally produce an incomplete checkpoint (e.g., a missing `model-00010-of-00093.safetensors`). After conversion, always verify:
86
+
- The number of `.safetensors` files matches the expected count.
87
+
- The `model.safetensors.index.json` contains entries for every layer.
88
+
- Spot-check that critical layers (e.g., the first MoE layer) have the expected number of keys.
89
+
90
+
4.**How to diagnose**
91
+
92
+
- Use `--check-weight-update-equal` to verify that weights after a Megatron → SGLang sync match the expected values. If a parameter shows all zeros on the SGLang side, it was likely incorrectly quantized or missing from the checkpoint.
93
+
- Use `--debug-rollout-only` with a small number of GPUs to quickly test whether SGLang can generate coherent text from the quantized checkpoint alone.
94
+
53
95
## Debug sglang illegal memory access (IMA)
54
96
55
97
When running large scale RL, we will occationally meet the IMA in SGLang, there are some debug suggestions based on our experience:
0 commit comments