Skip to content

Commit 8ea7a40

Browse files
authored
Optimize websearch pipeline (#259)
* feat: replace snippet truncation with token-aware accumulation in websearch * refactor: make torch optional for API-only websearch * chore: apply review changes * fix: improve generate summary and evaluate subquery relevance prompts * fix: correct all extra in pyproject * fix: use LLM tokenizer to retrieve exact token count and use it when truncating * feat: improve and standardize prompts * feat: prevent duplicated results * chore: refactor and apply review changes * chore: apply review changes * chore: apply last review changes * feat: add first version of websearch tests * fix: improve tests and remove useless ones * fix: add websearch extra to CI * fix: apply review changes * chore: clean tests * chore: add back heuristic 4 chars/token to count tokens * fix: add warning in case no local tokenizer found * fix: use margin in no-tokenizer fallback when truncatings * chore: apply copilot feedback * feat: add fast_tokenizer parameter in config file * chore: update websearch doc * chore: apply changes following copilot review * fix: apply review changes * tests: apply fixes to make tests independant from prompts in src code * fix: apply review changes
1 parent be5c552 commit 8ea7a40

9 files changed

Lines changed: 843 additions & 98 deletions

File tree

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- name: Install dependencies (using uv)
3232
run: |
3333
source .venv/bin/activate
34-
uv pip install -e ".[process,index,rag,api,cpu,dev]"
34+
uv pip install -e ".[process,index,rag,api,cpu,dev,websearch]"
3535
3636
- name: Show installed cohere and langchain-cohere versions
3737
run: |

docs/websearch.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ Users can adjust the pipeline according to their [requirements](/examples/websea
4848
- `use_summary`: Activates summarization of retrieved web snippets.
4949
- `n_loops`: Defines the number of search iterations to refine results.
5050
- `n_subqueries`: Specifies the number of subqueries generated for each input query.
51+
- `max_context_tokens`: Maximum token budget for prompts (default: 2048).
52+
- `fast_tokenizer`: If true, estimates tokens as ~4 chars/token instead of using the LLM tokenizer, faster but approximate (default: false).
5153

5254
### Workflow
5355

@@ -56,7 +58,7 @@ Users can adjust the pipeline according to their [requirements](/examples/websea
5658
1. **Input Query Processing:**
5759
- The pipeline processes the user query and generates subqueries for web searches in order to complete the current knowledge.
5860
2. **WebSearch Execution:**
59-
- DuckDuckGo searches are performed for each subquery
61+
- Web searches are performed for each subquery using the configured provider
6062
3. **Summarization:**
6163
- Retrieved web snippets are summarized using an LLM if `use_summary` is enabled.
6264
4. **Integration with RAG (if use_rag):**

examples/websearchRAG/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ websearch:
1111
search_provider: duckduckgo
1212
max_retries: 3
1313
max_context_tokens: 2048
14+
fast_tokenizer: false
1415
mode: local
1516
llm_config:
1617
llm_name: OpenMeditron/meditron3-8b # Qwen/Qwen2-0.5B

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ api = [
126126
# --- Composite + variant extras ---
127127

128128
all = [
129-
"mmore[process,rag,api]",
129+
"mmore[process,rag,api,websearch]",
130130
]
131131

132132
cpu = [

src/mmore/rag/llm.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
# from getpass import getpass
55
from typing import ClassVar, Optional, cast
66

7-
import torch
7+
try:
8+
import torch
9+
except ImportError:
10+
torch = None
11+
812
from langchain_anthropic import ChatAnthropic
913
from langchain_cohere import ChatCohere
1014
from langchain_core.language_models.chat_models import BaseChatModel
@@ -146,9 +150,16 @@ class LLM(BaseChatModel):
146150
"""Class parsing the model name and arguments to load the correct LangChain model"""
147151

148152
device_count: ClassVar[int] = 0
149-
nb_devices: ClassVar[int] = (
150-
torch.cuda.device_count() if torch.cuda.is_available() else 1
151-
)
153+
nb_devices: ClassVar[Optional[int]] = None
154+
155+
@classmethod
156+
def _get_nb_devices(cls) -> int:
157+
if cls.nb_devices is None:
158+
if torch is not None and torch.cuda.is_available():
159+
cls.nb_devices = torch.cuda.device_count()
160+
else:
161+
cls.nb_devices = 1
162+
return cls.nb_devices
152163

153164
@staticmethod
154165
def _check_key(provider):
@@ -165,6 +176,11 @@ def from_config(cls, config: str | LLMConfig) -> BaseChatModel:
165176
config = load_config(config, LLMConfig)
166177

167178
if config.provider == "HF":
179+
if torch is None:
180+
raise ImportError(
181+
"torch is required for HuggingFace models. "
182+
"Install it with: uv pip install 'mmore[cpu]' or uv pip install 'mmore[cu126]'"
183+
)
168184
if torch.backends.mps.is_available():
169185
return ChatHuggingFace(
170186
llm=HuggingFacePipeline.from_model_id(
@@ -176,7 +192,7 @@ def from_config(cls, config: str | LLMConfig) -> BaseChatModel:
176192
)
177193
if torch.cuda.is_available():
178194
current_device = cls.device_count
179-
cls.device_count = (cls.device_count + 1) % cls.nb_devices
195+
cls.device_count = (cls.device_count + 1) % cls._get_nb_devices()
180196
else:
181197
current_device = -1
182198

src/mmore/run_websearch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from dataclasses import dataclass
66
from typing import Optional, Union
77

8-
import torch
8+
try:
9+
import torch
10+
except ImportError:
11+
torch = None
12+
913
import uvicorn
1014
from dotenv import load_dotenv
1115
from fastapi import FastAPI
@@ -24,9 +28,10 @@
2428

2529

2630
# CUDA tweaks for best perf
27-
torch.backends.cuda.enable_mem_efficient_sdp(False)
28-
torch.backends.cuda.enable_flash_sdp(False)
29-
torch.backends.cuda.enable_math_sdp(True)
31+
if torch is not None and torch.cuda.is_available():
32+
torch.backends.cuda.enable_mem_efficient_sdp(False)
33+
torch.backends.cuda.enable_flash_sdp(False)
34+
torch.backends.cuda.enable_math_sdp(True)
3035

3136

3237
@dataclass

src/mmore/websearchRAG/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class WebsearchConfig:
2626
max_retries: (int) Max retries for search on rate limit (default: 3).
2727
search_provider: (str) Search provider: 'duckduckgo' (default, free) or 'tavily' (requires TAVILY_API_KEY, pip install "mmore[rag,websearch]").
2828
max_context_tokens: (int) Maximum number of context tokens for constructing prompts (default: 2048).
29+
fast_tokenizer: (bool) If True, use a fast heuristic (~4 chars/token) instead of the LLM tokenizer (default: False).
2930
llm_config: (dict) Passed to rag.llm.LLMConfig (keys: llm_name, max_new_tokens, temperature, etc.)
3031
mode: (str) Mode of operation ("local" or "api").
3132
"""
@@ -43,6 +44,7 @@ class WebsearchConfig:
4344
max_retries: int = 3
4445
search_provider: Literal["duckduckgo", "tavily"] = "duckduckgo"
4546
max_context_tokens: int = 2048
47+
fast_tokenizer: bool = False
4648

4749
llm_config: LLMConfig = field(
4850
default_factory=lambda: LLMConfig(

0 commit comments

Comments
 (0)