Skip to content

Commit aa18836

Browse files
committed
Inference and merging
1 parent 7d79da7 commit aa18836

4 files changed

Lines changed: 79 additions & 99 deletions

File tree

scripts/merge_lora.py

Lines changed: 54 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,74 @@
1-
"""Merge a PEFT/LoRA checkpoint into its base causal LM.
1+
"""Merge a LoRA checkpoint into a standalone model.
22
3-
Example:
3+
Uses Unsloth for merging (same as training) to handle its internal patches
4+
correctly, and reproduces the exact tokenizer setup from training.
5+
6+
Usage:
7+
python scripts/merge_lora.py \
8+
--checkpoint output/squeez_qwen/checkpoint-500 \
9+
--output output/squeez_qwen_merged
10+
11+
# With explicit base model (if not auto-detected from adapter_config):
412
python scripts/merge_lora.py \
5-
--base-model Qwen/Qwen3.5-2B \
6-
--adapter-path output/squeez_qwen/checkpoint-800 \
7-
--output-dir output/squeez_qwen_merged
13+
--checkpoint output/squeez_qwen \
14+
--output output/squeez_qwen_merged \
15+
--base-model Qwen/Qwen3.5-2B
816
"""
917

1018
from __future__ import annotations
1119

1220
import argparse
21+
import json
1322
import logging
23+
from pathlib import Path
1424

1525
logger = logging.getLogger(__name__)
1626

1727

18-
def build_parser() -> argparse.ArgumentParser:
19-
parser = argparse.ArgumentParser(description="Merge a LoRA adapter into its base model")
20-
parser.add_argument("--base-model", required=True, help="Base model name or path")
21-
parser.add_argument("--adapter-path", required=True, help="Path to LoRA checkpoint")
22-
parser.add_argument("--output-dir", required=True, help="Directory to save merged model")
23-
parser.add_argument(
24-
"--dtype",
25-
choices=["auto", "bf16", "fp16", "fp32"],
26-
default="auto",
27-
help="Torch dtype to load the base model with before merging",
28-
)
29-
return parser
30-
31-
32-
def _resolve_dtype(dtype_name: str):
33-
import torch
34-
35-
if dtype_name == "bf16":
36-
return torch.bfloat16
37-
if dtype_name == "fp16":
38-
return torch.float16
39-
if dtype_name == "fp32":
40-
return torch.float32
41-
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
42-
return torch.bfloat16
43-
if torch.cuda.is_available():
44-
return torch.float16
45-
return torch.float32
46-
47-
4828
def main(argv: list[str] | None = None) -> int:
49-
args = build_parser().parse_args(argv)
50-
51-
logging.basicConfig(
52-
level=logging.INFO,
53-
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
29+
parser = argparse.ArgumentParser(description="Merge LoRA checkpoint into standalone model")
30+
parser.add_argument("--checkpoint", required=True, help="Path to LoRA checkpoint")
31+
parser.add_argument("--output", required=True, help="Output path for merged model")
32+
parser.add_argument("--base-model", default=None, help="Base model (auto-detected if omitted)")
33+
parser.add_argument("--config", default=None, help="YAML config file")
34+
args = parser.parse_args(argv)
35+
36+
logging.basicConfig(level=logging.INFO, format="%(message)s")
37+
38+
from unsloth import FastLanguageModel
39+
40+
from squeez.training.train import _prepare_text_tokenizer, load_config
41+
42+
config = load_config(args.config)
43+
max_length = config.get("max_length", 16384)
44+
45+
# Detect base model from adapter config if not provided
46+
base_model_name = args.base_model
47+
adapter_config_path = Path(args.checkpoint) / "adapter_config.json"
48+
if not base_model_name and adapter_config_path.exists():
49+
with open(adapter_config_path) as f:
50+
base_model_name = json.load(f).get("base_model_name_or_path", "")
51+
if not base_model_name:
52+
base_model_name = config.get("model", "Qwen/Qwen3.5-2B")
53+
54+
logger.info(f"Loading checkpoint from {args.checkpoint} (base: {base_model_name})")
55+
model, tokenizer = FastLanguageModel.from_pretrained(
56+
args.checkpoint,
57+
max_seq_length=max_length,
58+
load_in_4bit=False,
59+
load_in_16bit=True,
5460
)
5561

56-
from peft import PeftModel
57-
from transformers import AutoModelForCausalLM, AutoTokenizer
58-
59-
dtype = _resolve_dtype(args.dtype)
60-
61-
logger.info("Loading tokenizer from %s", args.base_model)
62-
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
62+
# Reproduce the same tokenizer patches as training
63+
tokenizer = _prepare_text_tokenizer(base_model_name, tokenizer)
6364

64-
logger.info("Loading base model from %s", args.base_model)
65-
model = AutoModelForCausalLM.from_pretrained(
66-
args.base_model,
67-
torch_dtype=dtype,
68-
trust_remote_code=True,
69-
device_map="auto",
65+
logger.info(f"Merging and saving to {args.output}")
66+
model.save_pretrained_merged(
67+
args.output,
68+
tokenizer,
69+
save_method="merged_16bit",
7070
)
71-
72-
logger.info("Loading adapter from %s", args.adapter_path)
73-
model = PeftModel.from_pretrained(model, args.adapter_path)
74-
75-
logger.info("Merging adapter into base model")
76-
model = model.merge_and_unload()
77-
78-
logger.info("Saving merged model to %s", args.output_dir)
79-
model.save_pretrained(args.output_dir, safe_serialization=True)
80-
tokenizer.save_pretrained(args.output_dir)
81-
82-
logger.info("Merge complete")
71+
logger.info(f"Done. Merged model saved to {args.output}")
8372
return 0
8473

8574

squeez/inference/extractor.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,6 @@ def _build_messages(task: str, tool_output: str) -> list[dict]:
7979
]
8080

8181

82-
def _format_prompt(task: str, tool_output: str) -> str:
83-
"""Format the input prompt using the ChatML template for local generation."""
84-
messages = _build_messages(task, tool_output)
85-
return (
86-
f"<|im_start|>system\n{messages[0]['content']}<|im_end|>\n"
87-
f"<|im_start|>user\n{messages[1]['content']}<|im_end|>\n"
88-
f"<|im_start|>assistant\n"
89-
)
90-
91-
9282
def _is_encoder_model(model_path: str) -> bool:
9383
"""Check if a model path contains a squeez-encoder model."""
9484
import json
@@ -385,7 +375,10 @@ def _extract_transformers(
385375
"""Extract using local transformers model."""
386376
import torch
387377

388-
prompt = _format_prompt(task, tool_output)
378+
messages = _build_messages(task, tool_output)
379+
prompt = self._tokenizer.apply_chat_template(
380+
messages, tokenize=False, add_generation_prompt=True
381+
)
389382

390383
inputs = self._tokenizer(
391384
prompt,

squeez/training/train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,12 @@ def train(args: argparse.Namespace):
204204
logger.info("Starting training...")
205205
trainer.train()
206206

207-
# 7. Save
208-
logger.info(f"Saving model to {output_dir}")
209-
trainer.save_model(output_dir)
207+
# 7. Save merged model (LoRA weights folded into base — standalone, no adapter needed)
208+
logger.info(f"Merging LoRA and saving full model to {output_dir}")
209+
merged_model = model.merge_and_unload()
210+
merged_model.save_pretrained(output_dir)
210211
tokenizer.save_pretrained(output_dir)
211-
logger.info("Training complete!")
212+
logger.info("Training complete! Saved merged model.")
212213

213214

214215
def build_parser(parser: argparse.ArgumentParser | None = None) -> argparse.ArgumentParser:

tests/test_extractor.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,30 @@
11
"""Tests for squeez core functionality."""
22

33
from squeez.data.config import SYSTEM_PROMPT
4-
from squeez.inference.extractor import _format_prompt, _load_config
4+
from squeez.inference.extractor import _build_messages, _load_config
55

66

7-
def test_format_prompt_basic():
8-
prompt = _format_prompt("Fix the bug", "class Foo:\n pass")
9-
assert "Fix the bug" in prompt
10-
assert "class Foo:" in prompt
11-
assert SYSTEM_PROMPT in prompt
12-
assert "<|im_start|>system" in prompt
13-
assert "<|im_start|>user" in prompt
14-
assert "<|im_start|>assistant" in prompt
15-
assert "<|im_end|>" in prompt
7+
def test_build_messages_basic():
8+
messages = _build_messages("Fix the bug", "class Foo:\n pass")
9+
assert len(messages) == 2
10+
assert messages[0]["role"] == "system"
11+
assert messages[0]["content"] == SYSTEM_PROMPT
12+
assert messages[1]["role"] == "user"
13+
assert "Fix the bug" in messages[1]["content"]
14+
assert "class Foo:" in messages[1]["content"]
1615

1716

18-
def test_format_prompt_truncates_long_task():
17+
def test_build_messages_truncates_long_task():
1918
long_task = "x" * 5000
20-
prompt = _format_prompt(long_task, "output")
21-
assert len(long_task) > 3000
22-
assert "..." in prompt
23-
task_section = prompt.split("<query>\n", 1)[1].split("\n</query>", 1)[0]
19+
messages = _build_messages(long_task, "output")
20+
task_section = messages[1]["content"].split("<query>\n", 1)[1].split("\n</query>", 1)[0]
2421
assert len(task_section) == 3003 # 3000 + "..."
2522

2623

27-
def test_format_prompt_empty_task():
28-
prompt = _format_prompt("", "some output")
29-
assert "<query>" not in prompt
30-
assert "some output" in prompt
24+
def test_build_messages_empty_task():
25+
messages = _build_messages("", "some output")
26+
assert "<query>" not in messages[1]["content"]
27+
assert "some output" in messages[1]["content"]
3128

3229

3330
def test_system_prompt_has_relevant_lines_format():

0 commit comments

Comments
 (0)