Skip to content

Commit 097dd71

Browse files
authored
Add TreeManagerConfiguration.max_prompt_lottery_waiting parameter (#1889)
- add lang column to message_tree_state, create multi-column index on `state`, `lang` - stop generating initial prompt tasks for languages with more than `max_prompt_lottery_waiting ` prompts in `prompt_lottery_waiting` state - add `MAX_PROMPT_LOTTERY_WAITING` workflow variable
1 parent 8eda31a commit 097dd71

6 files changed

Lines changed: 73 additions & 15 deletions

File tree

.github/workflows/deploy-to-node.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
P_LONELY_CHILD_EXTENSION: ${{ vars.P_LONELY_CHILD_EXTENSION }}
4747
P_ACTIVATE_BACKLOG_TREE: ${{ vars.P_ACTIVATE_BACKLOG_TREE }}
4848
MIN_ACTIVE_RANKINGS_PER_LANG: ${{ vars.MIN_ACTIVE_RANKINGS_PER_LANG }}
49+
MAX_PROMPT_LOTTERY_WAITING: ${{ vars.MAX_PROMPT_LOTTERY_WAITING }}
4950
GOAL_TREE_SIZE: ${{ vars.GOAL_TREE_SIZE }}
5051
MESSAGE_SIZE_LIMIT: ${{ vars.MESSAGE_SIZE_LIMIT }}
5152
SKIP_TOXICITY_CALCULATION: ${{ vars.SKIP_TOXICITY_CALCULATION }}

ansible/deploy-to-node.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@
139139
TREE_MANAGER__MIN_ACTIVE_RANKINGS_PER_LANG:
140140
"{{ lookup('ansible.builtin.env', 'MIN_ACTIVE_RANKINGS_PER_LANG') |
141141
default('0', true) }}"
142+
TREE_MANAGER__MAX_PROMPT_LOTTERY_WAITING:
143+
"{{ lookup('ansible.builtin.env', 'MAX_PROMPT_LOTTERY_WAITING') |
144+
default('250', true) }}"
142145
MESSAGE_SIZE_LIMIT:
143146
"{{ lookup('ansible.builtin.env', 'MESSAGE_SIZE_LIMIT') |
144147
default('2000', true) }}"
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""add lang to message_tree_state
2+
3+
Revision ID: 9db92d504f64
4+
Revises: 8cd0c34d0c3c
5+
Create Date: 2023-02-26 00:52:40.624843
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
11+
# revision identifiers, used by Alembic.
12+
revision = "9db92d504f64"
13+
down_revision = "8cd0c34d0c3c"
14+
branch_labels = None
15+
depends_on = None
16+
17+
18+
def upgrade() -> None:
19+
# ### commands auto generated by Alembic - please adjust! ###
20+
op.add_column("message_tree_state", sa.Column("lang", sa.String(length=32), nullable=True))
21+
op.execute(
22+
"WITH msg AS (SELECT id, lang FROM message WHERE parent_id is NULL) UPDATE message_tree_state mts SET lang = msg.lang FROM msg WHERE mts.message_tree_id = msg.id"
23+
)
24+
op.alter_column("message_tree_state", "lang", nullable=False)
25+
op.drop_index("ix_message_tree_state_state", table_name="message_tree_state")
26+
op.create_index("ix_message_tree_state__lang__state", "message_tree_state", ["state", "lang"], unique=False)
27+
# ### end Alembic commands ###
28+
29+
30+
def downgrade() -> None:
31+
# ### commands auto generated by Alembic - please adjust! ###
32+
op.drop_index("ix_message_tree_state__lang__state", table_name="message_tree_state")
33+
op.create_index("ix_message_tree_state_state", "message_tree_state", ["state"], unique=False)
34+
op.drop_column("message_tree_state", "lang")
35+
# ### end Alembic commands ###

