11from __future__ import annotations
22
3- __all__ = ["config" ]
4-
53import importlib
6- import warnings
7- import numpy as np
8- import array_api_compat as apc
9- from array_api_compat import numpy as apcnumpy
10- import array
4+ import sys
5+ import types
6+ from typing import Any , Literal
117
12- import types , sys
13- from typing import *
8+ from array_api_compat import numpy as apcnumpy
149import array_api_strict
1510
1611
12+ __all__ = ["config" ]
13+
14+
1715class _Proxy (types .ModuleType ):
18- """Object to keep track of the current backend, and forward calls to the correct backend.
16+ """Keep track of current backend and forward calls to the correct backend.
1917
20- An instance of this object will be treated as the module `xp`. It acts like a
21- shallow wrapper around the actual backend (for example `numpy` or `torch`),
22- and forwards calls to the correct backend.
18+ An instance of this object will be treated as the module `xp`. It acts like
19+ a shallow wrapper around the actual backend (for example `numpy` or
20+ `torch`), and forwards calls to the correct backend.
2321
24- This is especially useful for array creation functions, to ensure that the correct
25- array type is created.
22+ This is especially useful for array creation functions, to ensure that the
23+ correct array type is created.
2624
2725 Parameters
2826 ----------
@@ -31,31 +29,69 @@ class _Proxy(types.ModuleType):
3129
3230 Attributes
3331 ----------
34- _backend : backend modukle
32+ _backend : backend module
3533 The actual backend module.
3634 __name__ : str
3735 The name of the proxy object.
36+
3837 """
3938
40- _backend : array_api_strict # types.ModuleType
39+ _backend : types . ModuleType # array_api_strict
4140 __name__ : str
4241
43- def __init__ (self , name : str ):
42+ def __init__ (self : _Proxy , name : str ) -> None :
43+ """Initialize the _Proxy object.
44+
45+ Parameters
46+ ----------
47+ name : str
48+ Name of the proxy object. This is used when printing the object.
49+
50+ """
51+
4452 self ._backend = apcnumpy
4553 self .__name__ = name
4654
47- def __getattr__ (self , attribute ):
55+ def __getattr__ (self : _Proxy , attribute : str ) -> Any :
56+ """Forward attribute access to the current backend.
57+
58+ Parameters
59+ ----------
60+ attribute : str
61+ The attribute name to retrieve from the backend.
62+
63+ Returns
64+ -------
65+ Any
66+ The attribute from the current backend module.
67+
68+ """
69+
4870 return getattr (self ._backend , attribute )
4971
50- def __dir__ (self ):
72+ def __dir__ (self : _Proxy ) -> list [str ]:
73+ """List attributes of the current backend.
74+
75+ Returns
76+ -------
77+ list
78+ List of attribute names in the current backend module.
79+
80+ """
81+
5182 return dir (self ._backend )
5283
5384
54- # TODO: once intersection types are available, use them here
85+ # TODO: Once intersection types are available, use them here.
86+ # Intersection types are in the pipeline for python 3.13 or 3.14. They let you
87+ # define types that are the combination of many subtypes. So Intersection[A, B]
88+ # would have all the properties of A and B. Here, it would let us define
89+ # exactly the type of xp as Intersection[_Proxy, apcnumpy, apctorch].
90+
5591
5692# This creates the xp object, which we will use a module.
57- # We assign the type to be `array_api_strict` to make IDEs see this as if it were
58- # an array api module, instead of the wrapper _Proxy object.
93+ # We assign the type to be `array_api_strict` to make IDEs see this as if it
94+ # were an array API module, instead of the wrapper _Proxy object.
5995xp : array_api_strict = _Proxy (__name__ + ".xp" )
6096
6197# This registers the xp object as a module. This should make import statements
@@ -66,35 +102,27 @@ def __dir__(self):
66102class NullContext :
67103 """A context manager that does nothing.
68104
69- Used when no context is needed, but the output expects
70- a context manager."""
71-
72- def __enter__ (self ):
73- pass
74-
75- def __exit__ (self , * args ):
76- pass
77-
105+ Used when no context is needed, but the output expects a context manager.
78106
79- class ImageWrapperContext :
80- """A context manager that enables the image wrapper.
107+ Examples
108+ --------
109+ >>> with NullContext():
110+ ... print("No special context is active.")
81111
82- Example
83- -------
84- >>> pipeline = dt.Value(1)
85- >>> normal_result = pipeline()
86- >>> with ImageWrapperContext():
87- ... wrapped_result = pipeline()
88- ...
89- >>> print(normal_result) # 1
90- >>> print(wrapped_result) # Image(1)
91112 """
92113
93- def __enter__ (self , config : Config ):
94- config .enable_image_wrapper ()
114+ def __enter__ (self : NullContext ) -> None :
115+ """Enter the runtime context related to this object."""
116+ pass
95117
96- def __exit__ (self , * args ):
97- config .disable_image_wrapper ()
118+ def __exit__ (
119+ self : NullContext ,
120+ exc_type : type [BaseException ] | None ,
121+ exc_value : BaseException | None ,
122+ traceback : types .TracebackType | None ,
123+ ) -> None :
124+ """Exit the runtime context related to this object."""
125+ pass
98126
99127
100128class Config :
@@ -108,29 +136,33 @@ def __init__(self):
108136 self .set_backend_numpy ()
109137 self .disable_image_wrapper ()
110138
111-
112- def set_device (self , device ):
139+ def set_device (self : Config , device ) -> None :
113140 """Set the device to use.
114141
115- Can be ["cpu", "gpu", "cuda", "mps", torch.device],
116- but needs to be used with a compatible backend. Can only be "cpu"
117- if using numpy backend.
142+ Can be "cpu", "gpu", "cuda", "mps", torch.device, but needs to be
143+ used with a compatible backend.
144+
145+ It can only be "cpu" if using NumPy backend.
118146
119147 Parameters
120148 ----------
121149 device : str
122150 The device to use.
151+
123152 """
153+
124154 self .device = device
125155
126- def get_device (self ) :
156+ def get_device (self : Config ) -> str :
127157 """Get the device to use.
128158
129159 Returns
130160 -------
131161 str
132162 The device to use.
163+
133164 """
165+
134166 return self .device
135167
136168 def set_backend_numpy (self ):
@@ -176,20 +208,6 @@ def enable_image_wrapper(self):
176208 This will ensure that `Image` objects are used."""
177209 self .image_wrapper = True
178210
179- def wrapper_enabled_context (self ):
180- """Return a context manager that enables the image wrapper.
181-
182- This will ensure that `Image` objects are used.
183-
184- Examples
185- --------
186- >>> pipeline = dt.Value(1)
187- >>> with config.wrapper_enabled_context():
188- ... result = pipeline()
189- >>> print(result) # Image(1)
190- """
191- return ImageWrapperContext (self ) if not self .image_wrapper else NullContext ()
192-
193211 def with_backend (self , backend : Literal ["numpy" , "torch" ]):
194212 """Return a context manager that changes the backend."""
195213
0 commit comments