Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions auto_round/compressors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,14 +864,15 @@ def get_fp_layer_names(model: torch.nn.Module, ignore_layers: str):
list: A list of layer names that match the specified FP layers or are
subcomponents of those layers.
"""
from auto_round.utils import SUPPORTED_LAYER_TYPES
from auto_round.utils import INNER_SUPPORTED_LAYER_TYPES, SUPPORTED_LAYER_TYPES

if not ignore_layers:
return []
ignore_layers = ignore_layers.replace(" ", "").split(",")
all_layer_names = []
for n, m in model.named_modules():
if type(m) in SUPPORTED_LAYER_TYPES:
# if type(m) in SUPPORTED_LAYER_TYPES:
if type(m) in SUPPORTED_LAYER_TYPES or m.__class__.__name__ in INNER_SUPPORTED_LAYER_TYPES:
all_layer_names.append(n)
not_to_quantized_layers = []

Expand Down
8 changes: 7 additions & 1 deletion auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,11 +998,17 @@ def check_seqlen_compatible(input_seqlen, tokenizer=None, model=None):
)


from transformers.modeling_utils import no_init_weights as skip_weights_initialize


def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16, device: str = "cpu"):
""" """
from auto_round.schemes import QuantizationScheme

new_layer = torch.nn.Linear(layer.in_features, layer.out_features, bias=layer.bias is not None, dtype=dtype)
# if "indexer" in getattr(layer, "tmp_name", ""):
# return layer # skip indexer layer
with skip_weights_initialize():
new_layer = torch.nn.Linear(layer.in_features, layer.out_features, bias=layer.bias is not None, dtype=dtype)
if layer.bias is not None:
new_layer.bias.data.copy_(layer.bias.data.to(dtype=dtype))
scheme_keys = (f.name for f in fields(QuantizationScheme))
Expand Down
45 changes: 45 additions & 0 deletions ds32/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
### Prerequisite
- https://github.com/yiliu30/transformers/tree/457-ds32
- https://github.com/intel/auto-round/tree/ds-v32

### Quantize
```bash
export MODEL_NAME=/storage/yiliu7/deepseek-ai/DeepSeek-V3.2/
export OUTPUT_DIR=/storage/yiliu7/deepseek-ai/DeepSeek-V3.2-W416
python quant_ds_v32.py --model_name $MODEL_NAME --output_dir $OUTPUT_DIR
```

### Eval

- https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-V3_2.html#launching-deepseek-v32
```
uv venv
source .venv/bin/activate
uv pip install git+https://github.com/deepseek-ai/DeepGEMM.git@v2.1.1.post3 --no-build-isolation
git clone https://github.com/vllm-project/vllm.git
git checkout 773d7073a
VLLM_USE_PRECOMPILED=1 uv pip install --editable .
```

```bash
VLLM_ALLREDUCE_USE_SYMM_MEM=0 NCCL_NVLS_ENABLE=0 VLLM_USE_FUSED_MOE_GROUPED_TOPK=0 \
vllm serve /storage/yiliu7/ds-v32-exp/ \
--tensor-parallel-size 4 \
--tokenizer-mode deepseek_v32 \
--tool-call-parser deepseek_v32 \
--enable-auto-tool-choice \
--reasoning-parser deepseek_v3
```

```bash
lm_eval --model local-completions \
--model_args "model=/storage/yiliu7/ds-v32-exp/,base_url=http://0.0.0.0:8000/v1/completions,max_length=8192,tokenized_requests=False,tokenizer_backend=None,num_concurrent=32" \
--tasks gsm8k \
--num_fewshot 5


VLLM_ALLREDUCE_USE_SYMM_MEM=0 NCCL_NVLS_ENABLE=0 VLLM_USE_FUSED_MOE_GROUPED_TOPK=0 vllm serve /storage/yiliu7/ds-v32-exp/ --tensor-parallel-size 4
lm_eval --model local-completions --model_args "model=/storage/yiliu7/ds-v32-exp/"
# lm-eval --model local-completions --tasks gsm8k --model_args model=/storage/yiliu7/ds-v32-exp/,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False
# lm-eval --model local-completions --tasks gsm8k --model_args model=deepseek-ai/DeepSeek-V3.2-Exp,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False
```
Loading
Loading