|
1 | 1 | import logging |
2 | 2 | import warnings |
3 | | -from collections.abc import Callable, Generator |
| 3 | +from collections.abc import Callable, Generator, Iterable |
4 | 4 | from datetime import timedelta |
5 | 5 |
|
| 6 | +import pytest |
6 | 7 | from pytest import fixture |
7 | | -from databricks.sdk import AccountGroupsAPI, GroupsAPI, WorkspaceClient |
| 8 | +from databricks.sdk.credentials_provider import OAuthCredentialsProvider, OauthCredentialsStrategy |
| 9 | +from databricks.sdk.oauth import ClientCredentials, Token |
| 10 | +from databricks.sdk.service.oauth2 import CreateServicePrincipalSecretResponse |
| 11 | +from databricks.labs.lsql import Row |
| 12 | +from databricks.labs.lsql.backends import StatementExecutionBackend, SqlBackend |
| 13 | +from databricks.sdk import AccountGroupsAPI, GroupsAPI, WorkspaceClient, AccountClient |
8 | 14 | from databricks.sdk.config import Config |
9 | 15 | from databricks.sdk.errors import ResourceConflict, NotFound |
10 | 16 | from databricks.sdk.retries import retried |
11 | 17 | from databricks.sdk.service import iam |
12 | | -from databricks.sdk.service.iam import User, Group |
| 18 | +from databricks.sdk.service.iam import ( |
| 19 | + User, |
| 20 | + Group, |
| 21 | + ServicePrincipal, |
| 22 | + Patch, |
| 23 | + PatchOp, |
| 24 | + ComplexValue, |
| 25 | + PatchSchema, |
| 26 | + WorkspacePermission, |
| 27 | +) |
13 | 28 |
|
14 | 29 | from databricks.labs.pytester.fixtures.baseline import factory |
15 | 30 |
|
@@ -183,3 +198,177 @@ def create( |
183 | 198 | return group |
184 | 199 |
|
185 | 200 | yield from factory(name, create, lambda item: interface.delete(item.id)) |
| 201 | + |
| 202 | + |
| 203 | +class RunAs: |
| 204 | + def __init__(self, service_principal: ServicePrincipal, workspace_client: WorkspaceClient, env_or_skip): |
| 205 | + self._service_principal = service_principal |
| 206 | + self._workspace_client = workspace_client |
| 207 | + self._env_or_skip = env_or_skip |
| 208 | + |
| 209 | + @property |
| 210 | + def ws(self): |
| 211 | + return self._workspace_client |
| 212 | + |
| 213 | + @property |
| 214 | + def sql_backend(self) -> SqlBackend: |
| 215 | + # TODO: Switch to `__getattr__` + `SubRequest` to get a generic way of initializing all workspace fixtures. |
| 216 | + # This will allow us to remove the `sql_backend` fixture and make the `RunAs` class more generic. |
| 217 | + # It turns out to be more complicated than it first appears, because we don't get these at pytest.collect phase. |
| 218 | + warehouse_id = self._env_or_skip("DATABRICKS_WAREHOUSE_ID") |
| 219 | + return StatementExecutionBackend(self._workspace_client, warehouse_id) |
| 220 | + |
| 221 | + def sql_exec(self, statement: str) -> None: |
| 222 | + return self.sql_backend.execute(statement) |
| 223 | + |
| 224 | + def sql_fetch_all(self, statement: str) -> Iterable[Row]: |
| 225 | + return self.sql_backend.fetch(statement) |
| 226 | + |
| 227 | + def __getattr__(self, item: str): |
| 228 | + if item in self.__dict__: |
| 229 | + return self.__dict__[item] |
| 230 | + fixture_value = self._request.getfixturevalue(item) |
| 231 | + return fixture_value |
| 232 | + |
| 233 | + @property |
| 234 | + def display_name(self) -> str: |
| 235 | + assert self._service_principal.display_name is not None |
| 236 | + return self._service_principal.display_name |
| 237 | + |
| 238 | + @property |
| 239 | + def application_id(self) -> str: |
| 240 | + assert self._service_principal.application_id is not None |
| 241 | + return self._service_principal.application_id |
| 242 | + |
| 243 | + def __repr__(self): |
| 244 | + return f'RunAs({self.display_name})' |
| 245 | + |
| 246 | + |
| 247 | +def _make_workspace_client( |
| 248 | + ws: WorkspaceClient, |
| 249 | + created_secret: CreateServicePrincipalSecretResponse, |
| 250 | + service_principal: ServicePrincipal, |
| 251 | +) -> WorkspaceClient: |
| 252 | + oidc = ws.config.oidc_endpoints |
| 253 | + assert oidc is not None, 'OIDC is required' |
| 254 | + application_id = service_principal.application_id |
| 255 | + secret_value = created_secret.secret |
| 256 | + assert application_id is not None |
| 257 | + assert secret_value is not None |
| 258 | + |
| 259 | + token_source = ClientCredentials( |
| 260 | + client_id=application_id, |
| 261 | + client_secret=secret_value, |
| 262 | + token_url=oidc.token_endpoint, |
| 263 | + scopes=["all-apis"], |
| 264 | + use_header=True, |
| 265 | + ) |
| 266 | + |
| 267 | + def inner() -> dict[str, str]: |
| 268 | + inner_token = token_source.token() |
| 269 | + return {'Authorization': f'{inner_token.token_type} {inner_token.access_token}'} |
| 270 | + |
| 271 | + def token() -> Token: |
| 272 | + return token_source.token() |
| 273 | + |
| 274 | + credentials_provider = OAuthCredentialsProvider(inner, token) |
| 275 | + credentials_strategy = OauthCredentialsStrategy('oauth-m2m', lambda _: credentials_provider) |
| 276 | + ws_as_spn = WorkspaceClient(host=ws.config.host, credentials_strategy=credentials_strategy) |
| 277 | + return ws_as_spn |
| 278 | + |
| 279 | + |
| 280 | +@fixture |
| 281 | +def make_run_as(acc: AccountClient, ws: WorkspaceClient, make_random, env_or_skip, log_account_link, is_in_debug): |
| 282 | + """ |
| 283 | + This fixture provides a function to create an account service principal via [`acc` fixture](#acc-fixture) and |
| 284 | + assign it to a workspace. The service principal is removed after the test is complete. The service principal is |
| 285 | + created with a random display name and assigned to the workspace with the default permissions. |
| 286 | +
|
| 287 | + Use the `account_groups` argument to assign the service principal to account groups, which have the required |
| 288 | + permissions to perform a specific action. |
| 289 | +
|
| 290 | + Example: |
| 291 | +
|
| 292 | + ```python |
| 293 | + def test_run_as_lower_privilege_user(make_run_as, ws): |
| 294 | + run_as = make_run_as(account_groups=['account.group.name']) |
| 295 | + through_query = next(run_as.sql_fetch_all("SELECT CURRENT_USER() AS my_name")) |
| 296 | + me = ws.current_user.me() |
| 297 | + assert me.user_name != through_query.my_name |
| 298 | + ``` |
| 299 | +
|
| 300 | + Returned object has the following properties: |
| 301 | + * `ws`: Workspace client that is authenticated as the ephemeral service principal. |
| 302 | + * `sql_backend`: SQL backend that is authenticated as the ephemeral service principal. |
| 303 | + * `sql_exec`: Function to execute a SQL statement on behalf of the ephemeral service principal. |
| 304 | + * `sql_fetch_all`: Function to fetch all rows from a SQL statement on behalf of the ephemeral service principal. |
| 305 | + * `display_name`: Display name of the ephemeral service principal. |
| 306 | + * `application_id`: Application ID of the ephemeral service principal. |
| 307 | + * if you want to have other fixtures available in the context of the ephemeral service principal, you can override |
| 308 | + the [`ws` fixture](#ws-fixture) on the file level, which would make all workspace fixtures provided by this |
| 309 | + plugin to run as lower privilege ephemeral service principal. You cannot combine it with the account-admin-level |
| 310 | + principal you're using to create the ephemeral principal. |
| 311 | +
|
| 312 | + Example: |
| 313 | +
|
| 314 | + ```python |
| 315 | + from pytest import fixture |
| 316 | +
|
| 317 | + @fixture |
| 318 | + def ws(make_run_as): |
| 319 | + run_as = make_run_as(account_groups=['account.group.used.for.all.tests.in.this.file']) |
| 320 | + return run_as.ws |
| 321 | +
|
| 322 | + def test_creating_notebook_on_behalf_of_ephemeral_principal(make_notebook): |
| 323 | + notebook = make_notebook() |
| 324 | + assert notebook.exists() |
| 325 | + ``` |
| 326 | +
|
| 327 | + This fixture currently doesn't work with Databricks Metadata Service authentication on Azure Databricks. |
| 328 | + """ |
| 329 | + |
| 330 | + if ws.config.auth_type == 'metadata-service' and ws.config.is_azure: |
| 331 | + # TODO: fix `invalid_scope: AADSTS1002012: The provided value for scope all-apis is not valid.` error |
| 332 | + # |
| 333 | + # We're having issues with the Azure Metadata Service and service principals. The error message is: |
| 334 | + # Client credential flows must have a scope value with /.default suffixed to the resource identifier |
| 335 | + # (application ID URI) |
| 336 | + pytest.skip('Azure Metadata Service does not support service principals') |
| 337 | + |
| 338 | + def create(*, account_groups: list[str] | None = None): |
| 339 | + workspace_id = ws.get_workspace_id() |
| 340 | + service_principal = acc.service_principals.create(display_name=f'spn-{make_random()}') |
| 341 | + assert service_principal.id is not None |
| 342 | + service_principal_id = int(service_principal.id) |
| 343 | + created_secret = acc.service_principal_secrets.create(service_principal_id) |
| 344 | + if account_groups: |
| 345 | + group_mapping = {} |
| 346 | + for group in acc.groups.list(attributes='id,displayName'): |
| 347 | + if group.id is None: |
| 348 | + continue |
| 349 | + group_mapping[group.display_name] = group.id |
| 350 | + for group_name in account_groups: |
| 351 | + if group_name not in group_mapping: |
| 352 | + raise ValueError(f'Group {group_name} does not exist') |
| 353 | + group_id = group_mapping[group_name] |
| 354 | + acc.groups.patch( |
| 355 | + group_id, |
| 356 | + operations=[ |
| 357 | + Patch(PatchOp.ADD, 'members', [ComplexValue(value=str(service_principal_id)).as_dict()]), |
| 358 | + ], |
| 359 | + schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], |
| 360 | + ) |
| 361 | + permissions = [WorkspacePermission.USER] |
| 362 | + acc.workspace_assignment.update(workspace_id, service_principal_id, permissions=permissions) |
| 363 | + ws_as_spn = _make_workspace_client(ws, created_secret, service_principal) |
| 364 | + |
| 365 | + log_account_link('account service principal', f'users/serviceprincipals/{service_principal_id}') |
| 366 | + |
| 367 | + return RunAs(service_principal, ws_as_spn, env_or_skip) |
| 368 | + |
| 369 | + def remove(run_as: RunAs): |
| 370 | + service_principal_id = run_as._service_principal.id # pylint: disable=protected-access |
| 371 | + assert service_principal_id is not None |
| 372 | + acc.service_principals.delete(service_principal_id) |
| 373 | + |
| 374 | + yield from factory("service principal", create, remove) |
0 commit comments