4
4
import json
5
5
import os
6
6
import typing
7
- from collections .abc import Iterator
8
7
from functools import partial
9
8
from threading import Lock
10
- from typing import Dict , List , Optional , Union
9
+ from typing import Dict , Iterator , List , Optional , Union
11
10
12
11
import anyio
13
12
from anyio .streams .memory import MemoryObjectSendStream
45
44
46
45
router = APIRouter (route_class = RouteErrorHandler )
47
46
48
- _server_settings : ServerSettings | None = None
47
+ _server_settings : Optional [ ServerSettings ] = None
49
48
50
49
51
50
def set_server_settings (server_settings : ServerSettings ):
@@ -57,13 +56,13 @@ def get_server_settings():
57
56
yield _server_settings
58
57
59
58
60
- _llama_proxy : LlamaProxy | None = None
59
+ _llama_proxy : Optional [ LlamaProxy ] = None
61
60
62
61
llama_outer_lock = Lock ()
63
62
llama_inner_lock = Lock ()
64
63
65
64
66
- def set_llama_proxy (model_settings : list [ModelSettings ]):
65
+ def set_llama_proxy (model_settings : List [ModelSettings ]):
67
66
global _llama_proxy
68
67
_llama_proxy = LlamaProxy (models = model_settings )
69
68
@@ -87,7 +86,7 @@ def get_llama_proxy():
87
86
llama_outer_lock .release ()
88
87
89
88
90
- _ping_message_factory : typing .Callable [[], bytes ] | None = None
89
+ _ping_message_factory : typing .Optional [ typing . Callable [[], bytes ]] = None
91
90
92
91
93
92
def set_ping_message_factory (factory : typing .Callable [[], bytes ]):
@@ -98,7 +97,7 @@ def set_ping_message_factory(factory: typing.Callable[[], bytes]):
98
97
def create_app (
99
98
settings : Settings | None = None ,
100
99
server_settings : ServerSettings | None = None ,
101
- model_settings : list [ModelSettings ] | None = None ,
100
+ model_settings : List [ModelSettings ] | None = None ,
102
101
):
103
102
config_file = os .environ .get ("CONFIG_FILE" , None )
104
103
if config_file is not None :
@@ -110,7 +109,7 @@ def create_app(
110
109
import yaml
111
110
112
111
config_file_settings = ConfigFileSettings .model_validate_json (
113
- json .dumps (yaml .safe_load (f )),
112
+ json .dumps (yaml .safe_load (f ))
114
113
)
115
114
else :
116
115
config_file_settings = ConfigFileSettings .model_validate_json (f .read ())
@@ -157,7 +156,7 @@ async def get_event_publisher(
157
156
request : Request ,
158
157
inner_send_chan : MemoryObjectSendStream [typing .Any ],
159
158
iterator : Iterator [typing .Any ],
160
- on_complete : typing .Callable [[], None ] | None = None ,
159
+ on_complete : typing .Optional [ typing . Callable [[], None ]] = None ,
161
160
):
162
161
server_settings = next (get_server_settings ())
163
162
interrupt_requests = (
@@ -185,9 +184,9 @@ async def get_event_publisher(
185
184
186
185
def _logit_bias_tokens_to_input_ids (
187
186
llama : llama_cpp .Llama ,
188
- logit_bias : dict [str , float ],
189
- ) -> dict [str , float ]:
190
- to_bias : dict [str , float ] = {}
187
+ logit_bias : Dict [str , float ],
188
+ ) -> Dict [str , float ]:
189
+ to_bias : Dict [str , float ] = {}
191
190
for token , score in logit_bias .items ():
192
191
token = token .encode ("utf-8" )
193
192
for input_id in llama .tokenize (token , add_bos = False , special = True ):
@@ -201,7 +200,7 @@ def _logit_bias_tokens_to_input_ids(
201
200
202
201
async def authenticate (
203
202
settings : Settings = Depends (get_server_settings ),
204
- authorization : str | None = Depends (bearer_scheme ),
203
+ authorization : Optional [ str ] = Depends (bearer_scheme ),
205
204
):
206
205
# Skip API key check if it's not set in settings
207
206
if settings .api_key is None :
@@ -237,21 +236,21 @@ async def authenticate(
237
236
"application/json" : {
238
237
"schema" : {
239
238
"anyOf" : [
240
- {"$ref" : "#/components/schemas/CreateCompletionResponse" },
239
+ {"$ref" : "#/components/schemas/CreateCompletionResponse" }
241
240
],
242
241
"title" : "Completion response, when stream=False" ,
243
- },
242
+ }
244
243
},
245
244
"text/event-stream" : {
246
245
"schema" : {
247
246
"type" : "string" ,
248
247
"title" : "Server Side Streaming response, when stream=True. "
249
248
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format" ,
250
249
"example" : """data: {... see CreateCompletionResponse ...} \\ n\\ n data: ... \\ n\\ n ... data: [DONE]""" ,
251
- },
250
+ }
252
251
},
253
252
},
254
- },
253
+ }
255
254
},
256
255
tags = [openai_v1_tag ],
257
256
)
@@ -267,7 +266,7 @@ async def create_completion(
267
266
) -> llama_cpp .Completion :
268
267
exit_stack = contextlib .ExitStack ()
269
268
llama_proxy = await run_in_threadpool (
270
- lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )()),
269
+ lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
271
270
)
272
271
if llama_proxy is None :
273
272
raise HTTPException (
@@ -281,7 +280,7 @@ async def create_completion(
281
280
llama = llama_proxy (
282
281
body .model
283
282
if request .url .path != "/v1/engines/copilot-codex/completions"
284
- else "copilot-codex" ,
283
+ else "copilot-codex"
285
284
)
286
285
287
286
exclude = {
@@ -305,14 +304,17 @@ async def create_completion(
305
304
306
305
if body .min_tokens > 0 :
307
306
_min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
308
- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())],
307
+ [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
309
308
)
310
309
if "logits_processor" not in kwargs :
311
310
kwargs ["logits_processor" ] = _min_tokens_logits_processor
312
311
else :
313
312
kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
314
313
315
- iterator_or_completion : llama_cpp .CreateCompletionResponse | Iterator [llama_cpp .CreateCompletionStreamResponse ] = await run_in_threadpool (llama , ** kwargs )
314
+ iterator_or_completion : Union [
315
+ llama_cpp .CreateCompletionResponse ,
316
+ Iterator [llama_cpp .CreateCompletionStreamResponse ],
317
+ ] = await run_in_threadpool (llama , ** kwargs )
316
318
317
319
if isinstance (iterator_or_completion , Iterator ):
318
320
# EAFP: It's easier to ask for forgiveness than permission
@@ -338,7 +340,8 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
338
340
sep = "\n " ,
339
341
ping_message_factory = _ping_message_factory ,
340
342
)
341
- return iterator_or_completion
343
+ else :
344
+ return iterator_or_completion
342
345
343
346
344
347
@router .post (
@@ -370,22 +373,22 @@ async def create_embedding(
370
373
"schema" : {
371
374
"anyOf" : [
372
375
{
373
- "$ref" : "#/components/schemas/CreateChatCompletionResponse" ,
374
- },
376
+ "$ref" : "#/components/schemas/CreateChatCompletionResponse"
377
+ }
375
378
],
376
379
"title" : "Completion response, when stream=False" ,
377
- },
380
+ }
378
381
},
379
382
"text/event-stream" : {
380
383
"schema" : {
381
384
"type" : "string" ,
382
385
"title" : "Server Side Streaming response, when stream=True"
383
386
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format" ,
384
387
"example" : """data: {... see CreateChatCompletionResponse ...} \\ n\\ n data: ... \\ n\\ n ... data: [DONE]""" ,
385
- },
388
+ }
386
389
},
387
390
},
388
- },
391
+ }
389
392
},
390
393
tags = [openai_v1_tag ],
391
394
)
@@ -437,7 +440,7 @@ async def create_chat_completion(
437
440
"required" : ["name" , "age" ],
438
441
},
439
442
},
440
- },
443
+ }
441
444
],
442
445
"tool_choice" : {
443
446
"type" : "function" ,
@@ -459,7 +462,7 @@ async def create_chat_completion(
459
462
"top_logprobs" : 10 ,
460
463
},
461
464
},
462
- },
465
+ }
463
466
),
464
467
) -> llama_cpp .ChatCompletion :
465
468
# This is a workaround for an issue in FastAPI dependencies
@@ -468,7 +471,7 @@ async def create_chat_completion(
468
471
# https://github.com/tiangolo/fastapi/issues/11143
469
472
exit_stack = contextlib .ExitStack ()
470
473
llama_proxy = await run_in_threadpool (
471
- lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )()),
474
+ lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
472
475
)
473
476
if llama_proxy is None :
474
477
raise HTTPException (
@@ -495,14 +498,16 @@ async def create_chat_completion(
495
498
496
499
if body .min_tokens > 0 :
497
500
_min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
498
- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())],
501
+ [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
499
502
)
500
503
if "logits_processor" not in kwargs :
501
504
kwargs ["logits_processor" ] = _min_tokens_logits_processor
502
505
else :
503
506
kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
504
507
505
- iterator_or_completion : llama_cpp .ChatCompletion | Iterator [llama_cpp .ChatCompletionChunk ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
508
+ iterator_or_completion : Union [
509
+ llama_cpp .ChatCompletion , Iterator [llama_cpp .ChatCompletionChunk ]
510
+ ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
506
511
507
512
if isinstance (iterator_or_completion , Iterator ):
508
513
# EAFP: It's easier to ask for forgiveness than permission
@@ -528,8 +533,9 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
528
533
sep = "\n " ,
529
534
ping_message_factory = _ping_message_factory ,
530
535
)
531
- exit_stack .close ()
532
- return iterator_or_completion
536
+ else :
537
+ exit_stack .close ()
538
+ return iterator_or_completion
533
539
534
540
535
541
@router .get (
0 commit comments