Skip to content

Commit 11daa72

Browse files
committed
utils
1 parent 8b1ccf6 commit 11daa72

1 file changed

Lines changed: 240 additions & 0 deletions

File tree

papers2code_app2/utils/__init__.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import logging
2+
from typing import Optional, Any, Dict, List
3+
from bson import ObjectId
4+
from bson.errors import InvalidId
5+
from pymongo import DESCENDING
6+
7+
# Assuming database.py is in the same directory
8+
from ..database import get_user_actions_collection_async # Changed from sync
9+
# Assuming shared.py is in the same directory and contains IMPL_STATUS_VOTING
10+
from ..shared import (
11+
IMPL_STATUS_VOTING,
12+
IMPL_STATUS_COMMUNITY_IMPLEMENTABLE,
13+
IMPL_STATUS_COMMUNITY_NOT_IMPLEMENTABLE,
14+
IMPL_STATUS_ADMIN_IMPLEMENTABLE,
15+
IMPL_STATUS_ADMIN_NOT_IMPLEMENTABLE
16+
)
17+
from ..cache import paper_cache
18+
19+
logger = logging.getLogger(__name__)
20+
21+
# --- Helper functions for transform_paper_async (previously _sync) ---
22+
# Helper function to transform author data
23+
def _transform_authors(authors_data: Any) -> List[str]:
24+
if authors_data and isinstance(authors_data, list):
25+
if all(isinstance(author, dict) for author in authors_data):
26+
return [author.get("name") for author in authors_data if author.get("name")]
27+
elif all(isinstance(author, str) for author in authors_data):
28+
return authors_data
29+
return []
30+
31+
# Helper function to transform URL strings
32+
def _transform_url(url_value: Any) -> Optional[str]:
33+
if url_value and isinstance(url_value, str) and url_value.strip():
34+
return str(url_value)
35+
return None
36+
37+
# Async helper function to get user-specific paper data (OPTIMIZED)
38+
async def _get_user_specific_paper_data_async(paper_obj_id: ObjectId, current_user_id_str: Optional[str]) -> Dict[str, Any]:
39+
user_data = {
40+
"current_user_vote": None,
41+
"current_user_implementability_vote": None,
42+
}
43+
if not current_user_id_str:
44+
return user_data
45+
46+
user_actions_collection = await get_user_actions_collection_async()
47+
48+
try:
49+
user_obj_id = ObjectId(current_user_id_str)
50+
51+
# OPTIMIZATION: Single query instead of multiple count queries
52+
from ..schemas.user_activity import LoggedActionTypes
53+
54+
# Get all user actions for this paper in one query
55+
user_actions = await user_actions_collection.find(
56+
{"userId": user_obj_id, "paperId": paper_obj_id},
57+
{"actionType": 1, "timestamp": 1}
58+
).sort([("timestamp", DESCENDING)]).to_list(length=None)
59+
60+
# Process actions
61+
has_upvote = False
62+
latest_implementability_action = None
63+
64+
implementability_action_types_map = {
65+
LoggedActionTypes.ADMIN_IMPLEMENTABLE.value: "up",
66+
LoggedActionTypes.COMMUNITY_IMPLEMENTABLE.value: "up",
67+
LoggedActionTypes.COMMUNITY_NOT_IMPLEMENTABLE.value: "down",
68+
LoggedActionTypes.ADMIN_NOT_IMPLEMENTABLE.value: "down",
69+
}
70+
71+
for action in user_actions:
72+
action_type = action.get("actionType")
73+
74+
# Check for upvote
75+
if action_type == LoggedActionTypes.UPVOTE.value:
76+
has_upvote = True
77+
78+
# Check for implementability vote (get the latest one)
79+
if action_type in implementability_action_types_map and latest_implementability_action is None:
80+
latest_implementability_action = action_type
81+
82+
if has_upvote:
83+
user_data["current_user_vote"] = "up"
84+
85+
if latest_implementability_action:
86+
user_data["current_user_implementability_vote"] = implementability_action_types_map[latest_implementability_action]
87+
88+
except InvalidId:
89+
logger.warning(f"Invalid ObjectId for current_user_id_str: {current_user_id_str} when fetching user-specific paper data. Skipping.")
90+
pass
91+
except Exception as e:
92+
logger.error(f"Error fetching user-specific data for paper {paper_obj_id} and user {current_user_id_str}: {e}", exc_info=True)
93+
return user_data
94+
95+
# Async helper function to get aggregate vote counts (OPTIMIZED)
96+
async def _get_aggregate_vote_counts_async(paper_obj_id: ObjectId) -> Dict[str, int]:
97+
from ..schemas.user_activity import LoggedActionTypes
98+
counts = {
99+
"not_implementable_votes": 0,
100+
"implementable_votes": 0,
101+
}
102+
user_actions_collection = await get_user_actions_collection_async()
103+
try:
104+
# OPTIMIZATION: Single aggregation query instead of multiple count queries
105+
pipeline = [
106+
{"$match": {"paperId": paper_obj_id}},
107+
{"$group": {
108+
"_id": "$actionType",
109+
"count": {"$sum": 1}
110+
}}
111+
]
112+
113+
cursor = await user_actions_collection.aggregate(pipeline)
114+
results = [doc async for doc in cursor]
115+
116+
not_implementable_action_types = {
117+
LoggedActionTypes.COMMUNITY_NOT_IMPLEMENTABLE.value,
118+
LoggedActionTypes.ADMIN_NOT_IMPLEMENTABLE.value
119+
}
120+
121+
implementable_action_types = {
122+
LoggedActionTypes.COMMUNITY_IMPLEMENTABLE.value,
123+
LoggedActionTypes.ADMIN_IMPLEMENTABLE.value
124+
}
125+
126+
for result in results:
127+
action_type = result["_id"]
128+
count = result["count"]
129+
130+
if action_type in not_implementable_action_types:
131+
counts["not_implementable_votes"] += count
132+
elif action_type in implementable_action_types:
133+
counts["implementable_votes"] += count
134+
135+
except Exception as e:
136+
logger.error(f"Error fetching aggregate vote counts for paper {paper_obj_id}: {e}", exc_info=True)
137+
return counts
138+
139+
# --- Main Asynchronous Transformation Function ---
140+
async def transform_paper_async(
141+
paper_doc: Dict[str, Any],
142+
current_user_id_str: Optional[str] = None,
143+
detail_level: str = "full"
144+
) -> Optional[Dict[str, Any]]:
145+
if not paper_doc:
146+
return None
147+
148+
paper_id = str(paper_doc["_id"]) if "_id" in paper_doc else None
149+
if not paper_id:
150+
logger.warning("Paper document missing _id field.")
151+
return None
152+
153+
paper_obj_id_val = paper_doc.get("_id")
154+
if not isinstance(paper_obj_id_val, ObjectId):
155+
try:
156+
paper_obj_id = ObjectId(paper_obj_id_val)
157+
except InvalidId:
158+
logger.error(f"Invalid ObjectId format for paper_doc._id: {paper_obj_id_val}")
159+
return None
160+
else:
161+
paper_obj_id = paper_obj_id_val
162+
163+
# Synchronous parts of transformation remain the same
164+
transformed_data = {
165+
"id": paper_id,
166+
"title": paper_doc.get("title"),
167+
"authors": _transform_authors(paper_doc.get("authors", [])),
168+
"publication_date": paper_doc.get("publicationDate"),
169+
"upvote_count": paper_doc.get("upvoteCount", 0),
170+
"status": paper_doc.get("status", "Not Started"),
171+
"url_github": _transform_url(paper_doc.get("urlGithub")), # Include in base for all detail levels
172+
"url_abs": _transform_url(paper_doc.get("urlAbs")), # Include in base for paper list icons
173+
"url_pdf": _transform_url(paper_doc.get("urlPdf")), # Include in base for paper list icons
174+
"has_code": paper_doc.get("hasCode", False), # Include in base for all detail levels
175+
}
176+
177+
# Preserve implementationProgress if it exists in the input paper_doc
178+
# This key is added in paper_view_service.py before calling this transform function
179+
if "implementationProgress" in paper_doc and paper_doc["implementationProgress"] is not None:
180+
transformed_data["implementationProgress"] = paper_doc["implementationProgress"]
181+
182+
# Add fields for summary level (used in dashboard and profile)
183+
if detail_level in ["summary", "full"]:
184+
raw_implementability_status = paper_doc.get("implementabilityStatus")
185+
current_implementability_status = IMPL_STATUS_VOTING
186+
if raw_implementability_status:
187+
if raw_implementability_status.lower() == "voting":
188+
current_implementability_status = IMPL_STATUS_VOTING
189+
elif raw_implementability_status in [
190+
IMPL_STATUS_COMMUNITY_IMPLEMENTABLE,
191+
IMPL_STATUS_COMMUNITY_NOT_IMPLEMENTABLE,
192+
IMPL_STATUS_ADMIN_IMPLEMENTABLE,
193+
IMPL_STATUS_ADMIN_NOT_IMPLEMENTABLE
194+
]:
195+
current_implementability_status = raw_implementability_status
196+
197+
# Truncate abstract for summary level to reduce payload size
198+
abstract = paper_doc.get("abstract", "")
199+
if detail_level == "summary" and abstract:
200+
abstract = abstract[:300] if len(abstract) > 300 else abstract
201+
202+
transformed_data.update({
203+
"abstract": abstract,
204+
"venue": paper_doc.get("venue"), # Also known as "proceeding"
205+
"tags": paper_doc.get("tasks", []), # DB field is "tasks", Pydantic field "tags" has alias "tasks"
206+
"implementability_status": current_implementability_status,
207+
})
208+
209+
# Add additional fields only for full detail level
210+
if detail_level == "full":
211+
transformed_data.update({
212+
"pwc_url": _transform_url(paper_doc.get("pwcUrl")),
213+
"arxiv_id": paper_doc.get("arxivId"),
214+
})
215+
216+
try:
217+
# Asynchronous calls to helper functions
218+
# For summary level, we skip user-specific data to improve performance
219+
if detail_level != "summary":
220+
user_specific_data = await _get_user_specific_paper_data_async(paper_obj_id, current_user_id_str)
221+
transformed_data.update(user_specific_data)
222+
223+
if detail_level == "full":
224+
aggregate_votes = await _get_aggregate_vote_counts_async(paper_obj_id)
225+
transformed_data.update(aggregate_votes)
226+
227+
except Exception as e:
228+
logger.error(f"Error during async data transformation for paper {paper_id}: {e}", exc_info=True)
229+
return None
230+
231+
return transformed_data
232+
233+
234+
__all__ = [
235+
'transform_paper_async',
236+
'_transform_authors',
237+
'_transform_url',
238+
'_get_user_specific_paper_data_async',
239+
'_get_aggregate_vote_counts_async',
240+
]

0 commit comments

Comments
 (0)