Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nilai-api/src/nilai_api/config/web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ class WebSearchSettings(BaseModel):
max_concurrent_requests: int = Field(
default=20, description="Maximum concurrent requests"
)
rps: int = Field(default=20, description="Requests per second limit")
rps: Optional[int] = Field(default=20, description="Requests per second limit")
64 changes: 46 additions & 18 deletions nilai-api/src/nilai_api/handlers/web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from functools import lru_cache
from typing import List, Dict, Any

from fastapi import HTTPException, status, Request
from nilai_api.rate_limiting import RateLimit

import httpx
import trafilatura
from fastapi import HTTPException, status

from nilai_api.config import CONFIG
from nilai_common.api_models import (
Expand Down Expand Up @@ -90,11 +92,12 @@ def _get_http_client() -> httpx.AsyncClient:
)


async def _make_brave_api_request(query: str) -> Dict[str, Any]:
async def _make_brave_api_request(query: str, request: Request) -> Dict[str, Any]:
"""Make an API request to the Brave Search API.

Args:
query: The search query string to execute
request: FastAPI request object for rate limiting

Returns:
Dict containing the raw API response data
Expand All @@ -108,6 +111,8 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]:
detail="Missing BRAVE_SEARCH_API key in environment",
)

await RateLimit.check_brave_rps(request)

q = " ".join(query.split())

params = {**_BRAVE_API_PARAMS_BASE, "q": q}
Expand All @@ -125,8 +130,15 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]:
params.get("lang"),
params.get("count"),
)

resp = await client.get(CONFIG.web_search.api_path, headers=headers, params=params)

if resp.status_code == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Web search rate limit exceeded",
)

if resp.status_code >= 400:
logger.error("Brave API error: %s - %s", resp.status_code, resp.text)
raise HTTPException(
Expand Down Expand Up @@ -225,12 +237,22 @@ async def _fetch_and_extract_page_content(
return None


async def perform_web_search_async(query: str) -> WebSearchContext:
async def perform_web_search_async(query: str, request: Request) -> WebSearchContext:
"""Perform an asynchronous web search using the Brave Search API.

Fetches only the exact page for each Brave URL and extracts its
main content with trafilatura. If extraction fails, falls back to
the Brave snippet.

Args:
query: The search query string to execute
request: FastAPI request object for rate limiting

Returns:
WebSearchContext with formatted search results and source information

Raises:
HTTPException: If no results are found (404) or if the API request fails
"""
if not (query and query.strip()):
logger.warning("Empty or invalid query provided for web search")
Expand All @@ -240,7 +262,7 @@ async def perform_web_search_async(query: str) -> WebSearchContext:
logger.debug("Web search query: %s", query)

try:
data = await _make_brave_api_request(query)
data = await _make_brave_api_request(query, request)
initial_results = _parse_brave_results(data)
except HTTPException:
logger.exception("Brave API request failed")
Expand Down Expand Up @@ -358,36 +380,38 @@ async def _generate_topic_query(
return None


async def _perform_search(query: str) -> WebSearchContext:
async def _perform_search(query: str, request: Request) -> WebSearchContext:
"""Execute a web search with error handling.

Args:
query: Search query string
request: FastAPI request object for rate limiting

Returns:
WebSearchContext with results, or empty context if search fails
"""
try:
return await perform_web_search_async(query)
return await perform_web_search_async(query, request)
except Exception:
logger.exception("Search failed for query '%s'", query)
return WebSearchContext(prompt="", sources=[])


async def enhance_messages_with_web_search(
req: ChatRequest, query: str
req: ChatRequest, query: str, request: Request
) -> WebSearchEnhancedMessages:
"""Enhance chat messages with web search context for a single query.

Args:
req: ChatRequest containing conversation messages
query: Search query to retrieve web search results for
request: FastAPI request object for rate limiting

