Skip to content

Commit ba4ad25

Browse files
committed
Implement lazy imports for numpy and sqlalchemy to improve startup time
1 parent d0e3c1c commit ba4ad25

6 files changed

Lines changed: 92 additions & 30 deletions

File tree

serde/compat.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@
2222
import typing_inspect
2323
from typing_extensions import TypeGuard, ParamSpec, TypeAliasType
2424

25-
from .sqlalchemy import is_sqlalchemy_inspectable
25+
# Lazy SQLAlchemy imports to improve startup time
26+
27+
28+
# Lazy SQLAlchemy import wrapper to improve startup time
29+
def _is_sqlalchemy_inspectable(subject: Any) -> bool:
30+
from .sqlalchemy import is_sqlalchemy_inspectable
31+
32+
return is_sqlalchemy_inspectable(subject)
2633

2734

2835
def get_np_origin(tp: type[Any]) -> Optional[Any]:
@@ -294,7 +301,7 @@ def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ig
294301
real_type = resolved_hints.get(f.name)
295302
if real_type is not None:
296303
f.type = real_type
297-
if is_generic(real_type) and is_sqlalchemy_inspectable(cls):
304+
if is_generic(real_type) and _is_sqlalchemy_inspectable(cls):
298305
f.type = get_args(real_type)[0]
299306

300307
return iter(raw_fields)

serde/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import re
1212
import casefy
1313
from dataclasses import dataclass
14+
1415
from beartype.door import is_bearable
1516
from collections.abc import Mapping, Sequence, Callable
1617
from typing import (

serde/de.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import typing
1313
import jinja2
1414
from collections.abc import Callable, Sequence, Iterable
15+
1516
from beartype import beartype, BeartypeConf
1617
from beartype.door import is_bearable
1718
from beartype.roar import BeartypeCallHintParamViolation
@@ -86,19 +87,31 @@
8687
raise_unsupported_type,
8788
union_func_name,
8889
)
89-
from .numpy import (
90-
deserialize_numpy_array,
91-
deserialize_numpy_scalar,
92-
deserialize_numpy_array_direct,
93-
deserialize_numpy_jaxtyping_array,
94-
is_numpy_array,
95-
is_numpy_jaxtyping,
96-
is_numpy_scalar,
97-
)
90+
91+
# Lazy numpy imports to improve startup time
9892

9993
__all__ = ["deserialize", "is_deserializable", "from_dict", "from_tuple"]
10094

10195

96+
# Lazy numpy import wrappers to improve startup time
97+
def _is_numpy_array(typ: Any) -> bool:
98+
from .numpy import is_numpy_array
99+
100+
return is_numpy_array(typ)
101+
102+
103+
def _is_numpy_scalar(typ: Any) -> bool:
104+
from .numpy import is_numpy_scalar
105+
106+
return is_numpy_scalar(typ)
107+
108+
109+
def _is_numpy_jaxtyping(typ: Any) -> bool:
110+
from .numpy import is_numpy_jaxtyping
111+
112+
return is_numpy_jaxtyping(typ)
113+
114+
102115
DeserializeFunc = Callable[[type[Any], Any], Any]
103116
""" Interface of Custom deserialize function. """
104117

@@ -490,7 +503,9 @@ def deserializable_to_obj(cls: type[T]) -> T:
490503
res = {
491504
thisfunc(type_args(c)[0], k): thisfunc(type_args(c)[1], v) for k, v in o.items()
492505
}
493-
elif is_numpy_array(c):
506+
elif _is_numpy_array(c):
507+
from .numpy import deserialize_numpy_array_direct
508+
494509
res = deserialize_numpy_array_direct(c, o)
495510
elif is_datetime(c):
496511
res = c.fromisoformat(o)
@@ -752,13 +767,19 @@ def render(self, arg: DeField[Any]) -> str:
752767
res = self.tuple(arg)
753768
elif is_enum(arg.type):
754769
res = self.enum(arg)
755-
elif is_numpy_scalar(arg.type):
770+
elif _is_numpy_scalar(arg.type):
771+
from .numpy import deserialize_numpy_scalar
772+
756773
self.import_numpy = True
757774
res = deserialize_numpy_scalar(arg)
758-
elif is_numpy_array(arg.type):
775+
elif _is_numpy_array(arg.type):
776+
from .numpy import deserialize_numpy_array
777+
759778
self.import_numpy = True
760779
res = deserialize_numpy_array(arg)
761-
elif is_numpy_jaxtyping(arg.type):
780+
elif _is_numpy_jaxtyping(arg.type):
781+
from .numpy import deserialize_numpy_jaxtyping_array
782+
762783
self.import_numpy = True
763784
res = deserialize_numpy_jaxtyping_array(arg)
764785
elif is_union(arg.type):

