Skip to content

Commit 7cbf2a3

Browse files
mikekgfbmike94043
andauthored
Add export --output-snapshot-path snap.tc, and --snapshot-path snap.tc (#1465)
* support model snapshots to save quantized models * import set backend --------- Co-authored-by: Michael Gschwind <[email protected]>
1 parent 4356b4c commit 7cbf2a3

File tree

3 files changed

+94
-2
lines changed

3 files changed

+94
-2
lines changed

torchchat/cli/builder.py

+33
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class BuilderArgs:
5656
gguf_kwargs: Optional[Dict[str, Any]] = None
5757
dso_path: Optional[Union[Path, str]] = None
5858
aoti_package_path: Optional[Union[Path, str]] = None
59+
snapshot_path: Optional[Union[Path, str]] = None
5960
pte_path: Optional[Union[Path, str]] = None
6061
device: Optional[str] = None
6162
precision: torch.dtype = torch.float32
@@ -87,6 +88,7 @@ def __post_init__(self):
8788
or (self.dso_path and Path(self.dso_path).is_file())
8889
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
8990
or (self.pte_path and Path(self.pte_path).is_file())
91+
or (self.snapshot_path and Path(self.snapshot_path).is_file())
9092
):
9193
raise RuntimeError(
9294
"need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
@@ -142,6 +144,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
142144
dso_path = getattr(args, "dso_path", None)
143145
pte_path = getattr(args, "pte_path", None)
144146
aoti_package_path = getattr(args, "aoti_package_path", None)
147+
snapshot_path = getattr(args, "snapshot_path", None)
145148

146149
is_chat_model = False
147150
if args.is_chat_model:
@@ -169,6 +172,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
169172
output_pte_path = getattr(args, "output_pte_path", None)
170173
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
171174
output_dso_path = getattr(args, "output_dso_path", None)
175+
output_snapshot_path = getattr(args, "output_snapshot_path", None)
172176
if output_pte_path and args.dtype.startswith("fast"):
173177
if args.dtype == "fast":
174178
# As per Kimish, float32 should be faster on ET XNNPACK
@@ -206,6 +210,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
206210
dso_path=dso_path,
207211
aoti_package_path=aoti_package_path,
208212
pte_path=pte_path,
213+
snapshot_path=snapshot_path,
209214
device=args.device,
210215
precision=dtype,
211216
setup_caches=(
@@ -631,6 +636,34 @@ def do_nothing(max_batch_size, max_seq_length):
631636
model = PTEModel(config, builder_args.pte_path)
632637
except Exception:
633638
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
639+
elif builder_args.snapshot_path:
640+
# Resolve ModelArgs for constructing the PTEModel
641+
# If a manual params_path is provided, use that
642+
if builder_args.params_path:
643+
config: ModelArgs = ModelArgs.from_params(builder_args.params_path)
644+
else:
645+
# TODO: Instead of loading the whole model, refactor to call a
646+
# helper that generate just model.config
647+
with measure_time("Time to load model: {time:.02f} seconds"):
648+
model = _load_model(builder_args)
649+
device_sync(device=builder_args.device)
650+
config = model.config
651+
model = None
652+
try:
653+
model = torch.load(builder_args.snapshot_path, weights_only=False)
654+
except Exception:
655+
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
656+
# _active_backend() does not allow DSO & AOTI to be true.
657+
# Choose either.
658+
from torchchat.utils.build_utils import set_backend
659+
set_backend (dso=True, pte=False, aoti_package=False)
660+
if (model.config != config):
661+
raise RuntimeError("loaded model architecture mismatch")
662+
##
663+
## import all libraries with custom kernels ans custom operators
664+
## that quantize may be pulling in
665+
##
666+
634667
elif builder_args.distributed:
635668
pp_degree = builder_args.pp
636669
tp_degree = builder_args.tp

torchchat/cli/cli.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,12 @@ def _add_export_output_path_args(parser) -> None:
207207
default=None,
208208
help="Output to the specified AOT Inductor .dso model file",
209209
)
210+
exclusive_parser.add_argument(
211+
"--output-snapshot-path",
212+
type=str,
213+
default=None,
214+
help="Output to the specified PyTorch model and sha256 file",
215+
)
210216
exclusive_parser.add_argument(
211217
"--output-aoti-package-path",
212218
type=str,
@@ -254,7 +260,13 @@ def _add_exported_input_path_args(parser) -> None:
254260
default=None,
255261
help="Use the specified ExecuTorch .pte model file",
256262
)
257-
263+
exclusive_parser.add_argument(
264+
"--snapshot-path",
265+
type=Path,
266+
default=None,
267+
help="Use the specified torchchat snaphot .tc model file",
268+
)
269+
258270

259271
# Add CLI Args related to JIT downloading of model artifacts
260272
def _add_jit_downloading_args(parser) -> None:

torchchat/export.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,31 @@
2828
default_device = "cpu"
2929

3030

31+
"""
32+
Export Snapshot
33+
"""
34+
35+
36+
def export_snapshot(
37+
model: nn.Module,
38+
device: Optional[str] = None,
39+
output_path: str = "model-snapshot.tc",
40+
) -> str:
41+
"""
42+
Export the model as snapshot.
43+
44+
Args:
45+
model: The model to be exported.
46+
device: The device to run the model on.
47+
output_path: The path to save the exported model.
48+
Returns:
49+
The path to the exported model.
50+
"""
51+
assert output_path.endswith(".tc"), "use .tc extension for snapshots"
52+
torch.save(model, output_path)
53+
return output_path
54+
55+
3156
"""
3257
Export for Server
3358
"""
@@ -72,6 +97,7 @@ def export_for_server(
7297
"aot_inductor.package": package,
7398
"aot_inductor.metadata": metadata or {},
7499
}
100+
75101
if not package:
76102
options = {"aot_inductor.output_path": output_path}
77103

@@ -373,14 +399,15 @@ def main(args):
373399

374400
output_pte_path = args.output_pte_path
375401
output_dso_path = args.output_dso_path
402+
output_snapshot_path = args.output_snapshot_path
376403
output_aoti_package_path = args.output_aoti_package_path
377404

378405
if output_pte_path and builder_args.device != "cpu":
379406
print(
380407
f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={builder_args.device} setting."
381408
)
382409
builder_args.device = "cpu"
383-
elif "mps" in builder_args.device:
410+
elif (output_pte_path or output_dso_path or output_aoti_package_path) and "mps" in builder_args.device:
384411
print("Warning! Device MPS not supported for export. Exporting for device CPU.")
385412
builder_args.device = "cpu"
386413

@@ -417,6 +444,7 @@ def main(args):
417444
model_to_pte = model
418445
model_to_dso = model
419446
model_to_aoti_package = model
447+
model_to_snapshot = model
420448
else:
421449
if output_pte_path:
422450
_set_gguf_kwargs(builder_args, is_et=True, context="export")
@@ -436,6 +464,15 @@ def main(args):
436464
model_to_dso = model_to_aoti_package
437465
_unset_gguf_kwargs(builder_args)
438466

467+
if output_snapshot_path:
468+
_set_gguf_kwargs(builder_args, is_et=False, context="export")
469+
model_to_snapshot = _initialize_model(
470+
builder_args,
471+
quantize,
472+
support_tensor_subclass=False,
473+
)
474+
_unset_gguf_kwargs(builder_args)
475+
439476
with torch.no_grad():
440477
if output_pte_path:
441478
output_pte_path = str(os.path.abspath(output_pte_path))
@@ -483,3 +520,13 @@ def main(args):
483520
package=True,
484521
metadata=metadata,
485522
)
523+
524+
if output_snapshot_path:
525+
output_snapshot_path = str(os.path.abspath(output_snapshot_path))
526+
print(f"Exporting model using Snapshot to {output_snapshot_path}")
527+
export_snapshot(
528+
model_to_snapshot,
529+
builder_args.device,
530+
output_snapshot_path,
531+
)
532+

0 commit comments

Comments
 (0)