backend/oasst_backend/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class TreeManagerConfiguration(BaseModel):
5454
p_full_labeling_review_prompt: float = 1.0
5555
"""Probability of full text-labeling (instead of mandatory only) for initial prompts."""
5656

57-
p_full_labeling_review_reply_assistant: float = 0.5
57+
p_full_labeling_review_reply_assistant: float = 1.0
5858
"""Probability of full text-labeling (instead of mandatory only) for assistant replies."""
5959

6060
p_full_labeling_review_reply_prompter: float = 0.25
@@ -145,6 +145,10 @@ class TreeManagerConfiguration(BaseModel):
145145
"""Maximum number of pending tasks (neither canceled nor completed) by a single user within
146146
the time span defined by `recent_tasks_span_sec`."""
147147

148+
max_prompt_lottery_waiting: int = 250
149+
"""Maximum number of prompts in prompt_lottery_waiting state per language. If this value
150+
is exceeded no new initial prompt tasks for that language are generated."""
151+
148152

149153
class Settings(BaseSettings):
150154
PROJECT_NAME: str = "open-assistant backend"

backend/oasst_backend/models/message_tree_state.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import sqlalchemy as sa
77
import sqlalchemy.dialects.postgresql as pg
8-
from sqlmodel import Field, SQLModel
8+
from sqlmodel import Field, Index, SQLModel
99

1010

1111
class State(str, Enum):
@@ -74,14 +74,16 @@ class State(str, Enum):
7474

7575
class MessageTreeState(SQLModel, table=True):
7676
__tablename__ = "message_tree_state"
77+
__table_args__ = (Index("ix_message_tree_state__lang__state", "state", "lang", unique=False),)
7778

7879
message_tree_id: UUID = Field(
7980
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), primary_key=True)
8081
)
8182
goal_tree_size: int = Field(nullable=False)
8283
max_depth: int = Field(nullable=False)
8384
max_children_count: int = Field(nullable=False)
84-
state: str = Field(nullable=False, max_length=128, index=True)
85+
state: str = Field(nullable=False, max_length=128)
8586
active: bool = Field(nullable=False, index=True)
8687
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
8788
won_prompt_lottery_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
89+
lang: str = Field(sa_column=sa.Column(sa.String(32), nullable=False))

backend/oasst_backend/tree_manager.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def _determine_task_availability_internal(
214214
) -> dict[protocol_schema.TaskRequestType, int]:
215215
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}
216216

217-
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
217+
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = max(0, num_missing_prompts)
218218

219219
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
220220
list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
@@ -256,13 +256,14 @@ def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
256256

257257
while True:
258258
stats = self.tree_counts_by_state_stats(lang=lang, only_active=True)
259-
259+
prompt_lottery_waiting = self.query_prompt_lottery_waiting(lang=lang)
260+
remaining_lottery_entries = max(0, self.cfg.max_prompt_lottery_waiting - prompt_lottery_waiting)
260261
remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review)
261262
num_missing_growing = max(0, self.cfg.max_active_trees - stats.growing)
262263
logger.info(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")
263264

264265
if num_missing_growing == 0 or activated >= max_activate:
265-
return num_missing_growing + remaining_prompt_review
266+
return min(num_missing_growing + remaining_prompt_review, remaining_lottery_entries)
266267

267268
@managed_tx_function(CommitMode.COMMIT)
268269
def activate_one(db: Session) -> int:
@@ -330,7 +331,7 @@ def activate_one(db: Session) -> int:
330331
return True
331332

332333
if not activate_one():
333-
return num_missing_growing + remaining_prompt_review
334+
return min(num_missing_growing + remaining_prompt_review, remaining_lottery_entries)
334335

335336
activated += 1
336337

@@ -714,8 +715,10 @@ async def handle_interaction(self, interaction: protocol_schema.AnyInteraction)
714715
)
715716

716717
if not message.parent_id:
717-
logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}")
718-
self._insert_default_state(message.id)
718+
logger.info(
719+
f"TreeManager: Inserting new tree state for initial prompt {message.id=} [{message.lang}]"
720+
)
721+
self._insert_default_state(message.id, message.lang)
719722

