Skip to content

feat: Support quantization before weight resharding#92

Open
hy2826 wants to merge 3 commits intoISEEKYAN:mainfrom
wucong25:low_precision_resharding
Open

feat: Support quantization before weight resharding#92
hy2826 wants to merge 3 commits intoISEEKYAN:mainfrom
wucong25:low_precision_resharding

Conversation

@hy2826
Copy link
Copy Markdown

@hy2826 hy2826 commented Mar 8, 2026

Motivation

In many current RL framework implementation such as VeRL, if we want to use low-bit quantization in rollout, the synchronization of weights between the Trainer worker and the Rollout worker follows a "Gather-then-Quantize" pattern.

The Current Workflow:
Gather: High-precision weights ($BF16/FP16$) are collected from different TP/PP ranks via all_gather or broadcast.
Quantize: The target inference rank receives the full high-precision weight and then performs quantization locally.

image

This may create communication overhead: for example, transferring $BF16$ data requires $2\times$ the bandwidth compared to $FP8$. As model sizes grow, this sync becomes a significant latency floor for each rollout iteration.

Our Design

For blockwise quantization, we propose shifting the quantization responsibility to the source ranks. By quantizing the weight shards locally before they enter the communication collective, we can reduce the data volume by $50%$ for FP8 quantization.

Technical Workflow
Take FP8 quantization as an example, the proposed "Quantize-then-Gather" approach involves three steps:
Local Quantization: Each Rank $i$ takes its local shard $W_i$ (in $BF16$) and computes:
$W_{i, fp8}$: The quantized shard.
$S_i$: The corresponding scaling factor (Scale).
Low-Precision Communication: Perform all_gather on the $FP8$ tensors and the associated Scales. Since $FP8$ is 1-byte, the communication volume is halved.
Shard Assembly: The target rank directly concatenates the $FP8$ shards. Since the quantization was done per-shard, the metadata (scales) are managed alongside the data to ensure numerical consistency.

image

Mathematical Representation

Instead of:
$$
W_{global} = \text{Gather}(W_0, W_1, \dots, W_n) \implies W_{fp8} = \text{Quantize}(W_{global})
$$
We implement:
$$
[W_{i, fp8}, S_i] = \text{Quantize}(W_i) \implies W_{fp8} = \text{Gather}(W_{0, fp8}, \dots, W_{n, fp8})
$$

Thus, the result of our design and the original pipeline should be bitwise-equal.

Usage

We introduce a new function named export_weights_quant. Compared with export_weights, it takes three additional arguments:

  1. weight_block_size: a tuple stating the quantization block size, such as (1,32).
  2. should_quantize_param_megatron: a function that takes a string that is the Megatron weight name and decides whether or not to quantize the corresponding weight.
  3. quant_fn: a function to quantize the weight, and takes an additional parameter of quantization block size.

Experiment Results: reduced wall-clock time

For Qwen3-30B-A3B with FP8 quantization, on 8 * H100, after one step of warm-up, our design reduced the wall-clock time from 25.456s to 18.216s, achieving a reduction of 28.4% compared with the original "gather-then-quantize" pipeline.

@hy2826 hy2826 changed the title Support quantization before weight resharding feat: Support quantization before weight resharding Mar 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants