Skip to content

feat: support union operator on BasePermission #3315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Release type: minor

Permissions can now be combined with the `|` operator to require the user pass _either_ of the `has_permission` checks:

```python
import strawberry
from strawberry.permission import PermissionExtension


@strawberry.type
class User:
@strawberry.field(
extensions=[
PermissionExtension(
# require auth AND (node is current user OR current user is staff)
permissions=[IsAuthenticated(), IsExactUser() | IsStaff()]
)
]
)
def ssn(self) -> str:
return "555-55-5555"
```
23 changes: 23 additions & 0 deletions docs/guides/permissions.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,29 @@ To customize the error handling, the `on_unauthorized` method on
the `BasePermission` class can be used. Further changes can be implemented by
subclassing the `PermissionExtension` class.

## `|` operator support

Permissions can be combined with the `|` operator to require the user pass _either_ of the `has_permission` checks:

```python
import strawberry
from strawberry.permission import PermissionExtension


@strawberry.type
class User:
@strawberry.field(
extensions=[
PermissionExtension(
# require auth AND (node is current user OR current user is staff)
permissions=[IsAuthenticated(), IsExactUser() | IsStaff()]
)
]
)
def ssn(self) -> str:
return "555-55-5555"
```

## Schema Directives

Permissions will automatically be added as schema directives to the schema. This
Expand Down
59 changes: 59 additions & 0 deletions strawberry/permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)

from asgiref.sync import async_to_sync

from strawberry.exceptions import StrawberryGraphQLError
from strawberry.exceptions.permission_fail_silently_requires_optional import (
PermissionFailSilentlyRequiresOptionalError,
Expand Down Expand Up @@ -47,6 +51,9 @@ class BasePermission(abc.ABC):

_schema_directive: Optional[object] = None

def __or__(self, other: BasePermission):
return OrPermission((self, other))

@abc.abstractmethod
def has_permission(
self, source: Any, info: Info, **kwargs: Any
Expand Down Expand Up @@ -89,6 +96,58 @@ class AutoDirective:
return self._schema_directive


class OrPermission(BasePermission):
failed_permission: Optional[BasePermission] = None

def __init__(self, permissions: Tuple[BasePermission, ...]):
if not permissions:
raise ValueError("At least one permission is required")

self.permissions = permissions

messages = [p.message for p in permissions if p.message]
if messages:
self.message = ", or ".join(messages)

def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
for permission in self.permissions:
has_permission = (
async_to_sync(permission.has_permission)
if iscoroutinefunction(permission.has_permission)
else permission.has_permission
)
if has_permission(source, info, **kwargs):
return True

self.failed_permission = permission
self.message = permission.message
return False

@property
def on_unauthorized(self) -> Callable[[], None]:
if not self.failed_permission:
return lambda: None

return self.failed_permission.on_unauthorized

@property
def schema_directive(self) -> object:
if not self._schema_directive:
directive_name = "Or".join([p.__class__.__name__ for p in self.permissions])

class AutoDirective:
__strawberry_directive__ = StrawberrySchemaDirective(
directive_name,
directive_name,
[Location.FIELD_DEFINITION],
[],
)

self._schema_directive = AutoDirective()

return self._schema_directive


class PermissionExtension(FieldExtension):
"""
Handles permissions for a field
Expand Down
35 changes: 34 additions & 1 deletion tests/schema/test_permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def name(self) -> User: # pragma: no cover
strawberry.Schema(query=Query)


def test_permission_directives_added():
def test_base_permission_directives_added():
class IsAuthorized(BasePermission):
message = "User is not authorized"

Expand All @@ -537,6 +537,39 @@ def name(self) -> str: # pragma: no cover
assert print_schema(schema) == textwrap.dedent(expected_output).strip()


def test_or_permission_directives_added():
class AllowedPermission(BasePermission):
message = "Allowed"

def has_permission(self, source, info, **kwargs: typing.Any):
return True

class DeniedPermission(BasePermission):
message = "Denied"

def has_permission(self, source, info, **kwargs: typing.Any):
return False

@strawberry.type
class Query:
@strawberry.field(
extensions=[PermissionExtension([AllowedPermission() | DeniedPermission()])]
)
def name(self) -> str: # pragma: no cover
return "ABC"

schema = strawberry.Schema(query=Query)

expected_output = """
directive @allowedPermissionOrDeniedPermission on FIELD_DEFINITION

type Query {
name: String! @allowedPermissionOrDeniedPermission
}
"""
assert print_schema(schema) == textwrap.dedent(expected_output).strip()


def test_permission_directives_not_added_on_field():
class IsAuthorized(BasePermission):
message = "User is not authorized"
Expand Down
74 changes: 74 additions & 0 deletions tests/test_permission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Any

from strawberry.permission import BasePermission
from strawberry.types import Info


def test_or_permission_condition():
class AllowedPermission(BasePermission):
message = "Allowed"

def has_permission(self, source: Any, info: Info, **kwargs: Any):
return True

class DeniedPermission(BasePermission):
message = "Denied"

def has_permission(self, source: Any, info: Info, **kwargs: Any):
return False

class IsAuthenticated(BasePermission):
message = "User is not authenticated"

def has_permission(self, source: Any, info: Info, **kwargs: Any):
return False

# cases with all `has_permission(...) == False` should deny on the last
denied_permission = DeniedPermission()
not_a_or_not_b = IsAuthenticated() | denied_permission
assert not_a_or_not_b.has_permission(None, None) is False
assert not_a_or_not_b.message == denied_permission.message
assert not_a_or_not_b.on_unauthorized == denied_permission.on_unauthorized

# cases with any true should allow
not_a_or_b = DeniedPermission() | AllowedPermission()
a_or_not_b = AllowedPermission() | DeniedPermission()
a_or_b = AllowedPermission() | AllowedPermission()
assert not_a_or_b.has_permission(None, None) is True
assert a_or_not_b.has_permission(None, None) is True
assert a_or_b.has_permission(None, None) is True


def test_or_permission_async_to_sync():
class AllowedPermission(BasePermission):
message = "Allowed"

def has_permission(self, source: Any, info: Info, **kwargs: Any):
return True

class DeniedPermission(BasePermission):
message = "Denied"

async def has_permission(self, source: Any, info: Info, **kwargs: Any):
return False

class IsAuthenticated(BasePermission):
message = "User is not authenticated"

def has_permission(self, source: Any, info: Info, **kwargs: Any):
return False

# cases with all `has_permission(...) == False` should deny on the last
denied_permission = DeniedPermission()
not_a_or_not_b = IsAuthenticated() | denied_permission
assert not_a_or_not_b.has_permission(None, None) is False
assert not_a_or_not_b.message == denied_permission.message
assert not_a_or_not_b.on_unauthorized == denied_permission.on_unauthorized

# cases with any true should allow
not_a_or_b = DeniedPermission() | AllowedPermission()
a_or_not_b = AllowedPermission() | DeniedPermission()
a_or_b = AllowedPermission() | AllowedPermission()
assert not_a_or_b.has_permission(None, None) is True
assert a_or_not_b.has_permission(None, None) is True
assert a_or_b.has_permission(None, None) is True