720723
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
721724
try:
@@ -1261,10 +1264,10 @@ def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
12611264

12621265
return ActiveTreeSizeRow.from_orm(qry.one())
12631266

1264-
def query_misssing_tree_states(self) -> list[UUID]:
1267+
def query_misssing_tree_states(self) -> list[Tuple[UUID, str]]:
12651268
"""Find all initial prompt messages that have no associated message tree state"""
12661269
qry_missing_tree_states = (
1267-
self.db.query(Message.id)
1270+
self.db.query(Message.id, Message.lang)
12681271
.outerjoin(MessageTreeState, Message.message_tree_id == MessageTreeState.message_tree_id)
12691272
.filter(
12701273
Message.parent_id.is_(None),
@@ -1273,7 +1276,7 @@ def query_misssing_tree_states(self) -> list[UUID]:
12731276
)
12741277
)
12751278

1276-
return [m.id for m in qry_missing_tree_states.all()]
1279+
return [(m.id, m.lang) for m in qry_missing_tree_states.all()]
12771280

12781281
_sql_find_tree_ranking_results = """
12791282
-- get all ranking results of completed tasks for all parents with >= 2 children
@@ -1326,13 +1329,13 @@ def ensure_tree_states(self) -> None:
13261329
"""Add message tree state rows for all root nodes (initial prompt messages)."""
13271330

13281331
missing_tree_ids = self.query_misssing_tree_states()
1329-
for id in missing_tree_ids:
1332+
for id, lang in missing_tree_ids:
13301333
tree_size = self.db.query(func.count(Message.id)).filter(Message.message_tree_id == id).scalar()
13311334
state = message_tree_state.State.INITIAL_PROMPT_REVIEW
13321335
if tree_size > 1:
13331336
state = message_tree_state.State.GROWING
13341337
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
1335-
self._insert_default_state(id, state=state)
1338+
self._insert_default_state(id, lang=lang, state=state)
13361339

13371340
halt_prompts_of_disabled_users(self.db)
13381341

@@ -1388,6 +1391,12 @@ def query_num_growing_trees(self, lang: str) -> int:
13881391
)
13891392
return query.scalar()
13901393

1394+
def query_prompt_lottery_waiting(self, lang: str) -> int:
1395+
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(
1396+
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING, MessageTreeState.lang == lang
1397+
)
1398+
return query.scalar()
1399+
13911400
def query_num_active_trees(
13921401
self, lang: str, exclude_ranking: bool = True, exclude_prompt_review: bool = True
13931402
) -> int:
@@ -1451,6 +1460,7 @@ def _insert_tree_state(
14511460
max_depth: int,
14521461
max_children_count: int,
14531462
active: bool,
1463+
lang: str,
14541464
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
14551465
) -> MessageTreeState:
14561466
model = MessageTreeState(
@@ -1460,6 +1470,7 @@ def _insert_tree_state(
14601470
max_children_count=max_children_count,
14611471
state=state.value,
14621472
active=active,
1473+
lang=lang,
14631474
)
14641475

14651476
self.db.add(model)
@@ -1469,6 +1480,7 @@ def _insert_tree_state(
14691480
def _insert_default_state(
14701481
self,
14711482
root_message_id: UUID,
1483+
lang: str,
14721484
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
14731485
*,
14741486
goal_tree_size: int = None,
@@ -1483,8 +1495,9 @@ def _insert_default_state(
14831495
goal_tree_size=goal_tree_size,
14841496
max_depth=self.cfg.max_tree_depth,
14851497
max_children_count=self.cfg.max_children_count,
1486-
state=state,
14871498
active=True,
1499+
lang=lang,
1500+
state=state,
14881501
)
14891502

14901503
def tree_counts_by_state(self, lang: str = None, only_active: bool = False) -> dict[str, int]:

0 commit comments

Comments
 (0)