Skip to content

Commit 9453aaf

Browse files
authored
feat: passing status_code via code field in DIALException (#155)
1 parent d6b32ad commit 9453aaf

7 files changed

Lines changed: 77 additions & 115 deletions

File tree

aidial_sdk/_errors.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from fastapi.responses import JSONResponse
33

44
from aidial_sdk.exceptions import HTTPException as DIALException
5+
from aidial_sdk.exceptions import invalid_request_error
56
from aidial_sdk.pydantic_v1 import ValidationError
6-
from aidial_sdk.utils.errors import json_error
77

88

99
def pydantic_validation_exception_handler(
@@ -14,10 +14,8 @@ def pydantic_validation_exception_handler(
1414
error = exc.errors()[0]
1515
path = ".".join(map(str, error["loc"]))
1616
message = f"Your request contained invalid structure on path {path}. {error['msg']}"
17-
return JSONResponse(
18-
status_code=400,
19-
content=json_error(message=message, type="invalid_request_error"),
20-
)
17+
18+
return invalid_request_error(message).to_fastapi_response()
2119

2220

2321
def fastapi_exception_handler(request: Request, exc: Exception) -> JSONResponse:
@@ -30,13 +28,4 @@ def fastapi_exception_handler(request: Request, exc: Exception) -> JSONResponse:
3028

3129
def dial_exception_handler(request: Request, exc: Exception) -> JSONResponse:
3230
assert isinstance(exc, DIALException)
33-
return JSONResponse(
34-
status_code=exc.status_code,
35-
content=json_error(
36-
message=exc.message,
37-
type=exc.type,
38-
param=exc.param,
39-
code=exc.code,
40-
display_message=exc.display_message,
41-
),
42-
)
31+
return exc.to_fastapi_response()

aidial_sdk/chat_completion/response.py

Lines changed: 23 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
)
1212
from uuid import uuid4
1313

14-
from fastapi import HTTPException
15-
1614
from aidial_sdk.chat_completion.choice import Choice
1715
from aidial_sdk.chat_completion.chunks import (
1816
BaseChunk,
@@ -23,9 +21,9 @@
2321
UsagePerModelChunk,
2422
)
2523
from aidial_sdk.chat_completion.request import Request
26-
from aidial_sdk.exceptions import HTTPException as DialHttpException
27-
from aidial_sdk.exceptions import request_validation_error
28-
from aidial_sdk.utils.errors import json_error, runtime_error
24+
from aidial_sdk.exceptions import HTTPException as DIALException
25+
from aidial_sdk.exceptions import request_validation_error, runtime_server_error
26+
from aidial_sdk.utils.errors import RUNTIME_ERROR_MESSAGE, runtime_error
2927
from aidial_sdk.utils.logging import log_error, log_exception
3028
from aidial_sdk.utils.merge_chunks import merge
3129
from aidial_sdk.utils.streaming import DONE_MARKER, format_chunk
@@ -101,35 +99,22 @@ async def _generate_stream(
10199
end_chunk_generated = False
102100
try:
103101
self.user_task.result()
104-
except DialHttpException as e:
102+
except DIALException as e:
105103
if self.request.stream:
106104
self._queue.put_nowait(EndChunk(e))
107105
end_chunk_generated = True
108106
else:
109-
raise HTTPException(
110-
status_code=e.status_code,
111-
detail=json_error(
112-
message=e.message,
113-
type=e.type,
114-
param=e.param,
115-
code=e.code,
116-
display_message=e.display_message,
117-
),
118-
)
107+
raise e.to_fastapi_exception()
119108
except Exception as e:
120-
log_exception("Error during processing the request")
109+
log_exception(RUNTIME_ERROR_MESSAGE)
121110

122111
if self.request.stream:
123112
self._queue.put_nowait(EndChunk(e))
124113
end_chunk_generated = True
125114
else:
126-
raise HTTPException(
127-
status_code=500,
128-
detail=json_error(
129-
message="Error during processing the request",
130-
type="runtime_error",
131-
),
132-
)
115+
raise runtime_server_error(
116+
RUNTIME_ERROR_MESSAGE
117+
).to_fastapi_exception()
133118

134119
if not end_chunk_generated:
135120
self._queue.put_nowait(EndChunk())
@@ -168,41 +153,26 @@ async def _generate_stream(
168153
yield chunk
169154

170155
if item.exc:
171-
if isinstance(item.exc, DialHttpException):
172-
formatted_chunk = format_chunk(
173-
json_error(
174-
message=item.exc.message,
175-
type=item.exc.type,
176-
param=item.exc.param,
177-
code=item.exc.code,
178-
display_message=item.exc.display_message,
179-
)
180-
)
156+
if isinstance(item.exc, DIALException):
157+
formatted_chunk = format_chunk(item.exc.json_error())
181158
else:
182159
formatted_chunk = format_chunk(
183-
json_error(
184-
message="Error during processing the request",
185-
type="runtime_error",
186-
)
160+
runtime_server_error(
161+
RUNTIME_ERROR_MESSAGE
162+
).json_error()
187163
)
188164
yield formatted_chunk
189165
else:
190166
if self._last_choice_index != (self.request.n or 1):
191167
log_error("Not all choices were generated")
192168

193-
error = json_error(
194-
message="Error during processing the request",
195-
type="runtime_error",
196-
)
169+
error = runtime_server_error(RUNTIME_ERROR_MESSAGE)
197170

198171
if self.request.stream:
199-
formatted_chunk = format_chunk(error)
172+
formatted_chunk = format_chunk(error.json_error())
200173
yield formatted_chunk
201174
else:
202-
raise HTTPException(
203-
status_code=500,
204-
detail=error,
205-
)
175+
raise error.to_fastapi_exception()
206176

207177
if self.request.stream:
208178
yield format_chunk(DONE_MARKER)
@@ -229,26 +199,13 @@ async def _generator(
229199
if self.user_task in done:
230200
try:
231201
self.user_task.result()
232-
except DialHttpException as e:
233-
raise HTTPException(
234-
status_code=e.status_code,
235-
detail=json_error(
236-
message=e.message,
237-
type=e.type,
238-
param=e.param,
239-
code=e.code,
240-
display_message=e.display_message,
241-
),
242-
)
202+
except DIALException as e:
203+
raise e.to_fastapi_exception()
243204
except Exception:
244-
log_exception("Error during processing the request")
245-
raise HTTPException(
246-
status_code=500,
247-
detail=json_error(
248-
message="Error during processing the request",
249-
type="runtime_error",
250-
),
251-
)
205+
log_exception(RUNTIME_ERROR_MESSAGE)
206+
raise runtime_server_error(
207+
RUNTIME_ERROR_MESSAGE
208+
).to_fastapi_exception()
252209

253210
return get_task.result() if get_task in done else await get_task
254211

aidial_sdk/exceptions.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from http import HTTPStatus
22
from typing import Optional
33

4+
from fastapi import HTTPException as FastAPIException
5+
from fastapi.responses import JSONResponse
6+
7+
from aidial_sdk.utils.json import remove_nones
8+
49

510
class HTTPException(Exception):
611
def __init__(
@@ -12,11 +17,13 @@ def __init__(
1217
code: Optional[str] = None,
1318
display_message: Optional[str] = None,
1419
) -> None:
20+
status_code = int(status_code)
21+
1522
self.message = message
1623
self.status_code = status_code
1724
self.type = type
1825
self.param = param
19-
self.code = code
26+
self.code = code or str(status_code)
2027
self.display_message = display_message
2128

2229
def __repr__(self):
@@ -33,6 +40,31 @@ def __repr__(self):
3340
)
3441
)
3542

43+
def json_error(self) -> dict:
44+
return {
45+
"error": remove_nones(
46+
{
47+
"message": self.message,
48+
"type": self.type,
49+
"param": self.param,
50+
"code": self.code,
51+
"display_message": self.display_message,
52+
}
53+
)
54+
}
55+
56+
def to_fastapi_response(self) -> JSONResponse:
57+
return JSONResponse(
58+
status_code=self.status_code,
59+
content=self.json_error(),
60+
)
61+
62+
def to_fastapi_exception(self) -> FastAPIException:
63+
return FastAPIException(
64+
status_code=self.status_code,
65+
detail=self.json_error(),
66+
)
67+
3668

3769
def resource_not_found_error(message: str, **kwargs) -> HTTPException:
3870
"""

aidial_sdk/utils/errors.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,9 @@
1-
from typing import Optional
2-
31
from aidial_sdk.exceptions import runtime_server_error
4-
from aidial_sdk.utils.json import remove_nones
52
from aidial_sdk.utils.logging import log_error
63

4+
RUNTIME_ERROR_MESSAGE = "Error during processing the request"
5+
76

87
def runtime_error(reason: str):
98
log_error(reason)
10-
return runtime_server_error("Error during processing the request")
11-
12-
13-
def json_error(
14-
message: Optional[str] = None,
15-
type: Optional[str] = None,
16-
param: Optional[str] = None,
17-
code: Optional[str] = None,
18-
display_message: Optional[str] = None,
19-
):
20-
return {
21-
"error": remove_nones(
22-
{
23-
"message": message,
24-
"type": type,
25-
"param": param,
26-
"code": code,
27-
"display_message": display_message,
28-
}
29-
)
30-
}
9+
return runtime_server_error(RUNTIME_ERROR_MESSAGE)

tests/applications/broken_immediately.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
from fastapi import HTTPException as FastapiHTTPException
1+
from fastapi import HTTPException as FastAPIException
22

3-
from aidial_sdk import HTTPException
3+
from aidial_sdk import HTTPException as DIALException
44
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
55

66

77
def raise_exception(exception_type: str):
88
if exception_type == "sdk_exception":
9-
raise HTTPException("Test error", 503)
9+
raise DIALException("Test error", 503)
1010
elif exception_type == "fastapi_exception":
11-
raise FastapiHTTPException(504, detail="Test detail")
11+
raise FastAPIException(504, detail="Test detail")
1212
elif exception_type == "value_error_exception":
1313
raise ValueError("Test value error")
1414
elif exception_type == "zero_division_exception":
1515
return 1 / 0
1616
elif exception_type == "sdk_exception_with_display_message":
17-
raise HTTPException("Test error", 503, display_message="I'm broken")
17+
raise DIALException("Test error", 503, display_message="I'm broken")
1818
else:
19-
raise HTTPException("Unexpected error")
19+
raise DIALException("Unexpected error")
2020

2121

2222
class BrokenApplication(ChatCompletion):

tests/test_errors.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,33 @@
1212
"error": {
1313
"message": "Error during processing the request",
1414
"type": "runtime_error",
15+
"code": "500",
1516
}
1617
}
1718

1819
API_KEY_IS_MISSING = {
1920
"error": {
2021
"message": "Api-Key header is required",
2122
"type": "invalid_request_error",
23+
"code": "400",
2224
}
2325
}
2426

2527
error_testdata = [
28+
("fastapi_exception", 500, DEFAULT_RUNTIME_ERROR),
29+
("value_error_exception", 500, DEFAULT_RUNTIME_ERROR),
30+
("zero_division_exception", 500, DEFAULT_RUNTIME_ERROR),
2631
(
2732
"sdk_exception",
2833
503,
2934
{
3035
"error": {
3136
"message": "Test error",
3237
"type": "runtime_error",
38+
"code": "503",
3339
}
3440
},
3541
),
36-
("fastapi_exception", 500, DEFAULT_RUNTIME_ERROR),
37-
("value_error_exception", 500, DEFAULT_RUNTIME_ERROR),
38-
("zero_division_exception", 500, DEFAULT_RUNTIME_ERROR),
3942
(
4043
"sdk_exception_with_display_message",
4144
503,
@@ -44,6 +47,7 @@
4447
"message": "Test error",
4548
"type": "runtime_error",
4649
"display_message": "I'm broken",
50+
"code": "503",
4751
}
4852
},
4953
),

tests/utils/errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def invalid_request_error(path: str, message: str) -> Error:
1313
"error": {
1414
"message": f"Your request contained invalid structure on path {path}. {message}",
1515
"type": "invalid_request_error",
16+
"code": "400",
1617
}
1718
},
1819
)

0 commit comments

Comments
 (0)