Skip to content

Commit e15563f

Browse files
authored
Update app.py
1 parent 9a13636 commit e15563f

File tree

1 file changed

+40
-34
lines changed

1 file changed

+40
-34
lines changed

llama_cpp/server/app.py

+40-34
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
import json
55
import os
66
import typing
7-
from collections.abc import Iterator
87
from functools import partial
98
from threading import Lock
10-
from typing import Dict, List, Optional, Union
9+
from typing import Dict, Iterator, List, Optional, Union
1110

1211
import anyio
1312
from anyio.streams.memory import MemoryObjectSendStream
@@ -45,7 +44,7 @@
4544

4645
router = APIRouter(route_class=RouteErrorHandler)
4746

48-
_server_settings: ServerSettings | None = None
47+
_server_settings: Optional[ServerSettings] = None
4948

5049

5150
def set_server_settings(server_settings: ServerSettings):
@@ -57,13 +56,13 @@ def get_server_settings():
5756
yield _server_settings
5857

5958

60-
_llama_proxy: LlamaProxy | None = None
59+
_llama_proxy: Optional[LlamaProxy] = None
6160

6261
llama_outer_lock = Lock()
6362
llama_inner_lock = Lock()
6463

6564

66-
def set_llama_proxy(model_settings: list[ModelSettings]):
65+
def set_llama_proxy(model_settings: List[ModelSettings]):
6766
global _llama_proxy
6867
_llama_proxy = LlamaProxy(models=model_settings)
6968

@@ -87,7 +86,7 @@ def get_llama_proxy():
8786
llama_outer_lock.release()
8887

8988

90-
_ping_message_factory: typing.Callable[[], bytes] | None = None
89+
_ping_message_factory: typing.Optional[typing.Callable[[], bytes]] = None
9190

9291

9392
def set_ping_message_factory(factory: typing.Callable[[], bytes]):
@@ -98,7 +97,7 @@ def set_ping_message_factory(factory: typing.Callable[[], bytes]):
9897
def create_app(
9998
settings: Settings | None = None,
10099
server_settings: ServerSettings | None = None,
101-
model_settings: list[ModelSettings] | None = None,
100+
model_settings: List[ModelSettings] | None = None,
102101
):
103102
config_file = os.environ.get("CONFIG_FILE", None)
104103
if config_file is not None:
@@ -110,7 +109,7 @@ def create_app(
110109
import yaml
111110

112111
config_file_settings = ConfigFileSettings.model_validate_json(
113-
json.dumps(yaml.safe_load(f)),
112+
json.dumps(yaml.safe_load(f))
114113
)
115114
else:
116115
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
@@ -157,7 +156,7 @@ async def get_event_publisher(
157156
request: Request,
158157
inner_send_chan: MemoryObjectSendStream[typing.Any],
159158
iterator: Iterator[typing.Any],
160-
on_complete: typing.Callable[[], None] | None = None,
159+
on_complete: typing.Optional[typing.Callable[[], None]] = None,
161160
):
162161
server_settings = next(get_server_settings())
163162
interrupt_requests = (
@@ -185,9 +184,9 @@ async def get_event_publisher(
185184

186185
def _logit_bias_tokens_to_input_ids(
187186
llama: llama_cpp.Llama,
188-
logit_bias: dict[str, float],
189-
) -> dict[str, float]:
190-
to_bias: dict[str, float] = {}
187+
logit_bias: Dict[str, float],
188+
) -> Dict[str, float]:
189+
to_bias: Dict[str, float] = {}
191190
for token, score in logit_bias.items():
192191
token = token.encode("utf-8")
193192
for input_id in llama.tokenize(token, add_bos=False, special=True):
@@ -201,7 +200,7 @@ def _logit_bias_tokens_to_input_ids(
201200

202201
async def authenticate(
203202
settings: Settings = Depends(get_server_settings),
204-
authorization: str | None = Depends(bearer_scheme),
203+
authorization: Optional[str] = Depends(bearer_scheme),
205204
):
206205
# Skip API key check if it's not set in settings
207206
if settings.api_key is None:
@@ -237,21 +236,21 @@ async def authenticate(
237236
"application/json": {
238237
"schema": {
239238
"anyOf": [
240-
{"$ref": "#/components/schemas/CreateCompletionResponse"},
239+
{"$ref": "#/components/schemas/CreateCompletionResponse"}
241240
],
242241
"title": "Completion response, when stream=False",
243-
},
242+
}
244243
},
245244
"text/event-stream": {
246245
"schema": {
247246
"type": "string",
248247
"title": "Server Side Streaming response, when stream=True. "
249248
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format",
250249
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
251-
},
250+
}
252251
},
253252
},
254-
},
253+
}
255254
},
256255
tags=[openai_v1_tag],
257256
)
@@ -267,7 +266,7 @@ async def create_completion(
267266
) -> llama_cpp.Completion:
268267
exit_stack = contextlib.ExitStack()
269268
llama_proxy = await run_in_threadpool(
270-
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)()),
269+
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
271270
)
272271
if llama_proxy is None:
273272
raise HTTPException(
@@ -281,7 +280,7 @@ async def create_completion(
281280
llama = llama_proxy(
282281
body.model
283282
if request.url.path != "/v1/engines/copilot-codex/completions"
284-
else "copilot-codex",
283+
else "copilot-codex"
285284
)
286285

287286
exclude = {
@@ -305,14 +304,17 @@ async def create_completion(
305304

306305
if body.min_tokens > 0:
307306
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
308-
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())],
307+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
309308
)
310309
if "logits_processor" not in kwargs:
311310
kwargs["logits_processor"] = _min_tokens_logits_processor
312311
else:
313312
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
314313

315-
iterator_or_completion: llama_cpp.CreateCompletionResponse | Iterator[llama_cpp.CreateCompletionStreamResponse] = await run_in_threadpool(llama, **kwargs)
314+
iterator_or_completion: Union[
315+
llama_cpp.CreateCompletionResponse,
316+
Iterator[llama_cpp.CreateCompletionStreamResponse],
317+
] = await run_in_threadpool(llama, **kwargs)
316318

317319
if isinstance(iterator_or_completion, Iterator):
318320
# EAFP: It's easier to ask for forgiveness than permission
@@ -338,7 +340,8 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
338340
sep="\n",
339341
ping_message_factory=_ping_message_factory,
340342
)
341-
return iterator_or_completion
343+
else:
344+
return iterator_or_completion
342345

