Skip to content

Commit 370c04c

Browse files
[model_free_ptq] Enhance to work with previously quantized checkpoints like nvidia/DeepSeek-R1-NVFP4 (#2228)
Prerequisites (tests will fail until merged): - [x] vllm-project/compressed-tensors#607 SUMMARY: This PR enhances the `model_free_ptq` entrypoint to work with previously quantized checkpoints. The added example extends the `nvidia/DeepSeek-R1-NVFP4` nvfp4-quantized checkpoint to: - [x] convert modelopt's NVFP4 format to create CT's, for corresponding mlp/expert layers. - [x] quantize all compatible linear self_attn layers to FP8_BLOCK, including ones with shape not exactly divisible by block_size[1]. - [x] merge the two quantization_configs into a single compressed-tensors config in `config.json` "quantization_config" Changes to src: - [x] removes the targets must be "Linear" constraint from model_free_ptq, as it is no longer an issue in vllm. - [x] import Converter abstraction from compressed-tensors convert_checkpoint entrypoint so that conversion from modelopt NVFP4 to CT format can happen at the same time as converting layers to some compressed form. - [x] Some helper code moved to CT and imported here instead. TEST PLAN: - [x] Checkpoint (and script to run) available at https://huggingface.co/bdellabe/DeepSeek-R1-NVFP4-FP8-BLOCK. Works in vllm 0.15.1 - [x] Confirmed checkpoint is equivalent when running `convert_checkpoint(..., converter=...) + model_free_ptq(..., converter=None)` vs. `model_free_ptq(..., converter=...)` - [x] Updated model_free_ptq tests to ensure they work when stacked (equivalently, when a CT checkpoint is used as the input model to model_free_ptq), that the quantization config is correct --------- Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Signed-off-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent aebbc93 commit 370c04c

File tree

14 files changed

+395
-170
lines changed

14 files changed

+395
-170
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from compressed_tensors.entrypoints.convert import (
2+
ModelOptNvfp4Converter,
3+
)
4+
from compressed_tensors.quantization import (
5+
QuantizationScheme,
6+
)
7+
from compressed_tensors.quantization.quant_scheme import FP8_BLOCK
8+
9+
from llmcompressor import model_free_ptq
10+
11+
MODEL_ID = "nvidia/DeepSeek-R1-NVFP4"
12+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-BLOCK"
13+
14+
15+
# Convert modelopt NVFP4 format to compressed-tensors format and
16+
# apply FP8-Block to the model's compatible self_attn Linear layers
17+
# Once quantized, the model is saved to SAVE_DIR.
18+
model_free_ptq(
19+
model_stub=MODEL_ID,
20+
save_directory=SAVE_DIR,
21+
scheme=QuantizationScheme(
22+
**FP8_BLOCK,
23+
targets=[
24+
# Target fused layers, must have the same quant config
25+
# shape 576x7168 is compatible with block size 128x128
26+
# - self_attn.kv_a_proj_with_mqa
27+
# - self_attn.q_a_proj
28+
"re:.*self_attn.(kv_a_proj_with_mqa|q_a_proj)$",
29+
# Skip self_attn.kv_b_proj, already dequantized by MLA
30+
# Target remaining self_attn layers:
31+
# - self_attn.o_proj
32+
# - self_attn.q_b_proj
33+
"re:.*self_attn.(o_proj|q_b_proj).*",
34+
],
35+
),
36+
max_workers=8,
37+
device="cuda:0",
38+
converter=ModelOptNvfp4Converter(
39+
targets=[
40+
# nvidia/DeepSeek-R1-NVFP4's nvfp4-quantized layers, found by inspection
41+
# - model.layers.0.mlp.down_proj.weight
42+
# - model.layers.0.mlp.gate_proj.weight
43+
# - model.layers.0.mlp.up_proj.weight
44+
# - model.layers.3.mlp.shared_experts.down_proj.weight
45+
# - model.layers.3.mlp.shared_experts.gate_proj.weight
46+
# - model.layers.3.mlp.shared_experts.up_proj.weight
47+
# - model.layers.3.mlp.experts.0.down_proj.weight
48+
# - model.layers.3.mlp.experts.0.gate_proj.weight
49+
# - model.layers.3.mlp.experts.0.up_proj.weight
50+
# NOTE: gate_up_proj also needs to be targeted, gate/up are fused
51+
"re:.*mlp.*(gate_up|gate|up|down)_proj$"
52+
]
53+
),
54+
)
Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
import os
22
import shutil
3-
from concurrent.futures import ThreadPoolExecutor, as_completed
43
from pathlib import Path
54
from typing import Iterable, Optional
65

76
import torch
8-
import tqdm
7+
from compressed_tensors.entrypoints.convert import (
8+
Converter,
9+
exec_jobs,
10+
get_checkpoint_files,
11+
is_weights_file,
12+
update_safetensors_index,
13+
)
914
from compressed_tensors.quantization import QuantizationScheme
1015
from loguru import logger
1116

1217
from llmcompressor.entrypoints.model_free.helpers import gpu_if_available
1318
from llmcompressor.entrypoints.model_free.microscale import (
1419
is_microscale_scheme,
1520
)
16-
from llmcompressor.entrypoints.model_free.model_utils import (
17-
get_checkpoint_files,
18-
is_weights_file,
19-
)
2021
from llmcompressor.entrypoints.model_free.process import (
2122
process_file,
2223
process_file_microscale_scheme,
2324
validate_file,
2425
)
2526
from llmcompressor.entrypoints.model_free.save_utils import (
2627
update_config,
27-
update_safetensors_index,
2828
)
2929
from llmcompressor.entrypoints.model_free.validate import (
3030
validate_safetensors_index,
@@ -41,6 +41,7 @@ def model_free_ptq(
4141
ignore: Iterable[str] = tuple(),
4242
max_workers: int = 1,
4343
device: Optional[torch.device | str] = None,
44+
converter: Converter | None = None,
4445
):
4546
"""
4647
Quantize a model without the need for a model definition. This function operates on
@@ -52,6 +53,10 @@ def model_free_ptq(
5253
ignored
5354
:param max_workers: number of worker threads to process files with
5455
:param device: gpu device to accelerate quantization with
56+
:param converter: optional converter to apply to the checkpoint to convert it to
57+
compressed-tensors format before running model-free PTQ
58+
e.g. conversion of some layers from modelopt format to compressed-tensors
59+
See compressed-tensors convert_checkpoint entrypoint for more information
5560
"""
5661
# validate arguments
5762
model_files = get_checkpoint_files(model_stub)
@@ -70,7 +75,9 @@ def model_free_ptq(
7075
save_path = Path(save_directory) / file_path
7176

7277
if file_path.endswith("safetensors"):
73-
jobs.append((job_fn, resolved_path, save_path, scheme, ignore, device))
78+
jobs.append(
79+
(job_fn, resolved_path, save_path, scheme, ignore, device, converter)
80+
)
7481

7582
else:
7683
if is_weights_file(file_path):
@@ -79,25 +86,19 @@ def model_free_ptq(
7986
logger.info(f"Copying {file_path} {save_path}")
8087
shutil.copyfile(resolved_path, save_path)
8188

82-
with ThreadPoolExecutor(max_workers) as executor:
83-
# 1. validate quantizable tensors fail fast before long-running quantization
84-
futures = [executor.submit(validate_file, *job[1:]) for job in jobs]
85-
for future in tqdm.tqdm(
86-
as_completed(futures), total=len(futures), desc="Validating"
87-
):
88-
future.result()
89+
# 1. validate quantizable tensors fail fast before long-running quantization
90+
exec_jobs(
91+
[(validate_file, *job[1:]) for job in jobs], max_workers, desc="Validating"
92+
)
8993

90-
# 2-5. quantize and compress weights
91-
total_size = 0
92-
weight_map = dict()
93-
futures = [executor.submit(*job) for job in jobs]
94-
for future in tqdm.tqdm(
95-
as_completed(futures), total=len(futures), desc="Quantizing"
96-
):
97-
_total_size, _weight_map = future.result()
98-
total_size += _total_size
99-
weight_map.update(_weight_map)
94+
# 2-5. quantize and compress weights
95+
total_size = 0
96+
weight_map = dict()
97+
quantize_results = exec_jobs(jobs, max_workers, desc="Quantizing")
98+
for _total_size, _weight_map in quantize_results:
99+
total_size += _total_size
100+
weight_map.update(_weight_map)
100101

101102
# 5. update config and safetensors index
102-
update_config(save_directory, scheme_name, scheme, ignore)
103+
update_config(save_directory, scheme_name, scheme, ignore, converter)
103104
update_safetensors_index(save_directory, total_size, weight_map)

src/llmcompressor/entrypoints/model_free/helpers.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
1-
import os
21
import re
32
from collections import defaultdict
43
from typing import Mapping, TypeVar
54

65
import torch
76
from compressed_tensors.utils.match import match_name
87
from loguru import logger
9-
from transformers.file_utils import CONFIG_NAME
108

119
__all__ = [
1210
"gpu_if_available",
13-
"find_safetensors_index_path",
14-
"find_config_path",
1511
"find_safetensors_index_file",
1612
"match_names_set_eager",
1713
"MatchedNamesSet",
@@ -43,22 +39,6 @@ def gpu_if_available(device: torch.device | str | None) -> torch.device:
4339
return torch.device("cpu")
4440

4541

46-
def find_safetensors_index_path(save_directory: str | os.PathLike) -> str | None:
47-
for file_name in os.listdir(save_directory):
48-
if file_name.endswith("safetensors.index.json"):
49-
return os.path.join(save_directory, file_name)
50-
51-
return None
52-
53-
54-
def find_config_path(save_directory: str | os.PathLike) -> str | None:
55-
for file_name in os.listdir(save_directory):
56-
if file_name in (CONFIG_NAME, "params.json"):
57-
return os.path.join(save_directory, file_name)
58-
59-
return None
60-
61-
6242
def find_safetensors_index_file(model_files: dict[str, str]) -> str | None:
6343
for file_path, resolved_path in model_files.items():
6444
if file_path.endswith("safetensors.index.json"):

src/llmcompressor/entrypoints/model_free/model_utils.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/llmcompressor/entrypoints/model_free/process.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import os
22
from collections import defaultdict
3-
from collections.abc import Iterator, Mapping
43
from typing import Iterable
54

65
import torch
6+
from compressed_tensors.entrypoints.convert import Converter
77
from compressed_tensors.quantization import QuantizationScheme
8-
from compressed_tensors.utils.match import match_name
8+
from compressed_tensors.utils import match_quantizable_tensors
99
from safetensors.torch import load_file, save_file
1010
from torch.nn import Module
1111

@@ -21,21 +21,11 @@
2121
is_microscale_scheme,
2222
)
2323

24-
__all__ = ["validate_file", "process_file", "process_file_microscale_scheme"]
25-
26-
27-
def iter_quantizable_tensors(
28-
tensors: Mapping[str, torch.Tensor],
29-
ignore: Iterable[str],
30-
) -> Iterator[tuple[str, str]]:
31-
for name in list(tensors.keys()):
32-
module_name, param_name = name.rsplit(".", 1)
33-
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
34-
is_ignored = any(match_name(module_name, ign) for ign in ignore)
35-
if not is_linear_weight or is_ignored:
36-
continue
37-
38-
yield module_name, name
24+
__all__ = [
25+
"validate_file",
26+
"process_file",
27+
"process_file_microscale_scheme",
28+
]
3929

4030

4131
def validate_file(
@@ -44,6 +34,7 @@ def validate_file(
4434
scheme: QuantizationScheme,
4535
ignore: Iterable[str],
4636
device: str | torch.device,
37+
converter: Converter | None = None,
4738
):
4839
"""
4940
Validate that each quantizable tensor in a safetensors file can be quantized.
@@ -52,10 +43,15 @@ def validate_file(
5243
:param scheme: quantization scheme to apply to tensors
5344
:param ignore: modules to ignore. Modules ending with "norm" are automatically
5445
ignored
46+
:param converter: optional converter to apply to the checkpoint,
47+
e.g. conversion of some layers from some format to compressed-tensors
5548
"""
5649
tensors = load_file(file_path)
5750

58-
for _, name in iter_quantizable_tensors(tensors, ignore):
51+
if converter is not None:
52+
converter.validate(tensors)
53+
54+
for _, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
5955
validate_weight_for_quantization(tensors[name], scheme, name)
6056

6157

@@ -65,6 +61,7 @@ def process_file(
6561
scheme: QuantizationScheme,
6662
ignore: Iterable[str],
6763
device: str | torch.device,
64+
converter: Converter | None = None,
6865
) -> tuple[int, dict[str, str]]:
6966
"""
7067
Quantize and compress tensors in a given safetensors file
@@ -75,11 +72,16 @@ def process_file(
7572
:param ignore: modules to ignore. Modules ending with "norm" are automatically
7673
ignored
7774
:param device: device used to quantize and compress weights
75+
:param converter: optional converter to apply to the checkpoint,
76+
e.g. conversion of some layers from some format to compressed-tensors
7877
"""
7978
assert not is_microscale_scheme(scheme), "Use `_process_file_microscale_scheme`"
8079
tensors = load_file(file_path)
8180

82-
for module_name, name in iter_quantizable_tensors(tensors, ignore):
81+
if converter is not None:
82+
converter.process(tensors)
83+
84+
for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
8385
validate_weight_for_quantization(tensors[name], scheme, name)
8486

8587
# 1. initialize module with qparams (on device)
@@ -109,6 +111,7 @@ def process_file_microscale_scheme(
109111
scheme: QuantizationScheme,
110112
ignore: Iterable[str],
111113
device: str | torch.device,
114+
converter: Converter | None = None,
112115
) -> tuple[int, dict[str, str]]:
113116
"""
114117
Quantize and compress tensors in a given safetensors file
@@ -119,9 +122,15 @@ def process_file_microscale_scheme(
119122
:param ignore: modules to ignore. Modules ending with "norm" are automatically
120123
ignored
121124
:param device: device used to quantize and compress weights
125+
:param converter: optional converter to apply to the checkpoint,
126+
e.g. conversion of some layers from some format to compressed-tensors
122127
"""
123128
assert is_microscale_scheme(scheme), "Use `_process_file` for non-microscale scheme"
124129
tensors = load_file(file_path)
130+
131+
if converter is not None:
132+
converter.process(tensors)
133+
125134
fused_sets, unmatched_sets = get_fused_names(tensors)
126135
assert len(unmatched_sets) <= 0 # should be caught by `validate_safetensors_index`
127136

@@ -135,7 +144,7 @@ def process_file_microscale_scheme(
135144
}
136145
fused_modules = defaultdict(dict)
137146

138-
for module_name, name in iter_quantizable_tensors(tensors, ignore):
147+
for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
139148
validate_weight_for_quantization(tensors[name], scheme, name)
140149

141150
# 1. initialize module with qparams (on device)

src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77

88
import torch
99
import tqdm
10+
from compressed_tensors.entrypoints.convert import (
11+
get_checkpoint_files,
12+
is_weights_file,
13+
update_safetensors_index,
14+
)
1015
from loguru import logger
1116
from safetensors.torch import load_file, save_file
1217

@@ -15,11 +20,6 @@
1520
invert_mapping,
1621
)
1722
from llmcompressor.entrypoints.model_free.microscale import get_fused_names
18-
from llmcompressor.entrypoints.model_free.model_utils import (
19-
get_checkpoint_files,
20-
is_weights_file,
21-
)
22-
from llmcompressor.entrypoints.model_free.save_utils import update_safetensors_index
2323

2424

2525
def parse_args():

0 commit comments

Comments
 (0)