-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquery_utils.py
More file actions
128 lines (103 loc) · 3.54 KB
/
query_utils.py
File metadata and controls
128 lines (103 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import annotations
import re
from typing import List
FORMAT_TOKEN_MAP = {
"tif": "TIFF",
"tiff": "TIFF",
"nii": "NIfTI",
"nii.gz": "NIfTI",
"dcm": "DICOM",
"dicom": "DICOM",
"nrrd": "NRRD",
"png": "PNG",
"jpg": "JPEG",
"jpeg": "JPEG",
}
_REPO_DRIFT_TERMS = {
"github",
"repository",
"repo",
"official",
"readme",
"docs",
"documentation",
"source",
"sourcecode",
}
_LOW_SIGNAL_TERMS = _REPO_DRIFT_TERMS | {
"tool",
"tools",
"project",
"framework",
}
def _tokenize_query(query: str) -> List[str]:
return [t for t in re.findall(r"[a-z0-9_+-]+", (query or "").lower()) if t]
def normalize_formats(formats: List[str]) -> List[str]:
seen = set()
out: List[str] = []
for ext in formats:
norm = (ext or "").strip().lower()
if not norm or norm in seen:
continue
seen.add(norm)
out.append(norm)
return out
def append_format_tokens(query: str, formats: List[str]) -> str:
fmt_tokens: List[str] = []
for ext in normalize_formats(formats):
canon = FORMAT_TOKEN_MAP.get(ext, ext.upper())
if canon not in fmt_tokens:
fmt_tokens.append(canon)
if not fmt_tokens:
return query.strip()
return (query.strip() + " " + " ".join(f"format:{t}" for t in fmt_tokens)).strip()
def strip_legacy_original_formats_line(query: str) -> tuple[str, List[str]]:
"""Parse and remove legacy OriginalFormats: line from query text."""
original_formats: List[str] = []
clean_lines = []
for line in (query or "").splitlines():
if line.lower().startswith("originalformats:"):
parts = line.split(":", 1)[1].strip().split()
original_formats.extend(parts)
continue
clean_lines.append(line)
base_query = " ".join(ln.strip() for ln in clean_lines if ln.strip())
return base_query, normalize_formats(original_formats)
def sanitize_retrieval_query(
query: str,
known_tool_names: List[str] | None = None,
fallback_query: str | None = None,
) -> str:
"""
Sanitize LLM-generated retrieval queries by removing repository drift terms.
If the query collapses into a tool-name-only or low-signal query, fallback
to the previous task-centric query when provided.
"""
raw = (query or "").strip()
if not raw:
return (fallback_query or "").strip()
# Remove URLs and punctuation-heavy fragments.
s = re.sub(r"https?://\S+", " ", raw, flags=re.IGNORECASE)
s = re.sub(r"www\.\S+", " ", s, flags=re.IGNORECASE)
tokens = _tokenize_query(s)
if not tokens:
return (fallback_query or raw).strip()
filtered = [t for t in tokens if t not in _REPO_DRIFT_TERMS]
if not filtered:
return (fallback_query or raw).strip()
# Detect tool-name-only drift (e.g., "dhsegment official github repository").
if known_tool_names:
token_set = set(filtered)
for nm in known_tool_names:
nm_tokens = set(_tokenize_query(nm))
if not nm_tokens:
continue
if token_set.issubset(nm_tokens) and len(token_set) <= 3:
return (fallback_query or " ".join(filtered)).strip()
low_signal = all(t in _LOW_SIGNAL_TERMS for t in filtered)
if low_signal:
return (fallback_query or " ".join(filtered)).strip()
# If it became too short and a previous task query exists, prefer that.
if len(filtered) <= 2 and fallback_query:
return fallback_query.strip()
return " ".join(filtered).strip()