Skip to content

Commit 243c150

Browse files
committed
Add custom tsquery from websearch function and related tests
1 parent e72d9c9 commit 243c150

File tree

3 files changed

+238
-2
lines changed

3 files changed

+238
-2
lines changed

store/neurostore/resources/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ..core import cache
2323
from ..database import db
24-
from .utils import get_current_user
24+
from .utils import get_current_user, validate_search_query, search_to_tsquery
2525
from ..models import (
2626
StudysetStudy,
2727
AnnotationAnalysis,
@@ -613,7 +613,10 @@ def search(self):
613613
if s is not None and s.isdigit():
614614
q = q.filter_by(pmid=s)
615615
elif s is not None and self._fulltext_fields:
616-
tsquery = sa.func.websearch_to_tsquery("english", s)
616+
valid = validate_search_query(s)
617+
if not valid:
618+
abort(400, description=valid)
619+
tsquery = search_to_tsquery(s)
617620
q = q.filter(m._ts_vector.op("@@")(tsquery))
618621

619622
# Alternatively (or in addition), search on individual fields.

store/neurostore/resources/utils.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,182 @@ class ClassView(cls):
4444
ClassView.__name__ = cls.__name__
4545

4646
return ClassView
47+
48+
49+
def validate_search_query(query: str) -> bool:
50+
"""
51+
Validate a search query string.
52+
53+
Args:
54+
query (str): The query string to validate.
55+
56+
Returns:
57+
bool: True if the query is valid, False otherwise.
58+
"""
59+
# Check for valid parentheses
60+
if not validate_parentheses(query):
61+
return 'Unmatched parentheses'
62+
63+
# Check for valid query end
64+
if not validate_query_end(query):
65+
return 'Query cannot end with an operator'
66+
67+
return True
68+
69+
70+
def validate_parentheses(query: str) -> bool:
71+
"""
72+
Validate the parentheses in a query string.
73+
74+
Args:
75+
query (str): The query string to validate.
76+
77+
Returns:
78+
bool: True if parentheses are valid, False otherwise.
79+
"""
80+
stack = []
81+
for char in query:
82+
if char == '(':
83+
stack.append(char)
84+
elif char == ')':
85+
if not stack:
86+
return False # Unmatched closing parenthesis
87+
stack.pop()
88+
return not stack # Ensure all opening parentheses are closed
89+
90+
91+
def validate_query_end(query: str) -> bool:
92+
""" Query should not end with an operator """
93+
operators = ('AND', 'OR', 'NOT')
94+
95+
if query.strip().split(' ')[-1] in operators:
96+
return False
97+
return True
98+
99+
100+
def count_chars(target, query: str) -> int:
101+
""" Count the number of chars in a query string.
102+
Excluding those in quoted phrases."""
103+
count = 0
104+
in_quotes = False
105+
for char in query:
106+
if char == '"':
107+
in_quotes = not in_quotes
108+
if char == target and not in_quotes:
109+
count += 1
110+
return count
111+
112+
113+
def pubmed_to_tsquery(query: str) -> str:
114+
"""
115+
Convert a PubMed-like search query to PostgreSQL tsquery format,
116+
grouping both single-quoted and double-quoted text with the <-> operator
117+
for proximity search.
118+
119+
Additionally, automatically adds & between non-explicitly connected terms
120+
and handles NOT terms.
121+
122+
Args:
123+
query (str): The search query.
124+
125+
Returns:
126+
str: The PostgreSQL tsquery equivalent.
127+
"""
128+
129+
query = query.upper() # Ensure uniformity
130+
131+
# Step 1: Split into tokens (preserving quoted phrases)
132+
# Regex pattern: match quoted phrases or non-space sequences
133+
tokens = re.findall( r'"[^"]*"|\'[^\']*\'|\S+', query)
134+
135+
# Step 2: Combine tokens in parantheses into single tokens
136+
def combine_parentheses(tokens: list) -> list:
137+
"""
138+
Combine tokens within parentheses into a single token.
139+
140+
Args:
141+
tokens (list): List of tokens to process.
142+
143+
Returns:
144+
list: Processed list with tokens inside parentheses combined.
145+
"""
146+
combined_tokens = []
147+
buffer = []
148+
paren_count = 0
149+
for token in tokens:
150+
# If buffer is not empty, we are inside parentheses
151+
if len(buffer) > 0:
152+
buffer.append(token)
153+
154+
# Adjust the count of parentheses
155+
paren_count += count_chars('(', token) - count_chars(')', token)
156+
157+
if paren_count < 1:
158+
# Combine all tokens in parentheses
159+
combined_tokens.append(' '.join(buffer))
160+
buffer = [] # Clear the buffer
161+
paren_count = 0
162+
163+
else:
164+
n_paren = count_chars('(', token) - count_chars(')', token)
165+
# If not in parentheses, but token contains opening parentheses
166+
# Start capturing tokens inside parentheses
167+
if token[0] == '(' and n_paren > 0:
168+
paren_count += n_paren
169+
buffer.append(token) # Start capturing tokens in parens
170+
print(buffer)
171+
else:
172+
combined_tokens.append(token)
173+
174+
# If the list ends without a closing parenthesis (invalid input)
175+
# append buffer contents (fallback)
176+
if buffer:
177+
combined_tokens.append(' '.join(buffer))
178+
179+
return combined_tokens
180+
181+
tokens = combine_parentheses(tokens)
182+
print(tokens)
183+
for i, token in enumerate(tokens):
184+
if token[0] == "(" and token[-1] == ")":
185+
# RECURSIVE: Process the contents of the parentheses
186+
token_res = pubmed_to_tsquery(token[1:-1])
187+
token = '(' + token_res + ')'
188+
tokens[i] = token
189+
190+
# Step 4: Handle both single-quoted and double-quoted phrases,
191+
# grouping them with <-> (proximity operator)
192+
elif token[0] in ('"', "'"):
193+
# Split quoted text into individual words and join with <-> for
194+
# proximity search
195+
words = re.findall(r'\w+', token)
196+
tokens[i] = '<->'.join(words)
197+
198+
# Step 3: Replace logical operators AND, OR, NOT
199+
else:
200+
if token == 'AND':
201+
tokens[i] = '&'
202+
elif token == 'OR':
203+
tokens[i] = '|'
204+
elif token == 'NOT':
205+
tokens[i] = '&!'
206+
207+
processed_tokens = []
208+
last_token = None
209+
for token in tokens:
210+
# Step 5: Add & between consecutive terms that aren't already
211+
# connected by an operator
212+
stripped_token = token.strip()
213+
214+
if stripped_token == '':
215+
continue # Ignore empty tokens from splitting
216+
217+
if last_token and last_token not in ('&', '|', '!', '&!'):
218+
if stripped_token not in ('&', '|', '!', '&!'):
219+
# Insert an implicit AND (&) between two non-operator tokens
220+
processed_tokens.append('&')
221+
222+
processed_tokens.append(stripped_token)
223+
last_token = stripped_token
224+
225+
return ' '.join(processed_tokens)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import pytest
2+
3+
from ..utils import search_to_tsquery, validate_search_query
4+
5+
6+
invalid_queries = [
7+
('("autism" OR "ASD" OR "autistic") AND (("decision*" OR "choice*" ', 'Unmatched parentheses'),
8+
('"autism" OR "ASD" OR "autistic" OR ', 'Query cannot end with an operator'),
9+
('(("Autism Spectrum Disorder" OR "autism spectrum disorder") OR ("Autism" OR "autism") OR ("ASD")) AND (("decision*" OR "Dec', 'Unmatched parentheses')
10+
]
11+
12+
valid_queries = [
13+
('"Mild Cognitive Impairment" or "Early Cognitive Decline" or "Pre-Dementia" or "Mild Neurocognitive Disorder"',
14+
'MILD<->COGNITIVE<->IMPAIRMENT | EARLY<->COGNITIVE<->DECLINE | PRE<->DEMENTIA | MILD<->NEUROCOGNITIVE<->DISORDER'),
15+
('("autism" OR "ASD" OR "autistic") AND ("decision" OR "choice")',
16+
'(AUTISM | ASD | AUTISTIC) & (DECISION | CHOICE)'),
17+
('stroop and depression or back and depression or go',
18+
'STROOP & DEPRESSION | BACK & DEPRESSION | GO'),
19+
('("autism" OR "ASD" OR "autistic") AND (("decision" OR "decision-making" OR "choice" OR "selection" OR "option" OR "value") OR ("feedback" OR "feedback-related" OR "reward" OR "error" OR "outcome" OR "punishment" OR "reinforcement"))',
20+
'(AUTISM | ASD | AUTISTIC) & ((DECISION | DECISION<->MAKING | CHOICE | SELECTION | OPTION | VALUE) | (FEEDBACK | FEEDBACK<->RELATED | REWARD | ERROR | OUTCOME | PUNISHMENT | REINFORCEMENT))'),
21+
('"dyslexia" or "Reading Disorder" or "Language-Based Learning Disability" or "Phonological Processing Disorder" or "Word Blindness"',
22+
'DYSLEXIA | READING<->DISORDER | LANGUAGE<->BASED<->LEARNING<->DISABILITY | PHONOLOGICAL<->PROCESSING<->DISORDER | WORD<->BLINDNESS'),
23+
('emotion and pain -physical -touch',
24+
'EMOTION & PAIN & -PHYSICAL & -TOUCH'),
25+
('("Schizophrenia"[Mesh] OR schizophrenia )',
26+
'(SCHIZOPHRENIA & [MESH] | SCHIZOPHRENIA)')
27+
('Bipolar Disorder',
28+
'BIPOLAR & DISORDER'),
29+
('"quchi" or "LI11"',
30+
'QUCHI | LI11'),
31+
('"rubber hand illusion"',
32+
'RUBBER<->HAND<->ILLUSION'),
33+
]
34+
35+
error_queries = [
36+
"[Major Depressive Disorder (MDD)] or [Clinical Depression] or [Unipolar Depression]"
37+
]
38+
39+
validate_queries = invalid_queries + [(q, True) for q, _ in valid_queries]
40+
41+
42+
@pytest.mark.parametrize("query, expected", valid_queries)
43+
def test_search_to_tsquery(query, expected):
44+
assert search_to_tsquery(query) == expected
45+
46+
47+
@pytest.mark.parametrize("query, expected", invalid_queries)
48+
def test_validate_search_query(query, expected):
49+
assert validate_search_query(query) == expected
50+
51+
@pytest.mark.parametrize("query", error_queries)
52+
def test_search_to_tsquery_error(query):
53+
with pytest.raises(ValueError):
54+
search_to_tsquery(query)

0 commit comments

Comments
 (0)