Skip to content

Id not set in checkpoint2 #4468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,16 @@ def generate_initial_answer(

consolidated_context_docs = structured_subquestion_docs.cited_documents
counter = 0
for original_doc_number, original_doc in enumerate(
orig_question_retrieval_documents
):
if original_doc_number not in structured_subquestion_docs.cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
):
consolidated_context_docs.append(original_doc)
counter += 1
for original_doc in orig_question_retrieval_documents:
if original_doc in structured_subquestion_docs.cited_documents:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's going on here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing fixes

continue

if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
):
consolidated_context_docs.append(original_doc)
counter += 1

# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,8 @@ def generate_validate_refined_answer(
consolidated_context_docs = structured_subquestion_docs.cited_documents

counter = 0
for original_doc_number, original_doc in enumerate(
original_question_verified_documents
):
if original_doc_number not in structured_subquestion_docs.cited_documents:
for original_doc in original_question_verified_documents:
if original_doc not in structured_subquestion_docs.cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def choose_tool(
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool._NAME
)
):
override_kwargs = SearchToolOverrideKwargs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False

if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was never True, see lines above

return False

# If the last sync is None, it has never been run so we run the sync
last_perm_sync = cc_pair.last_time_perm_sync
if last_perm_sync is None:
Expand Down
42 changes: 28 additions & 14 deletions backend/onyx/connectors/google_drive/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def _impersonate_user_for_retrieval(
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
concurrent_drive_itr: Callable[[str], Iterator[str]],
filtered_folder_ids: set[str],
sorted_filtered_folder_ids: list[str],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[RetrievedDriveFile]:
Expand Down Expand Up @@ -509,6 +509,7 @@ def _yield_from_drive(
yield from _yield_from_drive(drive_id, start)
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
resuming = False # we are starting the next stage for the first time

if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:

def _yield_from_folder_crawl(
Expand All @@ -526,16 +527,28 @@ def _yield_from_folder_crawl(
)

# resume from a checkpoint
last_processed_folder = None
if resuming:
folder_id = curr_stage.completed_until_parent_id
assert folder_id is not None, "folder id not set in checkpoint"
resume_start = curr_stage.completed_until
yield from _yield_from_folder_crawl(folder_id, resume_start)
last_processed_folder = folder_id

skipping_seen_folders = last_processed_folder is not None
for folder_id in sorted_filtered_folder_ids:
if skipping_seen_folders:
skipping_seen_folders = folder_id != last_processed_folder
continue

if folder_id in self._retrieved_ids:
continue

remaining_folders = filtered_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
curr_stage.completed_until = 0
curr_stage.completed_until_parent_id = folder_id
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
yield from _yield_from_folder_crawl(folder_id, start)

curr_stage.stage = DriveRetrievalStage.DONE

def _manage_service_account_retrieval(
Expand Down Expand Up @@ -584,11 +597,13 @@ def _manage_service_account_retrieval(
drive_ids_to_retrieve, checkpoint
)

sorted_filtered_folder_ids = sorted(folder_ids_to_retrieve)

# only process emails that we haven't already completed retrieval for
non_completed_org_emails = [
user_email
for user_email, stage in checkpoint.completion_map.items()
if stage != DriveRetrievalStage.DONE
for user_email, stage_completion in checkpoint.completion_map.items()
if stage_completion.stage != DriveRetrievalStage.DONE
]

# don't process too many emails before returning a checkpoint. This is
Expand All @@ -609,7 +624,7 @@ def _manage_service_account_retrieval(
is_slim,
checkpoint,
drive_id_iterator,
folder_ids_to_retrieve,
sorted_filtered_folder_ids,
start,
end,
)
Expand Down Expand Up @@ -837,14 +852,13 @@ def _checkpointed_retrieval(
return

for file in drive_files:
if file.error is None:
checkpoint.completion_map[file.user_email].update(
stage=file.completion_stage,
completed_until=datetime.fromisoformat(
file.drive_file[GoogleFields.MODIFIED_TIME.value]
).timestamp(),
completed_until_parent_id=file.parent_id,
)
checkpoint.completion_map[file.user_email].update(
stage=file.completion_stage,
completed_until=datetime.fromisoformat(
file.drive_file[GoogleFields.MODIFIED_TIME.value]
).timestamp(),
completed_until_parent_id=file.parent_id,
)
yield file

def _manage_oauth_retrieval(
Expand Down
7 changes: 4 additions & 3 deletions backend/onyx/connectors/google_drive/file_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,17 @@ def crawl_folders_for_files(
start=start,
end=end,
):
found_files = True
logger.info(f"Found file: {file['name']}, user email: {user_email}")
found_files = True
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
parent_id=parent_id,
completion_stage=DriveRetrievalStage.FOLDER_FILES,
)
# Only mark a folder as done if it was fully traversed without errors
if found_files:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the move? If it's important, would prefer to add a comment as to why. If not / purely stylistic, ignore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously we were marking folders as traversed if at least one file from the folder was retrieved without an error; now it will only be marked as done if ALL files from it are retrieved. With the new system for tracking folder completion (sorting and continuing from the last SEEN folder rather than last retrieved), this shouldn't cause us to get stuck and should let us handle pathological cases like a bunch of different users having individual access to files in a "shared folder" that isn't actually fully shared due to permission revoking.

update_traversed_ids_func(parent_id)
except Exception as e:
logger.error(f"Error getting files in parent {parent_id}: {e}")
yield RetrievedDriveFile(
Expand All @@ -139,8 +142,6 @@ def crawl_folders_for_files(
completion_stage=DriveRetrievalStage.FOLDER_FILES,
error=e,
)
if found_files:
update_traversed_ids_func(parent_id)
else:
logger.info(f"Skipping subfolder files since already traversed: {parent_id}")

Expand Down
20 changes: 4 additions & 16 deletions backend/onyx/context/search/postprocessing/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,14 +374,6 @@ def filter_sections(
if query.evaluation_type == LLMEvaluationType.SKIP:
return []

# Additional safeguard: Log a warning if this function is ever called with SKIP evaluation type
# This should never happen if our fast paths are working correctly
if query.evaluation_type == LLMEvaluationType.SKIP:
logger.warning(
"WARNING: filter_sections called with SKIP evaluation_type. This should never happen!"
)
return []

sections_to_filter = sections_to_filter[: query.max_llm_filter_sections]

contents = [
Expand Down Expand Up @@ -461,12 +453,10 @@ def search_postprocessing(

llm_filter_task_id = None
# Only add LLM filtering if not in SKIP mode and if LLM doc relevance is not disabled
if (
search_query.evaluation_type not in [LLMEvaluationType.SKIP]
and not DISABLE_LLM_DOC_RELEVANCE
and search_query.evaluation_type
in [LLMEvaluationType.BASIC, LLMEvaluationType.UNSPECIFIED]
):
if not DISABLE_LLM_DOC_RELEVANCE and search_query.evaluation_type in [
LLMEvaluationType.BASIC,
LLMEvaluationType.UNSPECIFIED,
]:
logger.info("Adding LLM filtering task for document relevance evaluation")
post_processing_tasks.append(
FunctionCall(
Expand All @@ -479,8 +469,6 @@ def search_postprocessing(
)
)
llm_filter_task_id = post_processing_tasks[-1].result_id
elif search_query.evaluation_type == LLMEvaluationType.SKIP:
logger.info("Fast path: Skipping LLM filtering task for ordering-only mode")
elif DISABLE_LLM_DOC_RELEVANCE:
logger.info("Skipping LLM filtering task because LLM doc relevance is disabled")

Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/server/features/tool/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,6 @@ def list_tools(
return [
ToolSnapshot.from_model(tool)
for tool in tools
if tool.in_code_tool_id != ImageGenerationTool.name
if tool.in_code_tool_id != ImageGenerationTool._NAME
or is_image_generation_available(db_session=db_session)
]
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mypy_path = "$MYPY_CONFIG_FILE_DIR"
explicit_package_bases = true
disallow_untyped_defs = true
enable_error_code = ["possibly-undefined"]
strict_equality = true

[[tool.mypy.overrides]]
module = "alembic.versions.*"
Expand Down
6 changes: 4 additions & 2 deletions backend/scripts/sources_selection_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,17 @@ def _identify_diff(self, content_key: str) -> list[dict]:
)
return changes

def check_config_changes(self, previous_doc_rank: int, new_doc_rank: int) -> None:
def check_config_changes(
self, previous_doc_rank: int | str, new_doc_rank: int
) -> None:
"""Try to identify possible reasons why a change has been detected by
checking the latest document update date or the boost value.

Args:
previous_doc_rank (int): The document rank for the previous analysis
new_doc_rank (int): The document rank for the new analysis
"""
if new_doc_rank == "not_ranked":
if isinstance(new_doc_rank, str) and new_doc_rank == "not_ranked":
color_output(
(
"NOTE: The document is missing in the 'current' analysis file. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from pathlib import Path
from typing import Any
from typing import cast

import pytest

Expand All @@ -24,7 +25,7 @@ def extract_key_value_pairs_to_set(

def load_test_data(
file_name: str = "test_salesforce_data.json",
) -> dict[str, list[str] | dict[str, Any]]:
) -> dict[str, str | list[str] | dict[str, Any] | list[dict[str, Any]]]:
current_dir = Path(__file__).parent
with open(current_dir / file_name, "r") as f:
return json.load(f)
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_salesforce_connector_basic(salesforce_connector: SalesforceConnector) -
if not isinstance(expected_text, list):
raise ValueError("Expected text is not a list")

unparsed_expected_key_value_pairs: list[str] = expected_text
unparsed_expected_key_value_pairs: list[str] = cast(list[str], expected_text)
received_key_value_pairs = extract_key_value_pairs_to_set(received_text)
expected_key_value_pairs = extract_key_value_pairs_to_set(
unparsed_expected_key_value_pairs
Expand All @@ -110,7 +111,12 @@ def test_salesforce_connector_basic(salesforce_connector: SalesforceConnector) -
assert primary_owner.first_name == expected_primary_owner["first_name"]
assert primary_owner.last_name == expected_primary_owner["last_name"]

assert target_test_doc.secondary_owners == test_data["secondary_owners"]
secondary_owners = (
[owner.model_dump() for owner in target_test_doc.secondary_owners]
if target_test_doc.secondary_owners
else None
)
assert secondary_owners == test_data["secondary_owners"]
assert target_test_doc.title == test_data["title"]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ def verify(
and set(user.email for user in fetched_persona.users)
== set(persona.users)
and set(fetched_persona.groups) == set(persona.groups)
and set(fetched_persona.labels) == set(persona.label_ids)
and {label.id for label in fetched_persona.labels}
== set(persona.label_ids)
)
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_fuzzy_match_quotes_to_docs() -> None:
results = match_quotes_to_docs(
test_quotes, [test_chunk_0, test_chunk_1], fuzzy_search=True
)
assert results == {
assert results.model_dump() == {
"a doc with some": {"document": "test doc 0", "link": "doc 0 base"},
"a doc with some LINK": {
"document": "test doc 0",
Expand Down
Loading