@@ -56,6 +56,7 @@ class BuilderArgs:
56
56
gguf_kwargs : Optional [Dict [str , Any ]] = None
57
57
dso_path : Optional [Union [Path , str ]] = None
58
58
aoti_package_path : Optional [Union [Path , str ]] = None
59
+ snapshot_path : Optional [Union [Path , str ]] = None
59
60
pte_path : Optional [Union [Path , str ]] = None
60
61
device : Optional [str ] = None
61
62
precision : torch .dtype = torch .float32
@@ -87,6 +88,7 @@ def __post_init__(self):
87
88
or (self .dso_path and Path (self .dso_path ).is_file ())
88
89
or (self .aoti_package_path and Path (self .aoti_package_path ).is_file ())
89
90
or (self .pte_path and Path (self .pte_path ).is_file ())
91
+ or (self .snapshot_path and Path (self .snapshot_path ).is_file ())
90
92
):
91
93
raise RuntimeError (
92
94
"need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
@@ -142,6 +144,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
142
144
dso_path = getattr (args , "dso_path" , None )
143
145
pte_path = getattr (args , "pte_path" , None )
144
146
aoti_package_path = getattr (args , "aoti_package_path" , None )
147
+ snapshot_path = getattr (args , "snapshot_path" , None )
145
148
146
149
is_chat_model = False
147
150
if args .is_chat_model :
@@ -169,6 +172,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
169
172
output_pte_path = getattr (args , "output_pte_path" , None )
170
173
output_aoti_package_path = getattr (args , "output_aoti_package_path" , None )
171
174
output_dso_path = getattr (args , "output_dso_path" , None )
175
+ output_snapshot_path = getattr (args , "output_snapshot_path" , None )
172
176
if output_pte_path and args .dtype .startswith ("fast" ):
173
177
if args .dtype == "fast" :
174
178
# As per Kimish, float32 should be faster on ET XNNPACK
@@ -206,6 +210,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
206
210
dso_path = dso_path ,
207
211
aoti_package_path = aoti_package_path ,
208
212
pte_path = pte_path ,
213
+ snapshot_path = snapshot_path ,
209
214
device = args .device ,
210
215
precision = dtype ,
211
216
setup_caches = (
@@ -631,6 +636,34 @@ def do_nothing(max_batch_size, max_seq_length):
631
636
model = PTEModel (config , builder_args .pte_path )
632
637
except Exception :
633
638
raise RuntimeError (f"Failed to load ET compiled { builder_args .pte_path } " )
639
+ elif builder_args .snapshot_path :
640
+ # Resolve ModelArgs for constructing the PTEModel
641
+ # If a manual params_path is provided, use that
642
+ if builder_args .params_path :
643
+ config : ModelArgs = ModelArgs .from_params (builder_args .params_path )
644
+ else :
645
+ # TODO: Instead of loading the whole model, refactor to call a
646
+ # helper that generate just model.config
647
+ with measure_time ("Time to load model: {time:.02f} seconds" ):
648
+ model = _load_model (builder_args )
649
+ device_sync (device = builder_args .device )
650
+ config = model .config
651
+ model = None
652
+ try :
653
+ model = torch .load (builder_args .snapshot_path , weights_only = False )
654
+ except Exception :
655
+ raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
656
+ # _active_backend() does not allow DSO & AOTI to be true.
657
+ # Choose either.
658
+ from torchchat .utils .build_utils import set_backend
659
+ set_backend (dso = True , pte = False , aoti_package = False )
660
+ if (model .config != config ):
661
+ raise RuntimeError ("loaded model architecture mismatch" )
662
+ ##
663
+ ## import all libraries with custom kernels ans custom operators
664
+ ## that quantize may be pulling in
665
+ ##
666
+
634
667
elif builder_args .distributed :
635
668
pp_degree = builder_args .pp
636
669
tp_degree = builder_args .tp
0 commit comments