|
14 | 14 | from dataclasses import dataclass, is_dataclass |
15 | 15 | from typing import TypeVar, Literal, Generic, Optional, Any, Union |
16 | 16 | from collections.abc import Callable, Iterable, Iterator |
| 17 | + |
17 | 18 | from beartype import beartype, BeartypeConf |
18 | 19 | from beartype.door import is_bearable |
19 | 20 | from typing_extensions import dataclass_transform |
|
75 | 76 | union_func_name, |
76 | 77 | GLOBAL_CLASS_SERIALIZER, |
77 | 78 | ) |
78 | | -from .numpy import ( |
79 | | - is_numpy_array, |
80 | | - is_numpy_jaxtyping, |
81 | | - is_numpy_datetime, |
82 | | - is_numpy_scalar, |
83 | | - serialize_numpy_array, |
84 | | - serialize_numpy_datetime, |
85 | | - serialize_numpy_scalar, |
86 | | -) |
| 79 | + |
| 80 | +# Lazy numpy imports to improve startup time |
87 | 81 |
|
88 | 82 | __all__ = ["serialize", "is_serializable", "to_dict", "to_tuple"] |
89 | 83 |
|
90 | 84 |
|
| 85 | +# Lazy numpy import wrappers to improve startup time |
| 86 | +def _is_numpy_array(typ: Any) -> bool: |
| 87 | + from .numpy import is_numpy_array |
| 88 | + |
| 89 | + return is_numpy_array(typ) |
| 90 | + |
| 91 | + |
| 92 | +def _is_numpy_scalar(typ: Any) -> bool: |
| 93 | + from .numpy import is_numpy_scalar |
| 94 | + |
| 95 | + return is_numpy_scalar(typ) |
| 96 | + |
| 97 | + |
| 98 | +def _is_numpy_jaxtyping(typ: Any) -> bool: |
| 99 | + from .numpy import is_numpy_jaxtyping |
| 100 | + |
| 101 | + return is_numpy_jaxtyping(typ) |
| 102 | + |
| 103 | + |
| 104 | +def _is_numpy_datetime(typ: Any) -> bool: |
| 105 | + from .numpy import is_numpy_datetime |
| 106 | + |
| 107 | + return is_numpy_datetime(typ) |
| 108 | + |
| 109 | + |
91 | 110 | SerializeFunc = Callable[[type[Any], Any], Any] |
92 | 111 | """ Interface of Custom serialize function. """ |
93 | 112 |
|
@@ -775,13 +794,21 @@ def render(self, arg: SeField[Any]) -> str: |
775 | 794 | res = self.tuple(arg) |
776 | 795 | elif is_enum(arg.type): |
777 | 796 | res = self.enum(arg) |
778 | | - elif is_numpy_datetime(arg.type): |
| 797 | + elif _is_numpy_datetime(arg.type): |
| 798 | + from .numpy import serialize_numpy_datetime |
| 799 | + |
779 | 800 | res = serialize_numpy_datetime(arg) |
780 | | - elif is_numpy_scalar(arg.type): |
| 801 | + elif _is_numpy_scalar(arg.type): |
| 802 | + from .numpy import serialize_numpy_scalar |
| 803 | + |
781 | 804 | res = serialize_numpy_scalar(arg) |
782 | | - elif is_numpy_array(arg.type): |
| 805 | + elif _is_numpy_array(arg.type): |
| 806 | + from .numpy import serialize_numpy_array |
| 807 | + |
783 | 808 | res = serialize_numpy_array(arg) |
784 | | - elif is_numpy_jaxtyping(arg.type): |
| 809 | + elif _is_numpy_jaxtyping(arg.type): |
| 810 | + from .numpy import serialize_numpy_array |
| 811 | + |
785 | 812 | res = serialize_numpy_array(arg) |
786 | 813 | elif is_primitive(arg.type): |
787 | 814 | res = self.primitive(arg) |
|
0 commit comments