Skip to content

Commit 6e7e360

Browse files
authored
Merge pull request #21 from andrea-fasoli/int8_llm
INT8 LLM support
2 parents d9d130b + e19b055 commit 6e7e360

File tree

1 file changed

+91
-25
lines changed

1 file changed

+91
-25
lines changed

scripts/inference.py

+91-25
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1+
# Standard
12
import argparse
3+
from functools import partial
24
import itertools
35
import json
46
import os
7+
from pathlib import Path
58
import random
6-
import sys
79
import time
8-
from pathlib import Path
910

11+
# Third Party
1012
from aiu_fms_testing_utils.utils import aiu_setup
1113
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
1214
import numpy as np
1315
import torch
14-
import torch._inductor.config
16+
from torch import distributed as dist
1517
from fms.models import get_model, register_model
1618
from fms.models.llama import LLaMAConfig, _llama_factory_factory
17-
from fms.utils import fusion, generation, tokenizers
19+
from fms.utils import generation, tokenizers
1820
from fms.utils.generation import generate, pad_input_ids
19-
from torch import distributed as dist
21+
2022

2123
# This example script validates the LLaMA implementation by running inference on a couple of prompts.
2224
#
@@ -59,10 +61,27 @@
5961
parser.add_argument(
6062
"--quantization",
6163
type=str,
62-
choices=["gptq"],
64+
choices=["gptq", "int8"],
6365
default=None,
6466
help="Type of quantization of the model checkpoint",
6567
)
68+
parser.add_argument(
69+
"--int8_weight_per_channel",
70+
action="store_true",
71+
help="Enable per-channel weight quantization in INT8 quantized model",
72+
)
73+
parser.add_argument(
74+
"--int8_activ_quant_type",
75+
default="per_token",
76+
choices=["per_token", "per_tensor_symm", "per_tensor_asymm"],
77+
type=str,
78+
help="Define strategy for activation quantization in INT8 quantized model",
79+
)
80+
parser.add_argument(
81+
"--int8_smoothquant",
82+
action="store_true",
83+
help="Enable smoothquant in INT8 quantized model",
84+
)
6685
parser.add_argument(
6786
"--tokenizer",
6887
type=str,
@@ -196,19 +215,18 @@
196215
args = parser.parse_args()
197216

198217
if args.quantization == "gptq":
199-
GPTQ_ENABLED = True
200-
try:
201-
if "aiu" in args.device_type:
218+
if "aiu" in args.device_type:
219+
try:
202220
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear
203221
print("Loaded `aiu_addons` functionalities")
204-
elif args.device_type != "cpu":
205-
raise ValueError(f"Device {args.device_type} unsupported for GPTQ run")
206-
except ImportError as e:
207-
print(f"Failed to import addon packages: {e}")
208-
GPTQ_ENABLED = False
209-
210-
if not GPTQ_ENABLED:
211-
raise Exception("GPTQ not enabled")
222+
except:
223+
raise ImportError("Failed to import GPTQ addons from fms-mo.")
224+
elif args.quantization == "int8":
225+
try:
226+
from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear
227+
print("Loaded `aiu_addons` functionalities")
228+
except:
229+
raise ImportError("Failed to import INT8 addons from fms-mo.")
212230

213231
# this is a test model config
214232
config = LLaMAConfig(
@@ -319,6 +337,13 @@
319337

320338
fused_weights = not args.unfuse_weights
321339
if args.quantization == "gptq":
340+
if fused_weights and is_aiu_backend:
341+
raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights")
342+
if default_dtype is not None:
343+
raise ValueError(
344+
"GPTQ default_dtype must be None to preserve the checkpoint data types."
345+
)
346+
322347
if "aiu" in args.device_type:
323348
linear_type = "gptq_aiu"
324349
elif args.device_type == "cpu":
@@ -352,12 +377,51 @@
352377
"group_size": group_size,
353378
"desc_act": desc_act,
354379
}
355-
# [ATTENTION] for GPTQ on AIU, we must always instantiate an unfused
356-
# model, the adapter will take care of converting key/values from
357-
# ckpt into the appropriate form for the model
358-
if fused_weights:
359-
raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights")
360-
default_dtype = None # GPTQ dtype always comes from ckpt, can't be enforced
380+
elif args.quantization == "int8":
381+
if fused_weights and is_aiu_backend:
382+
raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights")
383+
if default_dtype is not None:
384+
raise ValueError(
385+
"INT8 default_dtype must be None to preserve the checkpoint data types."
386+
)
387+
388+
def select_int8_module(
389+
module_name: str | None = None,
390+
smoothquant: bool = True,
391+
smoothquant_layers: list[str] | None = None,
392+
):
393+
if module_name is None:
394+
return "int8_aiu"
395+
smoothquant_on_module = (
396+
any([m in module_name for m in smoothquant_layers])
397+
if smoothquant_layers is not None
398+
else True
399+
)
400+
use_smoothquant = smoothquant and smoothquant_on_module
401+
return "int8_smoothquant_aiu" if use_smoothquant else "int8_aiu"
402+
403+
if args.int8_smoothquant:
404+
# TODO: consider saving this info into config during quantization
405+
if any("granite" in p.lower() for p in [args.model_path, args.architecture]):
406+
smoothquant_layers = ["key", "value", "w1", "wg"]
407+
elif any("roberta" in p.lower() for p in [args.model_path, args.architecture]):
408+
smoothquant_layers = ["query", "key", "value", "w1"]
409+
else:
410+
raise NotImplementedError(
411+
"INT8 architecture does not support smoothquant."
412+
)
413+
else:
414+
smoothquant_layers = []
415+
416+
linear_config = {
417+
"linear_type": partial(
418+
select_int8_module,
419+
smoothquant = args.int8_smoothquant,
420+
smoothquant_layers = smoothquant_layers,
421+
),
422+
"weight_per_channel": args.int8_weight_per_channel,
423+
"activ_quant_type": args.int8_activ_quant_type,
424+
}
361425
else:
362426
linear_config = {"linear_type": "torch_linear"}
363427

@@ -381,13 +445,13 @@
381445
fused_weights=fused_weights,
382446
)
383447

384-
if args.quantization == "gptq":
448+
if args.quantization in ["gptq", "int8"]:
385449
if rank == 0 and args.verbose > 0:
386450
dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters()))
387451
dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers()))
388452
dprint("="*60 + "\n")
389453
if args.architecture == "llama":
390-
dprint("[NOTE] It's OK for unused keys to contain bias and rotary embeddings, in GPTQ LLaMA models")
454+
dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.")
391455
dprint(model)
392456
dprint("="*60 + "\n")
393457

@@ -522,6 +586,8 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
522586
ids, extra_generation_kwargs = pad_input_ids(prompts, min_pad_length=padding_length)
523587
else:
524588
ids = prompts
589+
if isinstance(ids, list) and len(ids) == 1:
590+
ids = ids[0].unsqueeze(0)
525591
extra_generation_kwargs = None
526592

527593

0 commit comments

Comments
 (0)