Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions serde/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
import typing_inspect
from typing_extensions import TypeGuard, ParamSpec, TypeAliasType

from .sqlalchemy import is_sqlalchemy_inspectable
# Lazy SQLAlchemy imports to improve startup time
Comment thread
yukinarit marked this conversation as resolved.


# Lazy SQLAlchemy import wrapper to improve startup time
def _is_sqlalchemy_inspectable(subject: Any) -> bool:
from .sqlalchemy import is_sqlalchemy_inspectable

Comment thread
yukinarit marked this conversation as resolved.
return is_sqlalchemy_inspectable(subject)


def get_np_origin(tp: type[Any]) -> Optional[Any]:
Expand Down Expand Up @@ -294,7 +301,7 @@ def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ig
real_type = resolved_hints.get(f.name)
if real_type is not None:
f.type = real_type
if is_generic(real_type) and is_sqlalchemy_inspectable(cls):
if is_generic(real_type) and _is_sqlalchemy_inspectable(cls):
f.type = get_args(real_type)[0]

return iter(raw_fields)
Expand Down
1 change: 1 addition & 0 deletions serde/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import re
import casefy
from dataclasses import dataclass

from beartype.door import is_bearable
from collections.abc import Mapping, Sequence, Callable
from typing import (
Expand Down
47 changes: 34 additions & 13 deletions serde/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import typing
import jinja2
from collections.abc import Callable, Sequence, Iterable

from beartype import beartype, BeartypeConf
from beartype.door import is_bearable
from beartype.roar import BeartypeCallHintParamViolation
Comment thread
yukinarit marked this conversation as resolved.
Expand Down Expand Up @@ -86,19 +87,31 @@
raise_unsupported_type,
union_func_name,
)
from .numpy import (
deserialize_numpy_array,
deserialize_numpy_scalar,
deserialize_numpy_array_direct,
deserialize_numpy_jaxtyping_array,
is_numpy_array,
is_numpy_jaxtyping,
is_numpy_scalar,
)

# Lazy numpy imports to improve startup time

__all__ = ["deserialize", "is_deserializable", "from_dict", "from_tuple"]


# Lazy numpy import wrappers to improve startup time
def _is_numpy_array(typ: Any) -> bool:
from .numpy import is_numpy_array

return is_numpy_array(typ)


def _is_numpy_scalar(typ: Any) -> bool:
from .numpy import is_numpy_scalar

return is_numpy_scalar(typ)


def _is_numpy_jaxtyping(typ: Any) -> bool:
from .numpy import is_numpy_jaxtyping

return is_numpy_jaxtyping(typ)


DeserializeFunc = Callable[[type[Any], Any], Any]
""" Interface of Custom deserialize function. """

Expand Down Expand Up @@ -490,7 +503,9 @@ def deserializable_to_obj(cls: type[T]) -> T:
res = {
thisfunc(type_args(c)[0], k): thisfunc(type_args(c)[1], v) for k, v in o.items()
}
elif is_numpy_array(c):
elif _is_numpy_array(c):
from .numpy import deserialize_numpy_array_direct

res = deserialize_numpy_array_direct(c, o)
elif is_datetime(c):
res = c.fromisoformat(o)
Expand Down Expand Up @@ -752,13 +767,19 @@ def render(self, arg: DeField[Any]) -> str:
res = self.tuple(arg)
elif is_enum(arg.type):
res = self.enum(arg)
elif is_numpy_scalar(arg.type):
elif _is_numpy_scalar(arg.type):
from .numpy import deserialize_numpy_scalar

self.import_numpy = True
res = deserialize_numpy_scalar(arg)
elif is_numpy_array(arg.type):
elif _is_numpy_array(arg.type):
from .numpy import deserialize_numpy_array

self.import_numpy = True
res = deserialize_numpy_array(arg)
elif is_numpy_jaxtyping(arg.type):
elif _is_numpy_jaxtyping(arg.type):
from .numpy import deserialize_numpy_jaxtyping_array

