19
19
import sys
20
20
import sphinx_rtd_theme # noqa
21
21
import warnings
22
- from typing import ForwardRef
22
+
23
+ import jaxtyping
23
24
24
25
25
26
def read (* names , ** kwargs ):
@@ -112,7 +113,8 @@ def find_version(*file_paths):
112
113
intersphinx_mapping = {
113
114
"python" : ("https://docs.python.org/3/" , None ),
114
115
"torch" : ("https://pytorch.org/docs/stable/" , None ),
115
- "linear_operator" : ("https://linear-operator.readthedocs.io/en/stable/" , None ),
116
+ "linear_operator" : ("https://linear-operator.readthedocs.io/en/stable/" , "linear_operator_objects.inv" ),
117
+ # The local mapping here is temporary until we get a new release of linear_operator
116
118
}
117
119
118
120
# Disable docstring inheritance
@@ -237,41 +239,81 @@ def find_version(*file_paths):
237
239
]
238
240
239
241
240
- # -- Function to format typehints ----------------------------------------------
242
+ # -- Functions to format typehints ----------------------------------------------
241
243
# Adapted from
242
244
# https://github.com/cornellius-gp/linear_operator/blob/2b33b9f83b45f0cb8cb3490fc5f254cc59393c25/docs/source/conf.py
245
+
246
+
247
+ # Helper function
248
+ # Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
249
+ # For external classes, the format will be e.g. "torch.Tensor"
250
+ # For any internal class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
251
+ def _convert_internal_and_external_class_to_strings (annotation ):
252
+ module = annotation .__module__ + "."
253
+ if module .split ("." )[0 ] == "gpytorch" :
254
+ module = "~" + module
255
+ elif module == "torch." :
256
+ module = "~torch."
257
+ elif module == "linear_operator.operators._linear_operator." :
258
+ module = "~linear_operator."
259
+ elif module == "builtins." :
260
+ module = ""
261
+ res = f"{ module } { annotation .__name__ } "
262
+ return res
263
+
264
+
265
+ # Convert jaxtyping dimensions into strings
266
+ def _dim_to_str (dim ):
267
+ if isinstance (dim , jaxtyping ._array_types ._NamedVariadicDim ):
268
+ return "..."
269
+ elif isinstance (dim , jaxtyping ._array_types ._FixedDim ):
270
+ res = str (dim .size )
271
+ if dim .broadcastable :
272
+ res = "#" + res
273
+ return res
274
+ elif isinstance (dim , jaxtyping ._array_types ._SymbolicDim ):
275
+ expr = dim .elem
276
+ return f"({ expr } )"
277
+ elif "jaxtyping" not in str (dim .__class__ ): # Probably the case that we have an ellipsis
278
+ return "..."
279
+ else :
280
+ res = str (dim .name )
281
+ if dim .broadcastable :
282
+ res = "#" + res
283
+ return res
284
+
285
+
286
+ # Function to format type hints
243
287
def _process (annotation , config ):
244
288
"""
245
289
A function to convert a type/rtype typehint annotation into a :type:/:rtype: string.
246
290
This function is a bit hacky, and specific to the type annotations we use most frequently.
291
+
247
292
This function is recursive.
248
293
"""
249
294
# Simple/base case: any string annotation is ready to go
250
295
if type (annotation ) == str :
251
296
return annotation
252
297
298
+ # Jaxtyping: shaped tensors or linear operator
299
+ elif hasattr (annotation , "__module__" ) and "jaxtyping" == annotation .__module__ :
300
+ cls_annotation = _convert_internal_and_external_class_to_strings (annotation .array_type )
301
+ shape = " x " .join ([_dim_to_str (dim ) for dim in annotation .dims ])
302
+ return f"{ cls_annotation } ({ shape } )"
303
+
253
304
# Convert Ellipsis into "..."
254
305
elif annotation == Ellipsis :
255
306
return "..."
256
307
257
- # Convert any class (i.e. torch.Tensor, LinearOperator, gpytorch, etc.) into appropriate strings
258
- # For external classes, the format will be e.g. "torch.Tensor"
259
- # For any linear_operator class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
260
- # For any internal class, the format will be e.g. "~gpytorch.kernels.RBFKernel"
308
+ # Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
261
309
elif hasattr (annotation , "__name__" ):
262
- module = annotation .__module__ + "."
263
- if module .split ("." )[0 ] == "linear_operator" :
264
- if annotation .__name__ .endswith ("LinearOperator" ):
265
- module = "~linear_operator."
266
- elif annotation .__name__ .endswith ("LinearOperator" ):
267
- module = "~linear_operator.operators."
268
- else :
269
- module = "~" + module
270
- elif module .split ("." )[0 ] == "gpytorch" :
271
- module = "~" + module
272
- elif module == "builtins." :
273
- module = ""
274
- res = f"{ module } { annotation .__name__ } "
310
+ res = _convert_internal_and_external_class_to_strings (annotation )
311
+
312
+ elif str (annotation ).startswith ("typing.Callable" ):
313
+ if len (annotation .__args__ ) == 2 :
314
+ res = f"Callable[{ _process (annotation .__args__ [0 ], config )} -> { _process (annotation .__args__ [1 ], config )} ]"
315
+ else :
316
+ res = "Callable"
275
317
276
318
# Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
277
319
# Also, convert any Optional[*A*] into "*A*, optional"
@@ -291,33 +333,14 @@ def _process(annotation, config):
291
333
args = list (annotation .__args__ )
292
334
res = "(" + ", " .join (_process (arg , config ) for arg in args ) + ")"
293
335
294
- # Convert any List[*A*] into "list(*A*)"
295
- elif str (annotation ).startswith ("typing.List" ):
296
- arg = annotation .__args__ [0 ]
297
- res = "list(" + _process (arg , config ) + ")"
298
-
299
- # Convert any List[*A*] into "list(*A*)"
300
- elif str (annotation ).startswith ("typing.Dict" ):
301
- res = str (annotation )
302
-
303
- # Convert any Iterable[*A*] into "iterable(*A*)"
304
- elif str (annotation ).startswith ("typing.Iterable" ):
305
- arg = annotation .__args__ [0 ]
306
- res = "iterable(" + _process (arg , config ) + ")"
307
-
308
- # Handle "Callable"
309
- elif str (annotation ).startswith ("typing.Callable" ):
310
- res = "callable"
311
-
312
- # Handle "Any"
313
- elif str (annotation ).startswith ("typing.Any" ):
314
- res = ""
336
+ # Convert any List[*A*] or Iterable[*A*] into "[*A*, ...]"
337
+ elif str (annotation ).startswith ("typing.Iterable" ) or str (annotation ).startswith ("typing.List" ):
338
+ arg = list (annotation .__args__ )[0 ]
339
+ res = f"[{ _process (arg , config )} , ...]"
315
340
316
- # Special cases for forward references.
317
- # This is brittle, as it only contains case for a select few forward refs
318
- # All others that aren't caught by this are handled by the default case
319
- elif isinstance (annotation , ForwardRef ):
320
- res = str (annotation .__forward_arg__ )
341
+ # Callable typing annotation
342
+ elif str (annotation ).startswith ("typing." ):
343
+ return str (annotation )[7 :]
321
344
322
345
# For everything we didn't catch: use the simplist string representation
323
346
else :
0 commit comments