Skip to content

Commit 90c371c

Browse files
committed
Better type checking; better error messages
1 parent c29a153 commit 90c371c

File tree

6 files changed

+965
-43
lines changed

6 files changed

+965
-43
lines changed

gramps/plugins/db/dbapi/query_builder.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,119 @@ def get_sql_query(self, what, where, order_by, page=None, page_size=None):
894894
if where:
895895
where_condition = self.parser.parse_expression(where)
896896

897+
# Check for == None or != None (should use 'is' or 'is not' instead)
898+
from .query_model import CompareExpression, ConstantExpression
899+
900+
if isinstance(where_condition, CompareExpression):
901+
# Check if any comparator is None with == or !=
902+
for op, comparator in zip(
903+
where_condition.operators, where_condition.comparators
904+
):
905+
if (
906+
isinstance(comparator, ConstantExpression)
907+
and comparator.value is None
908+
):
909+
if op == "==":
910+
raise ValueError(
911+
f"Invalid WHERE clause: Cannot use '==' to check for None.\n"
912+
f" You wrote: {where}\n"
913+
f" Problem: In SQL, 'value = NULL' is always false (NULL = NULL is false)\n"
914+
f" Use instead: {where.replace('== None', 'is None')}\n"
915+
f"\n"
916+
f"The 'is None' operator generates 'IS NULL' which correctly checks for NULL values."
917+
)
918+
elif op == "!=":
919+
raise ValueError(
920+
f"Invalid WHERE clause: Cannot use '!=' to check for None.\n"
921+
f" You wrote: {where}\n"
922+
f" Problem: In SQL, 'value != NULL' is always false (NULL != NULL is false)\n"
923+
f" Use instead: {where.replace('!= None', 'is not None')}\n"
924+
f"\n"
925+
f"The 'is not None' operator generates 'IS NOT NULL' which correctly checks for non-NULL values."
926+
)
927+
# Also check the left side
928+
if (
929+
isinstance(where_condition.left, ConstantExpression)
930+
and where_condition.left.value is None
931+
):
932+
if where_condition.operators[0] == "==":
933+
raise ValueError(
934+
f"Invalid WHERE clause: Cannot use '==' to check for None.\n"
935+
f" You wrote: {where}\n"
936+
f" Use instead: {where.replace('None ==', 'None is')}\n"
937+
f"\n"
938+
f"In SQL, 'NULL = value' is always false. Use 'is None' instead."
939+
)
940+
elif where_condition.operators[0] == "!=":
941+
raise ValueError(
942+
f"Invalid WHERE clause: Cannot use '!=' to check for None.\n"
943+
f" You wrote: {where}\n"
944+
f" Use instead: {where.replace('None !=', 'None is not')}\n"
945+
f"\n"
946+
f"In SQL, 'NULL != value' is always false. Use 'is not None' instead."
947+
)
948+
949+
# Validate that WHERE clause is not a bare list comprehension
950+
# List comprehensions evaluate to lists, not booleans
951+
from .query_model import ListComprehensionExpression
952+
953+
if isinstance(where_condition, ListComprehensionExpression):
954+
raise ValueError(
955+
f"Invalid WHERE clause: List comprehensions are not valid boolean expressions.\n"
956+
f" You wrote: {where}\n"
957+
f" Did you mean: any({where})\n"
958+
f"\n"
959+
f"A WHERE clause must evaluate to True/False, but a list comprehension returns a list.\n"
960+
f"Use any([...]) to check if any element matches the condition, or\n"
961+
f"use 'item in array and condition' for array expansion patterns."
962+
)
963+
964+
# Validate that WHERE clause returns a boolean type
965+
# Use type inference to check the return type
966+
# Note: We only reject clear non-boolean cases (lists, strings at top level)
967+
# We allow None (unknown types) and object types (can be NULL checked)
968+
if self.enable_type_validation:
969+
where_type = self.type_inference.visit(where_condition)
970+
# Only validate if we have type information
971+
# Reject lists and strings that aren't in a comparison context
972+
from .query_model import AttributeExpression, ArrayAccessExpression
973+
from typing import get_origin
974+
975+
is_simple_attribute = isinstance(
976+
where_condition, (AttributeExpression, ArrayAccessExpression)
977+
)
978+
979+
if where_type is not None and is_simple_attribute:
980+
origin = get_origin(where_type)
981+
982+
# Reject plain lists - they should use len() or any()
983+
if origin is list or where_type is list:
984+
raise TypeError(
985+
f"Invalid WHERE clause: Expression does not return a boolean.\n"
986+
f" You wrote: {where}\n"
987+
f" This returns: list\n"
988+
f"\n"
989+
f"A WHERE clause must evaluate to True/False (boolean), not list.\n"
990+
f" Use len({where}) > 0 to check if the list is non-empty, or\n"
991+
f" use any([...]) to check if any element matches a condition."
992+
)
993+
994+
# Reject plain strings - they should use comparisons
995+
elif where_type is str:
996+
raise TypeError(
997+
f"Invalid WHERE clause: Expression does not return a boolean.\n"
998+
f" You wrote: {where}\n"
999+
f" This returns: string\n"
1000+
f"\n"
1001+
f"A WHERE clause must evaluate to True/False (boolean), not string.\n"
1002+
f" Use a comparison like {where} == 'value' or 'substring' in {where}."
1003+
)
1004+
1005+
# Note: We allow int, float, and object types because:
1006+
# - In SQL, numbers can be used as booleans (0=false, non-zero=true)
1007+
# - Object types can be NULL checked (NULL=false, non-NULL=true)
1008+
# This matches SQL semantics where WHERE column is valid
1009+
8971010
# Parse order_by clause
8981011
order_by_list = self._parse_order_by(order_by)
8991012

