diff --git a/src/flare_ai_consensus/api/routes/chat.py b/src/flare_ai_consensus/api/routes/chat.py index bcd9a28..1e1650f 100644 --- a/src/flare_ai_consensus/api/routes/chat.py +++ b/src/flare_ai_consensus/api/routes/chat.py @@ -30,7 +30,8 @@ class ChatRouter: def __init__( self, router: APIRouter, - provider: AsyncOpenRouterProvider, + api_key: str, + base_url: str = "https://openrouter.ai/api/v1", consensus_config: ConsensusConfig | None = None, ) -> None: """ @@ -38,11 +39,13 @@ def __init__( Args: router (APIRouter): FastAPI router to attach endpoints. - provider: instance of an async OpenRouter client. + api_key: API key for OpenRouter. + base_url: Base URL for OpenRouter API. consensus_config: config for running the consensus algorithm. """ self._router = router - self.provider = provider + self.api_key = api_key + self.base_url = base_url if consensus_config: self.consensus_config = consensus_config self.logger = logger.bind(router="chat") @@ -59,6 +62,11 @@ async def chat(message: ChatMessage) -> dict[str, str] | None: # pyright: ignor Process a chat message through the CL pipeline. Returns an aggregated response after a number of iterations. """ + # Create a new provider for each request + provider = AsyncOpenRouterProvider( + api_key=self.api_key, base_url=self.base_url + ) + try: self.logger.debug("Received chat message", message=message.user_message) # Build initial conversation @@ -69,13 +77,15 @@ async def chat(message: ChatMessage) -> dict[str, str] | None: # pyright: ignor # Run consensus algorithm answer = await run_consensus( - self.provider, + provider, self.consensus_config, initial_conversation, ) except Exception as e: self.logger.exception("Chat processing failed", error=str(e)) + # Make sure to close the provider even if an exception occurs + await provider.close() raise HTTPException(status_code=500, detail=str(e)) from e else: self.logger.info("Response generated", answer=answer) diff --git a/src/flare_ai_consensus/main.py b/src/flare_ai_consensus/main.py index 3397b66..4456ce4 100644 --- a/src/flare_ai_consensus/main.py +++ b/src/flare_ai_consensus/main.py @@ -4,7 +4,6 @@ from fastapi.middleware.cors import CORSMiddleware from flare_ai_consensus.api import ChatRouter -from flare_ai_consensus.router import AsyncOpenRouterProvider from flare_ai_consensus.settings import settings from flare_ai_consensus.utils import load_json @@ -18,7 +17,7 @@ def create_app() -> FastAPI: This function: 1. Creates a new FastAPI instance with optional CORS middleware. 2. Loads configuration. - 3. Sets up the OpenRouter client. + 3. Sets up the chat router with API key information. 4. Initializes a ChatRouter that wraps the RAG pipeline. 5. Registers the chat endpoint under the /chat prefix. @@ -42,15 +41,12 @@ def create_app() -> FastAPI: config_json = load_json(settings.input_path / "input.json") settings.load_consensus_config(config_json) - # Initialize the OpenRouter provider. - provider = AsyncOpenRouterProvider( - api_key=settings.open_router_api_key, base_url=settings.open_router_base_url - ) - # Create an APIRouter for chat endpoints and initialize ChatRouter. + # Instead of passing a provider, pass the API key and base URL chat_router = ChatRouter( router=APIRouter(), - provider=provider, + api_key=settings.open_router_api_key, + base_url=settings.open_router_base_url, consensus_config=settings.consensus_config, ) app.include_router(chat_router.router, prefix="/api/routes/chat", tags=["chat"])