From 9773cf9f962a274b4e032d9a7365bd819a7d05c3 Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Wed, 10 May 2023 16:08:11 -0400 Subject: [PATCH 1/2] allow controlling the number of weight quantization groups when calling deepspeed.init_inference The number of weight quantization groups can now be set by passing 'quantize_groups' argument to deepspeed.init_inference --- deepspeed/__init__.py | 10 +++++++++- deepspeed/module_inject/replace_module.py | 5 +++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 255dacdccf6e..9fe3a0bcb211 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -21,7 +21,7 @@ from .runtime.hybrid_engine import DeepSpeedHybridEngine from .runtime.pipe.engine import PipelineEngine from .inference.engine import InferenceEngine -from .inference.config import DeepSpeedInferenceConfig +from .inference.config import DeepSpeedInferenceConfig, QuantizationConfig from .runtime.lr_schedules import add_tuning_arguments from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError from .runtime.activation_checkpointing import checkpointing @@ -328,6 +328,14 @@ def init_inference(model, config=None, **kwargs): raise ValueError(f"Conflicting argument '{key}' in 'config':{config_dict[key]} and kwargs:{kwargs[key]}") config_dict.update(kwargs) + # Set the number of weight quantization groups if an optional 'quantize_groups' argument is given + if "quantize_groups" in config_dict: + if not ("dtype", torch.int8) in config_dict.items(): + raise ValueError(f"'dtype' argument expected int8 when 'quantize_groups' argument is provided") + quant = QuantizationConfig() + quant.weight.q_groups = config_dict.pop("quantize_groups") + config_dict["quant"] = quant + ds_inference_config = DeepSpeedInferenceConfig(**config_dict) engine = InferenceEngine(model, config=ds_inference_config) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 11bea696deda..87e26904135e 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -291,6 +291,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m # defining globals as internally defined functions inherit these everywhere fp16 = (config.dtype == torch.float16 or config.dtype == torch.int8) quantize = (config.dtype == torch.int8) + quantize_groups = config.quant.weight.q_groups if quantize else 0 # todo: Refactor later. In future, let's minimize the style used above and use config.** instead linear_layer_setting = None @@ -336,7 +337,7 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, _container.convert_to_required_dtype(dtype=torch.half) # 5. Set the quantization config - quantizer = GroupQuantizer(q_int8=quantize) + quantizer = GroupQuantizer(q_int8=quantize, num_groups=quantize_groups) _container.set_quantization_config(quantize, quantizer) # 6. create a DS Inference config object @@ -500,7 +501,7 @@ def replace_fn(child, _policy, layer_id=0): replace_fn=replace_fn, _replace_policy=config.injection_policy_tuple) - quantizer = GroupQuantizer(q_int8=quantize) + quantizer = GroupQuantizer(q_int8=quantize, num_groups=quantize_groups) world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 if checkpoint_dict is not None: From 6eb0fc4600ae7492daf6e4dace4831a46e791d7e Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Fri, 12 May 2023 13:35:20 -0400 Subject: [PATCH 2/2] fix error message formatting --- deepspeed/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 9fe3a0bcb211..8458c083360e 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -331,7 +331,7 @@ def init_inference(model, config=None, **kwargs): # Set the number of weight quantization groups if an optional 'quantize_groups' argument is given if "quantize_groups" in config_dict: if not ("dtype", torch.int8) in config_dict.items(): - raise ValueError(f"'dtype' argument expected int8 when 'quantize_groups' argument is provided") + raise ValueError("'dtype' argument expected int8 when 'quantize_groups' argument is provided") quant = QuantizationConfig() quant.weight.q_groups = config_dict.pop("quantize_groups") config_dict["quant"] = quant