Skip to content

Commit ad5f8a5

Browse files
committed
Merge branch 'main' of github.com:LAION-AI/Open-Chat-GPT
2 parents 5ab2e2d + 3ab2e01 commit ad5f8a5

34 files changed

Lines changed: 895 additions & 242 deletions

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,4 @@ repos:
6060
types_or: [javascript, jsx, ts, tsx]
6161
language: system
6262
pass_filenames: false
63-
entry: bash -c 'cd website && npm install && npm run lint'
63+
entry: bash -c 'cd website && npm ci && npm run lint'
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# -*- coding: utf-8 -*-
2+
"""Added lang column for ISO-639-1 codes
3+
4+
Revision ID: ef0b52902560
5+
Revises: 3358eb6834e6
6+
Create Date: 2022-12-28 18:24:21.393973
7+
8+
"""
9+
import sqlalchemy as sa
10+
import sqlmodel
11+
from alembic import op
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "ef0b52902560"
15+
down_revision = "3358eb6834e6"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.add_column(
23+
"post", sa.Column("lang", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False, default="en-US")
24+
)
25+
# ### end Alembic commands ###
26+
27+
28+
def downgrade() -> None:
29+
# ### commands auto generated by Alembic - please adjust! ###
30+
op.drop_column("post", "lang")
31+
# ### end Alembic commands ###
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding: utf-8 -*-
2+
"""add collective flag to task
3+
4+
Revision ID: 464ec4667aae
5+
Revises: d24b37426857
6+
Create Date: 2022-12-29 21:03:06.841962
7+
8+
"""
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "464ec4667aae"
14+
down_revision = "d24b37426857"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.add_column(
22+
"work_package", sa.Column("collective", sa.Boolean(), server_default=sa.text("false"), nullable=False)
23+
)
24+
# ### end Alembic commands ###
25+
26+
27+
def downgrade() -> None:
28+
# ### commands auto generated by Alembic - please adjust! ###
29+
op.drop_column("work_package", "collective")
30+
# ### end Alembic commands ###

backend/oasst_backend/api/v1/tasks.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def request_task(
139139
try:
140140
pr = PromptRepository(db, api_client, request.user)
141141
task, thread_id, parent_post_id = generate_task(request, pr)
142-
pr.store_task(task, thread_id, parent_post_id)
142+
pr.store_task(task, thread_id, parent_post_id, request.collective)
143143

144144
except OasstError:
145145
raise
@@ -252,3 +252,15 @@ def post_interaction(
252252
except Exception:
253253
logger.exception("Interaction request failed.")
254254
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
255+
256+
257+
@router.post("/close")
258+
def close_collective_task(
259+
close_task_request: protocol_schema.TaskClose,
260+
db: Session = Depends(deps.get_db),
261+
api_key: APIKey = Depends(deps.get_api_key),
262+
):
263+
api_client = deps.api_auth(api_key, db)
264+
pr = PromptRepository(db, api_client, user=None)
265+
pr.close_task(close_task_request.post_id)
266+
return protocol_schema.TaskDone()

backend/oasst_backend/exceptions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class OasstErrorCode(IntEnum):
4141
WORK_PACKAGE_ALREADY_UPDATED = 2103
4242
WORK_PACKAGE_NOT_ACK = 2104
4343
WORK_PACKAGE_ALREADY_DONE = 2105
44+
WORK_PACKAGE_NOT_COLLECTIVE = 2106
4445

4546

4647
class OasstError(Exception):

backend/oasst_backend/models/post.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,6 @@ class Post(SQLModel, table=True):
3131
)
3232
payload_type: str = Field(nullable=False, max_length=200)
3333
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
34+
lang: str = Field(nullable=False, max_length=200, default="en-US")
3435
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
3536
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))

backend/oasst_backend/models/work_package.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class WorkPackage(SQLModel, table=True):
3232
frontend_ref_post_id: Optional[str] = None
3333
thread_id: Optional[UUID] = None
3434
parent_post_id: Optional[UUID] = None
35+
collective: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
3536

3637
@property
3738
def expired(self) -> bool:

backend/oasst_backend/prompt_repository.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ def store_text_reply(self, text: str, post_id: str, user_post_id: str, role: str
160160
payload=db_payload.PostPayload(text=text),
161161
depth=depth,
162162
)
163-
wp.done = True
164-
self.db.add(wp)
163+
if not wp.collective:
164+
wp.done = True
165+
self.db.add(wp)
165166
self.db.commit()
166167
self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text))
167168
return user_post
@@ -186,15 +187,20 @@ def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
186187
# store reaction to post
187188
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
188189
reaction = self.insert_reaction(post.id, reaction_payload)
190+
if not work_package.collective:
191+
work_package.done = True
192+
self.db.add(work_package)
193+
189194
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
190195
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
191196
return reaction
192197

