Skip to content
80 changes: 77 additions & 3 deletions chord_metadata_service/discovery/api_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bento_lib.responses import errors
from collections import defaultdict
from django.core.exceptions import FieldError, ValidationError
from django.db.models import QuerySet
from django.db.models import QuerySet, Count
from drf_spectacular.utils import extend_schema, inline_serializer
from functools import partial, wraps
from operator import is_not
Expand All @@ -19,6 +19,7 @@
from structlog.stdlib import BoundLogger
from typing import Any, Awaitable, Callable, Literal, overload

from chord_metadata_service.authz.helpers import get_data_type_query_permissions
from chord_metadata_service.authz.middleware import authz_middleware
from chord_metadata_service.authz.permissions import BentoAllowAny, BentoDeferToHandler
from chord_metadata_service.authz.types import DataPermissions, DataTypeDiscoveryPermissions
Expand All @@ -30,8 +31,13 @@
from chord_metadata_service.utils import build_id_set

from . import responses as dres
from .censorship import get_rules, thresholded_count, censor_entity_counts
from .constants import DISCOVERY_ENTITIES
from .censorship import (
get_rules, thresholded_count,
censor_entity_counts,
censor_entity_counts_by_dataset,
aggregate_counts_from_censored_by_dataset,
)
from .constants import DISCOVERY_ENTITIES, ENTITY_TO_DATASET_GROUP_BY
from .exceptions import DiscoveryScopeException
from .fields import get_field_options, get_range_stats, get_categorical_stats, get_date_stats
from .fields_utils import normalize_field_path_true_model
Expand Down Expand Up @@ -350,6 +356,40 @@ async def _get_entity_count(ee: DiscoveryEntity) -> int:
}


async def discovery_queryset_entity_counts_by_dataset(
qqs: QueryQuerysetsCache,
) -> dict[str, EntityCounts]:
"""
Returns a dictionary of discovery entity counts grouped by dataset identifier for a given scope/query context.
"""
async def _get_entity_counts_by_dataset(ee: DiscoveryEntity) -> dict[str, int]:
qs, _ = await qqs.get_query_queryset_and_queried_entities(ee, validate_field=False)
group_by = ENTITY_TO_DATASET_GROUP_BY[ee]
res = await sync_to_async(list)(
qs.values(group_by).annotate(count=Count("id", distinct=True))
)
return {str(r[group_by]): r["count"] for r in res if r[group_by] is not None}

entity_counts_per_entity = await asyncio.gather(
*(_get_entity_counts_by_dataset(e) for e in DISCOVERY_ENTITIES)
)

all_datasets: set[str] = set()
for ec in entity_counts_per_entity:
all_datasets.update(ec.keys())

res: dict[str, EntityCounts] = {}
entities = tuple(DISCOVERY_ENTITIES)

for ds in all_datasets:
res[ds] = {
entity: entity_counts_per_entity[i].get(ds, 0)
for i, entity in enumerate(entities)
}

return res


