3
3
import sys
4
4
import time
5
5
import traceback
6
- from collections .abc import Callable , Coroutine
7
6
from re import Match , Pattern , compile
8
- from typing import Dict , Optional , Tuple , Union
7
+ from typing import Callable , Coroutine , Dict , Optional , Tuple , Union
9
8
10
9
from fastapi import (
11
10
HTTPException ,
@@ -28,8 +27,8 @@ class ErrorResponse(TypedDict):
28
27
29
28
message : str
30
29
type : str
31
- param : str | None
32
- code : str | None
30
+ param : Optional [ str ]
31
+ code : Optional [ str ]
33
32
34
33
35
34
class ErrorResponseFormatters :
@@ -46,9 +45,9 @@ class ErrorResponseFormatters:
46
45
47
46
@staticmethod
48
47
def context_length_exceeded (
49
- request : CreateCompletionRequest | CreateChatCompletionRequest ,
48
+ request : Union [ CreateCompletionRequest , CreateChatCompletionRequest ] ,
50
49
match , # type: Match[str] # type: ignore
51
- ) -> tuple [int , ErrorResponse ]:
50
+ ) -> Tuple [int , ErrorResponse ]:
52
51
"""Formatter for context length exceeded error"""
53
52
54
53
context_window = int (match .group (2 ))
@@ -84,9 +83,9 @@ def context_length_exceeded(
84
83
85
84
@staticmethod
86
85
def model_not_found (
87
- request : CreateCompletionRequest | CreateChatCompletionRequest ,
86
+ request : Union [ CreateCompletionRequest , CreateChatCompletionRequest ] ,
88
87
match , # type: Match[str] # type: ignore
89
- ) -> tuple [int , ErrorResponse ]:
88
+ ) -> Tuple [int , ErrorResponse ]:
90
89
"""Formatter for model_not_found error"""
91
90
92
91
model_path = str (match .group (1 ))
@@ -104,29 +103,35 @@ class RouteErrorHandler(APIRoute):
104
103
105
104
# key: regex pattern for original error message from llama_cpp
106
105
# value: formatter function
107
- pattern_and_formatters : dict [
106
+ pattern_and_formatters : Dict [
108
107
Pattern [str ],
109
108
Callable [
110
109
[
111
- CreateCompletionRequest | CreateChatCompletionRequest ,
110
+ Union [ CreateCompletionRequest , CreateChatCompletionRequest ] ,
112
111
Match [str ],
113
112
],
114
- tuple [int , ErrorResponse ],
113
+ Tuple [int , ErrorResponse ],
115
114
],
116
115
] = {
117
116
compile (
118
- r"Requested tokens \((\d+)\) exceed context window of (\d+)" ,
117
+ r"Requested tokens \((\d+)\) exceed context window of (\d+)"
119
118
): ErrorResponseFormatters .context_length_exceeded ,
120
119
compile (
121
- r"Model path does not exist: (.+)" ,
120
+ r"Model path does not exist: (.+)"
122
121
): ErrorResponseFormatters .model_not_found ,
123
122
}
124
123
125
124
def error_message_wrapper (
126
125
self ,
127
126
error : Exception ,
128
- body : CreateChatCompletionRequest | CreateCompletionRequest | CreateEmbeddingRequest | None = None ,
129
- ) -> tuple [int , ErrorResponse ]:
127
+ body : Optional [
128
+ Union [
129
+ CreateChatCompletionRequest ,
130
+ CreateCompletionRequest ,
131
+ CreateEmbeddingRequest ,
132
+ ]
133
+ ] = None ,
134
+ ) -> Tuple [int , ErrorResponse ]:
130
135
"""Wraps error message in OpenAI style error response"""
131
136
print (f"Exception: { error !s} " , file = sys .stderr )
132
137
traceback .print_exc (file = sys .stderr )
@@ -174,7 +179,13 @@ async def custom_route_handler(request: Request) -> Response:
174
179
try :
175
180
if "messages" in json_body :
176
181
# Chat completion
177
- body : CreateChatCompletionRequest | CreateCompletionRequest | CreateEmbeddingRequest | None = CreateChatCompletionRequest (** json_body )
182
+ body : Optional [
183
+ Union [
184
+ CreateChatCompletionRequest ,
185
+ CreateCompletionRequest ,
186
+ CreateEmbeddingRequest ,
187
+ ]
188
+ ] = CreateChatCompletionRequest (** json_body )
178
189
elif "prompt" in json_body :
179
190
# Text completion
180
191
body = CreateCompletionRequest (** json_body )
0 commit comments