feat: Support quantization before weight resharding#92
Open
hy2826 wants to merge 3 commits intoISEEKYAN:mainfrom
Open
feat: Support quantization before weight resharding#92hy2826 wants to merge 3 commits intoISEEKYAN:mainfrom
hy2826 wants to merge 3 commits intoISEEKYAN:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
In many current RL framework implementation such as VeRL, if we want to use low-bit quantization in rollout, the synchronization of weights between the Trainer worker and the Rollout worker follows a "Gather-then-Quantize" pattern.
The Current Workflow:$BF16/FP16$ ) are collected from different TP/PP ranks via all_gather or broadcast.
Gather: High-precision weights (
Quantize: The target inference rank receives the full high-precision weight and then performs quantization locally.
This may create communication overhead: for example, transferring$BF16$ data requires $2\times$ the bandwidth compared to $FP8$ . As model sizes grow, this sync becomes a significant latency floor for each rollout iteration.
Our Design
For blockwise quantization, we propose shifting the quantization responsibility to the source ranks. By quantizing the weight shards locally before they enter the communication collective, we can reduce the data volume by$50%$ for FP8 quantization.
Technical Workflow$i$ takes its local shard $W_i$ (in $BF16$ ) and computes:
$W_{i, fp8}$ : The quantized shard.
$S_i$ : The corresponding scaling factor (Scale).$FP8$ tensors and the associated Scales. Since $FP8$ is 1-byte, the communication volume is halved.$FP8$ shards. Since the quantization was done per-shard, the metadata (scales) are managed alongside the data to ensure numerical consistency.
Take FP8 quantization as an example, the proposed "Quantize-then-Gather" approach involves three steps:
Local Quantization: Each Rank
Low-Precision Communication: Perform all_gather on the
Shard Assembly: The target rank directly concatenates the
Mathematical Representation
Instead of:
$$
W_{global} = \text{Gather}(W_0, W_1, \dots, W_n) \implies W_{fp8} = \text{Quantize}(W_{global})
$$
$$
[W_{i, fp8}, S_i] = \text{Quantize}(W_i) \implies W_{fp8} = \text{Gather}(W_{0, fp8}, \dots, W_{n, fp8})
$$
We implement:
Thus, the result of our design and the original pipeline should be bitwise-equal.
Usage
We introduce a new function named export_weights_quant. Compared with export_weights, it takes three additional arguments:
Experiment Results: reduced wall-clock time
For Qwen3-30B-A3B with FP8 quantization, on 8 * H100, after one step of warm-up, our design reduced the wall-clock time from 25.456s to 18.216s, achieving a reduction of 28.4% compared with the original "gather-then-quantize" pipeline.