193198
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
194199
# fetch work_package
195200
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
196-
work_package.done = True
197-
self.db.add(work_package)
201+
if not work_package.collective:
202+
work_package.done = True
203+
self.db.add(work_package)
198204

199205
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
200206
work_package.payload.payload
@@ -250,6 +256,7 @@ def store_task(
250256
task: protocol_schema.Task,
251257
thread_id: UUID = None,
252258
parent_post_id: UUID = None,
259+
collective: bool = False,
253260
) -> WorkPackage:
254261
payload: db_payload.TaskPayload
255262
match type(task):
@@ -287,10 +294,7 @@ def store_task(
287294
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
288295

289296
wp = self.insert_work_package(
290-
payload=payload,
291-
id=task.id,
292-
thread_id=thread_id,
293-
parent_post_id=parent_post_id,
297+
payload=payload, id=task.id, thread_id=thread_id, parent_post_id=parent_post_id, collective=collective
294298
)
295299
assert wp.id == task.id
296300
return wp
@@ -301,6 +305,7 @@ def insert_work_package(
301305
id: UUID = None,
302306
thread_id: UUID = None,
303307
parent_post_id: UUID = None,
308+
collective: bool = False,
304309
) -> WorkPackage:
305310
c = PayloadContainer(payload=payload)
306311
wp = WorkPackage(
@@ -311,6 +316,7 @@ def insert_work_package(
311316
api_client_id=self.api_client.id,
312317
thread_id=thread_id,
313318
parent_post_id=parent_post_id,
319+
collective=collective,
314320
)
315321
self.db.add(wp)
316322
self.db.commit()
@@ -397,7 +403,7 @@ def fetch_random_thread(self, require_role: str = None) -> list[Post]:
397403
distinct_threads = distinct_threads.filter(Post.role == require_role)
398404
distinct_threads = distinct_threads.subquery()
399405

400-
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1).subquery()
406+
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1)
401407
thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all()
402408
return thread_posts
403409

@@ -443,7 +449,7 @@ def fetch_multiple_random_replies(self, max_size: int = 5, post_role: str = None
443449
if post_role:
444450
parent = parent.filter(Post.role == post_role)
445451

446-
parent = parent.order_by(func.random()).limit(1).subquery()
452+
parent = parent.order_by(func.random()).limit(1)
447453
replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
448454
if not replies:
449455
raise OasstError("No replies found", OasstErrorCode.NO_REPLIES_FOUND)
@@ -465,3 +471,20 @@ def fetch_multiple_random_replies(self, max_size: int = 5, post_role: str = None
465471

466472
def fetch_post(self, post_id: UUID) -> Optional[Post]:
467473
return self.db.query(Post).filter(Post.id == post_id).one()
474+
475+
def close_task(self, post_id: str, allow_personal_tasks: bool = False):
476+
self.validate_post_id(post_id)
477+
wp = self.fetch_workpackage_by_postid(post_id)
478+
479+
if not wp:
480+
raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
481+
if wp.expired:
482+
raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED)
483+
if not allow_personal_tasks and not wp.collective:
484+
raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE)
485+
if wp.done:
486+
raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
487+
488+
wp.done = True
489+
self.db.add(wp)
490+
self.db.commit()

discord-bot/api_client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,19 @@ def _parse_task(self, data: dict) -> protocol_schema.Task:
5252
return self.task_models_map[task_type].parse_obj(data)
5353

5454
def fetch_task(
55-
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
55+
self,
56+
task_type: protocol_schema.TaskRequestType,
57+
user: Optional[protocol_schema.User] = None,
58+
collective: bool = False,
5659
) -> protocol_schema.Task:
57-
req = protocol_schema.TaskRequest(type=task_type, user=user)
60+
req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective)
5861
data = self.post("/api/v1/tasks/", req.dict())
5962
return self._parse_task(data)
6063

61-
def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
62-
return self.fetch_task(protocol_schema.TaskRequestType.random, user)
64+
def fetch_random_task(
65+
self, user: Optional[protocol_schema.User] = None, collective: bool = False
66+
) -> protocol_schema.Task:
67+
return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective)
6368

6469
def ack_task(self, task_id: str, post_id: str) -> None:
6570
req = protocol_schema.TaskAck(post_id=post_id)

oasst-shared/oasst_shared/schemas/protocol.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class TaskRequest(BaseModel):
4343

4444
type: TaskRequestType = TaskRequestType.random
4545
user: Optional[User] = None
46+
collective: bool = False
4647

4748

4849
class TaskAck(BaseModel):
@@ -57,6 +58,12 @@ class TaskNAck(BaseModel):
5758
reason: str
5859

5960

61+
class TaskClose(BaseModel):
62+
"""The frontend asks to mark task as done"""
63+
64+
post_id: str
65+
66+
6067
class Task(BaseModel):
6168
"""A task is a unit of work that the backend gives to the frontend."""
6269

0 commit comments

Comments
 (0)