Skip to content

Commit b0718af

Browse files
committed
update tests and logic
1 parent 243c150 commit b0718af

File tree

5 files changed

+128
-83
lines changed

5 files changed

+128
-83
lines changed

store/neurostore/resources/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from flask import abort, request, current_app # jsonify
99
from flask.views import MethodView
1010

11+
from psycopg2 import errors
12+
1113
import sqlalchemy as sa
1214
import sqlalchemy.sql.expression as sae
1315
from sqlalchemy.orm import (
@@ -21,7 +23,7 @@
2123

2224
from ..core import cache
2325
from ..database import db
24-
from .utils import get_current_user, validate_search_query, search_to_tsquery
26+
from .utils import get_current_user, validate_search_query, pubmed_to_tsquery
2527
from ..models import (
2628
StudysetStudy,
2729
AnnotationAnalysis,
@@ -613,10 +615,11 @@ def search(self):
613615
if s is not None and s.isdigit():
614616
q = q.filter_by(pmid=s)
615617
elif s is not None and self._fulltext_fields:
616-
valid = validate_search_query(s)
617-
if not valid:
618-
abort(400, description=valid)
619-
tsquery = search_to_tsquery(s)
618+
try:
619+
validate_search_query(s)
620+
except errors.SyntaxError as e:
621+
abort(400, description=e.args[0])
622+
tsquery = pubmed_to_tsquery(s)
620623
q = q.filter(m._ts_vector.op("@@")(tsquery))
621624

622625
# Alternatively (or in addition), search on individual fields.

store/neurostore/resources/utils.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66

77
from connexion.context import context
8+
from psycopg2 import errors
89

910
from .. import models
1011
from .. import schemas
@@ -58,11 +59,11 @@ def validate_search_query(query: str) -> bool:
5859
"""
5960
# Check for valid parentheses
6061
if not validate_parentheses(query):
61-
return 'Unmatched parentheses'
62+
raise errors.SyntaxError("Unmatched parentheses")
6263

6364
# Check for valid query end
6465
if not validate_query_end(query):
65-
return 'Query cannot end with an operator'
66+
raise errors.SyntaxError("Query cannot end with an operator")
6667

6768
return True
6869

@@ -79,27 +80,27 @@ def validate_parentheses(query: str) -> bool:
7980
"""
8081
stack = []
8182
for char in query:
82-
if char == '(':
83+
if char == "(":
8384
stack.append(char)
84-
elif char == ')':
85+
elif char == ")":
8586
if not stack:
8687
return False # Unmatched closing parenthesis
8788
stack.pop()
8889
return not stack # Ensure all opening parentheses are closed
8990

9091

9192
def validate_query_end(query: str) -> bool:
92-
""" Query should not end with an operator """
93-
operators = ('AND', 'OR', 'NOT')
93+
"""Query should not end with an operator"""
94+
operators = ("AND", "OR", "NOT")
9495

95-
if query.strip().split(' ')[-1] in operators:
96+
if query.strip().split(" ")[-1] in operators:
9697
return False
9798
return True
9899

99100

100101
def count_chars(target, query: str) -> int:
101-
""" Count the number of chars in a query string.
102-
Excluding those in quoted phrases."""
102+
"""Count the number of chars in a query string.
103+
Excluding those in quoted phrases."""
103104
count = 0
104105
in_quotes = False
105106
for char in query:
@@ -112,11 +113,11 @@ def count_chars(target, query: str) -> int:
112113

113114
def pubmed_to_tsquery(query: str) -> str:
114115
"""
115-
Convert a PubMed-like search query to PostgreSQL tsquery format,
116-
grouping both single-quoted and double-quoted text with the <-> operator
116+
Convert a PubMed-like search query to PostgreSQL tsquery format,
117+
grouping both single-quoted and double-quoted text with the <-> operator
117118
for proximity search.
118119
119-
Additionally, automatically adds & between non-explicitly connected terms
120+
Additionally, automatically adds & between non-explicitly connected terms
120121
and handles NOT terms.
121122
122123
Args:
@@ -130,7 +131,7 @@ def pubmed_to_tsquery(query: str) -> str:
130131

131132
# Step 1: Split into tokens (preserving quoted phrases)
132133
# Regex pattern: match quoted phrases or non-space sequences
133-
tokens = re.findall( r'"[^"]*"|\'[^\']*\'|\S+', query)
134+
tokens = re.findall(r'"[^"]*"|\'[^\']*\'|\S+', query)
134135

135136
# Step 2: Combine tokens in parantheses into single tokens
136137
def combine_parentheses(tokens: list) -> list:
@@ -152,19 +153,19 @@ def combine_parentheses(tokens: list) -> list:
152153
buffer.append(token)
153154

154155
# Adjust the count of parentheses
155-
paren_count += count_chars('(', token) - count_chars(')', token)
156+
paren_count += count_chars("(", token) - count_chars(")", token)
156157

157158
if paren_count < 1:
158159
# Combine all tokens in parentheses
159-
combined_tokens.append(' '.join(buffer))
160+
combined_tokens.append(" ".join(buffer))
160161
buffer = [] # Clear the buffer
161162
paren_count = 0
162163

163164
else:
164-
n_paren = count_chars('(', token) - count_chars(')', token)
165+
n_paren = count_chars("(", token) - count_chars(")", token)
165166
# If not in parentheses, but token contains opening parentheses
166167
# Start capturing tokens inside parentheses
167-
if token[0] == '(' and n_paren > 0:
168+
if token[0] == "(" and n_paren > 0:
168169
paren_count += n_paren
169170
buffer.append(token) # Start capturing tokens in parens
170171
print(buffer)
@@ -174,7 +175,7 @@ def combine_parentheses(tokens: list) -> list:
174175
# If the list ends without a closing parenthesis (invalid input)
175176
# append buffer contents (fallback)
176177
if buffer:
177-
combined_tokens.append(' '.join(buffer))
178+
combined_tokens.append(" ".join(buffer))
178179

179180
return combined_tokens
180181

@@ -184,42 +185,43 @@ def combine_parentheses(tokens: list) -> list:
184185
if token[0] == "(" and token[-1] == ")":
185186
# RECURSIVE: Process the contents of the parentheses
186187
token_res = pubmed_to_tsquery(token[1:-1])
187-
token = '(' + token_res + ')'
188+
token = "(" + token_res + ")"
188189
tokens[i] = token
189190

190191
# Step 4: Handle both single-quoted and double-quoted phrases,
191192
# grouping them with <-> (proximity operator)
192193
elif token[0] in ('"', "'"):
193194
# Split quoted text into individual words and join with <-> for
194195
# proximity search
195-
words = re.findall(r'\w+', token)
196-
tokens[i] = '<->'.join(words)
196+
words = re.findall(r"\w+", token)
197+
tokens[i] = "<->".join(words)
197198

198199
# Step 3: Replace logical operators AND, OR, NOT
199200
else:
200-
if token == 'AND':
201-
tokens[i] = '&'
202-
elif token == 'OR':
203-
tokens[i] = '|'
204-
elif token == 'NOT':
205-
tokens[i] = '&!'
201+
if token == "AND":
202+
tokens[i] = "&"
203+
elif token == "OR":
204+
tokens[i] = "|"
205+
elif token == "NOT":
206+
tokens[i] = "&!"
206207

207208
processed_tokens = []
208209
last_token = None
209210
for token in tokens:
210211
# Step 5: Add & between consecutive terms that aren't already
211212
# connected by an operator
212213
stripped_token = token.strip()
213-
214-
if stripped_token == '':
214+
if stripped_token not in ("&", "|", "!", "&!"):
215+
stripped_token = re.sub(r"[\[\],;:!?@#]", "", stripped_token)
216+
if stripped_token == "":
215217
continue # Ignore empty tokens from splitting
216218

217-
if last_token and last_token not in ('&', '|', '!', '&!'):
218-
if stripped_token not in ('&', '|', '!', '&!'):
219+
if last_token and last_token not in ("&", "|", "!", "&!"):
220+
if stripped_token not in ("&", "|", "!", "&!"):
219221
# Insert an implicit AND (&) between two non-operator tokens
220-
processed_tokens.append('&')
222+
processed_tokens.append("&")
221223

222224
processed_tokens.append(stripped_token)
223225
last_token = stripped_token
224226

225-
return ' '.join(processed_tokens)
227+
return " ".join(processed_tokens)

store/neurostore/tests/api/test_query_params.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from ...models import Study
33
from ...schemas.data import StudysetSchema, StudySchema, AnalysisSchema, StringOrNested
4+
from ..conftest import valid_queries, invalid_queries
45

56

67
@pytest.mark.parametrize("nested", ["true", "false"])
@@ -99,3 +100,17 @@ def test_multiword_queries(auth_client, ingest_neurosynth, session):
99100

100101
multi_word_search = auth_client.get(f"/api/studies/?search={multiple_words}")
101102
assert multi_word_search.status_code == 200
103+
104+
105+
@pytest.mark.parametrize("query, expected", valid_queries)
106+
def test_valid_pubmed_queries(query, expected, auth_client, ingest_neurosynth, session):
107+
search = auth_client.get(f"/api/studies/?search={query}")
108+
assert search.status_code == 200
109+
110+
111+
@pytest.mark.parametrize("query, expected", invalid_queries)
112+
def test_invalid_pubmed_queries(
113+
query, expected, auth_client, ingest_neurosynth, session
114+
):
115+
search = auth_client.get(f"/api/studies/?search={query}")
116+
assert search.status_code == 400

store/neurostore/tests/conftest.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,59 @@ def simple_neurosynth_annotation(session, ingest_neurosynth):
586586
session.commit()
587587

588588
return smol_annot
589+
590+
591+
"""
592+
Queries for testing
593+
"""
594+
invalid_queries = [
595+
(
596+
'("autism" OR "ASD" OR "autistic") AND (("decision*" OR "choice*" ',
597+
"Unmatched parentheses",
598+
),
599+
('"autism" OR "ASD" OR "autistic" OR ', "Query cannot end with an operator"),
600+
(
601+
'(("Autism Spectrum Disorder" OR "autism spectrum disorder") OR ("Autism" OR "autism") OR ("ASD")) AND (("decision*" OR "Dec',
602+
"Unmatched parentheses",
603+
),
604+
]
605+
606+
valid_queries = [
607+
(
608+
'"Mild Cognitive Impairment" or "Early Cognitive Decline" or "Pre-Dementia" or "Mild Neurocognitive Disorder"',
609+
"MILD<->COGNITIVE<->IMPAIRMENT | EARLY<->COGNITIVE<->DECLINE | PRE<->DEMENTIA | MILD<->NEUROCOGNITIVE<->DISORDER",
610+
),
611+
(
612+
'("autism" OR "ASD" OR "autistic") AND ("decision" OR "choice")',
613+
"(AUTISM | ASD | AUTISTIC) & (DECISION | CHOICE)",
614+
),
615+
(
616+
"stroop and depression or back and depression or go",
617+
"STROOP & DEPRESSION | BACK & DEPRESSION | GO",
618+
),
619+
(
620+
'("autism" OR "ASD" OR "autistic") AND (("decision" OR "decision-making" OR "choice" OR "selection" OR "option" OR "value") OR ("feedback" OR "feedback-related" OR "reward" OR "error" OR "outcome" OR "punishment" OR "reinforcement"))',
621+
"(AUTISM | ASD | AUTISTIC) & ((DECISION | DECISION<->MAKING | CHOICE | SELECTION | OPTION | VALUE) | (FEEDBACK | FEEDBACK<->RELATED | REWARD | ERROR | OUTCOME | PUNISHMENT | REINFORCEMENT))",
622+
),
623+
(
624+
'"dyslexia" or "Reading Disorder" or "Language-Based Learning Disability" or "Phonological Processing Disorder" or "Word Blindness"',
625+
"DYSLEXIA | READING<->DISORDER | LANGUAGE<->BASED<->LEARNING<->DISABILITY | PHONOLOGICAL<->PROCESSING<->DISORDER | WORD<->BLINDNESS",
626+
),
627+
("emotion and pain -physical -touch", "EMOTION & PAIN & -PHYSICAL & -TOUCH"),
628+
(
629+
'("Schizophrenia"[Mesh] OR schizophrenia )',
630+
"(SCHIZOPHRENIA & MESH | SCHIZOPHRENIA)",
631+
),
632+
("Bipolar Disorder", "BIPOLAR & DISORDER"),
633+
('"quchi" or "LI11"', "QUCHI | LI11"),
634+
('"rubber hand illusion"', "RUBBER<->HAND<->ILLUSION"),
635+
]
636+
637+
weird_queries = [
638+
(
639+
"[Major Depressive Disorder (MDD)] or [Clinical Depression] or [Unipolar Depression]",
640+
"MAJOR & DEPRESSIVE & DISORDER & (MDD) | CLINICAL & DEPRESSION | UNIPOLAR & DEPRESSION",
641+
),
642+
]
643+
644+
validate_queries = invalid_queries + [(q, True) for q, _ in valid_queries]
Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,23 @@
11
import pytest
22

3-
from ..utils import search_to_tsquery, validate_search_query
4-
5-
6-
invalid_queries = [
7-
('("autism" OR "ASD" OR "autistic") AND (("decision*" OR "choice*" ', 'Unmatched parentheses'),
8-
('"autism" OR "ASD" OR "autistic" OR ', 'Query cannot end with an operator'),
9-
('(("Autism Spectrum Disorder" OR "autism spectrum disorder") OR ("Autism" OR "autism") OR ("ASD")) AND (("decision*" OR "Dec', 'Unmatched parentheses')
10-
]
11-
12-
valid_queries = [
13-
('"Mild Cognitive Impairment" or "Early Cognitive Decline" or "Pre-Dementia" or "Mild Neurocognitive Disorder"',
14-
'MILD<->COGNITIVE<->IMPAIRMENT | EARLY<->COGNITIVE<->DECLINE | PRE<->DEMENTIA | MILD<->NEUROCOGNITIVE<->DISORDER'),
15-
('("autism" OR "ASD" OR "autistic") AND ("decision" OR "choice")',
16-
'(AUTISM | ASD | AUTISTIC) & (DECISION | CHOICE)'),
17-
('stroop and depression or back and depression or go',
18-
'STROOP & DEPRESSION | BACK & DEPRESSION | GO'),
19-
('("autism" OR "ASD" OR "autistic") AND (("decision" OR "decision-making" OR "choice" OR "selection" OR "option" OR "value") OR ("feedback" OR "feedback-related" OR "reward" OR "error" OR "outcome" OR "punishment" OR "reinforcement"))',
20-
'(AUTISM | ASD | AUTISTIC) & ((DECISION | DECISION<->MAKING | CHOICE | SELECTION | OPTION | VALUE) | (FEEDBACK | FEEDBACK<->RELATED | REWARD | ERROR | OUTCOME | PUNISHMENT | REINFORCEMENT))'),
21-
('"dyslexia" or "Reading Disorder" or "Language-Based Learning Disability" or "Phonological Processing Disorder" or "Word Blindness"',
22-
'DYSLEXIA | READING<->DISORDER | LANGUAGE<->BASED<->LEARNING<->DISABILITY | PHONOLOGICAL<->PROCESSING<->DISORDER | WORD<->BLINDNESS'),
23-
('emotion and pain -physical -touch',
24-
'EMOTION & PAIN & -PHYSICAL & -TOUCH'),
25-
('("Schizophrenia"[Mesh] OR schizophrenia )',
26-
'(SCHIZOPHRENIA & [MESH] | SCHIZOPHRENIA)')
27-
('Bipolar Disorder',
28-
'BIPOLAR & DISORDER'),
29-
('"quchi" or "LI11"',
30-
'QUCHI | LI11'),
31-
('"rubber hand illusion"',
32-
'RUBBER<->HAND<->ILLUSION'),
33-
]
34-
35-
error_queries = [
36-
"[Major Depressive Disorder (MDD)] or [Clinical Depression] or [Unipolar Depression]"
37-
]
38-
39-
validate_queries = invalid_queries + [(q, True) for q, _ in valid_queries]
3+
from ..resources.utils import pubmed_to_tsquery, validate_search_query
4+
from .conftest import valid_queries, validate_queries, weird_queries
405

416

427
@pytest.mark.parametrize("query, expected", valid_queries)
43-
def test_search_to_tsquery(query, expected):
44-
assert search_to_tsquery(query) == expected
8+
def test_pubmed_to_tsquery(query, expected):
9+
assert pubmed_to_tsquery(query) == expected
4510

4611

47-
@pytest.mark.parametrize("query, expected", invalid_queries)
12+
@pytest.mark.parametrize("query, expected", validate_queries)
4813
def test_validate_search_query(query, expected):
49-
assert validate_search_query(query) == expected
14+
if expected is True:
15+
assert validate_search_query(query) == expected
16+
else:
17+
with pytest.raises(Exception):
18+
validate_search_query(query)
19+
5020

51-
@pytest.mark.parametrize("query", error_queries)
52-
def test_search_to_tsquery_error(query):
53-
with pytest.raises(ValueError):
54-
search_to_tsquery(query)
21+
@pytest.mark.parametrize("query, expected", weird_queries)
22+
def test_pubmed_to_tsquery_weird(query, expected):
23+
assert pubmed_to_tsquery(query) == expected

0 commit comments

Comments
 (0)