|
36 | 36 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
|
37 | 37 | from torchao.quantization.quant_api import (
|
38 | 38 | int4_weight_only,
|
| 39 | + int8_weight_only, |
39 | 40 | Int4WeightOnlyQuantizer,
|
40 | 41 | Int8DynActInt4WeightQuantizer,
|
41 | 42 | quantize_,
|
@@ -110,12 +111,20 @@ def quantize_model(
|
110 | 111 | if quantizer not in quantizer_class_dict:
|
111 | 112 | raise RuntimeError(f"unknown quantizer {quantizer} specified")
|
112 | 113 | else:
|
| 114 | + ao_quant = True |
113 | 115 | # Use tensor subclass API for int4 weight only.
|
114 | 116 | if device == "cuda" and quantizer == "linear:int4":
|
115 | 117 | quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
|
| 118 | + elif quantizer == "linear:int8": |
| 119 | + print("quantizer is linear int8") |
| 120 | + quantize_(model, int8_weight_only()) |
| 121 | + else: |
| 122 | + ao_quant = False |
| 123 | + if ao_quant: |
116 | 124 | if not support_tensor_subclass:
|
117 | 125 | unwrap_tensor_subclass(model)
|
118 | 126 | continue
|
| 127 | + |
119 | 128 |
|
120 | 129 | if quantizer in ["linear:a8wxdq", "embedding:wx"]:
|
121 | 130 | # These quantizers require float32 input weights. Note that after quantization,
|
@@ -529,147 +538,6 @@ def linear_int8_et(input, weight, scales):
|
529 | 538 | )
|
530 | 539 |
|
531 | 540 |
|
532 |
| -class WeightOnlyInt8Linear(nn.Module): |
533 |
| - __constants__ = ["in_features", "out_features"] |
534 |
| - in_features: int |
535 |
| - out_features: int |
536 |
| - weight: torch.Tensor |
537 |
| - scales: torch.Tensor |
538 |
| - |
539 |
| - def __init__( |
540 |
| - self, |
541 |
| - in_features, |
542 |
| - out_features, |
543 |
| - bias=None, |
544 |
| - device=None, |
545 |
| - dtype=None, |
546 |
| - *, |
547 |
| - weight: Optional[torch.Tensor] = None, |
548 |
| - scales: Optional[torch.Tensor] = None, |
549 |
| - groupsize: Optional[int] = None, |
550 |
| - ): |
551 |
| - super().__init__() |
552 |
| - if dtype is None: |
553 |
| - dtype = torch.get_default_dtype() |
554 |
| - |
555 |
| - if device is None: |
556 |
| - device = "cpu" |
557 |
| - |
558 |
| - assert not bias, "Bias is not supported by LinearInt8" |
559 |
| - self.in_features = in_features |
560 |
| - self.out_features = out_features |
561 |
| - |
562 |
| - assert (weight is None) == bool( |
563 |
| - scales is None |
564 |
| - ), "must specify both weights and scales, or neither" |
565 |
| - if weight is None: |
566 |
| - weight = torch.empty( |
567 |
| - (out_features, in_features), |
568 |
| - dtype=torch.int8, |
569 |
| - device=device, |
570 |
| - ) |
571 |
| - if groupsize is None or (groupsize == 0): |
572 |
| - scales = torch.empty(out_features, dtype=dtype, device=device) |
573 |
| - else: |
574 |
| - n_groups = (in_features + groupsize - 1) // groupsize |
575 |
| - scales = torch.empty(out_features, n_groups, dtype=dtype, device=device) |
576 |
| - |
577 |
| - self.register_buffer("weight", weight.to(device)) |
578 |
| - self.register_buffer("scales", scales.to(device)) |
579 |
| - |
580 |
| - if use_et_backend(): |
581 |
| - self.forward = self.et_forward |
582 |
| - else: |
583 |
| - self.forward = self.aoti_forward |
584 |
| - |
585 |
| - def aoti_forward(self, input: torch.Tensor) -> torch.Tensor: |
586 |
| - return linear_int8_aoti(input, self.weight, self.scales) |
587 |
| - |
588 |
| - def et_forward(self, input: torch.Tensor) -> torch.Tensor: |
589 |
| - return linear_int8_et(input, self.weight, self.scales) |
590 |
| - |
591 |
| - |
592 |
| -class WeightOnlyInt8QuantHandler(QuantHandler): |
593 |
| - def __init__( |
594 |
| - self, |
595 |
| - model: Optional[nn.Module] = None, |
596 |
| - device = None, |
597 |
| - precision=None, |
598 |
| - tokenizer=None, |
599 |
| - *, |
600 |
| - node_type: str = "*", |
601 |
| - bitwidth: Optional[int] = None, |
602 |
| - groupsize: Optional[int] = None, |
603 |
| - ): |
604 |
| - self.model_ = model |
605 |
| - self.device = device |
606 |
| - self.groupsize = groupsize |
607 |
| - self.node_type = node_type |
608 |
| - if bitwidth is None: |
609 |
| - self.bitwidth = 8 |
610 |
| - else: |
611 |
| - self.bitwidth = bitwidth |
612 |
| - |
613 |
| - @torch.no_grad() |
614 |
| - def quantize(self, module): |
615 |
| - # cur_state_dict = state_dict_device(self.model_.state_dict()) |
616 |
| - # dict_device = "cpu" # self.device |
617 |
| - |
618 |
| - if self.bitwidth == 4: |
619 |
| - range_min = -8 |
620 |
| - range_max = 7 |
621 |
| - elif self.bitwidth == 8: |
622 |
| - range_min = -128 |
623 |
| - range_max = 127 |
624 |
| - else: |
625 |
| - raise ValueError(f"Unsupported bitwidth {self.bitwidth}") |
626 |
| - |
627 |
| - for name, child in module.named_children(): |
628 |
| - # print(f"name: {name}") |
629 |
| - if isinstance(child, nn.Linear): |
630 |
| - if ( |
631 |
| - (self.node_type == "*") |
632 |
| - or (self.node_type == "output" and name == "output") |
633 |
| - or (self.node_type == "!output" and name != "output") |
634 |
| - ): |
635 |
| - # print(f"{name, child}") |
636 |
| - input_weight = child.weight.float() |
637 |
| - # print(f"{name, child}") |
638 |
| - # print(f"in_features: {child.in_features}") |
639 |
| - # print(f"out_features: {child.out_features}") |
640 |
| - |
641 |
| - # print(f"expanded weight shape {input_weight.shape}") |
642 |
| - weight, scales, _ = dynamically_quantize_per_channel( |
643 |
| - input_weight, |
644 |
| - range_min, |
645 |
| - range_max, |
646 |
| - torch.int8, |
647 |
| - self.groupsize, |
648 |
| - scales_dtype=child.weight.dtype, |
649 |
| - ) |
650 |
| - |
651 |
| - setattr( |
652 |
| - module, |
653 |
| - name, |
654 |
| - WeightOnlyInt8Linear( |
655 |
| - in_features=child.in_features, |
656 |
| - out_features=child.out_features, |
657 |
| - device=self.device, |
658 |
| - # update variables from quantization |
659 |
| - weight=weight, |
660 |
| - scales=scales, |
661 |
| - groupsize=self.groupsize, |
662 |
| - ), |
663 |
| - ) |
664 |
| - else: |
665 |
| - self.quantize(child) |
666 |
| - |
667 |
| - return module |
668 |
| - |
669 |
| - def quantized_model(self) -> nn.Module: |
670 |
| - return self.quantize(self.model_) |
671 |
| - |
672 |
| - |
673 | 541 | #########################################################################
|
674 | 542 | ##### embedding table quantization ######
|
675 | 543 | ### (unify with torchao in future) ###
|
@@ -886,10 +754,10 @@ def quantized_model(self) -> nn.Module:
|
886 | 754 | # class references
|
887 | 755 | quantizer_class_dict = {
|
888 | 756 | "embedding": EmbeddingOnlyQuantHandler,
|
889 |
| - "linear:int8": WeightOnlyInt8QuantHandler, |
890 | 757 | "precision": PrecisionHandler,
|
891 | 758 | "executor": ExecutorHandler,
|
892 | 759 | "linear:int4": Int4WeightOnlyQuantizer,
|
| 760 | + "linear:int8": int8_weight_only, |
893 | 761 | "linear:a8w4dq": Int8DynActInt4WeightQuantizer,
|
894 | 762 | }
|
895 | 763 |
|
@@ -917,6 +785,7 @@ def quantized_model(self) -> nn.Module:
|
917 | 785 | IntxWeightEmbeddingQuantizer,
|
918 | 786 | )
|
919 | 787 |
|
| 788 | + |
920 | 789 | quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
|
921 | 790 | quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
|
922 | 791 |
|
|
0 commit comments