diff --git a/inference/huggingface/text-generation/inference-test.py b/inference/huggingface/text-generation/inference-test.py index 5baad0b91..a93a25e9a 100644 --- a/inference/huggingface/text-generation/inference-test.py +++ b/inference/huggingface/text-generation/inference-test.py @@ -7,26 +7,34 @@ from utils import DSPipeline +def bool_arg(x): + return x.lower() =='true' + parser = ArgumentParser() parser.add_argument("--name", required=True, type=str, help="model_name") parser.add_argument("--batch_size", default=1, type=int, help="batch size") parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type") -parser.add_argument("--ds_inference", default=True, type=bool, help="enable ds-inference") +parser.add_argument("--ds_inference", default=True, type=lambda x : bool_arg(x), help="enable ds-inference") parser.add_argument("--max_tokens", default=1024, type=int, help="maximum tokens used for the text-generation KV-cache") parser.add_argument("--max_new_tokens", default=50, type=int, help="maximum new tokens to generate") -parser.add_argument("--greedy", default=False, type=bool, help="greedy generation mode") -parser.add_argument("--use_meta_tensor", default=False, type=bool, help="use the meta tensors to initialize model") -parser.add_argument("--use_cache", default=True, type=bool, help="use cache for generation") +parser.add_argument("--greedy", default=False, type=lambda x : bool_arg(x), help="greedy generation mode") +parser.add_argument("--use_meta_tensor", default=False, type=lambda x : bool_arg(x), help="use the meta tensors to initialize model") +parser.add_argument("--hf_low_cpu_mem_usage", default=False, type=lambda x : bool_arg(x), help="use the low_cpu_mem_usage flag in huggingface to initialize model") +parser.add_argument("--use_cache", default=True, type=lambda x : bool_arg(x), help="use cache for generation") parser.add_argument("--local_rank", type=int, default=0, help="local rank") args = parser.parse_args() world_size = int(os.getenv('WORLD_SIZE', '1')) +if args.use_meta_tensor and args.hf_low_cpu_mem_usage: + raise ValueError("Cannot use both use_meta_tensor and hf_low_cpu_mem_usage") + data_type = getattr(torch, args.dtype) pipe = DSPipeline(model_name=args.name, dtype=data_type, is_meta=args.use_meta_tensor, + is_hf_low_cpu_mem_usage=args.hf_low_cpu_mem_usage, device=args.local_rank) if args.use_meta_tensor: diff --git a/inference/huggingface/text-generation/utils.py b/inference/huggingface/text-generation/utils.py index 39ddbd9a3..06dfe3f5c 100644 --- a/inference/huggingface/text-generation/utils.py +++ b/inference/huggingface/text-generation/utils.py @@ -20,6 +20,7 @@ def __init__(self, model_name='bigscience/bloom-3b', dtype=torch.float16, is_meta=True, + is_hf_low_cpu_mem_usage=False, device=-1 ): self.model_name = model_name @@ -48,7 +49,7 @@ def __init__(self, with deepspeed.OnDevice(dtype=torch.float16, device="meta"): self.model = AutoModelForCausalLM.from_config(self.config) else: - self.model = AutoModelForCausalLM.from_pretrained(self.model_name) + self.model = AutoModelForCausalLM.from_pretrained(self.model_name, low_cpu_mem_usage = is_hf_low_cpu_mem_usage) self.model.eval()