@@ -96,6 +96,23 @@ def main(
9696 logger .info (f"Downloading model from { path_gcs_inference } ..." )
9797 subprocess .run (["gcloud" , "storage" , "rsync" , "-r" , path_gcs_inference , dir_tmp ], check = True )
9898
99+ # ONNX model
100+ logger .info ("Loading ONNX model (first load triggers export, can take a couple minutes)" )
101+ start = time .monotonic ()
102+ model_onnx = gt .utils .SentenceTransformer (
103+ dir_tmp ,
104+ backend = "onnx" ,
105+ trust_remote_code = True ,
106+ model_kwargs = {"provider" : "CUDAExecutionProvider" },
107+ text_prefix = text_prefix ,
108+ )
109+ _ = model_onnx .encode ("warm up" )
110+ logger .info (f"ONNX model ready in { time .monotonic () - start :.1f} s" )
111+
112+ times_onnx = _encode_timed (model_onnx , texts , desc = "onnx" )
113+ (model_onnx ,) = release_memory (model_onnx )
114+
115+ # Compiled model
99116 logger .info ("Loading compiled model" )
100117 start = time .monotonic ()
101118 model_compiled = gt .compiled .SentenceTransformer (
@@ -109,6 +126,7 @@ def main(
109126 times_compiled = _encode_timed (model_compiled , texts , desc = "compiled" )
110127 (model_compiled ,) = release_memory (model_compiled )
111128
129+ # Base model
112130 logger .info ("Loading base model" )
113131 start = time .monotonic ()
114132 model_base = gt .utils .SentenceTransformer (
@@ -120,23 +138,6 @@ def main(
120138 times_base = _encode_timed (model_base , texts , desc = "base" )
121139 (model_base ,) = release_memory (model_base )
122140
123- # ONNX export ignores dtype/attn_implementation, so we run it fp32 here. A post-export fp16
124- # optimization pass is the next step if this looks promising.
125- logger .info ("Loading ONNX model (first load triggers export, can take a couple minutes)" )
126- start = time .monotonic ()
127- model_onnx = gt .utils .SentenceTransformer (
128- dir_tmp ,
129- backend = "onnx" ,
130- trust_remote_code = True ,
131- model_kwargs = {"provider" : "CUDAExecutionProvider" },
132- text_prefix = text_prefix ,
133- )
134- _ = model_onnx .encode ("warm up" )
135- logger .info (f"ONNX model ready in { time .monotonic () - start :.1f} s" )
136-
137- times_onnx = _encode_timed (model_onnx , texts , desc = "onnx" )
138- (model_onnx ,) = release_memory (model_onnx )
139-
140141 df_out = pl .DataFrame (
141142 {
142143 "query_stacktrace_string" : texts ,
0 commit comments