-
Notifications
You must be signed in to change notification settings - Fork 2.5k
feat: adding QueryExpander, MultiQueryEmbeddingRetriever and MultiQueryTextRetriever
#10126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
davidsbatista
wants to merge
15
commits into
main
Choose a base branch
from
feat/adding-QueryExpander-MultiQueryRetriever
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,594
−1
Open
Changes from 8 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
93fe08a
importing files from experimental
davidsbatista 555b218
linting + tests
davidsbatista 1cf6c1b
fixing integrations tests
davidsbatista a8e5a59
adding release notes
davidsbatista 4ba3bfe
fixing imports
davidsbatista c1b4043
adding query component
davidsbatista 8f570a6
adding docs to docusaurus
davidsbatista c91acbb
Merge branch 'main' into feat/adding-QueryExpander-MultiQueryRetriever
davidsbatista a85313d
Update docs/pydoc/config_docusaurus/query_api.yml
davidsbatista 637f3ed
Update haystack/components/query/query_expander.py
davidsbatista 05f5784
Update releasenotes/notes/adding-QueryExpander-MultiQueryRetriever-88…
davidsbatista 4ea74f9
Merge branch 'main' into feat/adding-QueryExpander-MultiQueryRetriever
davidsbatista 63685af
fixing code examples
davidsbatista 2a99da5
adding extra unit tests to assert deduplication is working
davidsbatista 7537bd2
fixing and increasing QueryExpander tets
davidsbatista File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| loaders: | ||
| - ignore_when_discovered: | ||
| - __init__ | ||
| modules: | ||
| - haystack.components.query.query_expander | ||
| search_path: | ||
| - ../../../ | ||
| type: haystack_pydoc_tools.loaders.CustomPythonLoader | ||
| processors: | ||
| - do_not_filter_modules: false | ||
| documented_only: true | ||
| expression: null | ||
| skip_empty_modules: true | ||
| type: filter | ||
| - type: smart | ||
| - type: crossref | ||
| renderer: | ||
| description: Components for query processing and expansion. | ||
| id: experimental-query-api | ||
| markdown: | ||
| add_member_class_prefix: false | ||
| add_method_class_prefix: true | ||
| classdef_code_block: false | ||
| descriptive_class_title: false | ||
| descriptive_module_title: true | ||
| filename: experimental_query_api.md | ||
| title: Query | ||
| type: haystack_pydoc_tools.renderers.DocusaurusRenderer | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import sys | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from lazy_imports import LazyImporter | ||
|
|
||
| _import_structure = {"query_expander": ["QueryExpander"]} | ||
|
|
||
| if TYPE_CHECKING: | ||
| from .query_expander import QueryExpander | ||
|
|
||
| else: | ||
| sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,290 @@ | ||
| # SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import json | ||
| from typing import Any, Optional | ||
|
|
||
| from haystack import default_from_dict, default_to_dict, logging | ||
| from haystack.components.builders.prompt_builder import PromptBuilder | ||
| from haystack.components.generators.chat.openai import OpenAIChatGenerator | ||
| from haystack.components.generators.chat.types import ChatGenerator | ||
| from haystack.core.component import component | ||
| from haystack.core.serialization import component_to_dict | ||
| from haystack.dataclasses.chat_message import ChatMessage | ||
| from haystack.utils.deserialization import deserialize_chatgenerator_inplace | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| DEFAULT_PROMPT_TEMPLATE = """ | ||
| You are part of an information system that processes user queries for retrieval. | ||
| You have to expand a given query into {{ n_expansions }} queries that are | ||
| semantically similar to improve retrieval recall. | ||
|
|
||
| Structure: | ||
| Follow the structure shown below in examples to generate expanded queries. | ||
|
|
||
| Examples: | ||
| 1. Query: "climate change effects" | ||
| {"queries": ["impact of climate change", "consequences of global warming", "effects of environmental changes"]} | ||
|
|
||
| 2. Query: "machine learning algorithms" | ||
| {"queries": ["neural networks", "clustering techniques", "supervised learning methods", "deep learning models"]} | ||
|
|
||
| 3. Query: "open source NLP frameworks" | ||
| {"queries": ["natural language processing tools", "free nlp libraries", "open-source NLP platforms"]} | ||
|
|
||
| Guidelines: | ||
| - Generate queries that use different words and phrasings | ||
| - Include synonyms and related terms | ||
| - Maintain the same core meaning and intent | ||
| - Make queries that are likely to retrieve relevant information the original might miss | ||
| - Focus on variations that would work well with keyword-based search | ||
| - Respond in the same language as the input query | ||
|
|
||
| Your Task: | ||
| Query: "{{ query }}" | ||
|
|
||
| You *must* respond with a JSON object containing a "queries" array with the expanded queries. | ||
| Example: {"queries": ["query1", "query2", "query3"]}""" | ||
|
|
||
|
|
||
| @component | ||
| class QueryExpander: | ||
| """ | ||
| A component that returns a list of semantically similar queries to improve retrieval recall in RAG systems. | ||
|
|
||
| The component uses a chat generator to expand queries. The chat generator is expected to return a JSON response | ||
| with the following structure: | ||
| ```json | ||
| {"queries": ["expanded query 1", "expanded query 2", "expanded query 3"]} | ||
| ``` | ||
|
|
||
| ### Usage example | ||
|
|
||
| ```python | ||
| from haystack.components.generators.chat.openai import OpenAIChatGenerator | ||
| from haystack_experimental.components.query import QueryExpander | ||
davidsbatista marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| expander = QueryExpander( | ||
| chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), | ||
| n_expansions=3 | ||
| ) | ||
|
|
||
| result = expander.run(query="green energy sources") | ||
| print(result["queries"]) | ||
| # Output: ['alternative query 1', 'alternative query 2', 'alternative query 3', 'green energy sources'] | ||
| # Note: Up to 3 additional queries + 1 original query (if include_original_query=True) | ||
|
|
||
| # To control total number of queries: | ||
| expander = QueryExpander(n_expansions=2, include_original_query=True) # Up to 3 total | ||
| # or | ||
| expander = QueryExpander(n_expansions=3, include_original_query=False) # Exactly 3 total | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| chat_generator: Optional[ChatGenerator] = None, | ||
| prompt_template: Optional[str] = None, | ||
| n_expansions: int = 4, | ||
| include_original_query: bool = True, | ||
| ) -> None: | ||
| """ | ||
| Initialize the QueryExpander component. | ||
|
|
||
| :param chat_generator: The chat generator component to use for query expansion. | ||
| If None, a default OpenAIChatGenerator with gpt-4.1-mini model is used. | ||
| :param prompt_template: Custom [PromptBuilder](https://docs.haystack.deepset.ai/docs/promptbuilder) | ||
| template for query expansion. The template should instruct the LLM to return a JSON response with the | ||
| structure: `{"queries": ["query1", "query2", "query3"]}`. The template should include 'query' and | ||
| 'n_expansions' variables. | ||
| :param n_expansions: Number of alternative queries to generate (default: 4). | ||
| :param include_original_query: Whether to include the original query in the output. | ||
| """ | ||
| if n_expansions <= 0: | ||
| raise ValueError("n_expansions must be positive") | ||
|
|
||
| self.n_expansions = n_expansions | ||
| self.include_original_query = include_original_query | ||
|
|
||
| if chat_generator is None: | ||
| self.chat_generator: ChatGenerator = OpenAIChatGenerator( | ||
| model="gpt-4.1-mini", | ||
| generation_kwargs={ | ||
| "temperature": 0.7, | ||
| "response_format": { | ||
| "type": "json_schema", | ||
| "json_schema": { | ||
| "name": "query_expansion", | ||
| "schema": { | ||
| "type": "object", | ||
| "properties": {"queries": {"type": "array", "items": {"type": "string"}}}, | ||
| "required": ["queries"], | ||
| "additionalProperties": False, | ||
| }, | ||
| }, | ||
| }, | ||
| "seed": 42, | ||
| }, | ||
| ) | ||
| else: | ||
| self.chat_generator = chat_generator | ||
|
|
||
| self._is_warmed_up = False | ||
| self.prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE | ||
|
|
||
| # Check if required variables are present in the template | ||
| if "query" not in self.prompt_template: | ||
| logger.warning( | ||
| "The prompt template does not contain the 'query' variable. This may cause issues during execution." | ||
| ) | ||
| if "n_expansions" not in self.prompt_template: | ||
| logger.warning( | ||
| "The prompt template does not contain the 'n_expansions' variable. " | ||
| "This may cause issues during execution." | ||
| ) | ||
|
|
||
| self._prompt_builder = PromptBuilder( | ||
| template=self.prompt_template, required_variables=["n_expansions", "query"] | ||
| ) | ||
|
|
||
| def to_dict(self) -> dict[str, Any]: | ||
| """ | ||
| Serializes the component to a dictionary. | ||
|
|
||
| :return: Dictionary with serialized data. | ||
| """ | ||
| return default_to_dict( | ||
| self, | ||
| chat_generator=component_to_dict(self.chat_generator, name="chat_generator"), | ||
| prompt_template=self.prompt_template, | ||
| n_expansions=self.n_expansions, | ||
| include_original_query=self.include_original_query, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, data: dict[str, Any]) -> "QueryExpander": | ||
| """ | ||
| Deserializes the component from a dictionary. | ||
|
|
||
| :param data: Dictionary with serialized data. | ||
| :return: Deserialized component. | ||
| """ | ||
| init_params = data.get("init_parameters", {}) | ||
|
|
||
| deserialize_chatgenerator_inplace(init_params, key="chat_generator") | ||
|
|
||
| return default_from_dict(cls, data) | ||
|
|
||
| @component.output_types(queries=list[str]) | ||
| def run(self, query: str, n_expansions: Optional[int] = None) -> dict[str, list[str]]: | ||
| """ | ||
| Expand the input query into multiple semantically similar queries. | ||
|
|
||
| The language of the original query is preserved in the expanded queries. | ||
|
|
||
| :param query: The original query to expand. | ||
| :param n_expansions: Number of additional queries to generate (not including the original). | ||
| If None, uses the value from initialization. Can be 0 to generate no additional queries. | ||
| :return: Dictionary with "queries" key containing the list of expanded queries. | ||
| If include_original_query=True, the original query will be included in addition | ||
| to the n_expansions alternative queries. | ||
| :raises ValueError: If n_expansions is not positive (less than or equal to 0). | ||
| """ | ||
|
|
||
| if not self._is_warmed_up: | ||
| self.warm_up() | ||
|
|
||
| response = {"queries": [query] if self.include_original_query else []} | ||
|
|
||
| if not query.strip(): | ||
| logger.warning("Empty query provided to QueryExpander") | ||
| return response | ||
|
|
||
| expansion_count = n_expansions if n_expansions is not None else self.n_expansions | ||
| if expansion_count <= 0: | ||
| raise ValueError("n_expansions must be positive") | ||
|
|
||
| try: | ||
| prompt_result = self._prompt_builder.run(query=query.strip(), n_expansions=expansion_count) | ||
| generator_result = self.chat_generator.run(messages=[ChatMessage.from_user(prompt_result["prompt"])]) | ||
|
|
||
| if not generator_result.get("replies") or len(generator_result["replies"]) == 0: | ||
| logger.warning("ChatGenerator returned no replies for query: {query}", query=query) | ||
| return response | ||
|
|
||
| expanded_text = generator_result["replies"][0].text.strip() | ||
| expanded_queries = self._parse_expanded_queries(expanded_text) | ||
|
|
||
| # Limit the number of expanded queries to the requested amount | ||
| if len(expanded_queries) > expansion_count: | ||
| logger.warning( | ||
| "Generated {generated_count} queries but only {requested_count} were requested. " | ||
| "Truncating to the first {requested_count} queries. ", | ||
| generated_count=len(expanded_queries), | ||
| requested_count=expansion_count, | ||
| ) | ||
| expanded_queries = expanded_queries[:expansion_count] | ||
|
|
||
| # Add original query if requested and remove duplicates | ||
| if self.include_original_query: | ||
| expanded_queries_lower = [q.lower() for q in expanded_queries] | ||
| if query.lower() not in expanded_queries_lower: | ||
| expanded_queries.append(query) | ||
|
|
||
| response["queries"] = expanded_queries | ||
| return response | ||
|
|
||
| except Exception as e: | ||
| # Fallback: return original query to maintain pipeline functionality | ||
| logger.error("Failed to expand query {query}: {error}", query=query, error=str(e)) | ||
| return response | ||
|
|
||
| def warm_up(self): | ||
| """ | ||
| Warm up the LLM provider component. | ||
| """ | ||
| if not self._is_warmed_up: | ||
| if hasattr(self.chat_generator, "warm_up"): | ||
| self.chat_generator.warm_up() | ||
| self._is_warmed_up = True | ||
|
|
||
| @staticmethod | ||
| def _parse_expanded_queries(generator_response: str) -> list[str]: | ||
| """ | ||
| Parse the generator response to extract individual expanded queries. | ||
|
|
||
| :param generator_response: The raw text response from the generator. | ||
| :return: List of parsed expanded queries. | ||
| """ | ||
| if not generator_response.strip(): | ||
| return [] | ||
|
|
||
| try: | ||
| parsed = json.loads(generator_response) | ||
| if not isinstance(parsed, dict) or "queries" not in parsed: | ||
| logger.warning( | ||
| "Generator response is not a JSON object containing a 'queries' array: {response}", | ||
| response=generator_response[:100], | ||
| ) | ||
| return [] | ||
|
|
||
| queries = [] | ||
| for item in parsed["queries"]: | ||
| if isinstance(item, str) and item.strip(): | ||
| queries.append(item.strip()) | ||
| else: | ||
| logger.warning("Skipping non-string or empty query in response: {item}", item=item) | ||
|
|
||
| return queries | ||
|
|
||
| except json.JSONDecodeError as e: | ||
| logger.warning( | ||
| "Failed to parse JSON response: {error}. Response: {response}", | ||
| error=str(e), | ||
| response=generator_response[:100], | ||
| ) | ||
| return [] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.