Skip to content

Commit 1309f45

Browse files
authored
Merge pull request #635 from michaelfeil/rw/allow-onnx-selection
expose additional params to cli to allow for better onnx selection
2 parents 4355fef + f52db34 commit 1309f45

File tree

8 files changed

+38
-17
lines changed

8 files changed

+38
-17
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
android: true
4848
dotnet: true
4949
haskell: true
50-
large-packages: false
50+
large-packages: true
5151
docker-images: false
5252
swap-storage: false
5353
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}

libs/infinity_emb/infinity_emb/args.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class EngineArgs:
6868
lengths_via_tokenize: bool = MANAGER.lengths_via_tokenize[0]
6969
embedding_dtype: EmbeddingDtype = EmbeddingDtype[MANAGER.embedding_dtype[0]]
7070
served_model_name: str = MANAGER.served_model_name[0]
71+
onnx_disable_optimize: bool = MANAGER.onnx_disable_optimize[0]
72+
onnx_do_not_prefer_quantized: bool = MANAGER.onnx_do_not_prefer_quantized[0]
7173

7274
_loading_strategy: Optional[LoadingStrategy] = None
7375

@@ -160,8 +162,10 @@ def from_env(cls) -> list["EngineArgs"]:
160162
lengths_via_tokenize=lengths_via_tokenize,
161163
embedding_dtype=embedding_dtype,
162164
served_model_name=served_model_name,
165+
onnx_disable_optimize=onnx_disable_optimize,
166+
onnx_do_not_prefer_quantized=onnx_do_not_prefer_quantized
163167
)
164-
for model_name_or_path, batch_size, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name in zip_longest(
168+
for model_name_or_path, batch_size, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name,onnx_disable_optimize,onnx_do_not_prefer_quantized in zip_longest(
165169
MANAGER.model_id,
166170
MANAGER.batch_size,
167171
MANAGER.revision,
@@ -176,5 +180,7 @@ def from_env(cls) -> list["EngineArgs"]:
176180
MANAGER.lengths_via_tokenize,
177181
MANAGER.embedding_dtype,
178182
MANAGER.served_model_name,
183+
MANAGER.onnx_disable_optimize,
184+
MANAGER.onnx_do_not_prefer_quantized
179185
)
180186
]

libs/infinity_emb/infinity_emb/cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,14 @@ def v2(
270270
**_construct("proxy_root_path"),
271271
help="Proxy prefix for the application. See: https://fastapi.tiangolo.com/advanced/behind-a-proxy/",
272272
),
273+
onnx_disable_optimize: list[bool] = typer.Option(
274+
**_construct("onnx_disable_optimize"),
275+
help="Disable onnx optimization",
276+
),
277+
onnx_do_not_prefer_quantized: list[bool] = typer.Option(
278+
**_construct("onnx_do_not_prefer_quantized"),
279+
help="Do not use quantized onnx models by default if available",
280+
),
273281
):
274282
"""Infinity API ♾️ cli v2. MIT License. Copyright (c) 2023-now Michael Feil \n
275283
\n
@@ -309,6 +317,8 @@ def v2(
309317
permissive_cors, bool: add permissive CORS headers to enable consumption from a browser. Defaults to False.
310318
api_key, str: optional Bearer token for authentication. Defaults to "", which disables authentication.
311319
proxy_root_path, str: optional Proxy prefix for the application. See: https://fastapi.tiangolo.com/advanced/behind-a-proxy/
320+
onnx_disable_optimize, bool: disable onnx optimization
321+
onnx_do_not_prefer_quantized, bool: do not prefer quantized onnx model if its available
312322
"""
313323
logger.setLevel(log_level.to_int())
314324
device_id_typed = [DeviceID(d) for d in typer_option_resolve(device_id)]
@@ -330,6 +340,8 @@ def v2(
330340
compile=compile,
331341
bettertransformer=bettertransformer,
332342
served_model_name=served_model_name,
343+
onnx_disable_optimize=onnx_disable_optimize,
344+
onnx_do_not_prefer_quantized=onnx_do_not_prefer_quantized
333345
)
334346

335347
engine_args = []

libs/infinity_emb/infinity_emb/env.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,5 +260,15 @@ def device_id(self):
260260
def embedding_dtype(self) -> list[str]:
261261
return self._typed_multiple("embedding_dtype", EmbeddingDtype)
262262

