diff --git a/conversion/README.md b/conversion/README.md new file mode 100644 index 0000000..9a2b8aa --- /dev/null +++ b/conversion/README.md @@ -0,0 +1,48 @@ +# Conversion + +The following example demonstrates how to convert a GPTQ model from the HF hub to Marlin format. + +### Install + +In addition to Marlin and PyTorch, install the following: + +```bash +pip install -U transformers accelerate auto-gptq optimum +``` + +### Convert GTPQ Model to Marlin Format + +The following converts the model from GPTQ to Marlin format. Note that this requires: +- `sym=true` +- `group_size=128` +- `desc_activations=false` + +```bash +python3 convert.py --model-id "TheBloke/Llama-2-7B-Chat-GPTQ" --save-path "./marlin-model" --do-generation +``` + +### Load Marlin Model + +The following loads the Marlin model from disk. + +```python +from load import load_model +from transformers import AutoTokenizer + +# Load model from disk. +model_path = "./marlin-model" +model = load_model(model_path).to("cuda") +tokenizer = AutoTokenizer.from_pretrained(model_path) + + +# Run inference to confirm it is working. +inputs = tokenizer("My favorite song is", return_tensors="pt") +inputs = {k: v.to("cuda") for k, v in inputs.items()} +outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False) +print(tokenizer.batch_decode(outputs)[0]) +``` + +Output: +```bash + My favorite song is "Bohemian Rhapsody" by Queen. I love the operatic vocals, the guitar solo, and the way the song builds from a slow ballad to a full-on rock anthem. I've been listening to it +``` \ No newline at end of file diff --git a/conversion/convert.py b/conversion/convert.py new file mode 100644 index 0000000..ae35539 --- /dev/null +++ b/conversion/convert.py @@ -0,0 +1,161 @@ +import torch, argparse, copy +from transformers import AutoModelForCausalLM, AutoTokenizer +from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear +from marlin import Layer as MarlinLayer +import gc + +parser = argparse.ArgumentParser() +parser.add_argument("--model-id", type=str) +parser.add_argument("--save-path", type=str) +parser.add_argument("--do-generation", action="store_true") + +def _validate_compatibility(model): + if not hasattr(model.config, "quantization_config"): + raise ValueError("Must be a quantized model to convert to Marlin Format") + quantization_config = model.config.quantization_config + if quantization_config.quant_method != "gptq": + raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}") + if quantization_config.bits != 4: + raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}") + if quantization_config.group_size != 128: + raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}") + if not quantization_config.sym: + raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}") + if quantization_config.desc_act: + raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}") + +@torch.no_grad() +def unpack_4bit_to_32bit_signed(qweight, qzeros): + # Unpack 4-bit values and interpret them as signed integers + unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False) + unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False) + + for row in range(unpacked_weights.shape[0]): + i = row % 8 + unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF + + for col in range(unpacked_zeros.shape[1]): + i = col % 8 + unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF + + return unpacked_weights, unpacked_zeros + 1 + +@torch.no_grad() +def dequantize_weight(layer): + qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales + unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros) + group_size = unpacked_qweight.shape[0] // scales.shape[0] + scales = scales.repeat_interleave(group_size, dim=0) + unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) + unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales + + return unpacked_qweight.T + +@torch.no_grad() +def convert_model(model, verbose=True): + for name, module in model.named_modules(): + if not isinstance(module, QuantLinear): + continue + + if verbose: + print(f"--- Converting Module: {name}") + parent_name = ".".join(name.split(".")[:-1]) + layer_name = name[len(parent_name) + 1:] + + # Dequantize the weight. + dequantized_weight = dequantize_weight(module).to(torch.float16) + linear_module = torch.nn.Linear( + in_features=dequantized_weight.shape[1], + out_features=dequantized_weight.shape[0], + bias=False, + dtype=torch.float16, + device="cuda") + linear_module.weight.data.copy_(dequantized_weight) + + # Create new linear method and copy to model. + new_module = MarlinLayer( + infeatures=linear_module.in_features, + outfeatures=linear_module.out_features, + groupsize=model.config.quantization_config.group_size) + new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t())) + + # Save to parent. + parent_module = model.get_submodule(parent_name) + setattr(parent_module, layer_name, new_module) + + # Free cuda memory. + del dequantized_weight, module + torch.cuda.empty_cache() + gc.collect() + + return model + +@torch.no_grad() +def dequantize_model(model, verbose=True): + for name, module in model.named_modules(): + if not isinstance(module, QuantLinear): + continue + + if verbose: + print(f"--- Dequantizing Module: {name}") + parent_name = ".".join(name.split(".")[:-1]) + layer_name = name[len(parent_name) + 1:] + + # Dequantize the weight. + dequantized_weight = dequantize_weight(module) + dequantized_weight_cpu = dequantized_weight.to("cpu") + + # Create new linear method and copy to model. + new_module = torch.nn.Linear( + in_features=dequantized_weight_cpu.shape[1], + out_features=dequantized_weight_cpu.shape[0], + bias=False, + dtype=torch.float16) + new_module.weight.data.copy_(dequantized_weight_cpu) + new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data)) + + # Save to parent. + parent_module = model.get_submodule(parent_name) + setattr(parent_module, layer_name, new_module) + + # Free cuda memory. + del dequantized_weight, dequantized_weight_cpu, module + torch.cuda.empty_cache() + + return model + +if __name__ == "__main__": + args = parser.parse_args() + model_id = args.model_id + save_path = args.save_path + do_generation = args.do_generation + + print("Loading gptq model...") + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # Validate that this model is compatible with Marlin. + print("Validating compatibility...") + _validate_compatibility(model) + + # Dequantize the Model. + print("Converting model...") + model = convert_model(model).to("cpu") + + # Save after updating quantization config. + print("Saving marlin model...") + model.config.quantization_config = { + "group_size": model.config.quantization_config.group_size, + "quant_method": "marlin" + } + model.save_pretrained(save_path) + tokenizer.save_pretrained(save_path) + + if do_generation: + print("Generating sample text...") + model.to("cuda") + prompt = "My favorite song is" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False) + print(tokenizer.batch_decode(outputs)[0]) diff --git a/conversion/load.py b/conversion/load.py new file mode 100644 index 0000000..834bf3c --- /dev/null +++ b/conversion/load.py @@ -0,0 +1,172 @@ +import torch +import numpy as np +from huggingface_hub import snapshot_download +from safetensors.torch import safe_open +from typing import Optional, Tuple, List, Iterator +import os, filelock, json, glob +from accelerate import init_empty_weights +from transformers import AutoModelForCausalLM, AutoConfig +import marlin + +# Adapted from https://github.com/vllm-project/vllm/blob/14cc317ba48229d93ee2417822d96ccb8db56abe/vllm/model_executor/weight_utils.py#L191 + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir if cache_dir is not None else "/tmp" + lock_file_name = model_name_or_path.replace("/", "-") + ".lock" + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) + return lock + +def prepare_hf_model_weights( + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + fall_back_to_pt: bool = True, + revision: Optional[str] = None, +) -> Tuple[str, List[str], bool]: + # Download model weights from huggingface. + is_local = os.path.isdir(model_name_or_path) + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == "auto": + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == "safetensors": + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == "pt": + allow_patterns = ["*.pt"] + elif load_format == "npcache": + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download(model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + revision=revision) + else: + hf_folder = model_name_or_path + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + if not use_safetensors: + # Exclude files that are not needed for inference. + # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + +def hf_model_weights_iterator( + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + fall_back_to_pt: Optional[bool] = True, +) -> Iterator[Tuple[str, torch.Tensor]]: + hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( + model_name_or_path, + cache_dir=cache_dir, + load_format=load_format, + fall_back_to_pt=fall_back_to_pt, + revision=revision) + + if load_format == "npcache": + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names = [] + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file, "r") as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + elif use_safetensors: + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + else: + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + yield name, param + del state + torch.cuda.empty_cache() + +@torch.no_grad() +def load_model(model_path): + with init_empty_weights(): + config = AutoConfig.from_pretrained(model_path) + + if not hasattr(config, "quantization_config"): + raise ValueError("Must be a Marlin quantized model, but your config has no quantization config.") + if "quant_method" not in config.quantization_config: + raise ValueError("Must be a Marlin quantized model, but your quantization config has no quant_method.") + if config.quantization_config["quant_method"] != "marlin": + raise ValueError(f"Must be a Marline model, but you passed a model with quant_method = {quant_method}") + + model = AutoModelForCausalLM.from_config(config) + marlin.replace_linear( + model.model, + groupsize=config.quantization_config["group_size"] + ) + + module_dict = dict(model.named_modules()) + for name, loaded_weight in hf_model_weights_iterator(model_path): + module_name = ".".join(name.split(".")[:-1]) + param_name = name[len(module_name) + 1:] + module = module_dict[module_name] + + if not hasattr(module, param_name): + raise ValueError("Key mismatch.") + + setattr(module, param_name, torch.nn.Parameter(loaded_weight, requires_grad=False)) + + return model \ No newline at end of file