Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/flare_ai_consensus/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,22 @@ class ChatRouter:
def __init__(
self,
router: APIRouter,
provider: AsyncOpenRouterProvider,
api_key: str,
base_url: str = "https://openrouter.ai/api/v1",
consensus_config: ConsensusConfig | None = None,
) -> None:
"""
Initialize the ChatRouter.

Args:
router (APIRouter): FastAPI router to attach endpoints.
provider: instance of an async OpenRouter client.
api_key: API key for OpenRouter.
base_url: Base URL for OpenRouter API.
consensus_config: config for running the consensus algorithm.
"""
self._router = router
self.provider = provider
self.api_key = api_key
self.base_url = base_url
if consensus_config:
self.consensus_config = consensus_config
self.logger = logger.bind(router="chat")
Expand All @@ -59,6 +62,11 @@ async def chat(message: ChatMessage) -> dict[str, str] | None: # pyright: ignor
Process a chat message through the CL pipeline.
Returns an aggregated response after a number of iterations.
"""
# Create a new provider for each request
provider = AsyncOpenRouterProvider(
api_key=self.api_key, base_url=self.base_url
)

try:
self.logger.debug("Received chat message", message=message.user_message)
# Build initial conversation
Expand All @@ -69,13 +77,15 @@ async def chat(message: ChatMessage) -> dict[str, str] | None: # pyright: ignor

# Run consensus algorithm
answer = await run_consensus(
self.provider,
provider,
self.consensus_config,
initial_conversation,
)

except Exception as e:
self.logger.exception("Chat processing failed", error=str(e))
# Make sure to close the provider even if an exception occurs
await provider.close()
raise HTTPException(status_code=500, detail=str(e)) from e
else:
self.logger.info("Response generated", answer=answer)
Expand Down
12 changes: 4 additions & 8 deletions src/flare_ai_consensus/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from fastapi.middleware.cors import CORSMiddleware

from flare_ai_consensus.api import ChatRouter
from flare_ai_consensus.router import AsyncOpenRouterProvider
from flare_ai_consensus.settings import settings
from flare_ai_consensus.utils import load_json

Expand All @@ -18,7 +17,7 @@ def create_app() -> FastAPI:
This function:
1. Creates a new FastAPI instance with optional CORS middleware.
2. Loads configuration.
3. Sets up the OpenRouter client.
3. Sets up the chat router with API key information.
4. Initializes a ChatRouter that wraps the RAG pipeline.
5. Registers the chat endpoint under the /chat prefix.

Expand All @@ -42,15 +41,12 @@ def create_app() -> FastAPI:
config_json = load_json(settings.input_path / "input.json")
settings.load_consensus_config(config_json)

# Initialize the OpenRouter provider.
provider = AsyncOpenRouterProvider(
api_key=settings.open_router_api_key, base_url=settings.open_router_base_url
)

# Create an APIRouter for chat endpoints and initialize ChatRouter.
# Instead of passing a provider, pass the API key and base URL
chat_router = ChatRouter(
router=APIRouter(),
provider=provider,
api_key=settings.open_router_api_key,
base_url=settings.open_router_base_url,
consensus_config=settings.consensus_config,
)
app.include_router(chat_router.router, prefix="/api/routes/chat", tags=["chat"])
Expand Down