Skip to content

Commit 2d9ee84

Browse files
authored
Update errors.py
1 parent 82bead9 commit 2d9ee84

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

llama_cpp/server/errors.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import sys
44
import time
55
import traceback
6-
from collections.abc import Callable, Coroutine
76
from re import Match, Pattern, compile
8-
from typing import Dict, Optional, Tuple, Union
7+
from typing import Callable, Coroutine, Dict, Optional, Tuple, Union
98

109
from fastapi import (
1110
HTTPException,
@@ -28,8 +27,8 @@ class ErrorResponse(TypedDict):
2827

2928
message: str
3029
type: str
31-
param: str | None
32-
code: str | None
30+
param: Optional[str]
31+
code: Optional[str]
3332

3433

3534
class ErrorResponseFormatters:
@@ -46,9 +45,9 @@ class ErrorResponseFormatters:
4645

4746
@staticmethod
4847
def context_length_exceeded(
49-
request: CreateCompletionRequest | CreateChatCompletionRequest,
48+
request: Union[CreateCompletionRequest, CreateChatCompletionRequest],
5049
match, # type: Match[str] # type: ignore
51-
) -> tuple[int, ErrorResponse]:
50+
) -> Tuple[int, ErrorResponse]:
5251
"""Formatter for context length exceeded error"""
5352

5453
context_window = int(match.group(2))
@@ -84,9 +83,9 @@ def context_length_exceeded(
8483

8584
@staticmethod
8685
def model_not_found(
87-
request: CreateCompletionRequest | CreateChatCompletionRequest,
86+
request: Union[CreateCompletionRequest, CreateChatCompletionRequest],
8887
match, # type: Match[str] # type: ignore
89-
) -> tuple[int, ErrorResponse]:
88+
) -> Tuple[int, ErrorResponse]:
9089
"""Formatter for model_not_found error"""
9190

9291
model_path = str(match.group(1))
@@ -104,29 +103,35 @@ class RouteErrorHandler(APIRoute):
104103

105104
# key: regex pattern for original error message from llama_cpp
106105
# value: formatter function
107-
pattern_and_formatters: dict[
106+
pattern_and_formatters: Dict[
108107
Pattern[str],
109108
Callable[
110109
[
111-
CreateCompletionRequest | CreateChatCompletionRequest,
110+
Union[CreateCompletionRequest, CreateChatCompletionRequest],
112111
Match[str],
113112
],
114-
tuple[int, ErrorResponse],
113+
Tuple[int, ErrorResponse],
115114
],
116115
] = {
117116
compile(
118-
r"Requested tokens \((\d+)\) exceed context window of (\d+)",
117+
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
119118
): ErrorResponseFormatters.context_length_exceeded,
120119
compile(
121-
r"Model path does not exist: (.+)",
120+
r"Model path does not exist: (.+)"
122121
): ErrorResponseFormatters.model_not_found,
123122
}
124123

125124
def error_message_wrapper(
126125
self,
127126
error: Exception,
128-
body: CreateChatCompletionRequest | CreateCompletionRequest | CreateEmbeddingRequest | None = None,
129-
) -> tuple[int, ErrorResponse]:
127+
body: Optional[
128+
Union[
129+
CreateChatCompletionRequest,
130+
CreateCompletionRequest,
131+
CreateEmbeddingRequest,
132+
]
133+
] = None,
134+
) -> Tuple[int, ErrorResponse]:
130135
"""Wraps error message in OpenAI style error response"""
131136
print(f"Exception: {error!s}", file=sys.stderr)
132137
traceback.print_exc(file=sys.stderr)
@@ -174,7 +179,13 @@ async def custom_route_handler(request: Request) -> Response:
174179
try:
175180
if "messages" in json_body:
176181
# Chat completion
177-
body: CreateChatCompletionRequest | CreateCompletionRequest | CreateEmbeddingRequest | None = CreateChatCompletionRequest(**json_body)
182+
body: Optional[
183+
Union[
184+
CreateChatCompletionRequest,
185+
CreateCompletionRequest,
186+
CreateEmbeddingRequest,
187+
]
188+
] = CreateChatCompletionRequest(**json_body)
178189
elif "prompt" in json_body:
179190
# Text completion
180191
body = CreateCompletionRequest(**json_body)

0 commit comments

Comments
 (0)