Returns:
WebSearchEnhancedMessages with web search context added to system messages
and source information
"""
ctx = await perform_web_search_async(query)
ctx = await perform_web_search_async(query, request)
query_source = Source(source=WEB_SEARCH_QUERY_SOURCE, content=query)

web_search_content = _build_single_search_content(query, ctx.prompt)
Expand Down Expand Up @@ -469,7 +493,7 @@ async def generate_search_query_from_llm(


async def _execute_web_search_workflow(
user_query: str, model_name: str, client: Any
user_query: str, model_name: str, client: Any, request: Request
) -> tuple[List[TopicQuery], List[WebSearchContext]] | tuple[None, None]:
"""Execute the complete multi-topic web search workflow.

Expand All @@ -480,6 +504,7 @@ async def _execute_web_search_workflow(
user_query: User's query to analyze and search for
model_name: Name of the LLM model to use for topic analysis and query generation
client: LLM client instance for API calls
request: FastAPI request object for rate limiting

Returns:
Tuple of (topic_queries, contexts) if successful, or (None, None) if no topics
Expand Down Expand Up @@ -508,7 +533,7 @@ async def _execute_web_search_workflow(
)
return None, None

search_tasks = [_perform_search(tq.query) for tq in topic_queries]
search_tasks = [_perform_search(tq.query, request) for tq in topic_queries]
contexts = await asyncio.gather(*search_tasks)

return topic_queries, contexts
Expand All @@ -519,7 +544,7 @@ async def _execute_web_search_workflow(


async def handle_web_search(
req_messages: ChatRequest, model_name: str, client: Any
req_messages: ChatRequest, model_name: str, client: Any, request: Request
) -> WebSearchEnhancedMessages:
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is missing a docstring. Please add comprehensive documentation that includes the new request parameter:

"""Handle web search enhancement for chat requests.

Analyzes the user's message to identify topics that require web search,
generates optimized search queries for each topic using an LLM, and
enhances the request with relevant web search results. Falls back to
single-query search if topic analysis fails or no topics need search.

Args:
    req_messages: ChatRequest containing conversation messages
    model_name: Name of the LLM model to use for query generation
    client: LLM client instance for making API calls
    request: FastAPI request object for rate limiting

Returns:
    WebSearchEnhancedMessages with web search context added, or original
    messages if no user query is found or search fails
"""
Suggested change
) -> WebSearchEnhancedMessages:
) -> WebSearchEnhancedMessages:
"""
Handle web search enhancement for chat requests.
Analyzes the user's message to identify topics that require web search,
generates optimized search queries for each topic using an LLM, and
enhances the request with relevant web search results. Falls back to
single-query search if topic analysis fails or no topics need search.
Args:
req_messages: ChatRequest containing conversation messages
model_name: Name of the LLM model to use for query generation
client: LLM client instance for making API calls
request: FastAPI request object for rate limiting
Returns:
WebSearchEnhancedMessages with web search context added, or original
messages if no user query is found or search fails
"""