serde/json.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from .compat import T
88
from .de import Deserializer, from_dict
99
from .se import Serializer, to_dict
10-
from .numpy import encode_numpy
10+
11+
# Lazy numpy imports to improve startup time
1112

1213
try: # pragma: no cover
1314
import orjson
@@ -25,6 +26,8 @@ def json_loads(s: Union[str, bytes], **opts: Any) -> Any:
2526

2627
def json_dumps(obj: Any, **opts: Any) -> str:
2728
if "default" not in opts:
29+
from .numpy import encode_numpy
30+
2831
opts["default"] = encode_numpy
2932
# compact output
3033
ensure_ascii = opts.pop("ensure_ascii", False)

serde/msgpack.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from .compat import T
1111
from .compat import SerdeError
1212
from .de import Deserializer, from_dict, from_tuple
13-
from .numpy import encode_numpy
13+
14+
# Lazy numpy imports to improve startup time
1415
from .se import Serializer, to_dict, to_tuple
1516

1617
__all__ = ["from_msgpack", "to_msgpack"]
@@ -22,6 +23,8 @@ def serialize(
2223
cls, obj: Any, use_bin_type: bool = True, ext_type_code: Optional[int] = None, **opts: Any
2324
) -> bytes:
2425
if "default" not in opts:
26+
from .numpy import encode_numpy
27+
2528
opts["default"] = encode_numpy
2629
if ext_type_code is not None:
2730
obj_bytes = msgpack.packb(obj, use_bin_type=use_bin_type, **opts)

serde/se.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from dataclasses import dataclass, is_dataclass
1515
from typing import TypeVar, Literal, Generic, Optional, Any, Union
1616
from collections.abc import Callable, Iterable, Iterator
17+
1718
from beartype import beartype, BeartypeConf
1819
from beartype.door import is_bearable
1920
from typing_extensions import dataclass_transform
@@ -75,19 +76,37 @@
7576
union_func_name,
7677
GLOBAL_CLASS_SERIALIZER,
7778
)
78-
from .numpy import (
79-
is_numpy_array,
80-
is_numpy_jaxtyping,
81-
is_numpy_datetime,
82-
is_numpy_scalar,
83-
serialize_numpy_array,
84-
serialize_numpy_datetime,
85-
serialize_numpy_scalar,
86-
)
79+
80+
# Lazy numpy imports to improve startup time
8781

8882
__all__ = ["serialize", "is_serializable", "to_dict", "to_tuple"]
8983

9084

85+
# Lazy numpy import wrappers to improve startup time
86+
def _is_numpy_array(typ: Any) -> bool:
87+
from .numpy import is_numpy_array
88+
89+
return is_numpy_array(typ)
90+
91+
92+
def _is_numpy_scalar(typ: Any) -> bool:
93+
from .numpy import is_numpy_scalar
94+
95+
return is_numpy_scalar(typ)
96+
97+
98+
def _is_numpy_jaxtyping(typ: Any) -> bool:
99+
from .numpy import is_numpy_jaxtyping
100+
101+
return is_numpy_jaxtyping(typ)
102+
103+
104+
def _is_numpy_datetime(typ: Any) -> bool:
105+
from .numpy import is_numpy_datetime
106+
107+
return is_numpy_datetime(typ)
108+
109+
91110
SerializeFunc = Callable[[type[Any], Any], Any]
92111
""" Interface of Custom serialize function. """
93112

@@ -775,13 +794,21 @@ def render(self, arg: SeField[Any]) -> str:
775794
res = self.tuple(arg)
776795
elif is_enum(arg.type):
777796
res = self.enum(arg)
778-
elif is_numpy_datetime(arg.type):
797+
elif _is_numpy_datetime(arg.type):
798+
from .numpy import serialize_numpy_datetime
799+
779800
res = serialize_numpy_datetime(arg)
780-
elif is_numpy_scalar(arg.type):
801+
elif _is_numpy_scalar(arg.type):
802+
from .numpy import serialize_numpy_scalar
803+
781804
res = serialize_numpy_scalar(arg)
782-
elif is_numpy_array(arg.type):
805+
elif _is_numpy_array(arg.type):
806+
from .numpy import serialize_numpy_array
807+
783808
res = serialize_numpy_array(arg)
784-
elif is_numpy_jaxtyping(arg.type):
809+
elif _is_numpy_jaxtyping(arg.type):
810+
from .numpy import serialize_numpy_array
811+
785812
res = serialize_numpy_array(arg)
786813
elif is_primitive(arg.type):
787814
res = self.primitive(arg)

0 commit comments

Comments
 (0)