self.import_numpy = True
res = deserialize_numpy_jaxtyping_array(arg)
elif is_union(arg.type):
Expand Down
5 changes: 4 additions & 1 deletion serde/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from .compat import T
from .de import Deserializer, from_dict
from .se import Serializer, to_dict
from .numpy import encode_numpy

# Lazy numpy imports to improve startup time

try: # pragma: no cover
import orjson
Expand All @@ -25,6 +26,8 @@ def json_loads(s: Union[str, bytes], **opts: Any) -> Any:

def json_dumps(obj: Any, **opts: Any) -> str:
if "default" not in opts:
from .numpy import encode_numpy

opts["default"] = encode_numpy
# compact output
ensure_ascii = opts.pop("ensure_ascii", False)
Expand Down
5 changes: 4 additions & 1 deletion serde/msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from .compat import T
from .compat import SerdeError
from .de import Deserializer, from_dict, from_tuple
from .numpy import encode_numpy

# Lazy numpy imports to improve startup time
from .se import Serializer, to_dict, to_tuple

__all__ = ["from_msgpack", "to_msgpack"]
Expand All @@ -22,6 +23,8 @@ def serialize(
cls, obj: Any, use_bin_type: bool = True, ext_type_code: Optional[int] = None, **opts: Any
) -> bytes:
if "default" not in opts:
from .numpy import encode_numpy

opts["default"] = encode_numpy
if ext_type_code is not None:
obj_bytes = msgpack.packb(obj, use_bin_type=use_bin_type, **opts)
Expand Down
53 changes: 40 additions & 13 deletions serde/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dataclasses import dataclass, is_dataclass
from typing import TypeVar, Literal, Generic, Optional, Any, Union
from collections.abc import Callable, Iterable, Iterator

from beartype import beartype, BeartypeConf
from beartype.door import is_bearable
Comment thread
yukinarit marked this conversation as resolved.
from typing_extensions import dataclass_transform
Expand Down Expand Up @@ -75,19 +76,37 @@
union_func_name,
GLOBAL_CLASS_SERIALIZER,
)
from .numpy import (
is_numpy_array,
is_numpy_jaxtyping,
is_numpy_datetime,
is_numpy_scalar,
serialize_numpy_array,
serialize_numpy_datetime,
serialize_numpy_scalar,
)

# Lazy numpy imports to improve startup time

__all__ = ["serialize", "is_serializable", "to_dict", "to_tuple"]


# Lazy numpy import wrappers to improve startup time
def _is_numpy_array(typ: Any) -> bool:
from .numpy import is_numpy_array

return is_numpy_array(typ)


def _is_numpy_scalar(typ: Any) -> bool:
from .numpy import is_numpy_scalar

return is_numpy_scalar(typ)


def _is_numpy_jaxtyping(typ: Any) -> bool:
from .numpy import is_numpy_jaxtyping

return is_numpy_jaxtyping(typ)


def _is_numpy_datetime(typ: Any) -> bool:
from .numpy import is_numpy_datetime

return is_numpy_datetime(typ)


SerializeFunc = Callable[[type[Any], Any], Any]
""" Interface of Custom serialize function. """

Expand Down Expand Up @@ -775,13 +794,21 @@ def render(self, arg: SeField[Any]) -> str:
res = self.tuple(arg)
elif is_enum(arg.type):
res = self.enum(arg)
elif is_numpy_datetime(arg.type):
elif _is_numpy_datetime(arg.type):
from .numpy import serialize_numpy_datetime

res = serialize_numpy_datetime(arg)
elif is_numpy_scalar(arg.type):
elif _is_numpy_scalar(arg.type):
from .numpy import serialize_numpy_scalar

res = serialize_numpy_scalar(arg)
elif is_numpy_array(arg.type):
elif _is_numpy_array(arg.type):
from .numpy import serialize_numpy_array

res = serialize_numpy_array(arg)
elif is_numpy_jaxtyping(arg.type):
elif _is_numpy_jaxtyping(arg.type):
from .numpy import serialize_numpy_array

res = serialize_numpy_array(arg)
elif is_primitive(arg.type):
res = self.primitive(arg)
Expand Down