Skip to content

Commit 85f62f7

Browse files
committed
Update _config.py
1 parent d3f5b6c commit 85f62f7

1 file changed

Lines changed: 12 additions & 72 deletions

File tree

deeptrack/backend/_config.py

Lines changed: 12 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -246,22 +246,9 @@ def get_float_dtype(
246246
"""
247247

248248
if dtype == "default":
249-
# Robust to "real floating" and "real" as key in default_dtypes()
250-
default_dtypes = self._backend_info.default_dtypes()
251-
if "real floating" in default_dtypes:
252-
return default_dtypes["real floating"]
253-
if "real" in default_dtypes:
254-
return default_dtypes["real"]
255-
raise KeyError(
256-
"No default real floating dtype found in backend. "
257-
"Looked for 'real floating' and 'real' in default_dtypes()."
258-
)
259-
260-
# Support both flat and nested dictionaries in dtypes()
261-
dtypes = self._backend_info.dtypes()
262-
if "real floating" in dtypes:
263-
return dtypes["real floating"][dtype]
264-
return dtypes[dtype]
249+
return self._backend_info.default_dtypes()["real floating"]
250+
251+
return self._backend_info.dtypes(kind="real floating")[dtype]
265252

266253
def get_int_dtype(
267254
self: _Proxy,
@@ -311,24 +298,9 @@ def get_int_dtype(
311298
"""
312299

313300
if dtype == "default":
314-
# Robust to "integer" and "integral" as key in default_dtypes()
315-
default_dtypes = self._backend_info.default_dtypes()
316-
if "integer" in default_dtypes:
317-
return default_dtypes["integer"]
318-
if "integral" in default_dtypes:
319-
return default_dtypes["integral"]
320-
raise KeyError(
321-
"No default integer dtype found in backend. "
322-
"Looked for 'integer' and 'integral' in default_dtypes()."
323-
)
324-
325-
# Support both flat and nested dictionaries in dtypes()
326-
dtypes = self._backend_info.dtypes()
327-
if "integer" in dtypes:
328-
return dtypes["integer"][dtype]
329-
if "integral" in dtypes:
330-
return dtypes["integral"][dtype]
331-
return dtypes[dtype]
301+
return self._backend_info.default_dtypes()["integral"]
302+
303+
return self._backend_info.dtypes(kind="integral")[dtype]
332304

333305
def get_complex_dtype(
334306
self: _Proxy,
@@ -378,26 +350,9 @@ def get_complex_dtype(
378350
"""
379351

380352
if dtype == "default":
381-
# Robust to "complex floating" and "complex" as key in
382-
# default_dtypes()
383-
default_dtypes = self._backend_info.default_dtypes()
384-
if "complex floating" in default_dtypes:
385-
return default_dtypes["complex floating"]
386-
if "complex" in default_dtypes:
387-
return default_dtypes["complex"]
388-
raise KeyError(
389-
"No default complex dtype found in backend. "
390-
"Looked for 'complex floating' and 'complex' in"
391-
" default_dtypes()."
392-
)
393-
394-
# Support both flat and nested dictionaries in dtypes()
395-
dtypes = self._backend_info.dtypes()
396-
if "complex floating" in dtypes:
397-
return dtypes["complex floating"][dtype]
398-
if "complex" in dtypes:
399-
return dtypes["complex"][dtype]
400-
return dtypes[dtype]
353+
return self._backend_info.default_dtypes()["complex floating"]
354+
355+
return self._backend_info.dtypes(kind="complex floating")[dtype]
401356

402357
def get_bool_dtype(
403358
self: _Proxy,
@@ -447,24 +402,9 @@ def get_bool_dtype(
447402
"""
448403

449404
if dtype == "default":
450-
# Robust to "bool" and "boolean" as key in default_dtypes()
451-
default_dtypes = self._backend_info.default_dtypes()
452-
if "bool" in default_dtypes:
453-
return default_dtypes["bool"]
454-
if "boolean" in default_dtypes:
455-
return default_dtypes["boolean"]
456-
raise KeyError(
457-
"No default bool dtype found in backend. "
458-
"Looked for 'bool' and 'boolean' in default_dtypes()."
459-
)
460-
461-
# Support both flat and nested dictionaries in dtypes()
462-
dtypes = self._backend_info.dtypes()
463-
if "bool" in dtypes and isinstance(dtypes["bool"], dict):
464-
return dtypes["bool"][dtype]
465-
if "boolean" in dtypes and isinstance(dtypes["boolean"], dict):
466-
return dtypes["boolean"][dtype]
467-
return dtypes[dtype]
405+
dtype = "bool"
406+
407+
return self._backend_info.dtypes(kind="bool")[dtype]
468408

469409
def __getattr__(
470410
self: _Proxy,

0 commit comments

Comments
 (0)