Skip to content

Commit 17c2746

Browse files
committed
...
1 parent 5a352da commit 17c2746

File tree

1 file changed

+30
-6
lines changed
  • src/databricks/labs/pytester/fixtures

1 file changed

+30
-6
lines changed

src/databricks/labs/pytester/fixtures/iam.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from collections.abc import Callable, Generator, Iterable
44
from datetime import timedelta
55

6+
from databricks.sdk.credentials_provider import OAuthCredentialsProvider, OauthCredentialsStrategy
7+
from databricks.sdk.oauth import ClientCredentials, Token
68
from pytest import fixture
79
from databricks.labs.lsql import Row
810
from databricks.labs.lsql.backends import StatementExecutionBackend, SqlBackend
@@ -313,17 +315,39 @@ def create(*, account_groups: list[str] | None = None):
313315
)
314316
permissions = [WorkspacePermission.USER]
315317
acc.workspace_assignment.update(workspace_id, service_principal_id, permissions=permissions)
316-
ws_as_spn = WorkspaceClient(
317-
host=ws.config.host,
318-
auth_type='oauth-m2m',
319-
client_id=service_principal.application_id,
320-
client_secret=created_secret.secret,
321-
)
318+
ws_as_spn = _make_workspace_client(created_secret, service_principal)
322319

323320
log_account_link('account service principal', f'users/serviceprincipals/{service_principal_id}')
324321

325322
return RunAs(service_principal, ws_as_spn, env_or_skip)
326323

324+
def _make_workspace_client(created_secret, service_principal):
325+
oidc = ws.config.oidc_endpoints
326+
assert oidc is not None, 'OIDC is required'
327+
application_id = service_principal.application_id
328+
secret_value = created_secret.secret
329+
assert application_id is not None
330+
assert secret_value is not None
331+
token_source = ClientCredentials(
332+
client_id=application_id,
333+
client_secret=secret_value,
334+
token_url=oidc.token_endpoint,
335+
scopes=["all-apis"],
336+
use_header=True,
337+
)
338+
339+
def inner() -> dict[str, str]:
340+
inner_token = token_source.token()
341+
return {'Authorization': f'{inner_token.token_type} {inner_token.access_token}'}
342+
343+
def token() -> Token:
344+
return token_source.token()
345+
346+
credentials_provider = OAuthCredentialsProvider(inner, token)
347+
credentials_strategy = OauthCredentialsStrategy('oauth-m2m', lambda _: credentials_provider)
348+
ws_as_spn = WorkspaceClient(host=ws.config.host, credentials_strategy=credentials_strategy)
349+
return ws_as_spn
350+
327351
def remove(run_as: RunAs):
328352
service_principal_id = run_as._service_principal.id # pylint: disable=protected-access
329353
assert service_principal_id is not None

0 commit comments

Comments
 (0)