Skip to content

Commit ea5cdc4

Browse files
adelavegajdkent
andauthored
ENH: Add custom tsquery from websearch function and related tests (#838)
* Add custom tsquery from websearch function and related tests * update tests and logic * fix style issues --------- Co-authored-by: James Kent <[email protected]>
1 parent efdf3d9 commit ea5cdc4

File tree

5 files changed

+292
-2
lines changed

5 files changed

+292
-2
lines changed

store/neurostore/resources/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from flask import abort, request, current_app # jsonify
99
from flask.views import MethodView
1010

11+
from psycopg2 import errors
12+
1113
import sqlalchemy as sa
1214
import sqlalchemy.sql.expression as sae
1315
from sqlalchemy.orm import (
@@ -21,7 +23,7 @@
2123

2224
from ..core import cache
2325
from ..database import db
24-
from .utils import get_current_user
26+
from .utils import get_current_user, validate_search_query, pubmed_to_tsquery
2527
from ..models import (
2628
StudysetStudy,
2729
AnnotationAnalysis,
@@ -613,7 +615,11 @@ def search(self):
613615
if s is not None and s.isdigit():
614616
q = q.filter_by(pmid=s)
615617
elif s is not None and self._fulltext_fields:
616-
tsquery = sa.func.websearch_to_tsquery("english", s)
618+
try:
619+
validate_search_query(s)
620+
except errors.SyntaxError as e:
621+
abort(400, description=e.args[0])
622+
tsquery = pubmed_to_tsquery(s)
617623
q = q.filter(m._ts_vector.op("@@")(tsquery))
618624

619625
# Alternatively (or in addition), search on individual fields.

store/neurostore/resources/utils.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66

77
from connexion.context import context
8+
from psycopg2 import errors
89

910
from .. import models
1011
from .. import schemas
@@ -44,3 +45,183 @@ class ClassView(cls):
4445
ClassView.__name__ = cls.__name__
4546

4647
return ClassView
48+
49+
50+
def validate_search_query(query: str) -> bool:
51+
"""
52+
Validate a search query string.
53+
54+
Args:
55+
query (str): The query string to validate.
56+
57+
Returns:
58+
bool: True if the query is valid, False otherwise.
59+
"""
60+
# Check for valid parentheses
61+
if not validate_parentheses(query):
62+
raise errors.SyntaxError("Unmatched parentheses")
63+
64+
# Check for valid query end
65+
if not validate_query_end(query):
66+
raise errors.SyntaxError("Query cannot end with an operator")
67+
68+
return True
69+
70+
71+
def validate_parentheses(query: str) -> bool:
72+
"""
73+
Validate the parentheses in a query string.
74+
75+
Args:
76+
query (str): The query string to validate.
77+
78+
Returns:
79+
bool: True if parentheses are valid, False otherwise.
80+
"""
81+
stack = []
82+
for char in query:
83+
if char == "(":
84+
stack.append(char)
85+
elif char == ")":
86+
if not stack:
87+
return False # Unmatched closing parenthesis
88+
stack.pop()
89+
return not stack # Ensure all opening parentheses are closed
90+
91+
92+
def validate_query_end(query: str) -> bool:
93+
"""Query should not end with an operator"""
94+
operators = ("AND", "OR", "NOT")
95+
96+
if query.strip().split(" ")[-1] in operators:
97+
return False
98+
return True
99+
100+
101+
def count_chars(target, query: str) -> int:
102+
"""Count the number of chars in a query string.
103+
Excluding those in quoted phrases."""
104+
count = 0
105+
in_quotes = False
106+
for char in query:
107+
if char == '"':
108+
in_quotes = not in_quotes
109+
if char == target and not in_quotes:
110+
count += 1
111+
return count
112+
113+
114+
def pubmed_to_tsquery(query: str) -> str:
115+
"""
116+
Convert a PubMed-like search query to PostgreSQL tsquery format,
117+
grouping both single-quoted and double-quoted text with the <-> operator
118+
for proximity search.
119+
120+
Additionally, automatically adds & between non-explicitly connected terms
121+
and handles NOT terms.
122+
123+
Args:
124+
query (str): The search query.
125+
126+
Returns:
127+
str: The PostgreSQL tsquery equivalent.
128+
"""
129+
130+
query = query.upper() # Ensure uniformity
131+
132+
# Step 1: Split into tokens (preserving quoted phrases)
133+
# Regex pattern: match quoted phrases or non-space sequences
134+
tokens = re.findall(r'"[^"]*"|\'[^\']*\'|\S+', query)
135+
136+
# Step 2: Combine tokens in parantheses into single tokens
137+
def combine_parentheses(tokens: list) -> list:
138+
"""
139+
Combine tokens within parentheses into a single token.
140+
141+
Args:
142+
tokens (list): List of tokens to process.
143+
144+
Returns:
145+
list: Processed list with tokens inside parentheses combined.
146+
"""
147+
combined_tokens = []
148+
buffer = []
149+
paren_count = 0
150+
for token in tokens:
151+
# If buffer is not empty, we are inside parentheses
152+
if len(buffer) > 0:
153+
buffer.append(token)
154+
155+
# Adjust the count of parentheses
156+
paren_count += count_chars("(", token) - count_chars(")", token)
157+
158+
if paren_count < 1:
159+
# Combine all tokens in parentheses
160+
combined_tokens.append(" ".join(buffer))
161+
buffer = [] # Clear the buffer
162+
paren_count = 0
163+
164+
else:
165+
n_paren = count_chars("(", token) - count_chars(")", token)
166+
# If not in parentheses, but token contains opening parentheses
167+
# Start capturing tokens inside parentheses
168+
if token[0] == "(" and n_paren > 0:
169+
paren_count += n_paren
170+
buffer.append(token) # Start capturing tokens in parens
171+
print(buffer)
172+
else:
173+
combined_tokens.append(token)
174+
175+
# If the list ends without a closing parenthesis (invalid input)
176+
# append buffer contents (fallback)
177+
if buffer:
178+
combined_tokens.append(" ".join(buffer))
179+
180+
return combined_tokens
181+
182+
tokens = combine_parentheses(tokens)
183+
print(tokens)
184+
for i, token in enumerate(tokens):
185+
if token[0] == "(" and token[-1] == ")":
186+
# RECURSIVE: Process the contents of the parentheses
187+
token_res = pubmed_to_tsquery(token[1:-1])
188+
token = "(" + token_res + ")"
189+
tokens[i] = token
190+
191+
# Step 4: Handle both single-quoted and double-quoted phrases,
192+
# grouping them with <-> (proximity operator)
193+
elif token[0] in ('"', "'"):
194+
# Split quoted text into individual words and join with <-> for
195+
# proximity search
196+
words = re.findall(r"\w+", token)
197+
tokens[i] = "<->".join(words)
198+
199+
# Step 3: Replace logical operators AND, OR, NOT
200+
else:
201+
if token == "AND":
202+
tokens[i] = "&"
203+
elif token == "OR":
204+
tokens[i] = "|"
205+
elif token == "NOT":
206+
tokens[i] = "&!"
207+
208+
processed_tokens = []
209+
last_token = None
210+
for token in tokens:
211+
# Step 5: Add & between consecutive terms that aren't already
212+
# connected by an operator
213+
stripped_token = token.strip()
214+
if stripped_token not in ("&", "|", "!", "&!"):
215+
stripped_token = re.sub(r"[\[\],;:!?@#]", "", stripped_token)
216+
if stripped_token == "":
217+
continue # Ignore empty tokens from splitting
218+
219+
if last_token and last_token not in ("&", "|", "!", "&!"):
220+
if stripped_token not in ("&", "|", "!", "&!"):
221+
# Insert an implicit AND (&) between two non-operator tokens
222+
processed_tokens.append("&")
223+
224+
processed_tokens.append(stripped_token)
225+
last_token = stripped_token
226+
227+
return " ".join(processed_tokens)

store/neurostore/tests/api/test_query_params.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from ...models import Study
33
from ...schemas.data import StudysetSchema, StudySchema, AnalysisSchema, StringOrNested
4+
from ..conftest import valid_queries, invalid_queries
45

56

67
@pytest.mark.parametrize("nested", ["true", "false"])
@@ -99,3 +100,17 @@ def test_multiword_queries(auth_client, ingest_neurosynth, session):
99100

100101
multi_word_search = auth_client.get(f"/api/studies/?search={multiple_words}")
101102
assert multi_word_search.status_code == 200
103+
104+
105+
@pytest.mark.parametrize("query, expected", valid_queries)
106+
def test_valid_pubmed_queries(query, expected, auth_client, ingest_neurosynth, session):
107+
search = auth_client.get(f"/api/studies/?search={query}")
108+
assert search.status_code == 200
109+
110+
111+
@pytest.mark.parametrize("query, expected", invalid_queries)
112+
def test_invalid_pubmed_queries(
113+
query, expected, auth_client, ingest_neurosynth, session
114+
):
115+
search = auth_client.get(f"/api/studies/?search={query}")
116+
assert search.status_code == 400

store/neurostore/tests/conftest.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,68 @@ def simple_neurosynth_annotation(session, ingest_neurosynth):
586586
session.commit()
587587

588588
return smol_annot
589+
590+
591+
"""
592+
Queries for testing
593+
"""
594+
invalid_queries = [
595+
(
596+
'("autism" OR "ASD" OR "autistic") AND (("decision*" OR "choice*" ',
597+
"Unmatched parentheses",
598+
),
599+
('"autism" OR "ASD" OR "autistic" OR ', "Query cannot end with an operator"),
600+
(
601+
'(("Autism Spectrum Disorder" OR "autism spectrum disorder") OR ("Autism" OR "autism") '
602+
'OR ("ASD")) AND (("decision*" OR "Dec',
603+
"Unmatched parentheses",
604+
),
605+
]
606+
607+
valid_queries = [
608+
(
609+
'"Mild Cognitive Impairment" or "Early Cognitive Decline" or "Pre-Dementia" or '
610+
'"Mild Neurocognitive Disorder"',
611+
"MILD<->COGNITIVE<->IMPAIRMENT | EARLY<->COGNITIVE<->DECLINE | PRE<->DEMENTIA | "
612+
"MILD<->NEUROCOGNITIVE<->DISORDER",
613+
),
614+
(
615+
'("autism" OR "ASD" OR "autistic") AND ("decision" OR "choice")',
616+
"(AUTISM | ASD | AUTISTIC) & (DECISION | CHOICE)",
617+
),
618+
(
619+
"stroop and depression or back and depression or go",
620+
"STROOP & DEPRESSION | BACK & DEPRESSION | GO",
621+
),
622+
(
623+
'("autism" OR "ASD" OR "autistic") AND (("decision" OR "decision-making" OR "choice" OR '
624+
'"selection" OR "option" OR "value") OR ("feedback" OR "feedback-related" OR "reward" OR '
625+
'"error" OR "outcome" OR "punishment" OR "reinforcement"))',
626+
"(AUTISM | ASD | AUTISTIC) & ((DECISION | DECISION<->MAKING | CHOICE | SELECTION | OPTION "
627+
"| VALUE) | (FEEDBACK | FEEDBACK<->RELATED | REWARD | ERROR | OUTCOME | PUNISHMENT | "
628+
"REINFORCEMENT))",
629+
),
630+
(
631+
'"dyslexia" or "Reading Disorder" or "Language-Based Learning Disability" or '
632+
'"Phonological Processing Disorder" or "Word Blindness"',
633+
"DYSLEXIA | READING<->DISORDER | LANGUAGE<->BASED<->LEARNING<->DISABILITY | "
634+
"PHONOLOGICAL<->PROCESSING<->DISORDER | WORD<->BLINDNESS",
635+
),
636+
("emotion and pain -physical -touch", "EMOTION & PAIN & -PHYSICAL & -TOUCH"),
637+
(
638+
'("Schizophrenia"[Mesh] OR schizophrenia )',
639+
"(SCHIZOPHRENIA & MESH | SCHIZOPHRENIA)",
640+
),
641+
("Bipolar Disorder", "BIPOLAR & DISORDER"),
642+
('"quchi" or "LI11"', "QUCHI | LI11"),
643+
('"rubber hand illusion"', "RUBBER<->HAND<->ILLUSION"),
644+
]
645+
646+
weird_queries = [
647+
(
648+
"[Major Depressive Disorder (MDD)] or [Clinical Depression] or [Unipolar Depression]",
649+
"MAJOR & DEPRESSIVE & DISORDER & (MDD) | CLINICAL & DEPRESSION | UNIPOLAR & DEPRESSION",
650+
),
651+
]
652+
653+
validate_queries = invalid_queries + [(q, True) for q, _ in valid_queries]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
from ..resources.utils import pubmed_to_tsquery, validate_search_query
4+
from .conftest import valid_queries, validate_queries, weird_queries
5+
6+
7+
@pytest.mark.parametrize("query, expected", valid_queries)
8+
def test_pubmed_to_tsquery(query, expected):
9+
assert pubmed_to_tsquery(query) == expected
10+
11+
12+
@pytest.mark.parametrize("query, expected", validate_queries)
13+
def test_validate_search_query(query, expected):
14+
if expected is True:
15+
assert validate_search_query(query) == expected
16+
else:
17+
with pytest.raises(Exception):
18+
validate_search_query(query)
19+
20+
21+
@pytest.mark.parametrize("query, expected", weird_queries)
22+
def test_pubmed_to_tsquery_weird(query, expected):
23+
assert pubmed_to_tsquery(query) == expected

0 commit comments

Comments
 (0)