Skip to content

functional base persona creation #1792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""add nullable to persona id in Chat Session

Revision ID: c99d76fcd298
Revises: bceb1e139447
Create Date: 2024-07-09 19:27:01.579697

"""

from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = "c99d76fcd298"
down_revision = "bceb1e139447"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)


def downgrade() -> None:
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=False,
)
30 changes: 17 additions & 13 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,11 @@ def stream_chat_message_objects(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
dedupe_docs=(
retrieval_options.dedupe_docs
if retrieval_options
else False
),
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
Expand Down Expand Up @@ -767,16 +769,18 @@ def stream_chat_message_objects(
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else []
),
)

logger.debug("Committing messages")
Expand Down
17 changes: 12 additions & 5 deletions backend/danswer/connectors/confluence/rate_limit_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,24 @@ def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e

retry_after_header = e.response.headers.get("Retry-After")
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
try:
retry_after = int(e.response.headers.get("Retry-After"))
except (ValueError, TypeError):
pass
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
except ValueError:
pass

if retry_after:
if retry_after is not None:
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)
Expand Down
3 changes: 2 additions & 1 deletion backend/danswer/connectors/web/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def check_internet_connection(url: str) -> None:
response = requests.get(url, timeout=3)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
# Extract status code from the response, defaulting to -1 if response is None
status_code = e.response.status_code if e.response is not None else -1
error_msg = {
400: "Bad Request",
401: "Unauthorized",
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/db/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_default_admin_user_emails() -> list[str]:
get_default_admin_user_emails_fn: Callable[
[], list[str]
] = fetch_versioned_implementation_with_fallback(
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
"danswer.auth.users", "get_default_admin_user_emails_", lambda: list[str]()
)
return get_default_admin_user_emails_fn()

Expand Down
17 changes: 17 additions & 0 deletions backend/danswer/db/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from sqlalchemy.orm import Session

from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import DocumentSet
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.db.models import Tool as ToolModel
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from shared_configs.enums import EmbeddingProvider
Expand Down Expand Up @@ -103,6 +105,21 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())



def fetch_existing_doc_sets(
db_session: Session, doc_ids: list[int]
) -> list[DocumentSet]:
return list(
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
)


def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
return list(
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
)


def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,
Expand Down
9 changes: 6 additions & 3 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,9 @@ class ChatSession(Base):

id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
Expand Down Expand Up @@ -874,7 +876,6 @@ class ChatSession(Base):
prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True
)

time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
Expand All @@ -883,7 +884,6 @@ class ChatSession(Base):
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"
Expand All @@ -893,6 +893,9 @@ class ChatSession(Base):
)
persona: Mapped["Persona"] = relationship("Persona")

# def get_persona():
# return


class ChatMessage(Base):
"""Note, the first message in a chain has no contents, it's a workaround to allow edits
Expand Down
6 changes: 5 additions & 1 deletion backend/danswer/file_store/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections.abc import Callable
from io import BytesIO
from typing import Any
from typing import cast
from uuid import uuid4

Expand Down Expand Up @@ -73,5 +75,7 @@ def save_file_from_url(url: str) -> str:


def save_files_from_urls(urls: list[str]) -> list[str]:
funcs = [(save_file_from_url, (url,)) for url in urls]
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(save_file_from_url, (url,)) for url in urls
]
return run_functions_tuples_in_parallel(funcs)
52 changes: 35 additions & 17 deletions backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id
from danswer.llm.answering.answer import Answer
Expand Down Expand Up @@ -60,7 +61,7 @@
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time

from ee.danswer.server.query_and_chat.utils import create_temporary_persona

logger = setup_logger()

Expand Down Expand Up @@ -97,6 +98,7 @@ def stream_answer_objects(
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
temporary_persona: Persona | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> AnswerObjectIterator:
"""Streams in order:
Expand All @@ -114,11 +116,14 @@ def stream_answer_objects(
db_session=db_session,
description="", # One shot queries don't need naming as it's never displayed
user_id=user_id,
# TODO fix this - should not store? Or add a new default value (-3) for custom?
persona_id=query_req.persona_id,
one_shot=True,
danswerbot_flow=danswerbot_flow,
)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)

persona = temporary_persona if temporary_persona else chat_session.persona
llm, fast_llm = get_llms_for_persona(persona=persona)

llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
Expand Down Expand Up @@ -153,11 +158,11 @@ def stream_answer_objects(
prompt_id=query_req.prompt_id, user=None, db_session=db_session
)
if prompt is None:
if not chat_session.persona.prompts:
if not persona.prompts:
raise RuntimeError(
"Persona does not have any prompts - this should never happen"
)
prompt = chat_session.persona.prompts[0]
prompt = persona.prompts[0]

# Create the first User query message
new_user_message = create_new_chat_message(
Expand All @@ -174,29 +179,32 @@ def stream_answer_objects(
prompt_config = PromptConfig.from_model(prompt)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
chat_session.persona.num_chunks
if chat_session.persona.num_chunks is not None
else default_num_chunks
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
),
max_tokens=max_document_tokens,
)

if temporary_persona:
for tool in temporary_persona.tools:
if tool.in_code_tool_id == "SearchTool":
pass

search_tool = SearchTool(
db_session=db_session,
user=user,
evaluation_type=LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type,
persona=chat_session.persona,
persona=persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
bypass_acl=bypass_acl,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
bypass_acl=bypass_acl,
)

answer_config = AnswerStyleConfig(
Expand All @@ -209,24 +217,25 @@ def stream_answer_objects(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name,
args={"query": rephrased_query},
tools=[search_tool] if search_tool else [],
force_use_tool=(
ForceUseTool(
tool_name=search_tool.name,
args={"query": rephrased_query},
force_use=True,
)
),
# for now, don't use tool calling for this flow, as we haven't
# tested quotes with tool calling too much yet
skip_explicit_tool_calling=True,
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)

# won't be any ImageGenerationDisplay responses since that tool is never passed in

for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
print(packet)
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
# (likely fine that it comes after the initial creation of the search docs)
Expand Down Expand Up @@ -261,6 +270,7 @@ def stream_answer_objects(
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
)

yield initial_response

elif packet.id == SEARCH_DOC_CONTENT_ID:
Expand Down Expand Up @@ -348,7 +358,15 @@ def get_search_answer(
"""Collects the streamed one shot answer responses into a single object"""
qa_response = OneShotQAResponse()

temporary_persona: Persona | None = None
if query_req.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config
)
temporary_persona = new_persona

results = stream_answer_objects(
temporary_persona=temporary_persona,
query_req=query_req,
user=user,
max_document_tokens=max_document_tokens,
Expand Down
Loading
Loading