Skip to content

Shore up multi tenant tests #4484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions backend/tests/integration/common_utils/managers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,19 @@ def send_message(
use_existing_user_message=use_existing_user_message,
)

headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
cookies = user_performing_action.cookies if user_performing_action else None

response = requests.post(
f"{API_SERVER_URL}/chat/send-message",
json=chat_message_req.model_dump(),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=headers,
stream=True,
cookies=cookies,
)

return ChatSessionManager.analyze_response(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
INVITED_BASIC_USER_EMAIL = "[email protected]"


def test_user_invitation_flow(reset_multitenant: None) -> None:
def test_admin_can_invite_users(reset_multitenant: None) -> None:
"""Test that an admin can invite both registered and non-registered users."""
# Create first user (admin)
admin_user: DATestUser = UserManager.create(name="admin")
assert UserManager.is_role(admin_user, UserRole.ADMIN)
Expand All @@ -19,16 +20,44 @@ def test_user_invitation_flow(reset_multitenant: None) -> None:
UserManager.invite_user(invited_user.email, admin_user)
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)

# Verify users are in the invited users list
invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email in [
user.email for user in invited_users
], f"User {invited_user.email} not found in invited users list"


def test_non_registered_user_gets_basic_role(reset_multitenant: None) -> None:
"""Test that a non-registered user gets a BASIC role when they register after being invited."""
# Create admin user
admin_user: DATestUser = UserManager.create(name="admin")
assert UserManager.is_role(admin_user, UserRole.ADMIN)

# Admin user invites a non-registered user
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)

# Non-registered user registers
invited_basic_user: DATestUser = UserManager.create(
name=INVITED_BASIC_USER, email=INVITED_BASIC_USER_EMAIL
)
assert UserManager.is_role(invited_basic_user, UserRole.BASIC)

# Verify the user is in the invited users list
invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email in [
user.email for user in invited_users
], f"User {invited_user.email} not found in invited users list"

def test_user_can_accept_invitation(reset_multitenant: None) -> None:
"""Test that a user can accept an invitation and join the organization with BASIC role."""
# Create admin user
admin_user: DATestUser = UserManager.create(name="admin")
assert UserManager.is_role(admin_user, UserRole.ADMIN)

# Create a user to be invited
invited_user_email = "[email protected]"

# User registers with the same email as the invitation
invited_user: DATestUser = UserManager.create(
name="invited_user", email=invited_user_email
)
# Admin user invites the user
UserManager.invite_user(invited_user_email, admin_user)

# Get user info to check tenant information
user_info = UserManager.get_user_info(invited_user)
Expand All @@ -41,16 +70,17 @@ def test_user_invitation_flow(reset_multitenant: None) -> None:
)
assert invited_tenant_id is not None, "Expected to find an invitation tenant_id"

# User accepts invitation
UserManager.accept_invitation(invited_tenant_id, invited_user)

# Get updated user info after accepting invitation
updated_user_info = UserManager.get_user_info(invited_user)
# User needs to reauthenticate after accepting invitation
# Simulate this by creating a new user instance with the same credentials
authenticated_user: DATestUser = UserManager.create(
name="invited_user", email=invited_user_email
)

# Verify the user is no longer in the invited users list
updated_invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email not in [
user.email for user in updated_invited_users
], f"User {invited_user.email} should not be in invited users list after accepting"
# Get updated user info after accepting invitation and reauthenticating
updated_user_info = UserManager.get_user_info(authenticated_user)

