-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport_bundle.py
More file actions
72 lines (58 loc) · 2.19 KB
/
export_bundle.py
File metadata and controls
72 lines (58 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations
import argparse
from pathlib import Path
import mlx.core as mx
from mlx_lm import load
DEFAULT_SAMPLE_PROMPT = "Explain why MLX is useful for local inference on Apple Silicon."
def extract_logits(output):
if hasattr(output, "logits"):
return output.logits
if isinstance(output, tuple):
return output[0]
return output
def main() -> None:
parser = argparse.ArgumentParser(
description="Export a shapeless next-token MLX function for a local mlx-lm snapshot.",
)
parser.add_argument("--snapshot-dir", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument(
"--sample-prompt",
default=DEFAULT_SAMPLE_PROMPT,
help=(
"Example text used to generate sample input_ids for export. "
"This only seeds the example input tensor; it does not define the "
"runtime prompt format for your app."
),
)
parser.add_argument(
"--sample-prompt-file",
help="Optional text file whose contents override --sample-prompt.",
)
args = parser.parse_args()
snapshot_dir = Path(args.snapshot_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sample_prompt = (
Path(args.sample_prompt_file).read_text(encoding="utf-8")
if args.sample_prompt_file
else args.sample_prompt
)
model, tokenizer = load(str(snapshot_dir), lazy=False)
token_ids = tokenizer.encode(sample_prompt)
tokens = mx.array([token_ids], dtype=mx.int32)
def forward(input_ids):
output = extract_logits(model(input_ids))
return output[:, -1, :].astype(mx.float32)
export_path = output_dir / "function.mlxfn"
sample_inputs_path = output_dir / "inputs.safetensors"
if export_path.exists():
export_path.unlink()
if sample_inputs_path.exists():
sample_inputs_path.unlink()
mx.export_function(str(export_path), forward, tokens, shapeless=True)
mx.save_safetensors(str(sample_inputs_path), {"input_ids": tokens})
print(f"exported={export_path}")
print(f"inputs={sample_inputs_path}")
if __name__ == "__main__":
main()