Skip to content

Commit 983c471

Browse files
authored
feat: sharepoint scalability5 (#8631)
1 parent fb7e7e4 commit 983c471

File tree

7 files changed

+137
-28
lines changed

7 files changed

+137
-28
lines changed

backend/ee/onyx/external_permissions/sharepoint/group_sync.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,18 @@ def sharepoint_group_sync(
5353

5454
msal_app = connector.msal_app
5555
sp_tenant_domain = connector.sp_tenant_domain
56+
sp_domain_suffix = connector.sharepoint_domain_suffix
5657
for site_descriptor in site_descriptors:
5758
logger.debug(f"Processing site: {site_descriptor.url}")
5859

5960
ctx = ClientContext(site_descriptor.url).with_access_token(
60-
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
61+
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
6162
)
6263

6364
external_groups = get_sharepoint_external_groups(
6465
ctx,
6566
connector.graph_client,
67+
graph_api_base=connector.graph_api_base,
6668
get_access_token=connector._get_graph_access_token,
6769
enumerate_all_ad_groups=enumerate_all,
6870
)

backend/ee/onyx/external_permissions/sharepoint/permission_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from onyx.access.utils import build_ext_group_name_for_onyx
2121
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
2222
from onyx.configs.constants import DocumentSource
23-
from onyx.connectors.sharepoint.connector import GRAPH_API_BASE
2423
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
2524
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
2625
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
@@ -647,13 +646,14 @@ def add_user_and_group_to_sets(
647646
def _enumerate_ad_groups_paginated(
648647
get_access_token: Callable[[], str],
649648
already_resolved: set[str],
649+
graph_api_base: str,
650650
) -> Generator[ExternalUserGroup, None, None]:
651651
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
652652
653653
Skips groups whose suffixed name is already in *already_resolved*.
654654
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
655655
"""
656-
groups_url = f"{GRAPH_API_BASE}/groups"
656+
groups_url = f"{graph_api_base}/groups"
657657
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
658658
total_groups = 0
659659

@@ -679,7 +679,7 @@ def _enumerate_ad_groups_paginated(
679679
continue
680680

681681
member_emails: list[str] = []
682-
members_url = f"{GRAPH_API_BASE}/groups/{group_id}/members"
682+
members_url = f"{graph_api_base}/groups/{group_id}/members"
683683
members_params: dict[str, str] = {
684684
"$select": "userPrincipalName,mail",
685685
"$top": "999",
@@ -699,6 +699,7 @@ def _enumerate_ad_groups_paginated(
699699
def get_sharepoint_external_groups(
700700
client_context: ClientContext,
701701
graph_client: GraphClient,
702+
graph_api_base: str,
702703
get_access_token: Callable[[], str] | None = None,
703704
enumerate_all_ad_groups: bool = False,
704705
) -> list[ExternalUserGroup]:
@@ -769,7 +770,9 @@ def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None:
769770
return external_user_groups
770771

771772
already_resolved = set(groups_and_members.groups_to_emails.keys())
772-
for group in _enumerate_ad_groups_paginated(get_access_token, already_resolved):
773+
for group in _enumerate_ad_groups_paginated(
774+
get_access_token, already_resolved, graph_api_base
775+
):
773776
external_user_groups.append(group)
774777

775778
return external_user_groups

backend/onyx/connectors/sharepoint/connector.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@
8383

8484
ASPX_EXTENSION = ".aspx"
8585

86-
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
86+
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
87+
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
88+
DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com"
89+
90+
GRAPH_API_BASE = f"{DEFAULT_GRAPH_API_HOST}/v1.0"
8791
GRAPH_API_MAX_RETRIES = 5
8892
GRAPH_API_RETRYABLE_STATUSES = frozenset({429, 500, 502, 503, 504})
8993

@@ -285,10 +289,12 @@ def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData
285289

286290

287291
def acquire_token_for_rest(
288-
msal_app: msal.ConfidentialClientApplication, sp_tenant_domain: str
292+
msal_app: msal.ConfidentialClientApplication,
293+
sp_tenant_domain: str,
294+
sharepoint_domain_suffix: str,
289295
) -> TokenResponse:
290296
token = msal_app.acquire_token_for_client(
291-
scopes=[f"https://{sp_tenant_domain}.sharepoint.com/.default"]
297+
scopes=[f"https://{sp_tenant_domain}.{sharepoint_domain_suffix}/.default"]
292298
)
293299
return TokenResponse.from_json(token)
294300

@@ -403,12 +409,13 @@ def _download_via_graph_api(
403409
drive_id: str,
404410
item_id: str,
405411
bytes_allowed: int,
412+
graph_api_base: str,
406413
) -> bytes:
407414
"""Download a drive item via the Graph API /content endpoint with a byte cap.
408415
409416
Raises SizeCapExceeded if the cap is exceeded.
410417
"""
411-
url = f"{GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}/content"
418+
url = f"{graph_api_base}/drives/{drive_id}/items/{item_id}/content"
412419
headers = {"Authorization": f"Bearer {access_token}"}
413420
with requests.get(
414421
url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS
@@ -429,6 +436,7 @@ def _convert_driveitem_to_document_with_permissions(
429436
drive_name: str,
430437
ctx: ClientContext | None,
431438
graph_client: GraphClient,
439+
graph_api_base: str,
432440
include_permissions: bool = False,
433441
parent_hierarchy_raw_node_id: str | None = None,
434442
access_token: str | None = None,
@@ -485,6 +493,7 @@ def _convert_driveitem_to_document_with_permissions(
485493
driveitem.drive_id,
486494
driveitem.id,
487495
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD,
496+
graph_api_base=graph_api_base,
488497
)
489498
except SizeCapExceeded:
490499
logger.warning(
@@ -804,6 +813,9 @@ def __init__(
804813
sites: list[str] = [],
805814
include_site_pages: bool = True,
806815
include_site_documents: bool = True,
816+
authority_host: str = DEFAULT_AUTHORITY_HOST,
817+
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
818+
sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX,
807819
) -> None:
808820
self.batch_size = batch_size
809821
self.sites = list(sites)
@@ -819,6 +831,10 @@ def __init__(
819831
self._cached_rest_ctx: ClientContext | None = None
820832
self._cached_rest_ctx_url: str | None = None
821833
self._cached_rest_ctx_created_at: float = 0.0
834+
self.authority_host = authority_host.rstrip("/")
835+
self.graph_api_host = graph_api_host.rstrip("/")
836+
self.graph_api_base = f"{self.graph_api_host}/v1.0"
837+
self.sharepoint_domain_suffix = sharepoint_domain_suffix
822838

823839
def validate_connector_settings(self) -> None:
824840
# Validate that at least one content type is enabled
@@ -875,8 +891,9 @@ def _create_rest_client_context(self, site_url: str) -> ClientContext:
875891

876892
msal_app = self.msal_app
877893
sp_tenant_domain = self.sp_tenant_domain
894+
sp_domain_suffix = self.sharepoint_domain_suffix
878895
self._cached_rest_ctx = ClientContext(site_url).with_access_token(
879-
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
896+
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
880897
)
881898
self._cached_rest_ctx_url = site_url
882899
self._cached_rest_ctx_created_at = time.monotonic()
@@ -1148,7 +1165,7 @@ def _fetch_site_pages(
11481165
site_id = site.id
11491166

11501167
page_url: str | None = (
1151-
f"{GRAPH_API_BASE}/sites/{site_id}/pages/microsoft.graph.sitePage"
1168+
f"{self.graph_api_base}/sites/{site_id}" f"/pages/microsoft.graph.sitePage"
11521169
)
11531170
params: dict[str, str] | None = {"$expand": "canvasLayout"}
11541171
total_yielded = 0
@@ -1175,7 +1192,7 @@ def _acquire_token(self) -> dict[str, Any]:
11751192
raise RuntimeError("MSAL app is not initialized")
11761193

11771194
token = self.msal_app.acquire_token_for_client(
1178-
scopes=["https://graph.microsoft.com/.default"]
1195+
scopes=[f"{self.graph_api_host}/.default"]
11791196
)
11801197
return token
11811198

@@ -1248,7 +1265,7 @@ def _iter_drive_items_paged(
12481265
Performs BFS folder traversal manually, fetching one page of children
12491266
at a time so that memory usage stays bounded regardless of drive size.
12501267
"""
1251-
base = f"{GRAPH_API_BASE}/drives/{drive_id}"
1268+
base = f"{self.graph_api_base}/drives/{drive_id}"
12521269
if folder_path:
12531270
start_url = f"{base}/root:/{folder_path}:/children"
12541271
else:
@@ -1308,7 +1325,7 @@ def _iter_drive_items_delta(
13081325
"""
13091326
use_timestamp_token = start is not None and start > _EPOCH
13101327

1311-
initial_url = f"{GRAPH_API_BASE}/drives/{drive_id}/root/delta"
1328+
initial_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
13121329
if use_timestamp_token:
13131330
assert start is not None # mypy
13141331
token = quote(start.isoformat(timespec="seconds"))
@@ -1354,7 +1371,7 @@ def _iter_delta_pages(
13541371
drive_id,
13551372
)
13561373
yield from self._iter_delta_pages(
1357-
initial_url=f"{GRAPH_API_BASE}/drives/{drive_id}/root/delta",
1374+
initial_url=f"{self.graph_api_base}/drives/{drive_id}/root/delta",
13581375
drive_id=drive_id,
13591376
start=start,
13601377
end=end,
@@ -1471,7 +1488,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
14711488
sp_private_key = credentials.get("sp_private_key")
14721489
sp_certificate_password = credentials.get("sp_certificate_password")
14731490

1474-
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
1491+
authority_url = f"{self.authority_host}/{sp_directory_id}"
14751492

14761493
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
14771494
logger.info("Using certificate authentication")
@@ -1512,7 +1529,7 @@ def _acquire_token_for_graph() -> dict[str, Any]:
15121529
raise ConnectorValidationError("MSAL app is not initialized")
15131530

15141531
token = self.msal_app.acquire_token_for_client(
1515-
scopes=["https://graph.microsoft.com/.default"]
1532+
scopes=[f"{self.graph_api_host}/.default"]
15161533
)
15171534
if token is None:
15181535
raise ConnectorValidationError("Failed to acquire token for graph")
@@ -1941,6 +1958,7 @@ def _load_from_checkpoint(
19411958
self.graph_client,
19421959
include_permissions=include_permissions,
19431960
parent_hierarchy_raw_node_id=parent_hierarchy_url,
1961+
graph_api_base=self.graph_api_base,
19441962
access_token=access_token,
19451963
)
19461964

backend/onyx/connectors/teams/connector.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,31 @@ class TeamsCheckpoint(ConnectorCheckpoint):
5050
todo_team_ids: list[str] | None = None
5151

5252

53+
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
54+
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
55+
56+
5357
class TeamsConnector(
5458
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
5559
SlimConnectorWithPermSync,
5660
):
5761
MAX_WORKERS = 10
58-
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
5962

6063
def __init__(
6164
self,
6265
# TODO: (chris) move from "Display Names" to IDs, since display names
6366
# are not necessarily guaranteed to be unique
6467
teams: list[str] = [],
6568
max_workers: int = MAX_WORKERS,
69+
authority_host: str = DEFAULT_AUTHORITY_HOST,
70+
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
6671
) -> None:
6772
self.graph_client: GraphClient | None = None
6873
self.msal_app: msal.ConfidentialClientApplication | None = None
6974
self.max_workers = max_workers
7075
self.requested_team_list: list[str] = teams
76+
self.authority_host = authority_host.rstrip("/")
77+
self.graph_api_host = graph_api_host.rstrip("/")
7178

7279
# impls for BaseConnector
7380

@@ -76,7 +83,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
7683
teams_client_secret = credentials["teams_client_secret"]
7784
teams_directory_id = credentials["teams_directory_id"]
7885

79-
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
86+
authority_url = f"{self.authority_host}/{teams_directory_id}"
8087
self.msal_app = msal.ConfidentialClientApplication(
8188
authority=authority_url,
8289
client_id=teams_client_id,
@@ -91,7 +98,7 @@ def _acquire_token_func() -> dict[str, Any]:
9198
raise RuntimeError("MSAL app is not initialized")
9299

93100
token = self.msal_app.acquire_token_for_client(
94-
scopes=["https://graph.microsoft.com/.default"]
101+
scopes=[f"{self.graph_api_host}/.default"]
95102
)
96103

97104
if not isinstance(token, dict):

backend/tests/unit/ee/onyx/external_permissions/sharepoint/test_permission_utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323

2424
MODULE = "ee.onyx.external_permissions.sharepoint.permission_utils"
25+
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
2526

2627

2728
# ---------------------------------------------------------------------------
@@ -125,7 +126,11 @@ def test_enumerate_ad_groups_yields_groups(mock_get: MagicMock) -> None:
125126
}
126127
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, members)
127128

128-
results = list(_enumerate_ad_groups_paginated(_fake_token, already_resolved=set()))
129+
results = list(
130+
_enumerate_ad_groups_paginated(
131+
_fake_token, already_resolved=set(), graph_api_base=GRAPH_API_BASE
132+
)
133+
)
129134

130135
assert len(results) == 2
131136
eng = next(r for r in results if r.id == "Engineering_g1")
@@ -140,7 +145,11 @@ def test_enumerate_ad_groups_skips_already_resolved(mock_get: MagicMock) -> None
140145
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, {})
141146

142147
results = list(
143-
_enumerate_ad_groups_paginated(_fake_token, already_resolved={"Engineering_g1"})
148+
_enumerate_ad_groups_paginated(
149+
_fake_token,
150+
already_resolved={"Engineering_g1"},
151+
graph_api_base=GRAPH_API_BASE,
152+
)
144153
)
145154
assert results == []
146155

@@ -152,7 +161,11 @@ def test_enumerate_ad_groups_circuit_breaker(mock_get: MagicMock) -> None:
152161
groups = [{"id": f"g{i}", "displayName": f"Group{i}"} for i in range(over_limit)]
153162
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, {})
154163

155-
results = list(_enumerate_ad_groups_paginated(_fake_token, already_resolved=set()))
164+
results = list(
165+
_enumerate_ad_groups_paginated(
166+
_fake_token, already_resolved=set(), graph_api_base=GRAPH_API_BASE
167+
)
168+
)
156169
assert len(results) <= AD_GROUP_ENUMERATION_THRESHOLD
157170

158171

@@ -189,6 +202,7 @@ def test_default_skips_ad_enumeration(
189202
results = get_sharepoint_external_groups(
190203
client_context=MagicMock(),
191204
graph_client=MagicMock(),
205+
graph_api_base=GRAPH_API_BASE,
192206
)
193207

194208
assert len(results) == 1
@@ -219,6 +233,7 @@ def test_enumerate_all_includes_ad_groups(
219233
graph_client=MagicMock(),
220234
get_access_token=_fake_token,
221235
enumerate_all_ad_groups=True,
236+
graph_api_base=GRAPH_API_BASE,
222237
)
223238

224239
assert len(results) == 2
@@ -246,6 +261,7 @@ def test_enumerate_all_without_token_skips(
246261
graph_client=MagicMock(),
247262
get_access_token=None,
248263
enumerate_all_ad_groups=True,
264+
graph_api_base=GRAPH_API_BASE,
249265
)
250266

251267
assert results == []

backend/tests/unit/onyx/connectors/sharepoint/test_drive_matching.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def fake_convert(
206206
drive_name: str,
207207
ctx: Any, # noqa: ARG001
208208
graph_client: Any, # noqa: ARG001
209+
graph_api_base: str, # noqa: ARG001
209210
include_permissions: bool, # noqa: ARG001
210211
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
211212
access_token: str | None = None, # noqa: ARG001

0 commit comments

Comments
 (0)