# Verify the user has BASIC role in the organization
assert (
Expand All @@ -64,7 +94,7 @@ def test_user_invitation_flow(reset_multitenant: None) -> None:

# Check if the invited user is in the list of users with BASIC role
invited_user_emails = [user.email for user in user_page.items]
assert invited_user.email in invited_user_emails, (
f"User {invited_user.email} not found in the list of basic users "
assert invited_user_email in invited_user_emails, (
f"User {invited_user_email} not found in the list of basic users "
f"in the organization. Available users: {invited_user_emails}"
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from onyx.db.models import UserRole
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
Expand All @@ -11,12 +13,12 @@
from tests.integration.common_utils.test_models import DATestUser


def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Creating an admin user (first user created is automatically an admin and also proviions the tenant
def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]:
"""Helper function to set up test tenants with documents and users."""
# Creating an admin user for Tenant 1
admin_user1: DATestUser = UserManager.create(
email="[email protected]",
)

assert UserManager.is_role(admin_user1, UserRole.ADMIN)

# Create Tenant 2 and its Admin User
Expand All @@ -35,6 +37,16 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
api_key_1.headers.update(admin_user1.headers)
LLMProviderManager.create(user_performing_action=admin_user1)

# Create connectors for Tenant 2
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user2,
)
api_key_2: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user2,
)
api_key_2.headers.update(admin_user2.headers)
LLMProviderManager.create(user_performing_action=admin_user2)

# Seed documents for Tenant 1
cc_pair_1.documents = []
doc1_tenant1 = DocumentManager.seed_doc_with_content(
Expand All @@ -49,16 +61,6 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
)
cc_pair_1.documents.extend([doc1_tenant1, doc2_tenant1])

# Create connectors for Tenant 2
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user2,
)
api_key_2: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user2,
)
api_key_2.headers.update(admin_user2.headers)
LLMProviderManager.create(user_performing_action=admin_user2)

# Seed documents for Tenant 2
cc_pair_2.documents = []
doc1_tenant2 = DocumentManager.seed_doc_with_content(
Expand All @@ -84,21 +86,36 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
user_performing_action=admin_user2
)

return {
"admin_user1": admin_user1,
"admin_user2": admin_user2,
"chat_session1": chat_session1,
"chat_session2": chat_session2,
"tenant1_doc_ids": tenant1_doc_ids,
"tenant2_doc_ids": tenant2_doc_ids,
}


def test_tenant1_can_access_own_documents(reset_multitenant: None) -> None:
"""Test that Tenant 1 can access its own documents but not Tenant 2's."""
test_data = setup_test_tenants(reset_multitenant)

# User 1 sends a message and gets a response
response1 = ChatSessionManager.send_message(
chat_session_id=chat_session1.id,
chat_session_id=test_data["chat_session1"].id,
message="What is in Tenant 1's documents?",
user_performing_action=admin_user1,
user_performing_action=test_data["admin_user1"],
)

# Assert that the search tool was used
assert response1.tool_name == "run_search"

response_doc_ids = {doc["document_id"] for doc in response1.tool_result or []}
assert tenant1_doc_ids.issubset(
assert test_data["tenant1_doc_ids"].issubset(
response_doc_ids
), "Not all Tenant 1 document IDs are in the response"
assert not response_doc_ids.intersection(
tenant2_doc_ids
test_data["tenant2_doc_ids"]
), "Tenant 2 document IDs should not be in the response"

# Assert that the contents are correct
Expand All @@ -107,21 +124,28 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
for doc in response1.tool_result or []
), "Tenant 1 Document Content not found in any document"


def test_tenant2_can_access_own_documents(reset_multitenant: None) -> None:
"""Test that Tenant 2 can access its own documents but not Tenant 1's."""
test_data = setup_test_tenants(reset_multitenant)

# User 2 sends a message and gets a response
response2 = ChatSessionManager.send_message(
chat_session_id=chat_session2.id,
chat_session_id=test_data["chat_session2"].id,
message="What is in Tenant 2's documents?",
user_performing_action=admin_user2,
user_performing_action=test_data["admin_user2"],
)

# Assert that the search tool was used
assert response2.tool_name == "run_search"