263-
263+
@cached_property
264+
def onnx_disable_optimize(self):
265+
return self._to_bool_multiple(
266+
self._optional_infinity_var_multiple("onnx_disable_optimize", default=["false"])
267+
)
268+
269+
@cached_property
270+
def onnx_do_not_prefer_quantized(self):
271+
return self._to_bool_multiple(
272+
self._optional_infinity_var_multiple("onnx_do_not_prefer_quantized", default=["false"])
273+
)
264274
MANAGER = __Infinity_EnvManager()

libs/infinity_emb/infinity_emb/transformer/classifier/optimum.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Copyright (c) 2023-now michaelfeil
33

44
import copy
5-
import os
65

76
from infinity_emb._optional_imports import CHECK_ONNXRUNTIME, CHECK_TRANSFORMERS
87
from infinity_emb.args import EngineArgs
@@ -36,7 +35,7 @@ def __init__(self, *, engine_args: EngineArgs):
3635
model_name_or_path=engine_args.model_name_or_path,
3736
revision=engine_args.revision,
3837
use_auth_token=True,
39-
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()),
38+
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()) and not engine_args.onnx_do_not_prefer_quantized,
4039
)
4140

4241
model = optimize_model(
@@ -46,7 +45,7 @@ def __init__(self, *, engine_args: EngineArgs):
4645
trust_remote_code=engine_args.trust_remote_code,
4746
execution_provider=provider,
4847
file_name=onnx_file.as_posix(),
49-
optimize_model=not os.environ.get("INFINITY_ONNX_DISABLE_OPTIMIZE", False),
48+
optimize_model=not engine_args.onnx_disable_optimize
5049
)
5150
model.use_io_binding = False
5251

libs/infinity_emb/infinity_emb/transformer/crossencoder/optimum.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Copyright (c) 2023-now michaelfeil
33

44
import copy
5-
import os
65

76
import numpy as np
87

@@ -34,16 +33,14 @@ def __init__(self, *, engine_args: EngineArgs):
3433
model_name_or_path=engine_args.model_name_or_path,
3534
revision=engine_args.revision,
3635
use_auth_token=True,
37-
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()),
36+
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()) and not engine_args.onnx_do_not_prefer_quantized,
3837
)
3938

4039
self.model = optimize_model(
4140
engine_args.model_name_or_path,
4241
execution_provider=provider,
4342
file_name=onnx_file.as_posix(),
44-
optimize_model=not os.environ.get(
45-
"INFINITY_ONNX_DISABLE_OPTIMIZE", False
46-
), # TODO: make this env variable public
43+
optimize_model=not engine_args.onnx_disable_optimize,
4744
model_class=ORTModelForSequenceClassification,
4845
revision=engine_args.revision,
4946
trust_remote_code=engine_args.trust_remote_code,

libs/infinity_emb/infinity_emb/transformer/embedder/optimum.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Copyright (c) 2023-now michaelfeil
33

44
import copy
5-
import os
65

76
import numpy as np
87

@@ -42,7 +41,7 @@ def __init__(self, *, engine_args: EngineArgs):
4241
model_name_or_path=engine_args.model_name_or_path,
4342
revision=engine_args.revision,
4443
use_auth_token=True,
45-
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()),
44+
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()) and not engine_args.onnx_do_not_prefer_quantized,
4645
)
4746

4847
self.pooling = (
@@ -55,9 +54,7 @@ def __init__(self, *, engine_args: EngineArgs):
5554
trust_remote_code=engine_args.trust_remote_code,
5655
execution_provider=provider,
5756
file_name=onnx_file.as_posix(),
58-
optimize_model=not os.environ.get(
59-
"INFINITY_ONNX_DISABLE_OPTIMIZE", False
60-
), # TODO: make this env variable public
57+
optimize_model=not engine_args.onnx_disable_optimize,
6158
model_class=ORTModelForFeatureExtraction,
6259
)
6360
self.model.use_io_binding = False

libs/infinity_emb/infinity_emb/transformer/utils_optimum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def optimize_model(
144144
files_optimized = list(path_folder.glob(f"**/*{OPTIMIZED_SUFFIX}"))
145145

146146
logger.info(f"files_optimized: {files_optimized}")
147-
if files_optimized:
147+
if files_optimized and optimize_model:
148148
file_optimized = files_optimized[-1]
149149
logger.info(f"Optimized model found at {file_optimized}, skipping optimization")
150150
return model_class.from_pretrained(

0 commit comments

Comments
 (0)