Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -542,16 +542,15 @@ See also [`ws`](#ws-fixture), [`make_random`](#make_random-fixture), [`watchdog_
[[back to top](#python-testing-for-databricks)]

### `make_group` fixture
This fixture provides a function to manage Databricks workspace groups. Groups can be created with
specified members and roles, and they will be deleted after the test is complete. Deals with eventual
consistency issues by retrying the creation process for 30 seconds and allowing up to two minutes
for group to be provisioned. Returns an instance of [`Group`](https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/iam.html#databricks.sdk.service.iam.Group).
This fixture provides a function to manage Databricks workspace groups. Groups can be created with specified
members and roles, and they will be deleted after the test is complete. Deals with eventual consistency issues by
retrying the creation process for 30 seconds and then waiting for up to 3 minutes for the group to be provisioned.
Returns an instance of [`Group`](https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/iam.html#databricks.sdk.service.iam.Group).

Keyword arguments:
* `members` (list of strings): A list of user IDs to add to the group.
* `roles` (list of strings): A list of roles to assign to the group.
* `display_name` (str): The display name of the group.
* `wait_for_provisioning` (bool): If `True`, the function will wait for the group to be provisioned.
* `entitlements` (list of strings): A list of entitlements to assign to the group.

The following example creates a group with a single member and independently verifies that the group was created:
Expand Down
78 changes: 64 additions & 14 deletions src/databricks/labs/pytester/fixtures/iam.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import logging
import warnings
from collections.abc import Generator
from datetime import timedelta

from pytest import fixture
from databricks.sdk import AccountGroupsAPI, GroupsAPI, WorkspaceClient
from databricks.sdk.config import Config
from databricks.sdk.errors import ResourceConflict, NotFound
from databricks.sdk.retries import retried
from databricks.sdk.service.iam import User, Group
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import iam
from databricks.sdk.service.iam import User, Group

from databricks.labs.pytester.fixtures.baseline import factory

Expand Down Expand Up @@ -44,16 +45,15 @@ def create(**kwargs) -> User:
@fixture
def make_group(ws: WorkspaceClient, make_random, watchdog_purge_suffix):
"""
This fixture provides a function to manage Databricks workspace groups. Groups can be created with
specified members and roles, and they will be deleted after the test is complete. Deals with eventual
consistency issues by retrying the creation process for 30 seconds and allowing up to two minutes
for group to be provisioned. Returns an instance of `databricks.sdk.service.iam.Group`.
This fixture provides a function to manage Databricks workspace groups. Groups can be created with specified
members and roles, and they will be deleted after the test is complete. Deals with eventual consistency issues by
retrying the creation process for 30 seconds and then waiting for up to 3 minutes for the group to be provisioned.
Returns an instance of `databricks.sdk.service.iam.Group`.

Keyword arguments:
* `members` (list of strings): A list of user IDs to add to the group.
* `roles` (list of strings): A list of roles to assign to the group.
* `display_name` (str): The display name of the group.
* `wait_for_provisioning` (bool): If `True`, the function will wait for the group to be provisioned.
* `entitlements` (list of strings): A list of entitlements to assign to the group.

The following example creates a group with a single member and independently verifies that the group was created:
Expand Down Expand Up @@ -94,15 +94,63 @@ def _scim_values(ids: list[str]) -> list[iam.ComplexValue]:
return [iam.ComplexValue(value=x) for x in ids]


def _wait_group_provisioned(interface: AccountGroupsAPI | GroupsAPI, group: Group) -> None:
"""Wait for a group to be visible via the supplied group interface.

Due to consistency issues in the group-management APIs, new groups are not always visible in a consistent manner
after being created or modified. This method can be used to mitigate against this by checking that a group:

- Is visible via the `.get()` interface;
- Is visible via the `.list()` interface that enumerates groups.

Visibility is assumed when 2 calls in a row return the expected results.

Args:
interface: the group-management interface to use for checking whether the groups are visible.
group: the group whose visibility should be verified.
Raises:
NotFound: this is thrown if it takes longer than 90 seconds for the group to become visible via the
management interface.
"""
# Use double-checking to try and compensate for the lack of monotonic consistency with the group-management
# interfaces: two subsequent calls need to succeed for us to proceed. (This is probabilistic, and not a guarantee.)
# The REST API internals cache things for up to 60s, and we see times close to this during tests. The retry timeout
# reflects this: if it's taking much longer then something else is wrong.
group_id = group.id
assert group_id is not None

@retried(on=[NotFound], timeout=timedelta(seconds=90))
def _double_get_group() -> None:
interface.get(group_id)
interface.get(group_id)

def _check_group_in_listing() -> None:
found_groups = interface.list(attributes="id", filter=f'id eq "{group_id}"')
found_ids = {found_group.id for found_group in found_groups}
if group_id not in found_ids:
msg = f"Group id not (yet) found in group listing: {group_id}"
raise NotFound(msg)

@retried(on=[NotFound], timeout=timedelta(seconds=90))
def _double_check_group_in_listing() -> None:
_check_group_in_listing()
_check_group_in_listing()

_double_get_group()
_double_check_group_in_listing()


def _make_group(name: str, cfg: Config, interface, make_random, watchdog_purge_suffix) -> Generator[Group, None, None]:
_not_specified = object()

@retried(on=[ResourceConflict], timeout=timedelta(seconds=30))
def create(
*,
members: list[str] | None = None,
roles: list[str] | None = None,
entitlements: list[str] | None = None,
display_name: str | None = None,
wait_for_provisioning: bool = False,
wait_for_provisioning: bool | object = _not_specified,
**kwargs,
):
kwargs["display_name"] = (
Expand All @@ -114,19 +162,21 @@ def create(
kwargs["roles"] = _scim_values(roles)
if entitlements is not None:
kwargs["entitlements"] = _scim_values(entitlements)
if wait_for_provisioning is not _not_specified:
warnings.warn(
"Specifying wait_for_provisioning when making a group is deprecated; we always wait.",
DeprecationWarning,
# Call stack is: create()[iam.py], wrapper()[retries.py], inner()[baseline.py], client_code
stacklevel=4,
)
# TODO: REQUEST_LIMIT_EXCEEDED: GetUserPermissionsRequest RPC token bucket limit has been exceeded.
group = interface.create(**kwargs)
if cfg.is_account_client:
logger.info(f"Account group {group.display_name}: {cfg.host}/users/groups/{group.id}/members")
else:
logger.info(f"Workspace group {group.display_name}: {cfg.host}#setting/accounts/groups/{group.id}")

@retried(on=[NotFound], timeout=timedelta(minutes=2))
def _wait_for_provisioning() -> None:
interface.get(group.id)

if wait_for_provisioning:
_wait_for_provisioning()
_wait_group_provisioned(interface, group)

return group

Expand Down
19 changes: 18 additions & 1 deletion src/databricks/labs/pytester/fixtures/unwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import inspect
from collections.abc import Callable, Generator
from typing import TypeVar
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, TypeVar
from unittest.mock import MagicMock

from databricks.labs.lsql.backends import MockBackend
Expand All @@ -21,6 +23,20 @@ def call_fixture(fixture_fn: Callable[..., T], *args, **kwargs) -> T:
return wrapped.obj(*args, **kwargs)


_FIXTURES: ContextVar[dict[str, Callable[..., T]]] = ContextVar('fixtures')


@contextmanager
def fixtures(**kwargs: Callable[..., Any]) -> Generator[None, None, None]:
prior_fixtures = _FIXTURES.get({})
updated_fixtures = {**prior_fixtures, **kwargs}
token = _FIXTURES.set(updated_fixtures)
try:
yield
finally:
_FIXTURES.reset(token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this modifies shared global state and will definitely break when running in multiple threads. why is this necessary? revert.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need this because the default auto-mocks don't allow the double-check waiting inside make_group() to complete: the return values of .create() and .list() on the (account-)groups REST APIs need to be controlled.

At the moment I don't see a test-specific way to pre-configure the mocks that call_stateful() will set up as fixtures for the fixture being tested. (I didn't find any other unit tests that need this, but may have missed it or there's another path to the same thing.)

Given that pre-configuring mocks for something under test is a fairly common thing to do, I came up with this. It's specifically intended to not only be thread-safe but also async-safe: I don't understand why it will break when running in multiple threads?

The main alternatives that spring to mind are:

  • Use a specific keyword argument to call_stateful() for this purpose that isn't passed through to the fixture under test.
  • Modify the global fixture (inside CallContext) that all users of call_stateful() have available so that it includes the pre-configured mocks that these group-creation tests use. This would be in the same way as make_random (for example) is handled.

Do you have a preference, or is there something else I've overlooked?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do it without changing a global variable - modify call_context right above the result = ctx[some.__name__](**kwargs) (the first alternative)

def apply(ctx: CallContext):
    mock_group = Group(id="an_id")
    ctx['ws].groups.create.return_value = mock_group
    ctx['ws].groups.list.return_value = [mock_group]

ctx, group = call_stateful(make_group, call_context_callback=apply)



class CallContext:
def __init__(self):
self._fixtures = {
Expand All @@ -29,6 +45,7 @@ def __init__(self):
'env_or_skip': self.env_or_skip,
'watchdog_remove_after': '2024091313',
'watchdog_purge_suffix': 'XXXXX',
**_FIXTURES.get({}),
}

def __getitem__(self, name: str):
Expand Down
88 changes: 77 additions & 11 deletions tests/unit/fixtures/test_iam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from databricks.labs.pytester.fixtures.iam import make_user, make_group, make_acc_group
from databricks.labs.pytester.fixtures.unwrap import call_stateful
import warnings
import sys
from unittest.mock import call, create_autospec

import pytest
from databricks.sdk import AccountClient, WorkspaceClient

from databricks.labs.pytester.fixtures.iam import make_user, make_group, make_acc_group, Group
from databricks.labs.pytester.fixtures.unwrap import call_stateful, fixtures


def test_make_user_no_args():
Expand All @@ -11,16 +18,75 @@ def test_make_user_no_args():


def test_make_group_no_args():
ctx, group = call_stateful(make_group)
assert ctx is not None
assert group is not None
ctx['ws'].groups.create.assert_called_once()
# Ensure the wait-for-provisioning logic can complete.
ws = create_autospec(WorkspaceClient)
mock_group = Group(id="an_id")
ws.groups.create.return_value = mock_group
ws.groups.list.return_value = [mock_group]

# Perform the test.
with fixtures(ws=ws):
ctx, group = call_stateful(make_group)

# Verify the fixture for this test.
assert ctx is not None and ctx['ws'] is ws
assert group is mock_group

# Verify the fixture under test performed the expected actions.
ws.groups.create.assert_called_once()
assert ws.groups.get.call_args_list == [call("an_id"), call("an_id")]
assert ws.groups.list.call_args_list == [
call(attributes="id", filter='id eq "an_id"'),
call(attributes="id", filter='id eq "an_id"'),
]
ws.groups.delete.assert_called_once()
ctx['ws'].groups.delete.assert_called_once()


def test_make_acc_group_no_args():
ctx, group = call_stateful(make_acc_group)
assert ctx is not None
assert group is not None
ctx['acc'].groups.create.assert_called_once()
ctx['acc'].groups.delete.assert_called_once()
# Ensure the wait-for-provisioning logic can complete.
acc = create_autospec(AccountClient)
mock_group = Group(id="an_id")
acc.groups.create.return_value = mock_group
acc.groups.list.return_value = [mock_group]

# Perform the test.
with fixtures(acc=acc):
ctx, group = call_stateful(make_acc_group)

# Verify the fixture for this test.
assert ctx is not None and ctx['acc'] is acc
assert group is mock_group

# Verify the fixture under test performed the expected actions.
acc.groups.create.assert_called_once()
assert acc.groups.get.call_args_list == [call("an_id"), call("an_id")]
assert acc.groups.list.call_args_list == [
call(attributes="id", filter='id eq "an_id"'),
call(attributes="id", filter='id eq "an_id"'),
]
acc.groups.delete.assert_called_once()


@pytest.mark.parametrize(
"make_group_fixture, client_fixture_name, client_class",
[(make_group, "ws", WorkspaceClient), (make_acc_group, "acc", AccountClient)],
)
def test_make_group_deprecated_arg(make_group_fixture, client_fixture_name, client_class) -> None:
# Ensure the wait-for-provisioning logic can complete.
client = create_autospec(client_class)
mock_group = Group(id="an_id")
client.groups.create.return_value = mock_group
client.groups.list.return_value = [mock_group]

with fixtures(**{client_fixture_name: client}), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Verify the fixture that we're testing.
call_stateful(make_group_fixture, wait_for_provisioning=True)

# Check that the expected warning was emitted and attributed to the caller.
(the_warning,) = w
assert issubclass(the_warning.category, DeprecationWarning)
assert "wait_for_provisioning when making a group is deprecated" in str(the_warning.message)
assert the_warning.filename == sys.modules[call_stateful.__module__].__file__
Loading