-
Notifications
You must be signed in to change notification settings - Fork 2k
Expand file tree
/
Copy pathcompile.py
More file actions
152 lines (143 loc) · 4.65 KB
/
compile.py
File metadata and controls
152 lines (143 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Command line entrypoint of compilation."""
import argparse
import json
import re
from functools import partial
from pathlib import Path
from typing import Union
from mlc_llm.interface.compile import ( # pylint: disable=redefined-builtin
ModelConfigOverride,
OptimizationFlags,
compile,
)
from mlc_llm.interface.help import HELP
from mlc_llm.model import MODELS
from mlc_llm.quantization import QUANTIZATION
from mlc_llm.support.argparse import ArgumentParser
from mlc_llm.support.auto_config import (
detect_mlc_chat_config,
detect_model_type,
detect_quantization,
)
from mlc_llm.support.auto_target import detect_system_lib_prefix, detect_target_and_host
def main(argv):
"""Parse command line arguments and call `mlc_llm.compiler.compile`."""
def _parse_output(path: Union[str, Path]) -> Path:
path = Path(path)
if path.is_dir():
raise argparse.ArgumentTypeError(f"Output cannot be a directory: {path}")
parent = path.parent
if not parent.is_dir():
raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}")
return path
def _parse_dir(path: Union[str, Path], auto_create: bool = False) -> Path:
path = Path(path)
if not auto_create and not path.is_dir():
raise argparse.ArgumentTypeError(f"Directory does not exist: {path}")
if auto_create and not path.is_dir():
path.mkdir(parents=True)
return path
def _check_system_lib_prefix(prefix: str) -> str:
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$"
if prefix == "" or re.match(pattern, prefix):
return prefix
raise argparse.ArgumentTypeError(
"Invalid prefix. It should only consist of "
"numbers (0-9), alphabets (A-Z, a-z) and underscore (_)."
)
parser = ArgumentParser("mlc_llm compile")
parser.add_argument(
"model",
type=detect_mlc_chat_config,
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--quantization",
type=str,
choices=list(QUANTIZATION.keys()),
help=HELP["quantization"]
+ " (default: look up mlc-chat-config.json, choices: %(choices)s)",
)
parser.add_argument(
"--model-type",
type=str,
default="auto",
choices=["auto"] + list(MODELS.keys()),
help=HELP["model_type"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--device",
type=str,
default="auto",
help=HELP["device_compile"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--host",
type=str,
default="auto",
help=HELP["host"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--enable-subgroups",
action="store_true",
help=HELP["enable_subgroups"],
)
parser.add_argument(
"--opt",
type=OptimizationFlags.from_str,
default="O2",
help=HELP["opt"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--system-lib-prefix",
type=str,
default="auto",
help=HELP["system_lib_prefix"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--output",
"-o",
type=_parse_output,
required=True,
help=HELP["output_compile"] + " (required)",
)
parser.add_argument(
"--overrides",
type=ModelConfigOverride.from_str,
default="",
help=HELP["overrides"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--debug-dump",
type=partial(_parse_dir, auto_create=True),
default=None,
help=HELP["debug_dump"] + " (default: %(default)s)",
)
parsed = parser.parse_args(argv)
target, build_func = detect_target_and_host(
parsed.device,
parsed.host,
enable_subgroups=parsed.enable_subgroups,
)
parsed.model_type = detect_model_type(parsed.model_type, parsed.model)
parsed.quantization = detect_quantization(parsed.quantization, parsed.model)
parsed.system_lib_prefix = detect_system_lib_prefix(
parsed.device,
parsed.system_lib_prefix,
parsed.model_type.name,
parsed.quantization.name,
)
with open(parsed.model, "r", encoding="utf-8") as config_file:
config = json.load(config_file)
compile(
config=config,
quantization=parsed.quantization,
model_type=parsed.model_type,
target=target,
opt=parsed.opt,
build_func=build_func,
system_lib_prefix=parsed.system_lib_prefix,
output=parsed.output,
overrides=parsed.overrides,
debug_dump=parsed.debug_dump,
)