Skip to content

Commit 969f514

Browse files
authored
Fix prompt template bugs: build template ignored and runtime override not wired (#173)
* Fix prompt template bugs in build and search Bug 1: Build template ignored in new format - Updated compute_embeddings_openai() to read build_prompt_template or prompt_template - Updated compute_embeddings_ollama() with same fix - Maintains backward compatibility with old single-template format Bug 2: Runtime override not wired up - Wired CLI search to pass provider_options to searcher.search() - Enables runtime template override during search via --embedding-prompt-template All 42 prompt template tests passing. Fixes #155 * Fix: Prevent embedding server from applying templates during search - Filter out all prompt templates (build_prompt_template, query_prompt_template, prompt_template) from provider_options when launching embedding server during search - Templates are already applied in compute_query_embedding() before server call - Prevents double-templating and ensures runtime override works correctly This fixes the issue where --embedding-prompt-template during search was ignored because the server was applying build_prompt_template instead. * Format code with ruff
1 parent 1ef9cba commit 969f514

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

packages/leann-core/src/leann/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,11 @@ async def search_documents(self, args):
15451545
print("Invalid input. Aborting search.")
15461546
return
15471547

1548+
# Build provider_options for runtime override
1549+
provider_options = {}
1550+
if args.embedding_prompt_template:
1551+
provider_options["prompt_template"] = args.embedding_prompt_template
1552+
15481553
searcher = LeannSearcher(index_path=index_path)
15491554
results = searcher.search(
15501555
query,
@@ -1554,6 +1559,7 @@ async def search_documents(self, args):
15541559
prune_ratio=args.prune_ratio,
15551560
recompute_embeddings=args.recompute_embeddings,
15561561
pruning_strategy=args.pruning_strategy,
1562+
provider_options=provider_options if provider_options else None,
15571563
)
15581564

15591565
print(f"Search results for '{query}' (top {len(results)}):")

packages/leann-core/src/leann/embedding_compute.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,10 @@ def compute_embeddings_openai(
740740
print(f"len of texts: {len(texts)}")
741741

742742
# Apply prompt template if provided
743-
prompt_template = provider_options.get("prompt_template")
743+
# Priority: build_prompt_template (new format) > prompt_template (old format)
744+
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
745+
"prompt_template"
746+
)
744747

745748
if prompt_template:
746749
logger.warning(f"Applying prompt template: '{prompt_template}'")
@@ -1031,7 +1034,10 @@ def compute_embeddings_ollama(
10311034

10321035
# Apply prompt template if provided
10331036
provider_options = provider_options or {}
1034-
prompt_template = provider_options.get("prompt_template")
1037+
# Priority: build_prompt_template (new format) > prompt_template (old format)
1038+
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
1039+
"prompt_template"
1040+
)
10351041

10361042
if prompt_template:
10371043
logger.warning(f"Applying prompt template: '{prompt_template}'")

packages/leann-core/src/leann/searcher_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,23 @@ def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs)
7171
or "mips"
7272
)
7373

74+
# Filter out ALL prompt templates from provider_options during search
75+
# Templates are applied in compute_query_embedding (line 109-110) BEFORE server call
76+
# The server should never apply templates during search to avoid double-templating
77+
search_provider_options = {
78+
k: v
79+
for k, v in self.embedding_options.items()
80+
if k not in ("build_prompt_template", "query_prompt_template", "prompt_template")
81+
}
82+
7483
server_started, actual_port = self.embedding_server_manager.start_server(
7584
port=port,
7685
model_name=self.embedding_model,
7786
embedding_mode=self.embedding_mode,
7887
passages_file=passages_source_file,
7988
distance_metric=distance_metric,
8089
enable_warmup=kwargs.get("enable_warmup", False),
81-
provider_options=self.embedding_options,
90+
provider_options=search_provider_options,
8291
)
8392
if not server_started:
8493
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")

0 commit comments

Comments
 (0)