|
| 1 | +# Standard |
1 | 2 | import argparse
|
| 3 | +from functools import partial |
2 | 4 | import itertools
|
3 | 5 | import json
|
4 | 6 | import os
|
| 7 | +from pathlib import Path |
5 | 8 | import random
|
6 |
| -import sys |
7 | 9 | import time
|
8 |
| -from pathlib import Path |
9 | 10 |
|
| 11 | +# Third Party |
10 | 12 | from aiu_fms_testing_utils.utils import aiu_setup
|
11 | 13 | from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
|
12 | 14 | import numpy as np
|
13 | 15 | import torch
|
14 |
| -import torch._inductor.config |
| 16 | +from torch import distributed as dist |
15 | 17 | from fms.models import get_model, register_model
|
16 | 18 | from fms.models.llama import LLaMAConfig, _llama_factory_factory
|
17 |
| -from fms.utils import fusion, generation, tokenizers |
| 19 | +from fms.utils import generation, tokenizers |
18 | 20 | from fms.utils.generation import generate, pad_input_ids
|
19 |
| -from torch import distributed as dist |
| 21 | + |
20 | 22 |
|
21 | 23 | # This example script validates the LLaMA implementation by running inference on a couple of prompts.
|
22 | 24 | #
|
|
59 | 61 | parser.add_argument(
|
60 | 62 | "--quantization",
|
61 | 63 | type=str,
|
62 |
| - choices=["gptq"], |
| 64 | + choices=["gptq", "int8"], |
63 | 65 | default=None,
|
64 | 66 | help="Type of quantization of the model checkpoint",
|
65 | 67 | )
|
| 68 | +parser.add_argument( |
| 69 | + "--int8_weight_per_channel", |
| 70 | + action="store_true", |
| 71 | + help="Enable per-channel weight quantization in INT8 quantized model", |
| 72 | +) |
| 73 | +parser.add_argument( |
| 74 | + "--int8_activ_quant_type", |
| 75 | + default="per_token", |
| 76 | + choices=["per_token", "per_tensor_symm", "per_tensor_asymm"], |
| 77 | + type=str, |
| 78 | + help="Define strategy for activation quantization in INT8 quantized model", |
| 79 | +) |
| 80 | +parser.add_argument( |
| 81 | + "--int8_smoothquant", |
| 82 | + action="store_true", |
| 83 | + help="Enable smoothquant in INT8 quantized model", |
| 84 | +) |
66 | 85 | parser.add_argument(
|
67 | 86 | "--tokenizer",
|
68 | 87 | type=str,
|
|
196 | 215 | args = parser.parse_args()
|
197 | 216 |
|
198 | 217 | if args.quantization == "gptq":
|
199 |
| - GPTQ_ENABLED = True |
200 |
| - try: |
201 |
| - if "aiu" in args.device_type: |
| 218 | + if "aiu" in args.device_type: |
| 219 | + try: |
202 | 220 | from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear
|
203 | 221 | print("Loaded `aiu_addons` functionalities")
|
204 |
| - elif args.device_type != "cpu": |
205 |
| - raise ValueError(f"Device {args.device_type} unsupported for GPTQ run") |
206 |
| - except ImportError as e: |
207 |
| - print(f"Failed to import addon packages: {e}") |
208 |
| - GPTQ_ENABLED = False |
209 |
| - |
210 |
| - if not GPTQ_ENABLED: |
211 |
| - raise Exception("GPTQ not enabled") |
| 222 | + except: |
| 223 | + raise ImportError("Failed to import GPTQ addons from fms-mo.") |
| 224 | +elif args.quantization == "int8": |
| 225 | + try: |
| 226 | + from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear |
| 227 | + print("Loaded `aiu_addons` functionalities") |
| 228 | + except: |
| 229 | + raise ImportError("Failed to import INT8 addons from fms-mo.") |
212 | 230 |
|
213 | 231 | # this is a test model config
|
214 | 232 | config = LLaMAConfig(
|
|
319 | 337 |
|
320 | 338 | fused_weights = not args.unfuse_weights
|
321 | 339 | if args.quantization == "gptq":
|
| 340 | + if fused_weights and is_aiu_backend: |
| 341 | + raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights") |
| 342 | + if default_dtype is not None: |
| 343 | + raise ValueError( |
| 344 | + "GPTQ default_dtype must be None to preserve the checkpoint data types." |
| 345 | + ) |
| 346 | + |
322 | 347 | if "aiu" in args.device_type:
|
323 | 348 | linear_type = "gptq_aiu"
|
324 | 349 | elif args.device_type == "cpu":
|
|
352 | 377 | "group_size": group_size,
|
353 | 378 | "desc_act": desc_act,
|
354 | 379 | }
|
355 |
| - # [ATTENTION] for GPTQ on AIU, we must always instantiate an unfused |
356 |
| - # model, the adapter will take care of converting key/values from |
357 |
| - # ckpt into the appropriate form for the model |
358 |
| - if fused_weights: |
359 |
| - raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights") |
360 |
| - default_dtype = None # GPTQ dtype always comes from ckpt, can't be enforced |
| 380 | +elif args.quantization == "int8": |
| 381 | + if fused_weights and is_aiu_backend: |
| 382 | + raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights") |
| 383 | + if default_dtype is not None: |
| 384 | + raise ValueError( |
| 385 | + "INT8 default_dtype must be None to preserve the checkpoint data types." |
| 386 | + ) |
| 387 | + |
| 388 | + def select_int8_module( |
| 389 | + module_name: str | None = None, |
| 390 | + smoothquant: bool = True, |
| 391 | + smoothquant_layers: list[str] | None = None, |
| 392 | + ): |
| 393 | + if module_name is None: |
| 394 | + return "int8_aiu" |
| 395 | + smoothquant_on_module = ( |
| 396 | + any([m in module_name for m in smoothquant_layers]) |
| 397 | + if smoothquant_layers is not None |
| 398 | + else True |
| 399 | + ) |
| 400 | + use_smoothquant = smoothquant and smoothquant_on_module |
| 401 | + return "int8_smoothquant_aiu" if use_smoothquant else "int8_aiu" |
| 402 | + |
| 403 | + if args.int8_smoothquant: |
| 404 | + # TODO: consider saving this info into config during quantization |
| 405 | + if any("granite" in p.lower() for p in [args.model_path, args.architecture]): |
| 406 | + smoothquant_layers = ["key", "value", "w1", "wg"] |
| 407 | + elif any("roberta" in p.lower() for p in [args.model_path, args.architecture]): |
| 408 | + smoothquant_layers = ["query", "key", "value", "w1"] |
| 409 | + else: |
| 410 | + raise NotImplementedError( |
| 411 | + "INT8 architecture does not support smoothquant." |
| 412 | + ) |
| 413 | + else: |
| 414 | + smoothquant_layers = [] |
| 415 | + |
| 416 | + linear_config = { |
| 417 | + "linear_type": partial( |
| 418 | + select_int8_module, |
| 419 | + smoothquant = args.int8_smoothquant, |
| 420 | + smoothquant_layers = smoothquant_layers, |
| 421 | + ), |
| 422 | + "weight_per_channel": args.int8_weight_per_channel, |
| 423 | + "activ_quant_type": args.int8_activ_quant_type, |
| 424 | + } |
361 | 425 | else:
|
362 | 426 | linear_config = {"linear_type": "torch_linear"}
|
363 | 427 |
|
|
381 | 445 | fused_weights=fused_weights,
|
382 | 446 | )
|
383 | 447 |
|
384 |
| -if args.quantization == "gptq": |
| 448 | +if args.quantization in ["gptq", "int8"]: |
385 | 449 | if rank == 0 and args.verbose > 0:
|
386 | 450 | dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters()))
|
387 | 451 | dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers()))
|
388 | 452 | dprint("="*60 + "\n")
|
389 | 453 | if args.architecture == "llama":
|
390 |
| - dprint("[NOTE] It's OK for unused keys to contain bias and rotary embeddings, in GPTQ LLaMA models") |
| 454 | + dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.") |
391 | 455 | dprint(model)
|
392 | 456 | dprint("="*60 + "\n")
|
393 | 457 |
|
@@ -522,6 +586,8 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
|
522 | 586 | ids, extra_generation_kwargs = pad_input_ids(prompts, min_pad_length=padding_length)
|
523 | 587 | else:
|
524 | 588 | ids = prompts
|
| 589 | + if isinstance(ids, list) and len(ids) == 1: |
| 590 | + ids = ids[0].unsqueeze(0) |
525 | 591 | extra_generation_kwargs = None
|
526 | 592 |
|
527 | 593 |
|
|
0 commit comments