343346

344347
@router.post(
@@ -370,22 +373,22 @@ async def create_embedding(
370373
"schema": {
371374
"anyOf": [
372375
{
373-
"$ref": "#/components/schemas/CreateChatCompletionResponse",
374-
},
376+
"$ref": "#/components/schemas/CreateChatCompletionResponse"
377+
}
375378
],
376379
"title": "Completion response, when stream=False",
377-
},
380+
}
378381
},
379382
"text/event-stream": {
380383
"schema": {
381384
"type": "string",
382385
"title": "Server Side Streaming response, when stream=True"
383386
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format",
384387
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
385-
},
388+
}
386389
},
387390
},
388-
},
391+
}
389392
},
390393
tags=[openai_v1_tag],
391394
)
@@ -437,7 +440,7 @@ async def create_chat_completion(
437440
"required": ["name", "age"],
438441
},
439442
},
440-
},
443+
}
441444
],
442445
"tool_choice": {
443446
"type": "function",
@@ -459,7 +462,7 @@ async def create_chat_completion(
459462
"top_logprobs": 10,
460463
},
461464
},
462-
},
465+
}
463466
),
464467
) -> llama_cpp.ChatCompletion:
465468
# This is a workaround for an issue in FastAPI dependencies
@@ -468,7 +471,7 @@ async def create_chat_completion(
468471
# https://github.com/tiangolo/fastapi/issues/11143
469472
exit_stack = contextlib.ExitStack()
470473
llama_proxy = await run_in_threadpool(
471-
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)()),
474+
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
472475
)
473476
if llama_proxy is None:
474477
raise HTTPException(
@@ -495,14 +498,16 @@ async def create_chat_completion(
495498

496499
if body.min_tokens > 0:
497500
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
498-
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())],
501+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
499502
)
500503
if "logits_processor" not in kwargs:
501504
kwargs["logits_processor"] = _min_tokens_logits_processor
502505
else:
503506
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
504507

505-
iterator_or_completion: llama_cpp.ChatCompletion | Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
508+
iterator_or_completion: Union[
509+
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
510+
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
506511

507512
if isinstance(iterator_or_completion, Iterator):
508513
# EAFP: It's easier to ask for forgiveness than permission
@@ -528,8 +533,9 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
528533
sep="\n",
529534
ping_message_factory=_ping_message_factory,
530535
)
531-
exit_stack.close()
532-
return iterator_or_completion
536+
else:
537+
exit_stack.close()
538+
return iterator_or_completion
533539

534540

535541
@router.get(

0 commit comments

Comments
 (0)