|
1 | 1 | import re |
| 2 | +import time |
2 | 3 | from collections import deque |
| 4 | +from collections.abc import Callable |
| 5 | +from collections.abc import Generator |
3 | 6 | from typing import Any |
4 | 7 | from urllib.parse import unquote |
5 | 8 | from urllib.parse import urlparse |
6 | 9 |
|
| 10 | +import requests as _requests |
7 | 11 | from office365.graph_client import GraphClient # type: ignore[import-untyped] |
8 | 12 | from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped] |
9 | 13 | from office365.runtime.client_request import ClientRequestException # type: ignore |
|
14 | 18 | from ee.onyx.db.external_perm import ExternalUserGroup |
15 | 19 | from onyx.access.models import ExternalAccess |
16 | 20 | from onyx.access.utils import build_ext_group_name_for_onyx |
| 21 | +from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS |
17 | 22 | from onyx.configs.constants import DocumentSource |
| 23 | +from onyx.connectors.sharepoint.connector import GRAPH_API_BASE |
| 24 | +from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES |
| 25 | +from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES |
18 | 26 | from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE |
19 | 27 | from onyx.connectors.sharepoint.connector import sleep_and_retry |
20 | 28 | from onyx.utils.logger import setup_logger |
|
33 | 41 | LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"] |
34 | 42 |
|
35 | 43 |
|
| 44 | +AD_GROUP_ENUMERATION_THRESHOLD = 100_000 |
| 45 | + |
| 46 | + |
| 47 | +def _graph_api_get( |
| 48 | + url: str, |
| 49 | + get_access_token: Callable[[], str], |
| 50 | + params: dict[str, str] | None = None, |
| 51 | +) -> dict[str, Any]: |
| 52 | + """Authenticated Graph API GET with retry on transient errors.""" |
| 53 | + for attempt in range(GRAPH_API_MAX_RETRIES + 1): |
| 54 | + access_token = get_access_token() |
| 55 | + headers = {"Authorization": f"Bearer {access_token}"} |
| 56 | + try: |
| 57 | + resp = _requests.get( |
| 58 | + url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS |
| 59 | + ) |
| 60 | + if ( |
| 61 | + resp.status_code in GRAPH_API_RETRYABLE_STATUSES |
| 62 | + and attempt < GRAPH_API_MAX_RETRIES |
| 63 | + ): |
| 64 | + wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60) |
| 65 | + logger.warning( |
| 66 | + f"Graph API {resp.status_code} on attempt {attempt + 1}, " |
| 67 | + f"retrying in {wait}s: {url}" |
| 68 | + ) |
| 69 | + time.sleep(wait) |
| 70 | + continue |
| 71 | + resp.raise_for_status() |
| 72 | + return resp.json() |
| 73 | + except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError): |
| 74 | + if attempt < GRAPH_API_MAX_RETRIES: |
| 75 | + wait = min(2**attempt, 60) |
| 76 | + logger.warning( |
| 77 | + f"Graph API connection error on attempt {attempt + 1}, " |
| 78 | + f"retrying in {wait}s: {url}" |
| 79 | + ) |
| 80 | + time.sleep(wait) |
| 81 | + continue |
| 82 | + raise |
| 83 | + raise RuntimeError( |
| 84 | + f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}" |
| 85 | + ) |
| 86 | + |
| 87 | + |
| 88 | +def _iter_graph_collection( |
| 89 | + initial_url: str, |
| 90 | + get_access_token: Callable[[], str], |
| 91 | + params: dict[str, str] | None = None, |
| 92 | +) -> Generator[dict[str, Any], None, None]: |
| 93 | + """Paginate through a Graph API collection, yielding items one at a time.""" |
| 94 | + url: str | None = initial_url |
| 95 | + while url: |
| 96 | + data = _graph_api_get(url, get_access_token, params) |
| 97 | + params = None |
| 98 | + yield from data.get("value", []) |
| 99 | + url = data.get("@odata.nextLink") |
| 100 | + |
| 101 | + |
| 102 | +def _normalize_email(email: str) -> str: |
| 103 | + if MICROSOFT_DOMAIN in email: |
| 104 | + return email.replace(MICROSOFT_DOMAIN, "") |
| 105 | + return email |
| 106 | + |
| 107 | + |
36 | 108 | class SharepointGroup(BaseModel): |
37 | 109 | model_config = {"frozen": True} |
38 | 110 |
|
@@ -572,8 +644,63 @@ def add_user_and_group_to_sets( |
572 | 644 | ) |
573 | 645 |
|
574 | 646 |
|
| 647 | +def _enumerate_ad_groups_paginated( |
| 648 | + get_access_token: Callable[[], str], |
| 649 | + already_resolved: set[str], |
| 650 | +) -> Generator[ExternalUserGroup, None, None]: |
| 651 | + """Paginate through all Azure AD groups and yield ExternalUserGroup for each. |
| 652 | +
|
| 653 | + Skips groups whose suffixed name is already in *already_resolved*. |
| 654 | + Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD. |
| 655 | + """ |
| 656 | + groups_url = f"{GRAPH_API_BASE}/groups" |
| 657 | + groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"} |
| 658 | + total_groups = 0 |
| 659 | + |
| 660 | + for group_json in _iter_graph_collection( |
| 661 | + groups_url, get_access_token, groups_params |
| 662 | + ): |
| 663 | + group_id: str = group_json.get("id", "") |
| 664 | + display_name: str = group_json.get("displayName", "") |
| 665 | + if not group_id or not display_name: |
| 666 | + continue |
| 667 | + |
| 668 | + total_groups += 1 |
| 669 | + if total_groups > AD_GROUP_ENUMERATION_THRESHOLD: |
| 670 | + logger.warning( |
| 671 | + f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} " |
| 672 | + "groups — stopping to avoid excessive memory/API usage. " |
| 673 | + "Remaining groups will be resolved from role assignments only." |
| 674 | + ) |
| 675 | + return |
| 676 | + |
| 677 | + name = f"{display_name}_{group_id}" |
| 678 | + if name in already_resolved: |
| 679 | + continue |
| 680 | + |
| 681 | + member_emails: list[str] = [] |
| 682 | + members_url = f"{GRAPH_API_BASE}/groups/{group_id}/members" |
| 683 | + members_params: dict[str, str] = { |
| 684 | + "$select": "userPrincipalName,mail", |
| 685 | + "$top": "999", |
| 686 | + } |
| 687 | + for member_json in _iter_graph_collection( |
| 688 | + members_url, get_access_token, members_params |
| 689 | + ): |
| 690 | + email = member_json.get("userPrincipalName") or member_json.get("mail") |
| 691 | + if email: |
| 692 | + member_emails.append(_normalize_email(email)) |
| 693 | + |
| 694 | + yield ExternalUserGroup(id=name, user_emails=member_emails) |
| 695 | + |
| 696 | + logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API") |
| 697 | + |
| 698 | + |
575 | 699 | def get_sharepoint_external_groups( |
576 | | - client_context: ClientContext, graph_client: GraphClient |
| 700 | + client_context: ClientContext, |
| 701 | + graph_client: GraphClient, |
| 702 | + get_access_token: Callable[[], str] | None = None, |
| 703 | + enumerate_all_ad_groups: bool = False, |
577 | 704 | ) -> list[ExternalUserGroup]: |
578 | 705 |
|
579 | 706 | groups: set[SharepointGroup] = set() |
@@ -629,57 +756,20 @@ def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None: |
629 | 756 | client_context, graph_client, groups, is_group_sync=True |
630 | 757 | ) |
631 | 758 |
|
632 | | - # get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them |
633 | | - # We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups |
634 | | - azure_ad_groups = sleep_and_retry( |
635 | | - graph_client.groups.get_all(page_loaded=lambda _: None), |
636 | | - "get_sharepoint_external_groups:get_azure_ad_groups", |
637 | | - ) |
638 | | - logger.info(f"Azure AD Groups: {len(azure_ad_groups)}") |
639 | | - identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys()) |
640 | | - ad_groups_to_emails: dict[str, set[str]] = {} |
641 | | - for group in azure_ad_groups: |
642 | | - # If the group is already identified, we don't need to get the members |
643 | | - if group.display_name in identified_groups: |
644 | | - continue |
645 | | - # AD groups allows same display name for multiple groups, so we need to add the GUID to the name |
646 | | - name = group.display_name |
647 | | - name = _get_group_name_with_suffix(group.id, name, graph_client) |
648 | | - |
649 | | - members = sleep_and_retry( |
650 | | - group.members.get_all(page_loaded=lambda _: None), |
651 | | - "get_sharepoint_external_groups:get_azure_ad_groups:get_members", |
652 | | - ) |
653 | | - for member in members: |
654 | | - member_data = member.to_json() |
655 | | - user_principal_name = member_data.get("userPrincipalName") |
656 | | - mail = member_data.get("mail") |
657 | | - if not ad_groups_to_emails.get(name): |
658 | | - ad_groups_to_emails[name] = set() |
659 | | - if user_principal_name: |
660 | | - if MICROSOFT_DOMAIN in user_principal_name: |
661 | | - user_principal_name = user_principal_name.replace( |
662 | | - MICROSOFT_DOMAIN, "" |
663 | | - ) |
664 | | - ad_groups_to_emails[name].add(user_principal_name) |
665 | | - elif mail: |
666 | | - if MICROSOFT_DOMAIN in mail: |
667 | | - mail = mail.replace(MICROSOFT_DOMAIN, "") |
668 | | - ad_groups_to_emails[name].add(mail) |
| 759 | + external_user_groups: list[ExternalUserGroup] = [ |
| 760 | + ExternalUserGroup(id=group_name, user_emails=list(emails)) |
| 761 | + for group_name, emails in groups_and_members.groups_to_emails.items() |
| 762 | + ] |
669 | 763 |
|
670 | | - external_user_groups: list[ExternalUserGroup] = [] |
671 | | - for group_name, emails in groups_and_members.groups_to_emails.items(): |
672 | | - external_user_group = ExternalUserGroup( |
673 | | - id=group_name, |
674 | | - user_emails=list(emails), |
| 764 | + if not enumerate_all_ad_groups or get_access_token is None: |
| 765 | + logger.info( |
| 766 | + "Skipping exhaustive Azure AD group enumeration. " |
| 767 | + "Only groups found in site role assignments are included." |
675 | 768 | ) |
676 | | - external_user_groups.append(external_user_group) |
| 769 | + return external_user_groups |
677 | 770 |
|
678 | | - for group_name, emails in ad_groups_to_emails.items(): |
679 | | - external_user_group = ExternalUserGroup( |
680 | | - id=group_name, |
681 | | - user_emails=list(emails), |
682 | | - ) |
683 | | - external_user_groups.append(external_user_group) |
| 771 | + already_resolved = set(groups_and_members.groups_to_emails.keys()) |
| 772 | + for group in _enumerate_ad_groups_paginated(get_access_token, already_resolved): |
| 773 | + external_user_groups.append(group) |
684 | 774 |
|
685 | 775 | return external_user_groups |
0 commit comments