Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/routers/user_router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime
from typing import List

from fastapi import APIRouter, Depends
Expand Down Expand Up @@ -76,16 +75,9 @@ def get_versioned_resources_for_user(
for r in versioned_routers.get(version, [])
}

def sort_function(asset):
value = getattr(asset.aiod_entry, sorting.sort)
direction = -1 if sorting.direction == SortDirection.DESC else 1
return direction * datetime.datetime.timestamp(value)

# Assets are already sorted by _get_resources_for_user, just transform them
return {
asset_name: sorted(
(orm_to_read[asset_name](asset) for asset in assets),
key=sort_function,
)
asset_name: [orm_to_read[asset_name](asset) for asset in assets]
for asset_name, assets in resources.items()
}

Expand Down Expand Up @@ -117,20 +109,32 @@ def _get_resources_for_user(
if limit:
stmt = stmt.limit(limit)
entries = session.scalars(stmt).all()
assets_to_fetch = [entry.identifier for entry in entries]

# Create a mapping of entry_identifier -> sort_order to preserve database ordering
entry_order = {entry.identifier: idx for idx, entry in enumerate(entries)}
assets_to_fetch = list(entry_order.keys())

# We have AIoD entries, but want their respective asset information (e.g. publication).
# We lack the information about what the type of the asset is, so unfortunately we
# have to check all tables:
asset_types = list(non_abstract_subclasses(AIoDConcept))
found_assets: dict[str, list[AIoDConcept]] = {type_.__tablename__: [] for type_ in asset_types}

for asset_type in asset_types:
query = (
select(asset_type)
.where(asset_type.aiod_entry_identifier.in_(assets_to_fetch))
.where(asset_type.date_deleted.is_(None))
)
assets = session.scalars(query).all()
found_assets[asset_type.__tablename__] = list(assets)

# Sort assets by the original database order before adding to results
sorted_assets = sorted(
assets, key=lambda asset: entry_order.get(asset.aiod_entry_identifier, float("inf"))
)
found_assets[asset_type.__tablename__] = sorted_assets

if sum(map(len, found_assets.values())) == len(assets_to_fetch):
break
return found_assets # minor optimization since queries may be expensive

return found_assets