Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 111 additions & 7 deletions verl/utils/vllm/vllm_fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
"weight_block_size": [128, 128],
}

MXFP8_BLOCK_QUANT_KWARGS = {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "mxfp8",
"weight_block_size": [1, 32],
}

# Ref: https://github.com/NVIDIA-NeMo/RL/commit/bc24887c72a6e1b2699a228bc87c588546dfe6b7
@dataclass()
Expand All @@ -49,11 +55,28 @@ class FP8State:
fp8_state: FP8State = FP8State()


def is_mxfp8_vllm_ascend(quant_config):
try:
from vllm_ascend.quantization.quant_config import AscendQuantConfig
if isinstance(quant_config, AscendQuantConfig):
# Check if the specific quantization method is MXFP8
# AscendQuantConfig stores config in quant_description
quant_method = quant_config.quant_description.get("quant_method")
return quant_method in ["W8A8_MXFP8", "mxfp8"]
except ImportError:
pass

return False


def is_fp8_model(vllm_config):
from vllm.model_executor.layers.quantization.fp8 import Fp8Config

if hasattr(vllm_config, "quant_config") and isinstance(vllm_config.quant_config, Fp8Config):
return True
if hasattr(vllm_config, "quant_config"):
if isinstance(vllm_config.quant_config, Fp8Config):
return True
if is_mxfp8_vllm_ascend(vllm_config.quant_config):
return True

return False

