|
1 | | -"""Merge a PEFT/LoRA checkpoint into its base causal LM. |
| 1 | +"""Merge a LoRA checkpoint into a standalone model. |
2 | 2 |
|
3 | | -Example: |
| 3 | +Uses Unsloth for merging (same as training) to handle its internal patches |
| 4 | +correctly, and reproduces the exact tokenizer setup from training. |
| 5 | +
|
| 6 | +Usage: |
| 7 | + python scripts/merge_lora.py \ |
| 8 | + --checkpoint output/squeez_qwen/checkpoint-500 \ |
| 9 | + --output output/squeez_qwen_merged |
| 10 | +
|
| 11 | + # With explicit base model (if not auto-detected from adapter_config): |
4 | 12 | python scripts/merge_lora.py \ |
5 | | - --base-model Qwen/Qwen3.5-2B \ |
6 | | - --adapter-path output/squeez_qwen/checkpoint-800 \ |
7 | | - --output-dir output/squeez_qwen_merged |
| 13 | + --checkpoint output/squeez_qwen \ |
| 14 | + --output output/squeez_qwen_merged \ |
| 15 | + --base-model Qwen/Qwen3.5-2B |
8 | 16 | """ |
9 | 17 |
|
10 | 18 | from __future__ import annotations |
11 | 19 |
|
12 | 20 | import argparse |
| 21 | +import json |
13 | 22 | import logging |
| 23 | +from pathlib import Path |
14 | 24 |
|
15 | 25 | logger = logging.getLogger(__name__) |
16 | 26 |
|
17 | 27 |
|
18 | | -def build_parser() -> argparse.ArgumentParser: |
19 | | - parser = argparse.ArgumentParser(description="Merge a LoRA adapter into its base model") |
20 | | - parser.add_argument("--base-model", required=True, help="Base model name or path") |
21 | | - parser.add_argument("--adapter-path", required=True, help="Path to LoRA checkpoint") |
22 | | - parser.add_argument("--output-dir", required=True, help="Directory to save merged model") |
23 | | - parser.add_argument( |
24 | | - "--dtype", |
25 | | - choices=["auto", "bf16", "fp16", "fp32"], |
26 | | - default="auto", |
27 | | - help="Torch dtype to load the base model with before merging", |
28 | | - ) |
29 | | - return parser |
30 | | - |
31 | | - |
32 | | -def _resolve_dtype(dtype_name: str): |
33 | | - import torch |
34 | | - |
35 | | - if dtype_name == "bf16": |
36 | | - return torch.bfloat16 |
37 | | - if dtype_name == "fp16": |
38 | | - return torch.float16 |
39 | | - if dtype_name == "fp32": |
40 | | - return torch.float32 |
41 | | - if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): |
42 | | - return torch.bfloat16 |
43 | | - if torch.cuda.is_available(): |
44 | | - return torch.float16 |
45 | | - return torch.float32 |
46 | | - |
47 | | - |
48 | 28 | def main(argv: list[str] | None = None) -> int: |
49 | | - args = build_parser().parse_args(argv) |
50 | | - |
51 | | - logging.basicConfig( |
52 | | - level=logging.INFO, |
53 | | - format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| 29 | + parser = argparse.ArgumentParser(description="Merge LoRA checkpoint into standalone model") |
| 30 | + parser.add_argument("--checkpoint", required=True, help="Path to LoRA checkpoint") |
| 31 | + parser.add_argument("--output", required=True, help="Output path for merged model") |
| 32 | + parser.add_argument("--base-model", default=None, help="Base model (auto-detected if omitted)") |
| 33 | + parser.add_argument("--config", default=None, help="YAML config file") |
| 34 | + args = parser.parse_args(argv) |
| 35 | + |
| 36 | + logging.basicConfig(level=logging.INFO, format="%(message)s") |
| 37 | + |
| 38 | + from unsloth import FastLanguageModel |
| 39 | + |
| 40 | + from squeez.training.train import _prepare_text_tokenizer, load_config |
| 41 | + |
| 42 | + config = load_config(args.config) |
| 43 | + max_length = config.get("max_length", 16384) |
| 44 | + |
| 45 | + # Detect base model from adapter config if not provided |
| 46 | + base_model_name = args.base_model |
| 47 | + adapter_config_path = Path(args.checkpoint) / "adapter_config.json" |
| 48 | + if not base_model_name and adapter_config_path.exists(): |
| 49 | + with open(adapter_config_path) as f: |
| 50 | + base_model_name = json.load(f).get("base_model_name_or_path", "") |
| 51 | + if not base_model_name: |
| 52 | + base_model_name = config.get("model", "Qwen/Qwen3.5-2B") |
| 53 | + |
| 54 | + logger.info(f"Loading checkpoint from {args.checkpoint} (base: {base_model_name})") |
| 55 | + model, tokenizer = FastLanguageModel.from_pretrained( |
| 56 | + args.checkpoint, |
| 57 | + max_seq_length=max_length, |
| 58 | + load_in_4bit=False, |
| 59 | + load_in_16bit=True, |
54 | 60 | ) |
55 | 61 |
|
56 | | - from peft import PeftModel |
57 | | - from transformers import AutoModelForCausalLM, AutoTokenizer |
58 | | - |
59 | | - dtype = _resolve_dtype(args.dtype) |
60 | | - |
61 | | - logger.info("Loading tokenizer from %s", args.base_model) |
62 | | - tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) |
| 62 | + # Reproduce the same tokenizer patches as training |
| 63 | + tokenizer = _prepare_text_tokenizer(base_model_name, tokenizer) |
63 | 64 |
|
64 | | - logger.info("Loading base model from %s", args.base_model) |
65 | | - model = AutoModelForCausalLM.from_pretrained( |
66 | | - args.base_model, |
67 | | - torch_dtype=dtype, |
68 | | - trust_remote_code=True, |
69 | | - device_map="auto", |
| 65 | + logger.info(f"Merging and saving to {args.output}") |
| 66 | + model.save_pretrained_merged( |
| 67 | + args.output, |
| 68 | + tokenizer, |
| 69 | + save_method="merged_16bit", |
70 | 70 | ) |
71 | | - |
72 | | - logger.info("Loading adapter from %s", args.adapter_path) |
73 | | - model = PeftModel.from_pretrained(model, args.adapter_path) |
74 | | - |
75 | | - logger.info("Merging adapter into base model") |
76 | | - model = model.merge_and_unload() |
77 | | - |
78 | | - logger.info("Saving merged model to %s", args.output_dir) |
79 | | - model.save_pretrained(args.output_dir, safe_serialization=True) |
80 | | - tokenizer.save_pretrained(args.output_dir) |
81 | | - |
82 | | - logger.info("Merge complete") |
| 71 | + logger.info(f"Done. Merged model saved to {args.output}") |
83 | 72 | return 0 |
84 | 73 |
|
85 | 74 |
|
|
0 commit comments