Skip to content

Commit 9f50b2f

Browse files
authored
feat: sharepoint scalability4 (#8551)
1 parent 14416cc commit 9f50b2f

File tree

4 files changed

+419
-54
lines changed

4 files changed

+419
-54
lines changed

backend/ee/onyx/external_permissions/sharepoint/group_sync.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from collections.abc import Generator
22

3+
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
4+
35
from ee.onyx.db.external_perm import ExternalUserGroup
46
from ee.onyx.external_permissions.sharepoint.permission_utils import (
57
get_sharepoint_external_groups,
68
)
9+
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
10+
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
711
from onyx.connectors.sharepoint.connector import SharepointConnector
812
from onyx.db.models import ConnectorCredentialPair
913
from onyx.utils.logger import setup_logger
@@ -43,14 +47,25 @@ def sharepoint_group_sync(
4347

4448
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
4549

46-
# Process each site
50+
enumerate_all = connector_config.get(
51+
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
52+
)
53+
54+
msal_app = connector.msal_app
55+
sp_tenant_domain = connector.sp_tenant_domain
4756
for site_descriptor in site_descriptors:
4857
logger.debug(f"Processing site: {site_descriptor.url}")
4958

50-
ctx = connector._create_rest_client_context(site_descriptor.url)
59+
ctx = ClientContext(site_descriptor.url).with_access_token(
60+
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
61+
)
5162

52-
# Get external groups for this site
53-
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
63+
external_groups = get_sharepoint_external_groups(
64+
ctx,
65+
connector.graph_client,
66+
get_access_token=connector._get_graph_access_token,
67+
enumerate_all_ad_groups=enumerate_all,
68+
)
5469

5570
# Yield each group
5671
for group in external_groups:

backend/ee/onyx/external_permissions/sharepoint/permission_utils.py

Lines changed: 140 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import re
2+
import time
23
from collections import deque
4+
from collections.abc import Callable
5+
from collections.abc import Generator
36
from typing import Any
47
from urllib.parse import unquote
58
from urllib.parse import urlparse
69

10+
import requests as _requests
711
from office365.graph_client import GraphClient # type: ignore[import-untyped]
812
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
913
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -14,7 +18,11 @@
1418
from ee.onyx.db.external_perm import ExternalUserGroup
1519
from onyx.access.models import ExternalAccess
1620
from onyx.access.utils import build_ext_group_name_for_onyx
21+
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
1722
from onyx.configs.constants import DocumentSource
23+
from onyx.connectors.sharepoint.connector import GRAPH_API_BASE
24+
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
25+
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
1826
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
1927
from onyx.connectors.sharepoint.connector import sleep_and_retry
2028
from onyx.utils.logger import setup_logger
@@ -33,6 +41,70 @@
3341
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
3442

3543

44+
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
45+
46+
47+
def _graph_api_get(
48+
url: str,
49+
get_access_token: Callable[[], str],
50+
params: dict[str, str] | None = None,
51+
) -> dict[str, Any]:
52+
"""Authenticated Graph API GET with retry on transient errors."""
53+
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
54+
access_token = get_access_token()
55+
headers = {"Authorization": f"Bearer {access_token}"}
56+
try:
57+
resp = _requests.get(
58+
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
59+
)
60+
if (
61+
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
62+
and attempt < GRAPH_API_MAX_RETRIES
63+
):
64+
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
65+
logger.warning(
66+
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
67+
f"retrying in {wait}s: {url}"
68+
)
69+
time.sleep(wait)
70+
continue
71+
resp.raise_for_status()
72+
return resp.json()
73+
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
74+
if attempt < GRAPH_API_MAX_RETRIES:
75+
wait = min(2**attempt, 60)
76+
logger.warning(
77+
f"Graph API connection error on attempt {attempt + 1}, "
78+
f"retrying in {wait}s: {url}"
79+
)
80+
time.sleep(wait)
81+
continue
82+
raise
83+
raise RuntimeError(
84+
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
85+
)
86+
87+
88+
def _iter_graph_collection(
89+
initial_url: str,
90+
get_access_token: Callable[[], str],
91+
params: dict[str, str] | None = None,
92+
) -> Generator[dict[str, Any], None, None]:
93+
"""Paginate through a Graph API collection, yielding items one at a time."""
94+
url: str | None = initial_url
95+
while url:
96+
data = _graph_api_get(url, get_access_token, params)
97+
params = None
98+
yield from data.get("value", [])
99+
url = data.get("@odata.nextLink")
100+
101+
102+
def _normalize_email(email: str) -> str:
103+
if MICROSOFT_DOMAIN in email:
104+
return email.replace(MICROSOFT_DOMAIN, "")
105+
return email
106+
107+
36108
class SharepointGroup(BaseModel):
37109
model_config = {"frozen": True}
38110

@@ -572,8 +644,63 @@ def add_user_and_group_to_sets(
572644
)
573645

574646

647+
def _enumerate_ad_groups_paginated(
648+
get_access_token: Callable[[], str],
649+
already_resolved: set[str],
650+
) -> Generator[ExternalUserGroup, None, None]:
651+
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
652+
653+
Skips groups whose suffixed name is already in *already_resolved*.
654+
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
655+
"""
656+
groups_url = f"{GRAPH_API_BASE}/groups"
657+
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
658+
total_groups = 0
659+
660+
for group_json in _iter_graph_collection(
661+
groups_url, get_access_token, groups_params
662+
):
663+
group_id: str = group_json.get("id", "")
664+
display_name: str = group_json.get("displayName", "")
665+
if not group_id or not display_name:
666+
continue
667+
668+
total_groups += 1
669+
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
670+
logger.warning(
671+
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
672+
"groups — stopping to avoid excessive memory/API usage. "
673+
"Remaining groups will be resolved from role assignments only."
674+
)
675+
return
676+
677+
name = f"{display_name}_{group_id}"
678+
if name in already_resolved:
679+
continue
680+
681+
member_emails: list[str] = []
682+
members_url = f"{GRAPH_API_BASE}/groups/{group_id}/members"
683+
members_params: dict[str, str] = {
684+
"$select": "userPrincipalName,mail",
685+
"$top": "999",
686+
}
687+
for member_json in _iter_graph_collection(
688+
members_url, get_access_token, members_params
689+
):
690+
email = member_json.get("userPrincipalName") or member_json.get("mail")
691+
if email:
692+
member_emails.append(_normalize_email(email))
693+
694+
yield ExternalUserGroup(id=name, user_emails=member_emails)
695+
696+
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
697+
698+
575699
def get_sharepoint_external_groups(
576-
client_context: ClientContext, graph_client: GraphClient
700+
client_context: ClientContext,
701+
graph_client: GraphClient,
702+
get_access_token: Callable[[], str] | None = None,
703+
enumerate_all_ad_groups: bool = False,
577704
) -> list[ExternalUserGroup]:
578705

579706
groups: set[SharepointGroup] = set()
@@ -629,57 +756,20 @@ def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None:
629756
client_context, graph_client, groups, is_group_sync=True
630757
)
631758

632-
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
633-
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
634-
azure_ad_groups = sleep_and_retry(
635-
graph_client.groups.get_all(page_loaded=lambda _: None),
636-
"get_sharepoint_external_groups:get_azure_ad_groups",
637-
)
638-
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
639-
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
640-
ad_groups_to_emails: dict[str, set[str]] = {}
641-
for group in azure_ad_groups:
642-
# If the group is already identified, we don't need to get the members
643-
if group.display_name in identified_groups:
644-
continue
645-
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
646-
name = group.display_name
647-
name = _get_group_name_with_suffix(group.id, name, graph_client)
648-
649-
members = sleep_and_retry(
650-
group.members.get_all(page_loaded=lambda _: None),
651-
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
652-
)
653-
for member in members:
654-
member_data = member.to_json()
655-
user_principal_name = member_data.get("userPrincipalName")
656-
mail = member_data.get("mail")
657-
if not ad_groups_to_emails.get(name):
658-
ad_groups_to_emails[name] = set()
659-
if user_principal_name:
660-
if MICROSOFT_DOMAIN in user_principal_name:
661-
user_principal_name = user_principal_name.replace(
662-
MICROSOFT_DOMAIN, ""
663-
)
664-
ad_groups_to_emails[name].add(user_principal_name)
665-
elif mail:
666-
if MICROSOFT_DOMAIN in mail:
667-
mail = mail.replace(MICROSOFT_DOMAIN, "")
668-
ad_groups_to_emails[name].add(mail)
759+
external_user_groups: list[ExternalUserGroup] = [
760+
ExternalUserGroup(id=group_name, user_emails=list(emails))
761+
for group_name, emails in groups_and_members.groups_to_emails.items()
762+
]
669763

670-
external_user_groups: list[ExternalUserGroup] = []
671-
for group_name, emails in groups_and_members.groups_to_emails.items():
672-
external_user_group = ExternalUserGroup(
673-
id=group_name,
674-
user_emails=list(emails),
764+
if not enumerate_all_ad_groups or get_access_token is None:
765+
logger.info(
766+
"Skipping exhaustive Azure AD group enumeration. "
767+
"Only groups found in site role assignments are included."
675768
)
676-
external_user_groups.append(external_user_group)
769+
return external_user_groups
677770

678-
for group_name, emails in ad_groups_to_emails.items():
679-
external_user_group = ExternalUserGroup(
680-
id=group_name,
681-
user_emails=list(emails),
682-
)
683-
external_user_groups.append(external_user_group)
771+
already_resolved = set(groups_and_members.groups_to_emails.keys())
772+
for group in _enumerate_ad_groups_paginated(get_access_token, already_resolved):
773+
external_user_groups.append(group)
684774

685775
return external_user_groups

backend/onyx/configs/app_configs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,14 @@ def get_current_tz_offset() -> int:
637637
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
638638
)
639639

640+
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
641+
# When False (default), only groups found in site role assignments are synced.
642+
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
643+
# connector_specific_config.
644+
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
645+
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
646+
)
647+
640648
BLOB_STORAGE_SIZE_THRESHOLD = int(
641649
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
642650
)

0 commit comments

Comments
 (0)