Skip to content
53 changes: 53 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
Release type: minor

Remove deprecated `strawberry.scalar(cls, ...)` wrapper pattern and `ScalarWrapper`, deprecated since [0.288.0](https://github.com/strawberry-graphql/strawberry/releases/tag/0.288.0).

You can run `strawberry upgrade replace-scalar-wrappers <path>` to automatically replace built-in scalar wrapper imports.

### Migration guide

**Before (deprecated):**
```python
import strawberry
from datetime import datetime

EpochDateTime = strawberry.scalar(
datetime,
serialize=lambda v: int(v.timestamp()),
parse_value=lambda v: datetime.fromtimestamp(v),
)


@strawberry.type
class Query:
created: EpochDateTime
```

**After:**
```python
import strawberry
from typing import NewType
from datetime import datetime
from strawberry.schema.config import StrawberryConfig

EpochDateTime = NewType("EpochDateTime", datetime)


@strawberry.type
class Query:
created: datetime


schema = strawberry.Schema(
query=Query,
config=StrawberryConfig(
scalar_map={
EpochDateTime: strawberry.scalar(
name="EpochDateTime",
serialize=lambda v: int(v.timestamp()),
parse_value=lambda v: datetime.fromtimestamp(v),
)
}
),
)
```
2 changes: 1 addition & 1 deletion docs/types/schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ schema = strawberry.Schema(Query, types=[Individual, Company])

List of [extensions](/docs/extensions) to add to your Schema.

#### `scalar_overrides: Optional[Dict[object, ScalarWrapper]] = None`
#### `scalar_overrides: Optional[Dict[object, ScalarDefinition]] = None`

Override the implementation of the built in scalars.
[More information](/docs/types/scalars#overriding-built-in-scalars).
Expand Down
22 changes: 13 additions & 9 deletions strawberry/codegen/query_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from strawberry.types.enum import StrawberryEnumDefinition
from strawberry.types.lazy_type import LazyType
from strawberry.types.scalar import ScalarDefinition, ScalarWrapper
from strawberry.types.scalar import ScalarDefinition
from strawberry.types.union import StrawberryUnion
from strawberry.types.unset import UNSET
from strawberry.utils.str_converters import capitalize_first, to_camel_case
Expand Down Expand Up @@ -541,14 +541,18 @@ def _get_field_type(
not isinstance(field_type, StrawberryType)
and field_type in self.schema.schema_converter.scalar_registry
):
field_type = self.schema.schema_converter.scalar_registry[field_type] # type: ignore

if isinstance(field_type, ScalarWrapper):
python_type = field_type.wrap
if hasattr(python_type, "__supertype__"):
python_type = python_type.__supertype__

return self._collect_scalar(field_type._scalar_definition, python_type) # type: ignore
# Store the original Python type (could be a type or NewType)
# before replacing with the ScalarDefinition
original_python_type = field_type
# For NewTypes, get the underlying type for the codegen
if hasattr(original_python_type, "__supertype__"):
python_type = original_python_type.__supertype__
elif isinstance(original_python_type, type):
python_type = original_python_type
else:
python_type = None
field_type = self.schema.schema_converter.scalar_registry[field_type]
return self._collect_scalar(field_type, python_type)

if isinstance(field_type, ScalarDefinition):
return self._collect_scalar(field_type, None)
Expand Down
5 changes: 1 addition & 4 deletions strawberry/exceptions/invalid_union_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
union_definition: StrawberryUnion | None = None,
) -> None:
from strawberry.types.base import StrawberryList
from strawberry.types.scalar import ScalarWrapper

self.union_name = union_name
self.invalid_type = invalid_type
Expand All @@ -37,9 +36,7 @@ def __init__(
# one is our code checking for invalid types, the other is the caller
self.frame = getframeinfo(stack()[2][0])

if isinstance(invalid_type, ScalarWrapper):
type_name = invalid_type.wrap.__name__
elif isinstance(invalid_type, StrawberryList):
if isinstance(invalid_type, StrawberryList):
type_name = "list[...]"
else:
try:
Expand Down
96 changes: 46 additions & 50 deletions strawberry/federation/scalar.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from collections.abc import Callable, Iterable
from typing import (
Any,
NewType,
TypeVar,
overload,
)

from strawberry.types.scalar import ScalarWrapper, _process_scalar
from strawberry.types.scalar import ScalarDefinition

_T = TypeVar("_T", bound=type | NewType)
_T = TypeVar("_T", bound=type)


def identity(x: _T) -> _T: # pragma: no cover
Expand All @@ -18,7 +17,7 @@ def identity(x: _T) -> _T: # pragma: no cover
@overload
def scalar(
*,
name: str | None = None,
name: str,
description: str | None = None,
specified_by_url: str | None = None,
serialize: Callable = identity,
Expand All @@ -30,14 +29,13 @@ def scalar(
policy: list[list[str]] | None = None,
requires_scopes: list[list[str]] | None = None,
tags: Iterable[str] | None = (),
) -> Callable[[_T], _T]: ...
) -> ScalarDefinition: ...


@overload
def scalar(
cls: _T,
*,
name: str | None = None,
name: None = None,
description: str | None = None,
specified_by_url: str | None = None,
serialize: Callable = identity,
Expand All @@ -49,11 +47,10 @@ def scalar(
policy: list[list[str]] | None = None,
requires_scopes: list[list[str]] | None = None,
tags: Iterable[str] | None = (),
) -> _T: ...
) -> Callable[[_T], _T]: ...


def scalar(
cls: _T | None = None,
*,
name: str | None = None,
description: str | None = None,
Expand All @@ -68,11 +65,10 @@ def scalar(
requires_scopes: list[list[str]] | None = None,
tags: Iterable[str] | None = (),
) -> Any:
"""Annotates a class or type as a GraphQL custom scalar.
"""Creates a GraphQL custom scalar definition with federation support.

Args:
cls: The class or type to annotate
name: The GraphQL name of the scalar
name: The GraphQL name of the scalar (required for ScalarDefinition)
description: The description of the scalar
specified_by_url: The URL of the specification
serialize: The function to serialize the scalar
Expand All @@ -86,31 +82,33 @@ def scalar(
tags: The list of tags to add to the @tag directive

Returns:
The decorated class or type
A ScalarDefinition when called with `name`, or a decorator function
when called without `name`.

Example usages:
Example usage:

```python
strawberry.federation.scalar(
datetime.date,
serialize=lambda value: value.isoformat(),
parse_value=datetime.parse_date,
)

Base64Encoded = strawberry.federation.scalar(
NewType("Base64Encoded", bytes),
serialize=base64.b64encode,
parse_value=base64.b64decode,
from typing import NewType
import strawberry
from strawberry.schema.config import StrawberryConfig

# Define the type
Base64 = NewType("Base64", bytes)

# Configure the scalar with federation directives
schema = strawberry.federation.Schema(
query=Query,
config=StrawberryConfig(
scalar_map={
Base64: strawberry.federation.scalar(
name="Base64",
serialize=lambda v: base64.b64encode(v).decode(),
parse_value=lambda v: base64.b64decode(v),
authenticated=True,
)
}
),
)


@strawberry.federation.scalar(
serialize=lambda value: ",".join(value.items),
parse_value=lambda value: CustomList(value.split(",")),
)
class CustomList:
def __init__(self, items):
self.items = items
```
"""
from strawberry.federation.schema_directives import (
Expand All @@ -121,42 +119,40 @@ def __init__(self, items):
Tag,
)

if parse_value is None:
parse_value = cls

directives = list(directives)
all_directives = list(directives)

if authenticated:
directives.append(Authenticated())
all_directives.append(Authenticated())

if inaccessible:
directives.append(Inaccessible())
all_directives.append(Inaccessible())

if policy:
directives.append(Policy(policies=policy))
all_directives.append(Policy(policies=policy))

if requires_scopes:
directives.append(RequiresScopes(scopes=requires_scopes))
all_directives.append(RequiresScopes(scopes=requires_scopes))

if tags:
directives.extend(Tag(name=tag) for tag in tags)
all_directives.extend(Tag(name=tag) for tag in tags)

def wrap(cls: _T) -> ScalarWrapper:
return _process_scalar(
cls,
if name is not None:
return ScalarDefinition(
name=name,
description=description,
specified_by_url=specified_by_url,
serialize=serialize,
parse_value=parse_value,
parse_literal=parse_literal,
directives=directives,
parse_value=parse_value,
directives=tuple(all_directives),
origin=None,
)

if cls is None:
return wrap
# Decorator pattern for type hinting purposes only
def wrap(cls: _T) -> _T:
return cls

return wrap(cls)
return wrap


__all__ = ["scalar"]
9 changes: 3 additions & 6 deletions strawberry/federation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_object_definition,
)
from strawberry.types.info import Info
from strawberry.types.scalar import ScalarDefinition, ScalarWrapper, scalar
from strawberry.types.scalar import ScalarDefinition, scalar
from strawberry.types.union import StrawberryUnion
from strawberry.utils.inspect import get_func_args

Expand Down Expand Up @@ -56,8 +56,7 @@ def __init__(
extensions: Iterable[Union[type["SchemaExtension"], "SchemaExtension"]] = (),
execution_context_class: type["GraphQLExecutionContext"] | None = None,
config: Optional["StrawberryConfig"] = None,
scalar_overrides: dict[object, Union[type, "ScalarWrapper", "ScalarDefinition"]]
| None = None,
scalar_overrides: dict[object, Union[type, "ScalarDefinition"]] | None = None,
schema_directives: Iterable[object] = (),
federation_version: Literal[
"2.0",
Expand All @@ -83,9 +82,7 @@ def __init__(
types = [*types, FederationAny]

# Add federation scalars to scalar_overrides so they can be recognized
federation_scalar_overrides: dict[
object, type | ScalarDefinition | ScalarWrapper
] = {
federation_scalar_overrides: dict[object, type | ScalarDefinition] = {
FederationAny: scalar(
name="_Any", serialize=lambda v: v, parse_value=lambda v: v
),
Expand Down
7 changes: 1 addition & 6 deletions strawberry/federation/union.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from collections.abc import Collection, Iterable
from typing import Any
from collections.abc import Iterable

from strawberry.types.union import StrawberryUnion
from strawberry.types.union import union as base_union


def union(
name: str,
types: Collection[type[Any]] | None = None,
*,
description: str | None = None,
directives: Iterable[object] = (),
Expand All @@ -18,8 +16,6 @@ def union(

Args:
name: The GraphQL name of the Union type.
types: The types that the Union can be.
(Deprecated, use `Annotated[U, strawberry.federation.union("Name")]` instead)
description: The GraphQL description of the Union type.
directives: The directives to attach to the Union type.
inaccessible: Whether the Union type is inaccessible.
Expand Down Expand Up @@ -53,7 +49,6 @@ class B:

return base_union(
name,
types,
description=description,
directives=directives,
)
Expand Down
3 changes: 0 additions & 3 deletions strawberry/printer/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
has_object_definition,
)
from strawberry.types.enum import StrawberryEnumDefinition
from strawberry.types.scalar import ScalarWrapper
from strawberry.types.unset import UNSET

from .ast_from_value import ast_from_value
Expand Down Expand Up @@ -627,8 +626,6 @@ def print_schema(schema: BaseSchema) -> str:
def _name_getter(type_: Any) -> str:
if hasattr(type_, "name"):
return type_.name
if isinstance(type_, ScalarWrapper):
return type_._scalar_definition.name
return type_.__name__

return "\n\n".join(
Expand Down
4 changes: 2 additions & 2 deletions strawberry/scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
if TYPE_CHECKING:
from collections.abc import Mapping

from strawberry.types.scalar import ScalarDefinition, ScalarWrapper
from strawberry.types.scalar import ScalarDefinition


ID = NewType("ID", str)
Expand All @@ -26,7 +26,7 @@

def is_scalar(
annotation: Any,
scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition],
scalar_registry: Mapping[object, ScalarDefinition],
) -> bool:
if annotation in scalar_registry:
return True
Expand Down
Loading
Loading