Skip to content

Commit 52945f1

Browse files
authored
Exclude prompts of disabled users from prompt lottery (#1748)
Co-authored-by: --show <--show>
1 parent 7f8163b commit 52945f1

2 files changed

Lines changed: 35 additions & 1 deletion

File tree

backend/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from oasst_backend.models import message_tree_state
2222
from oasst_backend.prompt_repository import PromptRepository, UserRepository
2323
from oasst_backend.task_repository import TaskRepository, delete_expired_tasks
24-
from oasst_backend.tree_manager import TreeManager
24+
from oasst_backend.tree_manager import TreeManager, halt_prompts_of_disabled_users
2525
from oasst_backend.user_repository import User
2626
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
2727
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
@@ -333,6 +333,7 @@ def update_user_streak(session: Session) -> None:
333333
@managed_tx_function(auto_commit=CommitMode.COMMIT)
334334
def cronjob_delete_expired_tasks(session: Session) -> None:
335335
delete_expired_tasks(session)
336+
halt_prompts_of_disabled_users(session)
336337

337338

338339
@app.on_event("startup")

backend/oasst_backend/tree_manager.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,33 @@ class TreeManagerStats(pydantic.BaseModel):
120120
message_counts: list[TreeMessageCountStats]
121121

122122

123+
def halt_prompts_of_disabled_users(db: Session):
124+
_sql_halt_prompts_of_disabled_users = """
125+
-- remove prompts of disabled & deleted users from prompt lottery
126+
WITH cte AS (
127+
SELECT mts.message_tree_id
128+
FROM message_tree_state mts
129+
JOIN message m ON mts.message_tree_id = m.id
130+
JOIN "user" u ON m.user_id = u.id
131+
WHERE state = :prompt_lottery_waiting_state AND (NOT u.enabled OR u.deleted)
132+
)
133+
UPDATE message_tree_state mts2
134+
SET active=false, state=:halted_by_moderator_state
135+
FROM cte
136+
WHERE mts2.message_tree_id = cte.message_tree_id;
137+
"""
138+
139+
r = db.execute(
140+
text(_sql_halt_prompts_of_disabled_users),
141+
{
142+
"prompt_lottery_waiting_state": message_tree_state.State.PROMPT_LOTTERY_WAITING,
143+
"halted_by_moderator_state": message_tree_state.State.HALTED_BY_MODERATOR,
144+
},
145+
)
146+
if r.rowcount > 0:
147+
logger.info(f"Halted {r.rowcount} prompts of disabled users.")
148+
149+
123150
class TreeManager:
124151
def __init__(
125152
self,
@@ -240,16 +267,20 @@ def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
240267

241268
@managed_tx_function(CommitMode.COMMIT)
242269
def activate_one(db: Session) -> int:
270+
243271
# select among distinct users
244272
authors_qry = (
245273
db.query(Message.user_id)
246274
.select_from(MessageTreeState)
247275
.join(Message, MessageTreeState.message_tree_id == Message.id)
276+
.join(User, Message.user_id == User.id)
248277
.filter(
249278
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
250279
Message.lang == lang,
251280
not_(Message.deleted),
252281
Message.review_result,
282+
User.enabled,
283+
not_(User.deleted),
253284
)
254285
.distinct(Message.user_id)
255286
)
@@ -1309,6 +1340,8 @@ def ensure_tree_states(self) -> None:
13091340
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
13101341
self._insert_default_state(id, state=state)
13111342

1343+
halt_prompts_of_disabled_users(self.db)
1344+
13121345
# check tree state transitions (maybe variables haves changes): prompt review -> growing -> ranking -> scoring
13131346
prompt_review_trees: list[MessageTreeState] = (
13141347
self.db.query(MessageTreeState)

0 commit comments

Comments
 (0)