Skip to content

Commit 93886f0

Browse files
authored
Assistant Prompt length + client side (#4433)
1 parent 8c3a953 commit 93886f0

File tree

16 files changed

+211
-360
lines changed

16 files changed

+211
-360
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""update prompt length
2+
3+
Revision ID: 4794bc13e484
4+
Revises: f7505c5b0284
5+
Create Date: 2025-04-02 11:26:36.180328
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "4794bc13e484"
14+
down_revision = "f7505c5b0284"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
op.alter_column(
21+
"prompt",
22+
"system_prompt",
23+
existing_type=sa.TEXT(),
24+
type_=sa.String(length=5000000),
25+
existing_nullable=False,
26+
)
27+
op.alter_column(
28+
"prompt",
29+
"task_prompt",
30+
existing_type=sa.TEXT(),
31+
type_=sa.String(length=5000000),
32+
existing_nullable=False,
33+
)
34+
35+
36+
def downgrade() -> None:
37+
op.alter_column(
38+
"prompt",
39+
"system_prompt",
40+
existing_type=sa.String(length=5000000),
41+
type_=sa.TEXT(),
42+
existing_nullable=False,
43+
)
44+
op.alter_column(
45+
"prompt",
46+
"task_prompt",
47+
existing_type=sa.String(length=5000000),
48+
type_=sa.TEXT(),
49+
existing_nullable=False,
50+
)

Diff for: backend/alembic/versions/f71470ba9274_add_prompt_length_limit.py

+30-30
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
Create Date: 2025-04-01 15:07:14.977435
66
77
"""
8-
from alembic import op
9-
import sqlalchemy as sa
108

119

1210
# revision identifiers, used by Alembic.
@@ -17,34 +15,36 @@
1715

1816

1917
def upgrade() -> None:
20-
op.alter_column(
21-
"prompt",
22-
"system_prompt",
23-
existing_type=sa.TEXT(),
24-
type_=sa.String(length=8000),
25-
existing_nullable=False,
26-
)
27-
op.alter_column(
28-
"prompt",
29-
"task_prompt",
30-
existing_type=sa.TEXT(),
31-
type_=sa.String(length=8000),
32-
existing_nullable=False,
33-
)
18+
# op.alter_column(
19+
# "prompt",
20+
# "system_prompt",
21+
# existing_type=sa.TEXT(),
22+
# type_=sa.String(length=8000),
23+
# existing_nullable=False,
24+
# )
25+
# op.alter_column(
26+
# "prompt",
27+
# "task_prompt",
28+
# existing_type=sa.TEXT(),
29+
# type_=sa.String(length=8000),
30+
# existing_nullable=False,
31+
# )
32+
pass
3433

3534

3635
def downgrade() -> None:
37-
op.alter_column(
38-
"prompt",
39-
"system_prompt",
40-
existing_type=sa.String(length=8000),
41-
type_=sa.TEXT(),
42-
existing_nullable=False,
43-
)
44-
op.alter_column(
45-
"prompt",
46-
"task_prompt",
47-
existing_type=sa.String(length=8000),
48-
type_=sa.TEXT(),
49-
existing_nullable=False,
50-
)
36+
# op.alter_column(
37+
# "prompt",
38+
# "system_prompt",
39+
# existing_type=sa.String(length=8000),
40+
# type_=sa.TEXT(),
41+
# existing_nullable=False,
42+
# )
43+
# op.alter_column(
44+
# "prompt",
45+
# "task_prompt",
46+
# existing_type=sa.String(length=8000),
47+
# type_=sa.TEXT(),
48+
# existing_nullable=False,
49+
# )
50+
pass

Diff for: backend/onyx/db/persona.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from onyx.db.models import UserFolder
3838
from onyx.db.models import UserGroup
3939
from onyx.db.notification import create_notification
40+
from onyx.server.features.persona.models import FullPersonaSnapshot
4041
from onyx.server.features.persona.models import PersonaSharedNotificationData
41-
from onyx.server.features.persona.models import PersonaSnapshot
4242
from onyx.server.features.persona.models import PersonaUpsertRequest
4343
from onyx.utils.logger import setup_logger
4444
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -201,7 +201,7 @@ def create_update_persona(
201201
create_persona_request: PersonaUpsertRequest,
202202
user: User | None,
203203
db_session: Session,
204-
) -> PersonaSnapshot:
204+
) -> FullPersonaSnapshot:
205205
"""Higher level function than upsert_persona, although either is valid to use."""
206206
# Permission to actually use these is checked later
207207

@@ -271,7 +271,7 @@ def create_update_persona(
271271
logger.exception("Failed to create persona")
272272
raise HTTPException(status_code=400, detail=str(e))
273273

274-
return PersonaSnapshot.from_model(persona)
274+
return FullPersonaSnapshot.from_model(persona)
275275

276276

277277
def update_persona_shared_users(

Diff for: backend/onyx/server/features/persona/api.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from onyx.secondary_llm_flows.starter_message_creation import (
4444
generate_starter_messages,
4545
)
46+
from onyx.server.features.persona.models import FullPersonaSnapshot
4647
from onyx.server.features.persona.models import GenerateStarterMessageRequest
4748
from onyx.server.features.persona.models import ImageGenerationToolStatus
4849
from onyx.server.features.persona.models import PersonaLabelCreate
@@ -424,8 +425,8 @@ def get_persona(
424425
persona_id: int,
425426
user: User | None = Depends(current_limited_user),
426427
db_session: Session = Depends(get_session),
427-
) -> PersonaSnapshot:
428-
return PersonaSnapshot.from_model(
428+
) -> FullPersonaSnapshot:
429+
return FullPersonaSnapshot.from_model(
429430
get_persona_by_id(
430431
persona_id=persona_id,
431432
user=user,

Diff for: backend/onyx/server/features/persona/models.py

+83-52
Original file line numberDiff line numberDiff line change
@@ -91,82 +91,113 @@ class PersonaUpsertRequest(BaseModel):
9191

9292
class PersonaSnapshot(BaseModel):
9393
id: int
94-
owner: MinimalUserSnapshot | None
9594
name: str
96-
is_visible: bool
97-
is_public: bool
98-
display_priority: int | None
9995
description: str
100-
num_chunks: float | None
101-
llm_relevance_filter: bool
102-
llm_filter_extraction: bool
103-
llm_model_provider_override: str | None
104-
llm_model_version_override: str | None
105-
starter_messages: list[StarterMessage] | None
106-
builtin_persona: bool
107-
prompts: list[PromptSnapshot]
108-
tools: list[ToolSnapshot]
109-
document_sets: list[DocumentSet]
110-
users: list[MinimalUserSnapshot]
111-
groups: list[int]
112-
icon_color: str | None
113-
icon_shape: int | None
96+
is_public: bool
97+
is_visible: bool
98+
icon_shape: int | None = None
99+
icon_color: str | None = None
114100
uploaded_image_id: str | None = None
115-
is_default_persona: bool
101+
user_file_ids: list[int] = Field(default_factory=list)
102+
user_folder_ids: list[int] = Field(default_factory=list)
103+
display_priority: int | None = None
104+
is_default_persona: bool = False
105+
builtin_persona: bool = False
106+
starter_messages: list[StarterMessage] | None = None
107+
tools: list[ToolSnapshot] = Field(default_factory=list)
108+
labels: list["PersonaLabelSnapshot"] = Field(default_factory=list)
109+
owner: MinimalUserSnapshot | None = None
110+
users: list[MinimalUserSnapshot] = Field(default_factory=list)
111+
groups: list[int] = Field(default_factory=list)
112+
document_sets: list[DocumentSet] = Field(default_factory=list)
113+
llm_model_provider_override: str | None = None
114+
llm_model_version_override: str | None = None
115+
num_chunks: float | None = None
116+
117+
@classmethod
118+
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
119+
return PersonaSnapshot(
120+
id=persona.id,
121+
name=persona.name,
122+
description=persona.description,
123+
is_public=persona.is_public,
124+
is_visible=persona.is_visible,
125+
icon_shape=persona.icon_shape,
126+
icon_color=persona.icon_color,
127+
uploaded_image_id=persona.uploaded_image_id,
128+
user_file_ids=[file.id for file in persona.user_files],
129+
user_folder_ids=[folder.id for folder in persona.user_folders],
130+
display_priority=persona.display_priority,
131+
is_default_persona=persona.is_default_persona,
132+
builtin_persona=persona.builtin_persona,
133+
starter_messages=persona.starter_messages,
134+
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
135+
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
136+
owner=(
137+
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
138+
if persona.user
139+
else None
140+
),
141+
users=[
142+
MinimalUserSnapshot(id=user.id, email=user.email)
143+
for user in persona.users
144+
],
145+
groups=[user_group.id for user_group in persona.groups],
146+
document_sets=[
147+
DocumentSet.from_model(document_set_model)
148+
for document_set_model in persona.document_sets
149+
],
150+
llm_model_provider_override=persona.llm_model_provider_override,
151+
llm_model_version_override=persona.llm_model_version_override,
152+
num_chunks=persona.num_chunks,
153+
)
154+
155+
156+
# Model with full context on perona's internal settings
157+
# This is used for flows which need to know all settings
158+
class FullPersonaSnapshot(PersonaSnapshot):
116159
search_start_date: datetime | None = None
117-
labels: list["PersonaLabelSnapshot"] = []
118-
user_file_ids: list[int] | None = None
119-
user_folder_ids: list[int] | None = None
160+
prompts: list[PromptSnapshot] = Field(default_factory=list)
161+
llm_relevance_filter: bool = False
162+
llm_filter_extraction: bool = False
120163

121164
@classmethod
122165
def from_model(
123166
cls, persona: Persona, allow_deleted: bool = False
124-
) -> "PersonaSnapshot":
167+
) -> "FullPersonaSnapshot":
125168
if persona.deleted:
126169
error_msg = f"Persona with ID {persona.id} has been deleted"
127170
if not allow_deleted:
128171
raise ValueError(error_msg)
129172
else:
130173
logger.warning(error_msg)
131174

132-
return PersonaSnapshot(
175+
return FullPersonaSnapshot(
133176
id=persona.id,
134177
name=persona.name,
178+
description=persona.description,
179+
is_public=persona.is_public,
180+
is_visible=persona.is_visible,
181+
icon_shape=persona.icon_shape,
182+
icon_color=persona.icon_color,
183+
uploaded_image_id=persona.uploaded_image_id,
184+
user_file_ids=[file.id for file in persona.user_files],
185+
user_folder_ids=[folder.id for folder in persona.user_folders],
186+
display_priority=persona.display_priority,
187+
is_default_persona=persona.is_default_persona,
188+
builtin_persona=persona.builtin_persona,
189+
starter_messages=persona.starter_messages,
190+
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
191+
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
135192
owner=(
136193
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
137194
if persona.user
138195
else None
139196
),
140-
is_visible=persona.is_visible,
141-
is_public=persona.is_public,
142-
display_priority=persona.display_priority,
143-
description=persona.description,
144-
num_chunks=persona.num_chunks,
197+
search_start_date=persona.search_start_date,
198+
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
145199
llm_relevance_filter=persona.llm_relevance_filter,
146200
llm_filter_extraction=persona.llm_filter_extraction,
147-
llm_model_provider_override=persona.llm_model_provider_override,
148-
llm_model_version_override=persona.llm_model_version_override,
149-
starter_messages=persona.starter_messages,
150-
builtin_persona=persona.builtin_persona,
151-
is_default_persona=persona.is_default_persona,
152-
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
153-
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
154-
document_sets=[
155-
DocumentSet.from_model(document_set_model)
156-
for document_set_model in persona.document_sets
157-
],
158-
users=[
159-
MinimalUserSnapshot(id=user.id, email=user.email)
160-
for user in persona.users
161-
],
162-
groups=[user_group.id for user_group in persona.groups],
163-
icon_color=persona.icon_color,
164-
icon_shape=persona.icon_shape,
165-
uploaded_image_id=persona.uploaded_image_id,
166-
search_start_date=persona.search_start_date,
167-
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
168-
user_file_ids=[file.id for file in persona.user_files],
169-
user_folder_ids=[folder.id for folder in persona.user_folders],
170201
)
171202

172203

Diff for: backend/onyx/server/manage/models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from onyx.db.models import SlackChannelConfig as SlackChannelConfigModel
2020
from onyx.db.models import User
2121
from onyx.onyxbot.slack.config import VALID_SLACK_FILTERS
22+
from onyx.server.features.persona.models import FullPersonaSnapshot
2223
from onyx.server.features.persona.models import PersonaSnapshot
2324
from onyx.server.models import FullUserSnapshot
2425
from onyx.server.models import InvitedUserSnapshot
@@ -245,7 +246,7 @@ def from_model(
245246
id=slack_channel_config_model.id,
246247
slack_bot_id=slack_channel_config_model.slack_bot_id,
247248
persona=(
248-
PersonaSnapshot.from_model(
249+
FullPersonaSnapshot.from_model(
249250
slack_channel_config_model.persona, allow_deleted=True
250251
)
251252
if slack_channel_config_model.persona

Diff for: backend/tests/integration/common_utils/managers/persona.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import requests
55

66
from onyx.context.search.enums import RecencyBiasSetting
7-
from onyx.server.features.persona.models import PersonaSnapshot
7+
from onyx.server.features.persona.models import FullPersonaSnapshot
88
from onyx.server.features.persona.models import PersonaUpsertRequest
99
from tests.integration.common_utils.constants import API_SERVER_URL
1010
from tests.integration.common_utils.constants import GENERAL_HEADERS
@@ -181,29 +181,29 @@ def edit(
181181
@staticmethod
182182
def get_all(
183183
user_performing_action: DATestUser | None = None,
184-
) -> list[PersonaSnapshot]:
184+
) -> list[FullPersonaSnapshot]:
185185
response = requests.get(
186186
f"{API_SERVER_URL}/admin/persona",
187187
headers=user_performing_action.headers
188188
if user_performing_action
189189
else GENERAL_HEADERS,
190190
)
191191
response.raise_for_status()
192-
return [PersonaSnapshot(**persona) for persona in response.json()]
192+
return [FullPersonaSnapshot(**persona) for persona in response.json()]
193193

194194
@staticmethod
195195
def get_one(
196196
persona_id: int,
197197
user_performing_action: DATestUser | None = None,
198-
) -> list[PersonaSnapshot]:
198+
) -> list[FullPersonaSnapshot]:
199199
response = requests.get(
200200
f"{API_SERVER_URL}/persona/{persona_id}",
201201
headers=user_performing_action.headers
202202
if user_performing_action
203203
else GENERAL_HEADERS,
204204
)
205205
response.raise_for_status()
206-
return [PersonaSnapshot(**response.json())]
206+
return [FullPersonaSnapshot(**response.json())]
207207

208208
@staticmethod
209209
def verify(

0 commit comments

Comments
 (0)