Expand Down Expand Up @@ -158,19 +181,100 @@ def scaled_fp8_blockwise(
return fp_data, descale_fp


def npu_scaled_mxfp8_blockwise(
data_hp,
weight_block_size,
):
assert data_hp.dim() == 2, "Only 2D tensors supported (M, N)"

block_size = weight_block_size[1]

# Constants for MXFP8 / NPU
FP32_MIN_NORMAL = torch.finfo(torch.float32).tiny
MAX_NORM = torch.finfo(torch.float8_e4m3fn).max # 2 ** 8 * 1.75
EMAX = 8 # 2 ** (4 - 1)
SCALE_EMAX = 127 # 2 ** (8 - 1) - 1

data_hp = data_hp.float()
original_shape = data_hp.shape
M, N = original_shape
assert N % block_size == 0, f"Last dimension {N} must be divisible by block_size {block_size}"

# Reshape to (M, N // block_size, block_size)
num_blocks_n = N // block_size
data_blocked = data_hp.reshape(M, num_blocks_n, block_size)

# Calculate max absolute value per block
max_val = torch.amax(torch.abs(data_blocked), dim=-1)

# Shared exponent calculation
# Handle zero/tiny values to avoid log2(0) -> -inf
max_val_safe = torch.where(max_val == 0, FP32_MIN_NORMAL, max_val)
shared_exp = torch.floor(torch.log2(max_val_safe)) - EMAX

shared_exp[shared_exp > SCALE_EMAX] = float("NaN")

shared_exp_expanded = shared_exp.unsqueeze(-1)
scale_factor = torch.pow(2.0, shared_exp_expanded)
data_normalized = data_blocked / scale_factor
abs_norm = torch.abs(data_normalized)
private_exp = torch.floor(torch.log2(abs_norm + (abs_norm == 0).float()))
min_exp = -6
private_exp = private_exp.clamp(min=min_exp)

mantissa_scale = 8.0 # 2 ** (5 - 2)

scale_private = torch.pow(2.0, private_exp)
scaled = data_normalized / scale_private * mantissa_scale

# Round half away from zero: sign * floor(abs + 0.5)
data_quant = torch.sign(scaled) * torch.floor(torch.abs(scaled) + 0.5)
data_quant = data_quant / mantissa_scale * scale_private
data_quant = torch.clamp(data_quant, min=-MAX_NORM, max=MAX_NORM)

# Restore Inf/NaN
data_quant = torch.where(torch.isinf(data_normalized), data_normalized, data_quant)
data_quant = torch.where(torch.isnan(data_normalized), data_normalized, data_quant)

fp_data = data_quant.reshape(original_shape).to(torch.float8_e4m3fn)

# Encode scale/exponent for NPU (uint8)
shared_exp_fixed = torch.nan_to_num(shared_exp, nan=-127.0)
descale_fp = torch.clamp(shared_exp_fixed + 127, 0, 255).round().to(torch.uint8)

return fp_data, descale_fp


def quant_weights(weights, model, quant_config, dtype=torch.bfloat16):
weights_quantized = []

# Determine block size
weight_block_size = None
is_mxfp8_npu = is_mxfp8_vllm_ascend(quant_config)

if hasattr(quant_config, "weight_block_size"):
weight_block_size = quant_config.weight_block_size
elif is_mxfp8_npu:
weight_block_size = MXFP8_BLOCK_QUANT_KWARGS["weight_block_size"]

for k, v in weights:
if not is_fp8_weight(k, model):
weights_quantized.append((k, v))
continue
# Cast the weight into fp8 and its scale factor
if quant_config.weight_block_size is not None:
if weight_block_size is not None:
logger.info("Using blockwise quantization")
param_lp, param_scale = scaled_fp8_blockwise(
v.to(dtype),
weight_block_size=quant_config.weight_block_size,
)
if is_mxfp8_npu:
param_lp, param_scale = npu_scaled_mxfp8_blockwise(
v.to(dtype),
weight_block_size=weight_block_size,
)
else:
param_lp, param_scale = scaled_fp8_blockwise(
v.to(dtype),
weight_block_size=weight_block_size,
)

param_scale = param_scale.squeeze(-1)
weights_quantized.append([k, param_lp])
if vllm.__version__ >= "0.11.0":
Expand Down
16 changes: 8 additions & 8 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,27 +265,27 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
quantization = self.config.quantization

if quantization is not None:
_SUPPORTED_QUANTIZATION = ["fp8", "torchao"]
_SUPPORTED_QUANTIZATION = ["fp8", "mxfp8", "torchao"]
if quantization not in _SUPPORTED_QUANTIZATION:
raise ValueError(f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {quantization}")

if quantization == "fp8":
FP8_BLOCK_QUANT_KWARGS = {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128],
}
from verl.utils.vllm.vllm_fp8_utils import FP8_BLOCK_QUANT_KWARGS
fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS)
# Apply vllm fp8 patches
# Will remove the patch after vllm support on-the-fly quant for rollout natively.
apply_vllm_fp8_patches()
elif quantization == "mxfp8":
from verl.utils.vllm.vllm_fp8_utils import MXFP8_BLOCK_QUANT_KWARGS
fp8_block_quant_kwargs = dict(MXFP8_BLOCK_QUANT_KWARGS)
# TODO(slightwindsec): apply MXFP8 patches?
pass

hf_overrides = {}
if quantization is not None and self.config.quantization_config_file is not None:
hf_overrides["quantization_config_file"] = self.config.quantization_config_file

if quantization == "fp8":
if quantization == "fp8" or quantization == "mxfp8":
hf_overrides["quantization_config"] = fp8_block_quant_kwargs

args = {
Expand Down
5 changes: 4 additions & 1 deletion verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]):
lora_dtype = getattr(torch, self.config.dtype)
self.vllm_config.lora_config = LoRAConfig(lora_dtype=lora_dtype, **self.lora_config)
if self.config.quantization is not None:
_SUPPORTED_QUANTIZATION = ["fp8", "torchao"]
_SUPPORTED_QUANTIZATION = ["fp8", "mxfp8", "torchao"]
if self.config.quantization not in _SUPPORTED_QUANTIZATION:
raise ValueError(
f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {self.config.quantization}"
Expand All @@ -201,6 +201,9 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]):
# Apply vllm fp8 patches
# Will remove the patch after vllm support on-the-fly quant for rollout natively.
apply_vllm_fp8_patches()
elif self.config.quantization == "mxfp8":
# TODO(slightwindsec): apply MXFP8 patches?
pass

self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)
self.inference_engine.init_worker(all_kwargs)
Expand Down