@overload
async def get_censored_entity_counts(
scope: ValidatedDiscoveryScope,
Expand Down Expand Up @@ -491,6 +531,39 @@ async def discovery_endpoint(
scope, dt_permissions, lg=lg, query=query, return_raw_counts=True
)

# -- Per-dataset counts (permissions-dependent) -------------------------------------------------------------------

counts_by_dataset_res: dict[str, EntityCountOrBoolResponse] = {}

ds_level_permissions = await get_data_type_query_permissions(
request,
list(DISCOVERY_ENTITY_NAMES_TO_DATA_TYPE.values()),
dataset_level=True,
)
has_ds_level_counts_permission = any(p.counts for p in ds_level_permissions.values())

if has_ds_level_counts_permission:
# Raw per-dataset counts: dataset_id -> {entity -> count}
counts_by_dataset_raw: dict[str, EntityCounts] = await discovery_queryset_entity_counts_by_dataset(
qqs
)

# Censor per-dataset counts using the same logic as for top-level counts
counts_by_dataset_res = await censor_entity_counts_by_dataset(
scope,
counts_by_dataset_raw,
dt_permissions,
lg,
)

# When we expose per-dataset counts, recompute the top-level counts from the
# censored per-dataset view to avoid residual disclosure (total - sum(others)).
if counts_by_dataset_res:
count_or_bools_res = aggregate_counts_from_censored_by_dataset(
counts_by_dataset_res,
count_or_bools_res,
)

if (
not count_or_bools_res[queryset_entity]
and not dt_permissions[DISCOVERY_ENTITY_NAMES_TO_DATA_TYPE[queryset_entity]].data
Expand All @@ -512,6 +585,7 @@ async def discovery_endpoint(
message=message,
# permissions-dependent: dictionary of {entity: counts or True if above threshold, 0/False otherwise}:
counts=count_or_bools_res,
counts_by_dataset=counts_by_dataset_res,
)
)

Expand Down
52 changes: 52 additions & 0 deletions chord_metadata_service/discovery/censorship.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
DiscoveryConfigRules,
RULES_NO_PERMISSIONS,
RULES_FULL_PERMISSIONS,
DiscoveryEntity,
)
from structlog.stdlib import BoundLogger
from typing import TypeAlias
Expand All @@ -21,6 +22,8 @@
"get_max_query_parameters",
"get_rules",
"censor_entity_counts",
"censor_entity_counts_by_dataset",
"aggregate_counts_from_censored_by_dataset",
]

# If only we had interfaces...
Expand Down Expand Up @@ -121,3 +124,52 @@ async def censor_entity_counts(
count_or_bools_res[ee] = 0 if ee_perms.counts else False

return count_or_bools_res


async def censor_entity_counts_by_dataset(
scope: ValidatedDiscoveryScope,
counts_by_dataset: dict[str, EntityCounts],
dt_permissions: DataTypeDiscoveryPermissions,
lg: BoundLogger,
) -> dict[str, EntityCountOrBoolResponse]:

res: dict[str, EntityCountOrBoolResponse] = {}

for dataset_id, entity_counts in counts_by_dataset.items():
censored = await censor_entity_counts(scope, entity_counts, dt_permissions, lg)
if censored:
res[dataset_id] = censored

return res


def aggregate_counts_from_censored_by_dataset(
counts_by_dataset: dict[str, EntityCountOrBoolResponse],
base_counts: EntityCountOrBoolResponse,
) -> EntityCountOrBoolResponse:

aggregated: EntityCountOrBoolResponse = dict(base_counts)

# Collect entities in base_counts or any per-dataset mapping
all_entities: set[DiscoveryEntity] = set(aggregated.keys())
for ds_counts in counts_by_dataset.values():
all_entities.update(ds_counts.keys())

for entity in all_entities:
# All dataset-level values for this entity (int | bool)
values: list[int | bool] = [
ds_counts[entity]
for ds_counts in counts_by_dataset.values()
if entity in ds_counts
]
if not values:
continue

# If any dataset-level value is boolean, treat the aggregated value as boolean (OR)
if any(isinstance(v, bool) for v in values):
aggregated[entity] = any(bool(v) for v in values)
else:
# All numeric: sum them
aggregated[entity] = sum(int(v) for v in values)

return aggregated
8 changes: 8 additions & 0 deletions chord_metadata_service/discovery/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@
"experiment": ("experiment_result",),
"experiment_result": (),
}

ENTITY_TO_DATASET_GROUP_BY = {
"phenopacket": "dataset_id__identifier",
"individual": "phenopackets__dataset_id__identifier",
"biosample": "individual__phenopackets__dataset_id__identifier",
"experiment": "biosample__individual__phenopackets__dataset_id__identifier",
"experiment_result": "experiments__biosample__individual__phenopackets__dataset_id__identifier",
}
1 change: 1 addition & 0 deletions chord_metadata_service/discovery/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class DiscoveryResponse(BaseModel):
dict[str, EntityCountOrBoolResponse] |
dict[str, dict[str, EntityCountOrBoolResponse]]
)
counts_by_dataset: dict[str, EntityCountOrBoolResponse] = Field(default_factory=dict)


class DiscoveryMatchesPaginatedResponse(BaseModel):
Expand Down
Loading