55import re
66import sys
77
8-
98import infinity_emb
109from infinity_emb ._optional_imports import CHECK_TYPER , CHECK_UVICORN
1110from infinity_emb .args import EngineArgs
@@ -107,40 +106,41 @@ def _construct(name: str):
107106
108107 tp = typer .Typer ()
109108
109+
110110 @tp .command ("v1" )
111111 def v1 (
112- # v1 is deprecated. Please do no longer modify it.
113- model_name_or_path : str = MANAGER .model_id [0 ],
114- served_model_name : str = MANAGER .served_model_name [0 ],
115- batch_size : int = MANAGER .batch_size [0 ],
116- revision : str = MANAGER .revision [0 ],
117- trust_remote_code : bool = MANAGER .trust_remote_code [0 ],
118- redirect_slash : str = MANAGER .redirect_slash ,
119- engine : "InferenceEngine" = MANAGER .engine [0 ], # type: ignore # noqa
120- model_warmup : bool = MANAGER .model_warmup [0 ],
121- vector_disk_cache : bool = MANAGER .vector_disk_cache [0 ],
122- device : "Device" = MANAGER .device [0 ], # type: ignore
123- lengths_via_tokenize : bool = MANAGER .lengths_via_tokenize [0 ],
124- dtype : Dtype = MANAGER .dtype [0 ], # type: ignore
125- embedding_dtype : "EmbeddingDtype" = EmbeddingDtype .default_value (), # type: ignore
126- pooling_method : "PoolingMethod" = MANAGER .pooling_method [0 ], # type: ignore
127- compile : bool = MANAGER .compile [0 ],
128- bettertransformer : bool = MANAGER .bettertransformer [0 ],
129- preload_only : bool = MANAGER .preload_only ,
130- permissive_cors : bool = MANAGER .permissive_cors ,
131- api_key : str = MANAGER .api_key ,
132- url_prefix : str = MANAGER .url_prefix ,
133- host : str = MANAGER .host ,
134- port : int = MANAGER .port ,
135- log_level : "UVICORN_LOG_LEVELS" = MANAGER .log_level , # type: ignore
112+ # v1 is deprecated. Please do no longer modify it.
113+ model_name_or_path : str = MANAGER .model_id [0 ],
114+ served_model_name : str = MANAGER .served_model_name [0 ],
115+ batch_size : int = MANAGER .batch_size [0 ],
116+ revision : str = MANAGER .revision [0 ],
117+ trust_remote_code : bool = MANAGER .trust_remote_code [0 ],
118+ redirect_slash : str = MANAGER .redirect_slash ,
119+ engine : "InferenceEngine" = MANAGER .engine [0 ], # type: ignore # noqa
120+ model_warmup : bool = MANAGER .model_warmup [0 ],
121+ vector_disk_cache : bool = MANAGER .vector_disk_cache [0 ],
122+ device : "Device" = MANAGER .device [0 ], # type: ignore
123+ lengths_via_tokenize : bool = MANAGER .lengths_via_tokenize [0 ],
124+ dtype : Dtype = MANAGER .dtype [0 ], # type: ignore
125+ embedding_dtype : "EmbeddingDtype" = EmbeddingDtype .default_value (), # type: ignore
126+ pooling_method : "PoolingMethod" = MANAGER .pooling_method [0 ], # type: ignore
127+ compile : bool = MANAGER .compile [0 ],
128+ bettertransformer : bool = MANAGER .bettertransformer [0 ],
129+ preload_only : bool = MANAGER .preload_only ,
130+ permissive_cors : bool = MANAGER .permissive_cors ,
131+ api_key : str = MANAGER .api_key ,
132+ url_prefix : str = MANAGER .url_prefix ,
133+ host : str = MANAGER .host ,
134+ port : int = MANAGER .port ,
135+ log_level : "UVICORN_LOG_LEVELS" = MANAGER .log_level , # type: ignore
136136 ):
137137 """Infinity API ♾️ cli v1 - deprecated, consider use cli v2 via `infinity_emb v2`."""
138138 if api_key :
139139 # encourage switch to v2
140140 raise ValueError ("api_key is not supported in `v1`. Please migrate to `v2`." )
141141 if not (
142- embedding_dtype == EmbeddingDtype .float32
143- or embedding_dtype == EmbeddingDtype .default_value ()
142+ embedding_dtype == EmbeddingDtype .float32
143+ or embedding_dtype == EmbeddingDtype .default_value ()
144144 ):
145145 # encourage switch to v2
146146 raise ValueError (
@@ -177,107 +177,108 @@ def v1(
177177 proxy_root_path = "" , # set as empty string
178178 )
179179
180+
180181 @tp .command ("v2" )
181182 def v2 (
182- # t
183- # arguments for engine
184- model_id : list [str ] = typer .Option (
185- ** _construct ("model_id" ),
186- help = "Huggingface model repo id. Subset of possible models: https://huggingface.co/models?other=text-embeddings-inference&" ,
187- ),
188- served_model_name : list [str ] = typer .Option (
189- ** _construct ("served_model_name" ),
190- help = "the nickname for the API, under which the model_id can be selected" ,
191- ),
192- batch_size : list [int ] = typer .Option (
193- ** _construct ("batch_size" ), help = "maximum batch size for inference"
194- ),
195- revision : list [str ] = typer .Option (
196- ** _construct ("revision" ), help = "huggingface model repo revision."
197- ),
198- trust_remote_code : list [bool ] = typer .Option (
199- ** _construct ("trust_remote_code" ),
200- help = "if potential remote modeling code from huggingface repo is trusted." ,
201- ),
202- engine : list [InferenceEngine ] = typer .Option (
203- ** _construct ("engine" ),
204- help = "Which backend to use. `torch` uses Pytorch GPU/CPU, optimum uses ONNX on GPU/CPU/NVIDIA-TensorRT, `CTranslate2` uses torch+ctranslate2 on CPU/GPU." ,
205- ),
206- model_warmup : list [bool ] = typer .Option (
207- ** _construct ("model_warmup" ),
208- help = "if model should be warmed up after startup, and before ready." ,
209- ),
210- vector_disk_cache : list [bool ] = typer .Option (
211- ** _construct ("vector_disk_cache" ),
212- help = "If hash(request)/results should be cached to SQLite for latency improvement." ,
213- ),
214- device : list [Device ] = typer .Option (
215- ** _construct ("device" ),
216- help = "device to use for computing the model forward pass." ,
217- ),
218- device_id : list [str ] = typer .Option (
219- ** _construct ("device_id" ),
220- help = "device id defines the model placement. e.g. `0,1` will place the model on MPS/CUDA/GPU 0 and 1 each" ,
221- ),
222- lengths_via_tokenize : list [bool ] = typer .Option (
223- ** _construct ("lengths_via_tokenize" ),
224- help = "if True, returned tokens is based on actual tokenizer count. If false, uses len(input) as proxy." ,
225- ),
226- dtype : list [Dtype ] = typer .Option (
227- ** _construct ("dtype" ), help = "dtype for the model weights."
228- ),
229- embedding_dtype : list [EmbeddingDtype ] = typer .Option (
230- ** _construct ("embedding_dtype" ),
231- help = "dtype post-forward pass. If != `float32`, using Post-Forward Static quantization." ,
232- ),
233- pooling_method : list [PoolingMethod ] = typer .Option (
234- ** _construct ("pooling_method" ),
235- help = "overwrite the pooling method if inferred incorrectly." ,
236- ),
237- compile : list [bool ] = typer .Option (
238- ** _construct ("compile" ),
239- help = "Enable usage of `torch.compile(dynamic=True)` if engine relies on it." ,
240- ),
241- bettertransformer : list [bool ] = typer .Option (
242- ** _construct ("bettertransformer" ),
243- help = "Enables varlen flash-attention-2 via the `BetterTransformer` implementation. If available for this model." ,
244- ),
245- # arguments for uvicorn / server
246- preload_only : bool = typer .Option (
247- ** _construct ("preload_only" ),
248- help = "If true, only downloads models and verifies setup, then exit. Recommended for pre-caching the download in a Dockerfile." ,
249- ),
250- host : str = typer .Option (** _construct ("host" ), help = "host for the FastAPI uvicorn server" ),
251- port : int = typer .Option (** _construct ("port" ), help = "port for the FastAPI uvicorn server" ),
252- url_prefix : str = typer .Option (
253- ** _construct ("url_prefix" ),
254- callback = validate_url ,
255- help = "prefix for all routes of the FastAPI uvicorn server. Useful if you run behind a proxy / cascaded API." ,
256- ),
257- redirect_slash : str = typer .Option (
258- ** _construct ("redirect_slash" ), help = "where to redirect `/` requests to."
259- ),
260- log_level : "UVICORN_LOG_LEVELS" = typer .Option (
261- ** _construct ("log_level" ), help = "console log level."
262- ), # type: ignore
263- permissive_cors : bool = typer .Option (
264- ** _construct ("permissive_cors" ), help = "whether to allow permissive cors."
265- ),
266- api_key : str = typer .Option (
267- ** _construct ("api_key" ), help = "api_key used for authentication headers."
268- ),
269- proxy_root_path : str = typer .Option (
270- ** _construct ("proxy_root_path" ),
271- help = "Proxy prefix for the application. See: https://fastapi.tiangolo.com/advanced/behind-a-proxy/" ,
272- ),
273- onnx_disable_optimize : list [bool ] = typer .Option (
274- ** _construct ("onnx_disable_optimize" ),
275- help = "Disable onnx optimization" ,
276- ),
277- onnx_do_not_prefer_quantized : list [bool ] = typer .Option (
278- ** _construct ("onnx_do_not_prefer_quantized" ),
279- help = "Do not use quantized onnx models by default if available" ,
280- ),
183+ # t
184+ # arguments for engine
185+ model_id : list [str ] = typer .Option (
186+ ** _construct ("model_id" ),
187+ help = "Huggingface model repo id. Subset of possible models: https://huggingface.co/models?other=text-embeddings-inference&" ,
188+ ),
189+ served_model_name : list [str ] = typer .Option (
190+ ** _construct ("served_model_name" ),
191+ help = "the nickname for the API, under which the model_id can be selected" ,
192+ ),
193+ batch_size : list [int ] = typer .Option (
194+ ** _construct ("batch_size" ), help = "maximum batch size for inference"
195+ ),
196+ revision : list [str ] = typer .Option (
197+ ** _construct ("revision" ), help = "huggingface model repo revision."
198+ ),
199+ trust_remote_code : list [bool ] = typer .Option (
200+ ** _construct ("trust_remote_code" ),
201+ help = "if potential remote modeling code from huggingface repo is trusted." ,
202+ ),
203+ engine : list [InferenceEngine ] = typer .Option (
204+ ** _construct ("engine" ),
205+ help = "Which backend to use. `torch` uses Pytorch GPU/CPU, optimum uses ONNX on GPU/CPU/NVIDIA-TensorRT, `CTranslate2` uses torch+ctranslate2 on CPU/GPU." ,
206+ ),
207+ model_warmup : list [bool ] = typer .Option (
208+ ** _construct ("model_warmup" ),
209+ help = "if model should be warmed up after startup, and before ready." ,
210+ ),
211+ vector_disk_cache : list [bool ] = typer .Option (
212+ ** _construct ("vector_disk_cache" ),
213+ help = "If hash(request)/results should be cached to SQLite for latency improvement." ,
214+ ),
215+ device : list [Device ] = typer .Option (
216+ ** _construct ("device" ),
217+ help = "device to use for computing the model forward pass." ,
218+ ),
219+ device_id : list [str ] = typer .Option (
220+ ** _construct ("device_id" ),
221+ help = "device id defines the model placement. e.g. `0,1` will place the model on MPS/CUDA/GPU 0 and 1 each" ,
222+ ),
223+ lengths_via_tokenize : list [bool ] = typer .Option (
224+ ** _construct ("lengths_via_tokenize" ),
225+ help = "if True, returned tokens is based on actual tokenizer count. If false, uses len(input) as proxy." ,
226+ ),
227+ dtype : list [Dtype ] = typer .Option (
228+ ** _construct ("dtype" ), help = "dtype for the model weights."
229+ ),
230+ embedding_dtype : list [EmbeddingDtype ] = typer .Option (
231+ ** _construct ("embedding_dtype" ),
232+ help = "dtype post-forward pass. If != `float32`, using Post-Forward Static quantization." ,
233+ ),
234+ pooling_method : list [PoolingMethod ] = typer .Option (
235+ ** _construct ("pooling_method" ),
236+ help = "overwrite the pooling method if inferred incorrectly." ,
237+ ),
238+ compile : list [bool ] = typer .Option (
239+ ** _construct ("compile" ),
240+ help = "Enable usage of `torch.compile(dynamic=True)` if engine relies on it." ,
241+ ),
242+ bettertransformer : list [bool ] = typer .Option (
243+ ** _construct ("bettertransformer" ),
244+ help = "Enables varlen flash-attention-2 via the `BetterTransformer` implementation. If available for this model." ,
245+ ),
246+ # arguments for uvicorn / server
247+ preload_only : bool = typer .Option (
248+ ** _construct ("preload_only" ),
249+ help = "If true, only downloads models and verifies setup, then exit. Recommended for pre-caching the download in a Dockerfile." ,
250+ ),
251+ host : str = typer .Option (** _construct ("host" ), help = "host for the FastAPI uvicorn server" ),
252+ port : int = typer .Option (** _construct ("port" ), help = "port for the FastAPI uvicorn server" ),
253+ url_prefix : str = typer .Option (
254+ ** _construct ("url_prefix" ),
255+ callback = validate_url ,
256+ help = "prefix for all routes of the FastAPI uvicorn server. Useful if you run behind a proxy / cascaded API." ,
257+ ),
258+ redirect_slash : str = typer .Option (
259+ ** _construct ("redirect_slash" ), help = "where to redirect `/` requests to."
260+ ),
261+ log_level : "UVICORN_LOG_LEVELS" = typer .Option (
262+ ** _construct ("log_level" ), help = "console log level."
263+ ), # type: ignore
264+ permissive_cors : bool = typer .Option (
265+ ** _construct ("permissive_cors" ), help = "whether to allow permissive cors."
266+ ),
267+ api_key : str = typer .Option (
268+ ** _construct ("api_key" ), help = "api_key used for authentication headers."
269+ ),
270+ proxy_root_path : str = typer .Option (
271+ ** _construct ("proxy_root_path" ),
272+ help = "Proxy prefix for the application. See: https://fastapi.tiangolo.com/advanced/behind-a-proxy/" ,
273+ ),
274+ onnx_disable_optimize : list [bool ] = typer .Option (
275+ ** _construct ("onnx_disable_optimize" ),
276+ help = "Disable onnx optimization" ,
277+ ),
278+ onnx_do_not_prefer_quantized : list [bool ] = typer .Option (
279+ ** _construct ("onnx_do_not_prefer_quantized" ),
280+ help = "Do not use quantized onnx models by default if available" ,
281+ ),
281282 ):
282283 """Infinity API ♾️ cli v2. MIT License. Copyright (c) 2023-now Michael Feil \n
283284 \n
@@ -380,6 +381,8 @@ def v2(
380381 api_key = api_key ,
381382 proxy_root_path = proxy_root_path ,
382383 )
384+ # Update logging configs
385+ set_uvicorn_logging_configs ()
383386
384387 uvicorn .run (
385388 app ,
@@ -391,6 +394,47 @@ def v2(
391394 )
392395
393396
397+ def set_uvicorn_logging_configs ():
398+ """Configure Uvicorn logging with environment variable overrides.
399+
400+ Allows customization of log formats through environment variables:
401+ - INFINITY_UVICORN_DEFAULT_FORMAT: Format for default logs
402+ - INFINITY_UVICORN_ACCESS_FORMAT: Format for access logs
403+ - INFINITY_UVICORN_DATE_FORMAT: Date format for all logs
404+ """
405+ from uvicorn .config import LOGGING_CONFIG
406+ import os
407+
408+ # Define constants for environment variable names to improve maintainability
409+ default_format_env = MANAGER .uvicorn_default_format
410+ access_format_env = MANAGER .uvicorn_access_format
411+ date_format_env = MANAGER .uvicorn_date_format
412+
413+ # Default log format (can be overridden by env var)
414+ default_fmt = os .getenv (
415+ default_format_env ,
416+ "%(asctime)s %(levelprefix)s %(message)s"
417+ )
418+
419+ # Access log format (can be overridden by env var)
420+ access_fmt = os .getenv (
421+ access_format_env ,
422+ '%(asctime)s %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
423+ )
424+
425+ # Date format for all logs (can be overridden by env var)
426+ date_fmt = os .getenv (
427+ date_format_env ,
428+ "%Y-%m-%d %H:%M:%S"
429+ )
430+
431+ # Apply the configurations
432+ LOGGING_CONFIG ["formatters" ]["default" ]["fmt" ] = default_fmt
433+ LOGGING_CONFIG ["formatters" ]["default" ]["datefmt" ] = date_fmt
434+ LOGGING_CONFIG ["formatters" ]["access" ]["fmt" ] = access_fmt
435+ LOGGING_CONFIG ["formatters" ]["access" ]["datefmt" ] = date_fmt
436+
437+
394438def cli ():
395439 CHECK_TYPER .mark_required ()
396440 if len (sys .argv ) == 1 or sys .argv [1 ] not in [
0 commit comments