Copilot uses AI. Check for mistakes.
logger.info("Handle web search start")
logger.debug(
Expand All @@ -534,14 +559,16 @@ async def handle_web_search(

try:
topic_queries, contexts = await _execute_web_search_workflow(
user_query, model_name, client
user_query, model_name, client, request
)

if topic_queries is None or contexts is None:
concise_query = await generate_search_query_from_llm(
user_query, model_name, client
)
return await enhance_messages_with_web_search(req_messages, concise_query)
return await enhance_messages_with_web_search(
req_messages, concise_query, request
)

return await enhance_messages_with_multi_web_search(
req_messages, topic_queries, contexts
Expand Down Expand Up @@ -628,7 +655,7 @@ async def enhance_messages_with_multi_web_search(


async def enhance_input_with_web_search(
req: ResponseRequest, query: str
req: ResponseRequest, query: str, request: Request
) -> WebSearchEnhancedInput:
"""Enhance response input with web search context for a single query.

Expand All @@ -640,7 +667,7 @@ async def enhance_input_with_web_search(
WebSearchEnhancedInput with web search context added to instructions
and source information
"""
ctx = await perform_web_search_async(query)
ctx = await perform_web_search_async(query, request)
query_source = Source(source=WEB_SEARCH_QUERY_SOURCE, content=query)

web_search_instructions = _build_single_search_content(query, ctx.prompt)
Expand Down Expand Up @@ -692,7 +719,7 @@ async def enhance_input_with_multi_web_search(


async def handle_web_search_for_responses(
req: ResponseRequest, model_name: str, client: Any
req: ResponseRequest, model_name: str, client: Any, request: Request
) -> WebSearchEnhancedInput:
"""Handle web search enhancement for response requests.

Expand All @@ -705,6 +732,7 @@ async def handle_web_search_for_responses(
req: ResponseRequest containing input to process
model_name: Name of the LLM model to use for query generation
client: LLM client instance for making API calls
request: FastAPI request object for rate limiting

Returns:
WebSearchEnhancedInput with web search context added, or original
Expand All @@ -724,14 +752,14 @@ async def handle_web_search_for_responses(

try:
topic_queries, contexts = await _execute_web_search_workflow(
user_query, model_name, client
user_query, model_name, client, request
)

if topic_queries is None or contexts is None:
concise_query = await generate_search_query_from_llm(
user_query, model_name, client
)
return await enhance_input_with_web_search(req, concise_query)
return await enhance_input_with_web_search(req, concise_query, request)

return await enhance_input_with_multi_web_search(req, topic_queries, contexts)

Expand Down
27 changes: 20 additions & 7 deletions nilai-api/src/nilai_api/rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,6 @@ async def __call__(
user_limits.rate_limits.web_search_rate_limit_day,
DAY_MS,
)
await self.check_bucket(
redis,
redis_rate_limit_command,
"web_search_rps",
CONFIG.web_search.rps,
1000,
)
await self.check_bucket(
redis,
redis_rate_limit_command,
Expand Down Expand Up @@ -241,6 +234,26 @@ async def check_concurrent_and_increment(
)
return key

@staticmethod
async def check_brave_rps(request: Request) -> None:
"""
Global RPS limit for Brave API calls, across all users.
"""
redis = request.state.redis
redis_rate_limit_command = request.state.redis_rate_limit_command

limit = CONFIG.web_search.rps
if not limit or limit <= 0:
return

await RateLimit.check_bucket(
redis,
redis_rate_limit_command,
"brave_rps_global",
limit,
1000,
)

@staticmethod
async def concurrent_decrement(redis: Redis, key: str | None):
if key is None:
Expand Down
3 changes: 2 additions & 1 deletion nilai-api/src/nilai_api/routers/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ async def chat_completion_web_search_rate_limit(request: Request) -> bool:

@chat_completion_router.post("/v1/chat/completions", tags=["Chat"], response_model=None)
async def chat_completion(
request: Request,
req: ChatRequest = Body(
ChatRequest(
model="meta-llama/Llama-3.2-1B-Instruct",
Expand Down Expand Up @@ -188,7 +189,7 @@ async def chat_completion(
if req.web_search:
logger.info(f"[chat] web_search start request_id={request_id}")
t_ws = time.monotonic()
web_search_result = await handle_web_search(req, model_name, client)
web_search_result = await handle_web_search(req, model_name, client, request)
messages = web_search_result.messages
sources = web_search_result.sources
logger.info(
Expand Down
3 changes: 2 additions & 1 deletion nilai-api/src/nilai_api/routers/endpoints/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def responses_web_search_rate_limit(request: Request) -> bool:
"/v1/responses", tags=["Responses"], response_model=SignedResponse
)
async def create_response(
request: Request,
req: ResponseRequest = Body(
{
"model": "openai/gpt-oss-20b",
Expand Down Expand Up @@ -171,7 +172,7 @@ async def create_response(
logger.info(f"[responses] web_search start request_id={request_id}")
t_ws = time.monotonic()
web_search_result = await handle_web_search_for_responses(
req, model_name, client
req, model_name, client, request
)
input_items = web_search_result.input
instructions = web_search_result.instructions
Expand Down
Loading
Loading