@@ -1378,6 +1491,18 @@ def _handle_list_comprehensions(self, query: SelectQuery, what) -> SelectQuery:
13781491
)
13791492
left_query.union_queries = [right_query]
13801493
return left_query
1494+
elif listcomp.array_info.get("type") == "nested_listcomp":
1495+
# Check if the inner list comprehension is concatenated
1496+
inner_listcomp = listcomp.array_info["inner_listcomp"]
1497+
if inner_listcomp.array_info.get("type") == "concatenated":
1498+
# Need to generate UNION for nested+concatenated case
1499+
left_query, right_query = (
1500+
self._build_nested_concatenated_array_queries(
1501+
query, listcomp, inner_listcomp
1502+
)
1503+
)
1504+
left_query.union_queries = [right_query]
1505+
return left_query
13811506

13821507
return query
13831508

@@ -1453,6 +1578,103 @@ def _build_concatenated_array_queries(
14531578

14541579
return left_query, right_query
14551580

1581+
def _build_nested_concatenated_array_queries(
1582+
self,
1583+
query: SelectQuery,
1584+
outer_listcomp: ListComprehensionExpression,
1585+
inner_listcomp: ListComprehensionExpression,
1586+
) -> tuple:
1587+
"""
1588+
Build left and right queries for nested list comprehensions with concatenated inner arrays.
1589+
1590+
Example: [s.surname for s in [name.surname_list for name in [primary_name] + alternate_names]]
1591+
- Outer: iterates over surname_list results
1592+
- Inner: iterates over [primary_name] + alternate_names (concatenated)
1593+
"""
1594+
from .query_model import (
1595+
AttributeExpression,
1596+
CallExpression,
1597+
ConstantExpression,
1598+
)
1599+
1600+
# Create modified inner list comprehensions for left and right sides
1601+
# Left side: [name.surname_list for name in [primary_name]]
1602+
# The left side expression from the original concatenated array
1603+
left_array_attr = inner_listcomp.array_info["left"]
1604+
1605+
# For the left side, we need to mark that it should be wrapped in json_array()
1606+
# We'll use a special marker in array_info
1607+
left_inner_listcomp = ListComprehensionExpression(
1608+
expression=inner_listcomp.expression,
1609+
item_var=inner_listcomp.item_var,
1610+
array_info={
1611+
"type": "single",
1612+
"path": "primary_name",
1613+
"wrap_in_json_array": True, # Signal that this needs json_array wrapping
1614+
},
1615+
condition=inner_listcomp.condition,
1616+
)
1617+
1618+
# Right side: [name.surname_list for name in alternate_names]
1619+
right_path = inner_listcomp.array_info["right_path"]
1620+
right_inner_listcomp = ListComprehensionExpression(
1621+
expression=inner_listcomp.expression,
1622+
item_var=inner_listcomp.item_var,
1623+
array_info={
1624+
"type": "single",
1625+
"path": right_path,
1626+
},
1627+
condition=inner_listcomp.condition,
1628+
)
1629+
1630+
# Create outer list comprehensions with modified inner ones
1631+
left_outer_listcomp = ListComprehensionExpression(
1632+
expression=outer_listcomp.expression,
1633+
item_var=outer_listcomp.item_var,
1634+
array_info={
1635+
"type": "nested_listcomp",
1636+
"inner_listcomp": left_inner_listcomp,
1637+
},
1638+
condition=outer_listcomp.condition,
1639+
)
1640+
1641+
right_outer_listcomp = ListComprehensionExpression(
1642+
expression=outer_listcomp.expression,
1643+
item_var=outer_listcomp.item_var,
1644+
array_info={
1645+
"type": "nested_listcomp",
1646+
"inner_listcomp": right_inner_listcomp,
1647+
},
1648+
condition=outer_listcomp.condition,
1649+
)
1650+
1651+
# Build left and right queries
1652+
left_select_expr = SelectExpression(expression=left_outer_listcomp)
1653+
left_query = SelectQuery(
1654+
base_table=self.table_name,
1655+
select_expressions=[left_select_expr],
1656+
where_condition=query.where_condition,
1657+
joins=query.joins,
1658+
order_by=query.order_by,
1659+
array_expansion=None,
1660+
limit=query.limit,
1661+
offset=query.offset,
1662+
)
1663+
1664+
right_select_expr = SelectExpression(expression=right_outer_listcomp)
1665+
right_query = SelectQuery(
1666+
base_table=self.table_name,
1667+
select_expressions=[right_select_expr],
1668+
where_condition=query.where_condition,
1669+
joins=query.joins,
1670+
order_by=query.order_by,
1671+
array_expansion=None,
1672+
limit=query.limit,
1673+
offset=query.offset,
1674+
)
1675+
1676+
return left_query, right_query
1677+
14561678
def normalize_expression(self, expr: Expression) -> Optional[Expression]:
14571679
"""
14581680
Normalize an expression tree by flattening nested AND/OR and optimizing comparisons.

gramps/plugins/db/dbapi/query_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ class ListComprehensionExpression(Expression):
164164
item_var: str # Iteration variable name
165165
array_info: dict # {'type': 'single'/'concatenated', 'path': str, etc.}
166166
condition: Optional[Expression] = None # Optional filter condition
167+
item_type: Optional[type] = (
168+
None # Type of items being iterated (e.g., Name, Surname)
169+
)
167170

168171
def __repr__(self):
169172
return f"ListComp({self.item_var} in {self.array_info})"

gramps/plugins/db/dbapi/query_parser.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,9 @@ def _parse_listcomp(self, node: ast.ListComp) -> ListComprehensionExpression:
522522
# Extract array info
523523
array_info = self._extract_array_info(iter_node)
524524

525+
# Infer the item type from the array being iterated
526+
item_type = self._infer_list_item_type(array_info)
527+
525528
# Extract expression with item_var context
526529
# Temporarily set item_var so that the expression can reference it
527530
old_item_var = self.item_var
@@ -544,6 +547,7 @@ def _parse_listcomp(self, node: ast.ListComp) -> ListComprehensionExpression:
544547
item_var=item_var,
545548
array_info=array_info,
546549
condition=condition,
550+
item_type=item_type,
547551
)
548552

549553
def _extract_array_info(self, iter_node: ast.AST) -> dict:
@@ -603,8 +607,83 @@ def _extract_array_info(self, iter_node: ast.AST) -> dict:
603607
"outer_item_var": self.item_var,
604608
}
605609

610+
# Check for nested list comprehension
611+
if isinstance(iter_node, ast.ListComp):
612+
# Parse the inner list comprehension recursively
613+
inner_listcomp = self._parse_listcomp(iter_node)
614+
return {
615+
"type": "nested_listcomp",
616+
"inner_listcomp": inner_listcomp,
617+
}
618+
606619
raise ValueError(f"Could not extract array info from: {iter_node}")
607620

621+
def _infer_list_item_type(self, array_info: dict) -> Optional[type]:
622+
"""
623+
Infer the type of items in a list comprehension from the array info.
624+
625+
Args:
626+
array_info: Dictionary containing array information
627+
628+
Returns:
629+
Type of items, or None if cannot be determined
630+
"""
631+
array_type = array_info.get("type")
632+
633+
if array_type == "single":
634+
# Simple array like person.alternate_names
635+
array_path = array_info.get("path", "")
636+
return self.type_inference.infer_array_item_type(
637+
self.table_name, array_path
638+
)
639+
640+
elif array_type == "concatenated":
641+
# Concatenated array like [person.primary_name] + person.alternate_names
642+
# Both sides should have the same item type
643+
right_path = array_info.get("right_path", "")
644+
return self.type_inference.infer_array_item_type(
645+
self.table_name, right_path
646+
)
647+
648+
elif array_type == "nested":
649+
# Nested iteration like "for surname in name.surname_list"
650+
# We need to look up the type from the outer item variable's type
651+
# For now, we don't have the outer item's type stored, so return None
652+
# This could be enhanced later
653+
return None
654+
655+
elif array_type == "nested_listcomp":
656+
# Nested list comprehension - the item type is the type of the inner comprehension's expression
657+
# For example: [name.surname_list for name in ...] returns List[Surname]
658+
# So the outer iteration "for s in [name.surname_list ...]" iterates over Surname items
659+
inner_listcomp = array_info.get("inner_listcomp")
660+
if (
661+
inner_listcomp
662+
and inner_listcomp.expression
663+
and inner_listcomp.item_type
664+
):
665+
# Push the inner item type onto the stack so we can infer the expression type
666+
self.type_inference._listcomp_item_type_stack.append(
667+
inner_listcomp.item_type
668+
)
669+
try:
670+
# Infer the type of the inner expression (e.g., name.surname_list)
671+
expr_type = self.type_inference.visit(inner_listcomp.expression)
672+
if expr_type:
673+
# If the expression is a list, extract the item type
674+
item_type = self.type_inference._extract_list_item_type(
675+
expr_type
676+
)
677+
if item_type:
678+
return item_type
679+
# If it's not a list type, return the expression type itself
680+
return expr_type
681+
finally:
682+
# Pop the item type from the stack
683+
self.type_inference._listcomp_item_type_stack.pop()
684+
685+
return None
686+
608687
def detect_array_expansion(self, expr_str: str) -> Optional[ArrayExpansion]:
609688
"""Detect array expansion pattern in expression."""
610689
try:

0 commit comments

Comments
 (0)