Skip to content

Commit 60f4f6d

Browse files
authored
Added make_run_as fixture (#82)
### `make_run_as` fixture This fixture provides a function to create an account service principal via [`acc` fixture](#acc-fixture) and assign it to a workspace. The service principal is removed after the test is complete. The service principal is created with a random display name and assigned to the workspace with the default permissions. Use the `account_groups` argument to assign the service principal to account groups, which have the required permissions to perform a specific action. Example: ```python def test_run_as_lower_privilege_user(make_run_as, ws): run_as = make_run_as(account_groups=['account.group.name']) through_query = next(run_as.sql_fetch_all("SELECT CURRENT_USER() AS my_name")) me = ws.current_user.me() assert me.user_name != through_query.my_name ``` Returned object has the following properties: * `ws`: Workspace client that is authenticated as the ephemeral service principal. * `sql_backend`: SQL backend that is authenticated as the ephemeral service principal. * `sql_exec`: Function to execute a SQL statement on behalf of the ephemeral service principal. * `sql_fetch_all`: Function to fetch all rows from a SQL statement on behalf of the ephemeral service principal. * `display_name`: Display name of the ephemeral service principal. * `application_id`: Application ID of the ephemeral service principal. * if you want to have other fixtures available in the context of the ephemeral service principal, you can override the [`ws` fixture](#ws-fixture) on the file level, which would make all workspace fixtures provided by this plugin to run as lower privilege ephemeral service principal. You cannot combine it with the account-admin-level principal you're using to create the ephemeral principal. Example: ```python from pytest import fixture @fixture def ws(make_run_as): run_as = make_run_as(account_groups=['account.group.used.for.all.tests.in.this.file']) return run_as.ws def test_creating_notebook_on_behalf_of_ephemeral_principal(make_notebook): notebook = make_notebook() assert notebook.exists() ``` See also [`acc`](#acc-fixture), [`ws`](#ws-fixture), [`make_random`](#make_random-fixture), [`env_or_skip`](#env_or_skip-fixture), [`log_account_link`](#log_account_link-fixture).
1 parent 52d074f commit 60f4f6d

File tree

11 files changed

+342
-35
lines changed

11 files changed

+342
-35
lines changed

README.md

Lines changed: 92 additions & 22 deletions
Large diffs are not rendered by default.

scripts/gen-readme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def discover_fixtures() -> list[Fixture]:
7070
upstreams = []
7171
sig = inspect.signature(fn)
7272
for param in sig.parameters.values():
73-
if param.name in {'fresh_local_wheel_file', 'monkeypatch', 'log_workspace_link'}:
73+
if param.name in {'fresh_local_wheel_file', 'monkeypatch', 'log_workspace_link', 'request'}:
7474
continue
7575
upstreams.append(param.name)
7676
see_also[param.name].add(fixture)

src/databricks/labs/pytester/fixtures/baseline.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,15 @@ def inner(name: str, path: str, *, anchor: bool = True):
222222
_LOG.info(f'Created {name}: {url}')
223223

224224
return inner
225+
226+
227+
@fixture
228+
def log_account_link(acc):
229+
"""Returns a function to log an account link."""
230+
231+
def inner(name: str, path: str, *, anchor: bool = False):
232+
a = '#' if anchor else ''
233+
url = f'https://{acc.config.hostname}/{a}{path}'
234+
_LOG.info(f'Created {name}: {url}')
235+
236+
return inner

src/databricks/labs/pytester/fixtures/catalog.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
logger = logging.getLogger(__name__)
2626

2727

28+
# TODO: replace with LSQL implementation
2829
def escape_sql_identifier(path: str, *, maxsplit: int = 2) -> str:
2930
"""
3031
Escapes the path components to make them SQL safe.

src/databricks/labs/pytester/fixtures/iam.py

Lines changed: 192 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
import logging
22
import warnings
3-
from collections.abc import Callable, Generator
3+
from collections.abc import Callable, Generator, Iterable
44
from datetime import timedelta
55

6+
import pytest
67
from pytest import fixture
7-
from databricks.sdk import AccountGroupsAPI, GroupsAPI, WorkspaceClient
8+
from databricks.sdk.credentials_provider import OAuthCredentialsProvider, OauthCredentialsStrategy
9+
from databricks.sdk.oauth import ClientCredentials, Token
10+
from databricks.sdk.service.oauth2 import CreateServicePrincipalSecretResponse
11+
from databricks.labs.lsql import Row
12+
from databricks.labs.lsql.backends import StatementExecutionBackend, SqlBackend
13+
from databricks.sdk import AccountGroupsAPI, GroupsAPI, WorkspaceClient, AccountClient
814
from databricks.sdk.config import Config
915
from databricks.sdk.errors import ResourceConflict, NotFound
1016
from databricks.sdk.retries import retried
1117
from databricks.sdk.service import iam
12-
from databricks.sdk.service.iam import User, Group
18+
from databricks.sdk.service.iam import (
19+
User,
20+
Group,
21+
ServicePrincipal,
22+
Patch,
23+
PatchOp,
24+
ComplexValue,
25+
PatchSchema,
26+
WorkspacePermission,
27+
)
1328

1429
from databricks.labs.pytester.fixtures.baseline import factory
1530

@@ -183,3 +198,177 @@ def create(
183198
return group
184199

185200
yield from factory(name, create, lambda item: interface.delete(item.id))
201+
202+
203+
class RunAs:
204+
def __init__(self, service_principal: ServicePrincipal, workspace_client: WorkspaceClient, env_or_skip):
205+
self._service_principal = service_principal
206+
self._workspace_client = workspace_client
207+
self._env_or_skip = env_or_skip
208+
209+
@property
210+
def ws(self):
211+
return self._workspace_client
212+
213+
@property
214+
def sql_backend(self) -> SqlBackend:
215+
# TODO: Switch to `__getattr__` + `SubRequest` to get a generic way of initializing all workspace fixtures.
216+
# This will allow us to remove the `sql_backend` fixture and make the `RunAs` class more generic.
217+
# It turns out to be more complicated than it first appears, because we don't get these at pytest.collect phase.
218+
warehouse_id = self._env_or_skip("DATABRICKS_WAREHOUSE_ID")
219+
return StatementExecutionBackend(self._workspace_client, warehouse_id)
220+
221+
def sql_exec(self, statement: str) -> None:
222+
return self.sql_backend.execute(statement)
223+
224+
def sql_fetch_all(self, statement: str) -> Iterable[Row]:
225+
return self.sql_backend.fetch(statement)
226+
227+
def __getattr__(self, item: str):
228+
if item in self.__dict__:
229+
return self.__dict__[item]
230+
fixture_value = self._request.getfixturevalue(item)
231+
return fixture_value
232+
233+
@property
234+
def display_name(self) -> str:
235+
assert self._service_principal.display_name is not None
236+
return self._service_principal.display_name
237+
238+
@property
239+
def application_id(self) -> str:
240+
assert self._service_principal.application_id is not None
241+
return self._service_principal.application_id
242+
243+
def __repr__(self):
244+
return f'RunAs({self.display_name})'
245+
246+
247+
def _make_workspace_client(
248+
ws: WorkspaceClient,
249+
created_secret: CreateServicePrincipalSecretResponse,
250+
service_principal: ServicePrincipal,
251+
) -> WorkspaceClient:
252+
oidc = ws.config.oidc_endpoints
253+
assert oidc is not None, 'OIDC is required'
254+
application_id = service_principal.application_id
255+
secret_value = created_secret.secret
256+
assert application_id is not None
257+
assert secret_value is not None
258+
259+
token_source = ClientCredentials(
260+
client_id=application_id,
261+
client_secret=secret_value,
262+
token_url=oidc.token_endpoint,
263+
scopes=["all-apis"],
264+
use_header=True,
265+
)
266+
267+
def inner() -> dict[str, str]:
268+
inner_token = token_source.token()
269+
return {'Authorization': f'{inner_token.token_type} {inner_token.access_token}'}
270+
271+
def token() -> Token:
272+
return token_source.token()
273+
274+
credentials_provider = OAuthCredentialsProvider(inner, token)
275+
credentials_strategy = OauthCredentialsStrategy('oauth-m2m', lambda _: credentials_provider)
276+
ws_as_spn = WorkspaceClient(host=ws.config.host, credentials_strategy=credentials_strategy)
277+
return ws_as_spn
278+
279+
280+
@fixture
281+
def make_run_as(acc: AccountClient, ws: WorkspaceClient, make_random, env_or_skip, log_account_link, is_in_debug):
282+
"""
283+
This fixture provides a function to create an account service principal via [`acc` fixture](#acc-fixture) and
284+
assign it to a workspace. The service principal is removed after the test is complete. The service principal is
285+
created with a random display name and assigned to the workspace with the default permissions.
286+
287+
Use the `account_groups` argument to assign the service principal to account groups, which have the required
288+
permissions to perform a specific action.
289+
290+
Example:
291+
292+
```python
293+
def test_run_as_lower_privilege_user(make_run_as, ws):
294+
run_as = make_run_as(account_groups=['account.group.name'])
295+
through_query = next(run_as.sql_fetch_all("SELECT CURRENT_USER() AS my_name"))
296+
me = ws.current_user.me()
297+
assert me.user_name != through_query.my_name
298+
```
299+
300+
Returned object has the following properties:
301+
* `ws`: Workspace client that is authenticated as the ephemeral service principal.
302+
* `sql_backend`: SQL backend that is authenticated as the ephemeral service principal.
303+
* `sql_exec`: Function to execute a SQL statement on behalf of the ephemeral service principal.
304+
* `sql_fetch_all`: Function to fetch all rows from a SQL statement on behalf of the ephemeral service principal.
305+
* `display_name`: Display name of the ephemeral service principal.
306+
* `application_id`: Application ID of the ephemeral service principal.
307+
* if you want to have other fixtures available in the context of the ephemeral service principal, you can override
308+
the [`ws` fixture](#ws-fixture) on the file level, which would make all workspace fixtures provided by this
309+
plugin to run as lower privilege ephemeral service principal. You cannot combine it with the account-admin-level
310+
principal you're using to create the ephemeral principal.
311+
312+
Example:
313+
314+
```python
315+
from pytest import fixture
316+
317+
@fixture
318+
def ws(make_run_as):
319+
run_as = make_run_as(account_groups=['account.group.used.for.all.tests.in.this.file'])
320+
return run_as.ws
321+
322+
def test_creating_notebook_on_behalf_of_ephemeral_principal(make_notebook):
323+
notebook = make_notebook()
324+
assert notebook.exists()
325+
```
326+
327+
This fixture currently doesn't work with Databricks Metadata Service authentication on Azure Databricks.
328+
"""
329+
330+
if ws.config.auth_type == 'metadata-service' and ws.config.is_azure:
331+
# TODO: fix `invalid_scope: AADSTS1002012: The provided value for scope all-apis is not valid.` error
332+
#
333+
# We're having issues with the Azure Metadata Service and service principals. The error message is:
334+
# Client credential flows must have a scope value with /.default suffixed to the resource identifier
335+
# (application ID URI)
336+
pytest.skip('Azure Metadata Service does not support service principals')
337+
338+
def create(*, account_groups: list[str] | None = None):
339+
workspace_id = ws.get_workspace_id()
340+
service_principal = acc.service_principals.create(display_name=f'spn-{make_random()}')
341+
assert service_principal.id is not None
342+
service_principal_id = int(service_principal.id)
343+
created_secret = acc.service_principal_secrets.create(service_principal_id)
344+
if account_groups:
345+
group_mapping = {}
346+
for group in acc.groups.list(attributes='id,displayName'):
347+
if group.id is None:
348+
continue
349+
group_mapping[group.display_name] = group.id
350+
for group_name in account_groups:
351+
if group_name not in group_mapping:
352+
raise ValueError(f'Group {group_name} does not exist')
353+
group_id = group_mapping[group_name]
354+
acc.groups.patch(
355+
group_id,
356+
operations=[
357+
Patch(PatchOp.ADD, 'members', [ComplexValue(value=str(service_principal_id)).as_dict()]),
358+
],
359+
schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP],
360+
)
361+
permissions = [WorkspacePermission.USER]
362+
acc.workspace_assignment.update(workspace_id, service_principal_id, permissions=permissions)
363+
ws_as_spn = _make_workspace_client(ws, created_secret, service_principal)
364+
365+
log_account_link('account service principal', f'users/serviceprincipals/{service_principal_id}')
366+
367+
return RunAs(service_principal, ws_as_spn, env_or_skip)
368+
369+
def remove(run_as: RunAs):
370+
service_principal_id = run_as._service_principal.id # pylint: disable=protected-access
371+
assert service_principal_id is not None
372+
acc.service_principals.delete(service_principal_id)
373+
374+
yield from factory("service principal", create, remove)

src/databricks/labs/pytester/fixtures/notebooks.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,18 @@ def create(
5656
default_content = "SELECT 1"
5757
else:
5858
raise ValueError(f"Unsupported language: {language}")
59-
path = path or f"/Users/{ws.current_user.me().user_name}/dummy-{make_random(4)}-{watchdog_purge_suffix}"
59+
current_user = ws.current_user.me()
60+
path = path or f"/Users/{current_user.user_name}/dummy-{make_random(4)}-{watchdog_purge_suffix}"
61+
workspace_path = WorkspacePath(ws, path)
62+
if '@' not in current_user.user_name:
63+
# If current user is a service principal added with `make_run_as`, there might be no home folder
64+
workspace_path.parent.mkdir(exist_ok=True)
6065
content = content or default_content
6166
if isinstance(content, str):
6267
content = io.BytesIO(content.encode(encoding))
6368
if isinstance(ws, Mock): # For testing
6469
ws.workspace.download.return_value = content if isinstance(content, io.BytesIO) else io.BytesIO(content)
6570
ws.workspace.upload(path, content, language=language, format=format, overwrite=overwrite)
66-
workspace_path = WorkspacePath(ws, path)
6771
logger.info(f"Created notebook: {workspace_path.as_uri()}")
6872
return workspace_path
6973

@@ -110,10 +114,14 @@ def create(
110114
suffix = ".sql"
111115
else:
112116
raise ValueError(f"Unsupported language: {language}")
113-
path = path or f"/Users/{ws.current_user.me().user_name}/dummy-{make_random(4)}-{watchdog_purge_suffix}{suffix}"
117+
current_user = ws.current_user.me()
118+
path = path or f"/Users/{current_user.user_name}/dummy-{make_random(4)}-{watchdog_purge_suffix}{suffix}"
114119
content = content or default_content
115120
encoding = encoding or _DEFAULT_ENCODING
116121
workspace_path = WorkspacePath(ws, path)
122+
if '@' not in current_user.user_name:
123+
# If current user is a service principal added with `make_run_as`, there might be no home folder
124+
workspace_path.parent.mkdir(exist_ok=True)
117125
if isinstance(content, bytes):
118126
workspace_path.write_bytes(content)
119127
else:

src/databricks/labs/pytester/fixtures/plugin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
make_random,
77
product_info,
88
log_workspace_link,
9+
log_account_link,
910
)
1011
from databricks.labs.pytester.fixtures.sql import sql_backend, sql_exec, sql_fetch_all
1112
from databricks.labs.pytester.fixtures.compute import (
@@ -16,7 +17,7 @@
1617
make_pipeline,
1718
make_warehouse,
1819
)
19-
from databricks.labs.pytester.fixtures.iam import make_group, make_acc_group, make_user
20+
from databricks.labs.pytester.fixtures.iam import make_group, make_acc_group, make_user, make_run_as
2021
from databricks.labs.pytester.fixtures.catalog import (
2122
make_udf,
2223
make_catalog,
@@ -60,6 +61,7 @@
6061
'debug_env',
6162
'env_or_skip',
6263
'ws',
64+
'make_run_as',
6365
'acc',
6466
'spark',
6567
'sql_backend',
@@ -105,6 +107,7 @@
105107
'make_warehouse_permissions',
106108
'make_lakeview_dashboard_permissions',
107109
'log_workspace_link',
110+
'log_account_link',
108111
'make_dashboard_permissions',
109112
'make_alert_permissions',
110113
'make_query',

tests/integration/fixtures/test_catalog.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def test_schema_fixture(make_schema):
1717
logger.info(f"Created new schema: {make_schema()}")
1818

1919

20+
@pytest.mark.skip("Invalid configuration value detected for fs.azure.account.key")
2021
def test_managed_schema_fixture(make_schema, make_random, env_or_skip):
2122
schema_name = f"dummy_s{make_random(4)}".lower()
2223
schema_location = f"{env_or_skip('TEST_MOUNT_CONTAINER')}/a/{schema_name}"

tests/integration/fixtures/test_iam.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,10 @@ def test_new_account_group(make_acc_group, acc):
1919
group = make_acc_group()
2020
loaded = acc.groups.get(group.id)
2121
assert group.display_name == loaded.display_name
22+
23+
24+
def test_run_as_lower_privilege_user(make_run_as, ws):
25+
run_as = make_run_as(account_groups=['role.labs.lsql.write'])
26+
through_query = next(run_as.sql_fetch_all("SELECT CURRENT_USER() AS my_name"))
27+
current_user = ws.current_user.me()
28+
assert current_user.user_name != through_query.my_name
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pytest import fixture
2+
3+
4+
@fixture
5+
def ws(make_run_as):
6+
run_as = make_run_as(account_groups=['role.labs.lsql.write'])
7+
return run_as.ws
8+
9+
10+
def test_creating_notebook_on_behalf_of_ephemeral_principal(make_notebook):
11+
notebook = make_notebook()
12+
assert notebook.exists()

0 commit comments

Comments
 (0)