Skip to content

Commit 65c2554

Browse files
committed
move up onnx
1 parent faf1ceb commit 65c2554

2 files changed

Lines changed: 19 additions & 18 deletions

File tree

benchmark/run.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,23 @@ def main(
9696
logger.info(f"Downloading model from {path_gcs_inference} ...")
9797
subprocess.run(["gcloud", "storage", "rsync", "-r", path_gcs_inference, dir_tmp], check=True)
9898

99+
# ONNX model
100+
logger.info("Loading ONNX model (first load triggers export, can take a couple minutes)")
101+
start = time.monotonic()
102+
model_onnx = gt.utils.SentenceTransformer(
103+
dir_tmp,
104+
backend="onnx",
105+
trust_remote_code=True,
106+
model_kwargs={"provider": "CUDAExecutionProvider"},
107+
text_prefix=text_prefix,
108+
)
109+
_ = model_onnx.encode("warm up")
110+
logger.info(f"ONNX model ready in {time.monotonic() - start:.1f}s")
111+
112+
times_onnx = _encode_timed(model_onnx, texts, desc="onnx")
113+
(model_onnx,) = release_memory(model_onnx)
114+
115+
# Compiled model
99116
logger.info("Loading compiled model")
100117
start = time.monotonic()
101118
model_compiled = gt.compiled.SentenceTransformer(
@@ -109,6 +126,7 @@ def main(
109126
times_compiled = _encode_timed(model_compiled, texts, desc="compiled")
110127
(model_compiled,) = release_memory(model_compiled)
111128

129+
# Base model
112130
logger.info("Loading base model")
113131
start = time.monotonic()
114132
model_base = gt.utils.SentenceTransformer(
@@ -120,23 +138,6 @@ def main(
120138
times_base = _encode_timed(model_base, texts, desc="base")
121139
(model_base,) = release_memory(model_base)
122140

123-
# ONNX export ignores dtype/attn_implementation, so we run it fp32 here. A post-export fp16
124-
# optimization pass is the next step if this looks promising.
125-
logger.info("Loading ONNX model (first load triggers export, can take a couple minutes)")
126-
start = time.monotonic()
127-
model_onnx = gt.utils.SentenceTransformer(
128-
dir_tmp,
129-
backend="onnx",
130-
trust_remote_code=True,
131-
model_kwargs={"provider": "CUDAExecutionProvider"},
132-
text_prefix=text_prefix,
133-
)
134-
_ = model_onnx.encode("warm up")
135-
logger.info(f"ONNX model ready in {time.monotonic() - start:.1f}s")
136-
137-
times_onnx = _encode_timed(model_onnx, texts, desc="onnx")
138-
(model_onnx,) = release_memory(model_onnx)
139-
140141
df_out = pl.DataFrame(
141142
{
142143
"query_stacktrace_string": texts,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ license = { file = "LICENSE" }
1111
dependencies = [
1212
"accelerate==1.12.0",
1313
"datasets==4.4.1",
14-
"onnxruntime-gpu>=1.22", # placeholder floor; pin once we know what resolves on cu128
14+
"onnxruntime-gpu==1.25.0",
1515
"optimum==1.27.0",
1616
"polars==1.32.0", # cudf lol
1717
"pydantic==2.11.9",

0 commit comments

Comments
 (0)