@@ -214,7 +214,7 @@ def _determine_task_availability_internal(
214214 ) -> dict [protocol_schema .TaskRequestType , int ]:
215215 task_count_by_type : dict [protocol_schema .TaskRequestType , int ] = {t : 0 for t in protocol_schema .TaskRequestType }
216216
217- task_count_by_type [protocol_schema .TaskRequestType .initial_prompt ] = num_missing_prompts
217+ task_count_by_type [protocol_schema .TaskRequestType .initial_prompt ] = max ( 0 , num_missing_prompts )
218218
219219 task_count_by_type [protocol_schema .TaskRequestType .prompter_reply ] = len (
220220 list (filter (lambda x : x .parent_role == "assistant" , extendible_parents ))
@@ -256,13 +256,14 @@ def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
256256
257257 while True :
258258 stats = self .tree_counts_by_state_stats (lang = lang , only_active = True )
259-
259+ prompt_lottery_waiting = self .query_prompt_lottery_waiting (lang = lang )
260+ remaining_lottery_entries = max (0 , self .cfg .max_prompt_lottery_waiting - prompt_lottery_waiting )
260261 remaining_prompt_review = max (0 , self .cfg .max_initial_prompt_review - stats .initial_prompt_review )
261262 num_missing_growing = max (0 , self .cfg .max_active_trees - stats .growing )
262263 logger .info (f"_prompt_lottery { remaining_prompt_review = } , { num_missing_growing = } " )
263264
264265 if num_missing_growing == 0 or activated >= max_activate :
265- return num_missing_growing + remaining_prompt_review
266+ return min ( num_missing_growing + remaining_prompt_review , remaining_lottery_entries )
266267
267268 @managed_tx_function (CommitMode .COMMIT )
268269 def activate_one (db : Session ) -> int :
@@ -330,7 +331,7 @@ def activate_one(db: Session) -> int:
330331 return True
331332
332333 if not activate_one ():
333- return num_missing_growing + remaining_prompt_review
334+ return min ( num_missing_growing + remaining_prompt_review , remaining_lottery_entries )
334335
335336 activated += 1
336337
@@ -714,8 +715,10 @@ async def handle_interaction(self, interaction: protocol_schema.AnyInteraction)
714715 )
715716
716717 if not message .parent_id :
717- logger .info (f"TreeManager: Inserting new tree state for initial prompt { message .id = } " )
718- self ._insert_default_state (message .id )
718+ logger .info (
719+ f"TreeManager: Inserting new tree state for initial prompt { message .id = } [{ message .lang } ]"
720+ )
721+ self ._insert_default_state (message .id , message .lang )
719722
720723 if not settings .DEBUG_SKIP_EMBEDDING_COMPUTATION :
721724 try :
@@ -1261,10 +1264,10 @@ def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
12611264
12621265 return ActiveTreeSizeRow .from_orm (qry .one ())
12631266
1264- def query_misssing_tree_states (self ) -> list [UUID ]:
1267+ def query_misssing_tree_states (self ) -> list [Tuple [ UUID , str ] ]:
12651268 """Find all initial prompt messages that have no associated message tree state"""
12661269 qry_missing_tree_states = (
1267- self .db .query (Message .id )
1270+ self .db .query (Message .id , Message . lang )
12681271 .outerjoin (MessageTreeState , Message .message_tree_id == MessageTreeState .message_tree_id )
12691272 .filter (
12701273 Message .parent_id .is_ (None ),
@@ -1273,7 +1276,7 @@ def query_misssing_tree_states(self) -> list[UUID]:
12731276 )
12741277 )
12751278
1276- return [m .id for m in qry_missing_tree_states .all ()]
1279+ return [( m .id , m . lang ) for m in qry_missing_tree_states .all ()]
12771280
12781281 _sql_find_tree_ranking_results = """
12791282-- get all ranking results of completed tasks for all parents with >= 2 children
@@ -1326,13 +1329,13 @@ def ensure_tree_states(self) -> None:
13261329 """Add message tree state rows for all root nodes (initial prompt messages)."""
13271330
13281331 missing_tree_ids = self .query_misssing_tree_states ()
1329- for id in missing_tree_ids :
1332+ for id , lang in missing_tree_ids :
13301333 tree_size = self .db .query (func .count (Message .id )).filter (Message .message_tree_id == id ).scalar ()
13311334 state = message_tree_state .State .INITIAL_PROMPT_REVIEW
13321335 if tree_size > 1 :
13331336 state = message_tree_state .State .GROWING
13341337 logger .info (f"Inserting missing message tree state for message: { id } ({ tree_size = } , { state = :s} )" )
1335- self ._insert_default_state (id , state = state )
1338+ self ._insert_default_state (id , lang = lang , state = state )
13361339
13371340 halt_prompts_of_disabled_users (self .db )
13381341
@@ -1388,6 +1391,12 @@ def query_num_growing_trees(self, lang: str) -> int:
13881391 )
13891392 return query .scalar ()
13901393
1394+ def query_prompt_lottery_waiting (self , lang : str ) -> int :
1395+ query = self .db .query (func .count (MessageTreeState .message_tree_id )).filter (
1396+ MessageTreeState .state == message_tree_state .State .PROMPT_LOTTERY_WAITING , MessageTreeState .lang == lang
1397+ )
1398+ return query .scalar ()
1399+
13911400 def query_num_active_trees (
13921401 self , lang : str , exclude_ranking : bool = True , exclude_prompt_review : bool = True
13931402 ) -> int :
@@ -1451,6 +1460,7 @@ def _insert_tree_state(
14511460 max_depth : int ,
14521461 max_children_count : int ,
14531462 active : bool ,
1463+ lang : str ,
14541464 state : message_tree_state .State = message_tree_state .State .INITIAL_PROMPT_REVIEW ,
14551465 ) -> MessageTreeState :
14561466 model = MessageTreeState (
@@ -1460,6 +1470,7 @@ def _insert_tree_state(
14601470 max_children_count = max_children_count ,
14611471 state = state .value ,
14621472 active = active ,
1473+ lang = lang ,
14631474 )
14641475
14651476 self .db .add (model )
@@ -1469,6 +1480,7 @@ def _insert_tree_state(
14691480 def _insert_default_state (
14701481 self ,
14711482 root_message_id : UUID ,
1483+ lang : str ,
14721484 state : message_tree_state .State = message_tree_state .State .INITIAL_PROMPT_REVIEW ,
14731485 * ,
14741486 goal_tree_size : int = None ,
@@ -1483,8 +1495,9 @@ def _insert_default_state(
14831495 goal_tree_size = goal_tree_size ,
14841496 max_depth = self .cfg .max_tree_depth ,
14851497 max_children_count = self .cfg .max_children_count ,
1486- state = state ,
14871498 active = True ,
1499+ lang = lang ,
1500+ state = state ,
14881501 )
14891502
14901503 def tree_counts_by_state (self , lang : str = None , only_active : bool = False ) -> dict [str , int ]:
0 commit comments