-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathproxy.py
101 lines (83 loc) · 3.25 KB
/
proxy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from fastapi import APIRouter, Response, Request, Path, Query
from fastapi.responses import JSONResponse
# from libs.chains import (
# Context,
# ProviderSelectionHandler,
# ImageMessageHandler,
# ToolExtractionHandler,
# ToolResponseHandler,
# DefaultCompletionHandler,
# FallbackHandler,
# )
from libs import (
Context,
ProviderSelectionHandler,
ImageMessageHandler,
ToolExtractionHandler,
ToolResponseHandler,
DefaultCompletionHandler,
FallbackHandler,
)
from typing import Optional
router = APIRouter()
# Add get endpoint for /openai/v1 and print request body
@router.get("/{provider}/v1")
async def get_openai_v1(
response: Response, provider: str = Path(..., title="Provider")
) -> JSONResponse:
return JSONResponse(content={"message": f"GET request to {provider} v1"})
@router.post("/groqchain/{provider}/v1/chat/completions")
async def post_groq_chat_completions(
request: Request,
provider: str = Path(..., title="Provider")
) -> JSONResponse:
# Call the original post_chat_completions method with provider set to "groq"
return await post_chat_completions(request, provider="groq")
@router.post("/{provider}/v1/chat/completions")
async def post_chat_completions(
request: Request,
provider: str = Path(..., title="Provider")
) -> JSONResponse:
try:
if not provider:
provider = "openai"
if not ProviderSelectionHandler.provider_exists(provider):
return JSONResponse(content={"error": "Invalid provider"}, status_code=400)
# Extract the API token and body from the request
api_token = request.headers.get("Authorization").split("Bearer ")[1]
body = await request.json()
# Initialize the context with request details
context = Context(request, provider, body)
context.api_token = (
api_token # Adding the API token to the context for use in handlers
)
# Initialize and link the handlers
provider_selection_handler = ProviderSelectionHandler()
image_message_handler = ImageMessageHandler()
tool_extraction_handler = ToolExtractionHandler()
tool_response_handler = ToolResponseHandler()
default_completion_handler = DefaultCompletionHandler()
fallback_handler = FallbackHandler()
# Set up the chain of responsibility
chains = [
provider_selection_handler,
image_message_handler,
tool_extraction_handler,
tool_response_handler,
default_completion_handler,
fallback_handler,
]
for i in range(len(chains) - 1):
chains[i].set_next(chains[i + 1])
# provider_selection_handler.set_next(tool_extraction_handler).set_next(
# tool_response_handler
# ).set_next(default_completion_handler).set_next(fallback_handler)
# Execute the chain with the initial context
response = await provider_selection_handler.handle(context)
# Return the response generated by the handlers
return response
except Exception as e:
print(f"Error processing the request: {e}")
return JSONResponse(
content={"error": "An unexpected error occurred"}, status_code=500
)