|
3 | 3 | from collections.abc import Callable, Generator, Iterable |
4 | 4 | from datetime import timedelta |
5 | 5 |
|
| 6 | +from databricks.sdk.credentials_provider import OAuthCredentialsProvider, OauthCredentialsStrategy |
| 7 | +from databricks.sdk.oauth import ClientCredentials, Token |
6 | 8 | from pytest import fixture |
7 | 9 | from databricks.labs.lsql import Row |
8 | 10 | from databricks.labs.lsql.backends import StatementExecutionBackend, SqlBackend |
@@ -313,17 +315,39 @@ def create(*, account_groups: list[str] | None = None): |
313 | 315 | ) |
314 | 316 | permissions = [WorkspacePermission.USER] |
315 | 317 | acc.workspace_assignment.update(workspace_id, service_principal_id, permissions=permissions) |
316 | | - ws_as_spn = WorkspaceClient( |
317 | | - host=ws.config.host, |
318 | | - auth_type='oauth-m2m', |
319 | | - client_id=service_principal.application_id, |
320 | | - client_secret=created_secret.secret, |
321 | | - ) |
| 318 | + ws_as_spn = _make_workspace_client(created_secret, service_principal) |
322 | 319 |
|
323 | 320 | log_account_link('account service principal', f'users/serviceprincipals/{service_principal_id}') |
324 | 321 |
|
325 | 322 | return RunAs(service_principal, ws_as_spn, env_or_skip) |
326 | 323 |
|
| 324 | + def _make_workspace_client(created_secret, service_principal): |
| 325 | + oidc = ws.config.oidc_endpoints |
| 326 | + assert oidc is not None, 'OIDC is required' |
| 327 | + application_id = service_principal.application_id |
| 328 | + secret_value = created_secret.secret |
| 329 | + assert application_id is not None |
| 330 | + assert secret_value is not None |
| 331 | + token_source = ClientCredentials( |
| 332 | + client_id=application_id, |
| 333 | + client_secret=secret_value, |
| 334 | + token_url=oidc.token_endpoint, |
| 335 | + scopes=["all-apis"], |
| 336 | + use_header=True, |
| 337 | + ) |
| 338 | + |
| 339 | + def inner() -> dict[str, str]: |
| 340 | + inner_token = token_source.token() |
| 341 | + return {'Authorization': f'{inner_token.token_type} {inner_token.access_token}'} |
| 342 | + |
| 343 | + def token() -> Token: |
| 344 | + return token_source.token() |
| 345 | + |
| 346 | + credentials_provider = OAuthCredentialsProvider(inner, token) |
| 347 | + credentials_strategy = OauthCredentialsStrategy('oauth-m2m', lambda _: credentials_provider) |
| 348 | + ws_as_spn = WorkspaceClient(host=ws.config.host, credentials_strategy=credentials_strategy) |
| 349 | + return ws_as_spn |
| 350 | + |
327 | 351 | def remove(run_as: RunAs): |
328 | 352 | service_principal_id = run_as._service_principal.id # pylint: disable=protected-access |
329 | 353 | assert service_principal_id is not None |
|
0 commit comments