# Assert that the tool_result contains Tenant 2's documents
response_doc_ids = {doc["document_id"] for doc in response2.tool_result or []}
assert tenant2_doc_ids.issubset(
assert test_data["tenant2_doc_ids"].issubset(
response_doc_ids
), "Not all Tenant 2 document IDs are in the response"
assert not response_doc_ids.intersection(
tenant1_doc_ids
test_data["tenant1_doc_ids"]
), "Tenant 1 document IDs should not be in the response"

# Assert that the contents are correct
Expand All @@ -130,28 +154,91 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
for doc in response2.tool_result or []
), "Tenant 2 Document Content not found in any document"


def test_tenant1_cannot_access_tenant2_documents(reset_multitenant: None) -> None:
"""Test that Tenant 1 cannot access Tenant 2's documents."""
test_data = setup_test_tenants(reset_multitenant)

# User 1 tries to access Tenant 2's documents
response_cross = ChatSessionManager.send_message(
chat_session_id=chat_session1.id,
chat_session_id=test_data["chat_session1"].id,
message="What is in Tenant 2's documents?",
user_performing_action=admin_user1,
user_performing_action=test_data["admin_user1"],
)

# Assert that the search tool was used
assert response_cross.tool_name == "run_search"

# Assert that the tool_result is empty or does not contain Tenant 2's documents
response_doc_ids = {doc["document_id"] for doc in response_cross.tool_result or []}

# Ensure none of Tenant 2's document IDs are in the response
assert not response_doc_ids.intersection(tenant2_doc_ids)
assert not response_doc_ids.intersection(test_data["tenant2_doc_ids"])


def test_tenant2_cannot_access_tenant1_documents(reset_multitenant: None) -> None:
"""Test that Tenant 2 cannot access Tenant 1's documents."""
test_data = setup_test_tenants(reset_multitenant)

# User 2 tries to access Tenant 1's documents
response_cross2 = ChatSessionManager.send_message(
chat_session_id=chat_session2.id,
chat_session_id=test_data["chat_session2"].id,
message="What is in Tenant 1's documents?",
user_performing_action=admin_user2,
user_performing_action=test_data["admin_user2"],
)

# Assert that the search tool was used
assert response_cross2.tool_name == "run_search"

# Assert that the tool_result is empty or does not contain Tenant 1's documents
response_doc_ids = {doc["document_id"] for doc in response_cross2.tool_result or []}

# Ensure none of Tenant 1's document IDs are in the response
assert not response_doc_ids.intersection(tenant1_doc_ids)
assert not response_doc_ids.intersection(test_data["tenant1_doc_ids"])


def test_multi_tenant_access_control(reset_multitenant: None) -> None:
"""Legacy test for multi-tenant access control."""
test_data = setup_test_tenants(reset_multitenant)

# User 1 sends a message and gets a response with only Tenant 1's documents
response1 = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session1"].id,
message="What is in Tenant 1's documents?",
user_performing_action=test_data["admin_user1"],
)
assert response1.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response1.tool_result or []}
assert test_data["tenant1_doc_ids"].issubset(response_doc_ids)
assert not response_doc_ids.intersection(test_data["tenant2_doc_ids"])

# User 2 sends a message and gets a response with only Tenant 2's documents
response2 = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session2"].id,
message="What is in Tenant 2's documents?",
user_performing_action=test_data["admin_user2"],
)
assert response2.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response2.tool_result or []}
assert test_data["tenant2_doc_ids"].issubset(response_doc_ids)
assert not response_doc_ids.intersection(test_data["tenant1_doc_ids"])

# User 1 tries to access Tenant 2's documents and fails
response_cross = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session1"].id,
message="What is in Tenant 2's documents?",
user_performing_action=test_data["admin_user1"],
)
assert response_cross.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response_cross.tool_result or []}
assert not response_doc_ids.intersection(test_data["tenant2_doc_ids"])

# User 2 tries to access Tenant 1's documents and fails
response_cross2 = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session2"].id,
message="What is in Tenant 1's documents?",
user_performing_action=test_data["admin_user2"],
)
assert response_cross2.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response_cross2.tool_result or []}
assert not response_doc_ids.intersection(test_data["tenant1_doc_ids"])
Loading
Loading