Skip to content

Commit da7d608

Browse files
committed
move index query to models, tests
1 parent 56c0680 commit da7d608

File tree

3 files changed

+226
-60
lines changed

3 files changed

+226
-60
lines changed

poliloom/poliloom/cli.py

Lines changed: 7 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from poliloom.database import get_engine
1515
from poliloom.logging import setup_logging
1616
from sqlalchemy.orm import Session
17-
from sqlalchemy import exists, text
17+
from sqlalchemy import exists, func, select
1818
from poliloom.models import (
1919
Country,
2020
CurrentImportEntity,
@@ -899,24 +899,6 @@ def clean_properties(dry_run):
899899
raise SystemExit(1)
900900

901901

902-
def _get_search_indexed_models():
903-
"""Get all models that use WikidataEntityMixin (which provides search indexing)."""
904-
from poliloom.models.wikidata import WikidataEntityMixin
905-
import poliloom.models as models_module
906-
907-
search_indexed_models = []
908-
for name in dir(models_module):
909-
obj = getattr(models_module, name)
910-
if (
911-
isinstance(obj, type)
912-
and issubclass(obj, WikidataEntityMixin)
913-
and obj is not WikidataEntityMixin
914-
and hasattr(obj, "__tablename__")
915-
):
916-
search_indexed_models.append(obj)
917-
return search_indexed_models
918-
919-
920902
@main.command("index-create")
921903
def index_create():
922904
"""Create the Meilisearch entities index.
@@ -1001,48 +983,16 @@ def index_build(batch_size, rebuild):
1001983
if search_service.ensure_index():
1002984
click.echo(f" Created index '{INDEX_NAME}'")
1003985

1004-
# Get all searchable models
1005-
models = _get_search_indexed_models()
1006-
click.echo(f" Types: {', '.join(m.__name__ for m in models)}")
1007-
1008-
# Build dynamic SQL query
1009-
# LEFT JOIN each model table and build types array from which ones match
1010-
left_joins = []
1011-
case_statements = []
1012-
group_by_columns = ["we.wikidata_id"]
1013-
1014-
for model in models:
1015-
table_name = model.__tablename__
1016-
left_joins.append(
1017-
f"LEFT JOIN {table_name} ON we.wikidata_id = {table_name}.wikidata_id"
1018-
)
1019-
case_statements.append(
1020-
f"CASE WHEN {table_name}.wikidata_id IS NOT NULL THEN '{model.__name__}' END"
1021-
)
1022-
group_by_columns.append(f"{table_name}.wikidata_id")
1023-
1024-
array_expr = f"array_remove(ARRAY[{', '.join(case_statements)}], NULL)"
1025-
1026-
base_sql = f"""
1027-
SELECT
1028-
we.wikidata_id,
1029-
array_agg(DISTINCT wel.label) as labels,
1030-
{array_expr} as types
1031-
FROM wikidata_entities we
1032-
JOIN wikidata_entity_labels wel ON we.wikidata_id = wel.entity_id
1033-
{chr(10).join(left_joins)}
1034-
WHERE we.deleted_at IS NULL
1035-
GROUP BY {", ".join(group_by_columns)}
1036-
HAVING array_length({array_expr}, 1) > 0
1037-
"""
986+
# Build query for search index documents
987+
query = WikidataEntity.search_index_query()
1038988

1039989
total_indexed = 0
1040990
task_uids = []
1041991

1042992
with Session(get_engine()) as session:
1043993
# Count total
1044-
count_result = session.execute(text(f"SELECT COUNT(*) FROM ({base_sql}) subq"))
1045-
total = count_result.scalar()
994+
count_query = select(func.count()).select_from(query.subquery())
995+
total = session.execute(count_query).scalar()
1046996

1047997
if total == 0:
1048998
click.echo(" No entities to index")
@@ -1053,11 +1003,8 @@ def index_build(batch_size, rebuild):
10531003
# Process in batches
10541004
offset_val = 0
10551005
while offset_val < total:
1056-
paginated_sql = f"{base_sql} OFFSET :offset LIMIT :limit"
1057-
rows = session.execute(
1058-
text(paginated_sql),
1059-
{"offset": offset_val, "limit": batch_size},
1060-
).fetchall()
1006+
paginated_query = query.offset(offset_val).limit(batch_size)
1007+
rows = session.execute(paginated_query).fetchall()
10611008

10621009
if not rows:
10631010
break

poliloom/poliloom/models/wikidata.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
delete,
2020
exists,
2121
func,
22+
literal,
2223
literal_column,
2324
or_,
2425
select,
2526
text,
27+
union_all,
2628
update,
2729
)
2830
from sqlalchemy.dialects.postgresql import UUID
@@ -623,6 +625,51 @@ def cleanup_orphaned(cls, session: Session) -> int:
623625

624626
return result.rowcount
625627

628+
@classmethod
629+
def search_index_query(cls):
630+
"""Build query for search index documents.
631+
632+
Creates a query that returns all searchable entities with their
633+
aggregated types and labels. Only includes non-deleted entities.
634+
635+
Returns:
636+
SQLAlchemy select query with columns: wikidata_id, types, labels
637+
"""
638+
models = WikidataEntityMixin.__subclasses__()
639+
640+
# Build UNION of all model tables with their type names
641+
entity_unions = union_all(
642+
*[
643+
select(
644+
model.wikidata_id.label("wikidata_id"),
645+
literal(model.__name__).label("type"),
646+
)
647+
for model in models
648+
]
649+
).subquery("entity_types")
650+
651+
# Main query: aggregate types and labels per entity
652+
return (
653+
select(
654+
entity_unions.c.wikidata_id,
655+
func.array_agg(func.distinct(entity_unions.c.type)).label("types"),
656+
func.array_agg(func.distinct(WikidataEntityLabel.label)).label(
657+
"labels"
658+
),
659+
)
660+
.select_from(entity_unions)
661+
.join(
662+
WikidataEntityLabel,
663+
entity_unions.c.wikidata_id == WikidataEntityLabel.entity_id,
664+
)
665+
.join(
666+
cls,
667+
entity_unions.c.wikidata_id == cls.wikidata_id,
668+
)
669+
.where(cls.deleted_at.is_(None))
670+
.group_by(entity_unions.c.wikidata_id)
671+
)
672+
626673

627674
class WikidataEntityLabel(Base, TimestampMixin, UpsertMixin):
628675
"""Normalized label storage for wikidata entities."""

poliloom/tests/models/test_wikidata.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,3 +987,175 @@ def test_dry_run_does_not_call_delete_documents(self, db_session):
987987

988988
# Dry run reports what would be removed
989989
assert stats["entities_removed"] == 1
990+
991+
992+
class TestSearchIndexQuery:
993+
"""Test WikidataEntity.search_index_query functionality."""
994+
995+
def _create_entity_with_labels(self, db_session, wikidata_id, name, labels):
996+
"""Helper to create a WikidataEntity with labels."""
997+
from poliloom.models import WikidataEntityLabel
998+
999+
stmt = insert(WikidataEntity).values(
1000+
[{"wikidata_id": wikidata_id, "name": name}]
1001+
)
1002+
stmt = stmt.on_conflict_do_nothing(index_elements=["wikidata_id"])
1003+
db_session.execute(stmt)
1004+
1005+
for label in labels:
1006+
stmt = insert(WikidataEntityLabel).values(
1007+
[{"entity_id": wikidata_id, "label": label}]
1008+
)
1009+
stmt = stmt.on_conflict_do_nothing(index_elements=["entity_id", "label"])
1010+
db_session.execute(stmt)
1011+
1012+
db_session.flush()
1013+
1014+
def _create_location(self, db_session, wikidata_id, name, labels):
1015+
"""Helper to create a Location entity with labels."""
1016+
from poliloom.models import Location
1017+
1018+
self._create_entity_with_labels(db_session, wikidata_id, name, labels)
1019+
1020+
stmt = insert(Location.__table__).values([{"wikidata_id": wikidata_id}])
1021+
stmt = stmt.on_conflict_do_nothing(index_elements=["wikidata_id"])
1022+
db_session.execute(stmt)
1023+
db_session.flush()
1024+
1025+
def _create_country(self, db_session, wikidata_id, name, labels):
1026+
"""Helper to create a Country entity with labels."""
1027+
from poliloom.models import Country
1028+
1029+
self._create_entity_with_labels(db_session, wikidata_id, name, labels)
1030+
1031+
stmt = insert(Country.__table__).values([{"wikidata_id": wikidata_id}])
1032+
stmt = stmt.on_conflict_do_nothing(index_elements=["wikidata_id"])
1033+
db_session.execute(stmt)
1034+
db_session.flush()
1035+
1036+
def _create_position(self, db_session, wikidata_id, name, labels):
1037+
"""Helper to create a Position entity with labels."""
1038+
from poliloom.models import Position
1039+
1040+
self._create_entity_with_labels(db_session, wikidata_id, name, labels)
1041+
1042+
stmt = insert(Position.__table__).values([{"wikidata_id": wikidata_id}])
1043+
stmt = stmt.on_conflict_do_nothing(index_elements=["wikidata_id"])
1044+
db_session.execute(stmt)
1045+
db_session.flush()
1046+
1047+
def test_returns_entity_with_single_type(self, db_session):
1048+
"""Test query returns entity with single type."""
1049+
self._create_location(db_session, "Q60", "New York City", ["New York", "NYC"])
1050+
1051+
query = WikidataEntity.search_index_query()
1052+
results = db_session.execute(query).fetchall()
1053+
1054+
assert len(results) == 1
1055+
row = results[0]
1056+
assert row.wikidata_id == "Q60"
1057+
assert "Location" in row.types
1058+
assert set(row.labels) == {"New York", "NYC"}
1059+
1060+
def test_returns_entity_with_multiple_types(self, db_session):
1061+
"""Test query aggregates multiple types for same entity."""
1062+
from poliloom.models import Location, Country
1063+
1064+
# Germany is both a Location and a Country
1065+
self._create_entity_with_labels(
1066+
db_session, "Q183", "Germany", ["Germany", "Deutschland"]
1067+
)
1068+
1069+
# Add to both tables
1070+
stmt = insert(Location.__table__).values([{"wikidata_id": "Q183"}])
1071+
stmt = stmt.on_conflict_do_nothing(index_elements=["wikidata_id"])
1072+
db_session.execute(stmt)
1073+
1074+
stmt = insert(Country.__table__).values([{"wikidata_id": "Q183"}])
1075+
stmt = stmt.on_conflict_do_nothing(index_elements=["wikidata_id"])
1076+
db_session.execute(stmt)
1077+
db_session.flush()
1078+
1079+
query = WikidataEntity.search_index_query()
1080+
results = db_session.execute(query).fetchall()
1081+
1082+
assert len(results) == 1
1083+
row = results[0]
1084+
assert row.wikidata_id == "Q183"
1085+
assert "Location" in row.types
1086+
assert "Country" in row.types
1087+
assert set(row.labels) == {"Germany", "Deutschland"}
1088+
1089+
def test_returns_multiple_entities(self, db_session):
1090+
"""Test query returns multiple entities."""
1091+
self._create_location(db_session, "Q60", "New York City", ["NYC"])
1092+
self._create_position(db_session, "Q30185", "Mayor", ["Mayor", "Bürgermeister"])
1093+
1094+
query = WikidataEntity.search_index_query()
1095+
results = db_session.execute(query).fetchall()
1096+
1097+
assert len(results) == 2
1098+
results_by_id = {r.wikidata_id: r for r in results}
1099+
1100+
assert "Location" in results_by_id["Q60"].types
1101+
assert set(results_by_id["Q60"].labels) == {"NYC"}
1102+
1103+
assert "Position" in results_by_id["Q30185"].types
1104+
assert set(results_by_id["Q30185"].labels) == {"Mayor", "Bürgermeister"}
1105+
1106+
def test_excludes_soft_deleted_entities(self, db_session):
1107+
"""Test query excludes soft-deleted entities."""
1108+
from datetime import datetime, timezone
1109+
1110+
self._create_location(db_session, "Q60", "New York City", ["NYC"])
1111+
self._create_location(db_session, "Q84", "London", ["London"])
1112+
1113+
# Soft-delete London
1114+
db_session.execute(
1115+
WikidataEntity.__table__.update()
1116+
.where(WikidataEntity.wikidata_id == "Q84")
1117+
.values(deleted_at=datetime.now(timezone.utc))
1118+
)
1119+
db_session.flush()
1120+
1121+
query = WikidataEntity.search_index_query()
1122+
results = db_session.execute(query).fetchall()
1123+
1124+
assert len(results) == 1
1125+
assert results[0].wikidata_id == "Q60"
1126+
1127+
def test_excludes_entities_not_in_model_tables(self, db_session):
1128+
"""Test query only returns entities that exist in model tables."""
1129+
# Create entity with labels but NOT in any model table
1130+
self._create_entity_with_labels(db_session, "Q999", "Orphan Entity", ["Orphan"])
1131+
1132+
# Create proper location
1133+
self._create_location(db_session, "Q60", "New York City", ["NYC"])
1134+
1135+
query = WikidataEntity.search_index_query()
1136+
results = db_session.execute(query).fetchall()
1137+
1138+
# Only the location should be returned
1139+
assert len(results) == 1
1140+
assert results[0].wikidata_id == "Q60"
1141+
1142+
def test_pagination_with_offset_and_limit(self, db_session):
1143+
"""Test query supports pagination with offset and limit."""
1144+
# Create multiple locations
1145+
for i in range(5):
1146+
self._create_location(db_session, f"Q{i}", f"Location {i}", [f"Label {i}"])
1147+
1148+
query = WikidataEntity.search_index_query()
1149+
1150+
# Get first 2
1151+
results_page1 = db_session.execute(query.limit(2)).fetchall()
1152+
assert len(results_page1) == 2
1153+
1154+
# Get next 2
1155+
results_page2 = db_session.execute(query.offset(2).limit(2)).fetchall()
1156+
assert len(results_page2) == 2
1157+
1158+
# Verify no overlap
1159+
page1_ids = {r.wikidata_id for r in results_page1}
1160+
page2_ids = {r.wikidata_id for r in results_page2}
1161+
assert page1_ids.isdisjoint(page2_ids)

0 commit comments

Comments
 (0)