Skip to content

Commit 16d7531

Browse files
committed
Update _config.py
1 parent d5c338e commit 16d7531

1 file changed

Lines changed: 85 additions & 67 deletions

File tree

deeptrack/backend/_config.py

Lines changed: 85 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,26 @@
11
from __future__ import annotations
22

3-
__all__ = ["config"]
4-
53
import importlib
6-
import warnings
7-
import numpy as np
8-
import array_api_compat as apc
9-
from array_api_compat import numpy as apcnumpy
10-
import array
4+
import sys
5+
import types
6+
from typing import Any, Literal
117

12-
import types, sys
13-
from typing import *
8+
from array_api_compat import numpy as apcnumpy
149
import array_api_strict
1510

1611

12+
__all__ = ["config"]
13+
14+
1715
class _Proxy(types.ModuleType):
18-
"""Object to keep track of the current backend, and forward calls to the correct backend.
16+
"""Keep track of current backend and forward calls to the correct backend.
1917
20-
An instance of this object will be treated as the module `xp`. It acts like a
21-
shallow wrapper around the actual backend (for example `numpy` or `torch`),
22-
and forwards calls to the correct backend.
18+
An instance of this object will be treated as the module `xp`. It acts like
19+
a shallow wrapper around the actual backend (for example `numpy` or
20+
`torch`), and forwards calls to the correct backend.
2321
24-
This is especially useful for array creation functions, to ensure that the correct
25-
array type is created.
22+
This is especially useful for array creation functions, to ensure that the
23+
correct array type is created.
2624
2725
Parameters
2826
----------
@@ -31,31 +29,69 @@ class _Proxy(types.ModuleType):
3129
3230
Attributes
3331
----------
34-
_backend : backend modukle
32+
_backend : backend module
3533
The actual backend module.
3634
__name__ : str
3735
The name of the proxy object.
36+
3837
"""
3938

40-
_backend: array_api_strict # types.ModuleType
39+
_backend: types.ModuleType # array_api_strict
4140
__name__: str
4241

43-
def __init__(self, name: str):
42+
def __init__(self: _Proxy, name: str) -> None:
43+
"""Initialize the _Proxy object.
44+
45+
Parameters
46+
----------
47+
name : str
48+
Name of the proxy object. This is used when printing the object.
49+
50+
"""
51+
4452
self._backend = apcnumpy
4553
self.__name__ = name
4654

47-
def __getattr__(self, attribute):
55+
def __getattr__(self: _Proxy, attribute: str) -> Any:
56+
"""Forward attribute access to the current backend.
57+
58+
Parameters
59+
----------
60+
attribute : str
61+
The attribute name to retrieve from the backend.
62+
63+
Returns
64+
-------
65+
Any
66+
The attribute from the current backend module.
67+
68+
"""
69+
4870
return getattr(self._backend, attribute)
4971

50-
def __dir__(self):
72+
def __dir__(self: _Proxy) -> list[str]:
73+
"""List attributes of the current backend.
74+
75+
Returns
76+
-------
77+
list
78+
List of attribute names in the current backend module.
79+
80+
"""
81+
5182
return dir(self._backend)
5283

5384

54-
# TODO: once intersection types are available, use them here
85+
# TODO: Once intersection types are available, use them here.
86+
# Intersection types are in the pipeline for python 3.13 or 3.14. They let you
87+
# define types that are the combination of many subtypes. So Intersection[A, B]
88+
# would have all the properties of A and B. Here, it would let us define
89+
# exactly the type of xp as Intersection[_Proxy, apcnumpy, apctorch].
90+
5591

5692
# This creates the xp object, which we will use a module.
57-
# We assign the type to be `array_api_strict` to make IDEs see this as if it were
58-
# an array api module, instead of the wrapper _Proxy object.
93+
# We assign the type to be `array_api_strict` to make IDEs see this as if it
94+
# were an array API module, instead of the wrapper _Proxy object.
5995
xp: array_api_strict = _Proxy(__name__ + ".xp")
6096

6197
# This registers the xp object as a module. This should make import statements
@@ -66,35 +102,27 @@ def __dir__(self):
66102
class NullContext:
67103
"""A context manager that does nothing.
68104
69-
Used when no context is needed, but the output expects
70-
a context manager."""
71-
72-
def __enter__(self):
73-
pass
74-
75-
def __exit__(self, *args):
76-
pass
77-
105+
Used when no context is needed, but the output expects a context manager.
78106
79-
class ImageWrapperContext:
80-
"""A context manager that enables the image wrapper.
107+
Examples
108+
--------
109+
>>> with NullContext():
110+
... print("No special context is active.")
81111
82-
Example
83-
-------
84-
>>> pipeline = dt.Value(1)
85-
>>> normal_result = pipeline()
86-
>>> with ImageWrapperContext():
87-
... wrapped_result = pipeline()
88-
...
89-
>>> print(normal_result) # 1
90-
>>> print(wrapped_result) # Image(1)
91112
"""
92113

93-
def __enter__(self, config: Config):
94-
config.enable_image_wrapper()
114+
def __enter__(self: NullContext) -> None:
115+
"""Enter the runtime context related to this object."""
116+
pass
95117

96-
def __exit__(self, *args):
97-
config.disable_image_wrapper()
118+
def __exit__(
119+
self: NullContext,
120+
exc_type: type[BaseException] | None,
121+
exc_value: BaseException | None,
122+
traceback: types.TracebackType | None,
123+
) -> None:
124+
"""Exit the runtime context related to this object."""
125+
pass
98126

99127

100128
class Config:
@@ -108,29 +136,33 @@ def __init__(self):
108136
self.set_backend_numpy()
109137
self.disable_image_wrapper()
110138

111-
112-
def set_device(self, device):
139+
def set_device(self: Config, device) -> None:
113140
"""Set the device to use.
114141
115-
Can be ["cpu", "gpu", "cuda", "mps", torch.device],
116-
but needs to be used with a compatible backend. Can only be "cpu"
117-
if using numpy backend.
142+
Can be "cpu", "gpu", "cuda", "mps", torch.device, but needs to be
143+
used with a compatible backend.
144+
145+
It can only be "cpu" if using NumPy backend.
118146
119147
Parameters
120148
----------
121149
device : str
122150
The device to use.
151+
123152
"""
153+
124154
self.device = device
125155

126-
def get_device(self):
156+
def get_device(self: Config) -> str:
127157
"""Get the device to use.
128158
129159
Returns
130160
-------
131161
str
132162
The device to use.
163+
133164
"""
165+
134166
return self.device
135167

136168
def set_backend_numpy(self):
@@ -176,20 +208,6 @@ def enable_image_wrapper(self):
176208
This will ensure that `Image` objects are used."""
177209
self.image_wrapper = True
178210

179-
def wrapper_enabled_context(self):
180-
"""Return a context manager that enables the image wrapper.
181-
182-
This will ensure that `Image` objects are used.
183-
184-
Examples
185-
--------
186-
>>> pipeline = dt.Value(1)
187-
>>> with config.wrapper_enabled_context():
188-
... result = pipeline()
189-
>>> print(result) # Image(1)
190-
"""
191-
return ImageWrapperContext(self) if not self.image_wrapper else NullContext()
192-
193211
def with_backend(self, backend: Literal["numpy", "torch"]):
194212
"""Return a context manager that changes the backend."""
195213

0 commit comments

Comments
 (0)