Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@
"name": "Khush Agrawal",
"profile": "https://github.com/Khushmagrawal",
"contributions": [
"code",
"code"
]
},
{
Expand All @@ -257,5 +257,15 @@
"bug"
]
},
{
"login": "sun-9545sunoj",
"name": "Sunoj",
"avatar_url": "https://avatars.githubusercontent.com/u/156280523?v=4",
"profile": "https://github.com/sun-9545sunoj",
"contributions": [
"code",
"bug"
]
}
]
}
}
13 changes: 10 additions & 3 deletions skpro/datatypes/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from skpro.datatypes._base import BaseDatatype
from skpro.datatypes._common import _metadata_requested, _ret
from skpro.datatypes._proba import check_dict_Proba
from skpro.datatypes._registry import AMBIGUOUS_MTYPES, SCITYPE_LIST, mtype_to_scitype
from skpro.datatypes._registry import AMBIGUOUS_MTYPES, mtype_to_scitype


def get_check_dict(soft_deps="present"):
Expand Down Expand Up @@ -545,14 +545,16 @@ def check_is_error_msg(msg, var_name="obj", allowed_msg=None, raise_exception=Fa
raise raise_exception(msg_invalid_input)


def scitype(obj, candidate_scitypes=SCITYPE_LIST, exclude_mtypes=AMBIGUOUS_MTYPES):
def scitype(obj, candidate_scitypes=None, exclude_mtypes=AMBIGUOUS_MTYPES):
"""Infer the scitype of an object.

Parameters
----------
obj : object to infer type of - any type, should comply with some mtype spec
if as_scitype is provided, this must be mtype belonging to scitype
candidate_scitypes: str or list of str, scitypes to pick from
candidate_scitypes: str, list of str, or None, optional, default=None
scitypes to pick from. If None, it defaults to all valid scitypes dynamically
resolved via lazy import (i.e., datatypes.SCITYPE_LIST).
valid scitype strings are in datatypes.SCITYPE_REGISTER
exclude_mtypes : list of str, default = AMBIGUOUS_MTYPES
which mtypes to ignore in inferring mtype, default = ambiguous ones
Expand All @@ -568,6 +570,11 @@ def scitype(obj, candidate_scitypes=SCITYPE_LIST, exclude_mtypes=AMBIGUOUS_MTYPE
------
TypeError if no type can be identified, or more than one type is identified
"""
if candidate_scitypes is None:
from skpro.datatypes._registry import SCITYPE_LIST

candidate_scitypes = SCITYPE_LIST

candidate_scitypes = _coerce_list_of_str(
candidate_scitypes, var_name="candidate_scitypes"
)
Expand Down
6 changes: 3 additions & 3 deletions skpro/datatypes/_proba/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@


MTYPE_REGISTER_PROBA = [
("pred_interval", "Proba", "predictive intervals"),
("pred_quantiles", "Proba", "quantile predictions"),
("pred_var", "Proba", "variance predictions"),
("pred_interval", "Proba", "predictive intervals", None),
("pred_quantiles", "Proba", "quantile predictions", None),
("pred_var", "Proba", "variance predictions", None),
# ("pred_dost", "Proba", "full distribution predictions, tensorflow-probability"),
]

Expand Down
134 changes: 113 additions & 21 deletions skpro/datatypes/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
each tuple corresponds to an mtype, elements as follows:
0 : string - name of the mtype as used throughout skpro and in datatypes
1 : string - name of the scitype the mtype is for, must be in SCITYPE_REGISTER
2 : string - plain English description of the scitype
2 : string - plain English description of the mtype
3 : string or list of strings - soft dependencies of the mtype, or None if no soft deps

---

Expand All @@ -39,38 +40,127 @@
---
"""

import threading

from skpro.datatypes._proba._registry import MTYPE_LIST_PROBA, MTYPE_REGISTER_PROBA
from skpro.datatypes._table._registry import MTYPE_LIST_TABLE, MTYPE_REGISTER_TABLE

MTYPE_REGISTER = []
MTYPE_REGISTER += MTYPE_REGISTER_TABLE
MTYPE_REGISTER += MTYPE_REGISTER_PROBA

MTYPE_SOFT_DEPS = {
"polars_eager_table": "polars",
"polars_lazy_table": "polars",
}


# mtypes to exclude in checking since they are ambiguous and rare
AMBIGUOUS_MTYPES = []


__all__ = [
"MTYPE_REGISTER",
"MTYPE_LIST_TABLE",
"MTYPE_LIST_PROBA",
"MTYPE_SOFT_DEPS",
"SCITYPE_REGISTER",
"MTYPE_SOFT_DEPS", # noqa: F822 - dynamically exported via __getattr__
"SCITYPE_REGISTER", # noqa: F822 - dynamically exported via __getattr__
"SCITYPE_LIST", # noqa: F822 - dynamically exported via __getattr__
]

_SCITYPE_DESCRIPTIONS = {
"Table": "data table with primitive column types",
"Proba": "probability distribution or distribution statistics, return types",
}

SCITYPE_REGISTER = [
("Table", "data table with primitive column types"),
("Proba", "probability distribution or distribution statistics, return types"),
]

SCITYPE_LIST = [x[0] for x in SCITYPE_REGISTER]
# We expose these via __getattr__ for programmatic lookup
_CACHE = {}
_REGISTRY_LOCK = threading.RLock()


def _get_registry(name):
if name in _CACHE:
return _CACHE[name]

with _REGISTRY_LOCK:
if name in _CACHE:
return _CACHE[name]
Comment on lines +74 to +80
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_registry reads _CACHE outside of _REGISTRY_LOCK (lines 75-76). If one thread pre-seeds _CACHE[...] during generation, another thread can observe and return the still-empty placeholders, leading to transient wrong results (e.g., empty SCITYPE_LIST). Consider acquiring _REGISTRY_LOCK before the initial cache-hit check (or using a per-name “generating” sentinel/event) so other threads block until generation is complete while still allowing same-thread re-entrancy via the RLock.

Copilot uses AI. Check for mistakes.

if name in ["MTYPE_SOFT_DEPS", "SCITYPE_REGISTER", "SCITYPE_LIST"]:
# Pre-seed _CACHE to prevent infinite recursion during generation
# Re-entrant calls (e.g., from inspect) will receive these references
_CACHE["MTYPE_SOFT_DEPS"] = {}
_CACHE["SCITYPE_REGISTER"] = []
_CACHE["SCITYPE_LIST"] = []

try:
from skpro.datatypes._check import get_check_dict

check_dict = get_check_dict(soft_deps="all")

soft_deps = {}
scitypes = set()

for k, cls in check_dict.items():
if hasattr(cls, "get_class_tag"):
mtype = cls.get_class_tag("name")
scitype = cls.get_class_tag("scitype")
deps = cls.get_class_tag("python_dependencies", None)
else:
mtype = k[0]
scitype = k[1]
deps = None

if deps is not None:
soft_deps[mtype] = (
list(deps) if isinstance(deps, tuple) else deps
)
if scitype is not None:
scitypes.add(scitype)

for mtype_tuple in MTYPE_REGISTER:
mtype = mtype_tuple[0]
scitype = mtype_tuple[1]
scitypes.add(scitype)
if len(mtype_tuple) >= 4 and mtype_tuple[3] is not None:
deps = mtype_tuple[3]
soft_deps[mtype] = (
list(deps) if isinstance(deps, tuple) else deps
)

scitype_register = []
for sci in sorted(scitypes):
desc = _SCITYPE_DESCRIPTIONS.get(sci)
if desc is None:
raise ValueError(
f"scitype '{sci}' is missing a description in "
"`_SCITYPE_DESCRIPTIONS`. Please add it to "
"`skpro.datatypes._registry._SCITYPE_DESCRIPTIONS`."
)
scitype_register.append((sci, desc))

scitype_list = [x[0] for x in scitype_register]

# Update the pre-seeded objects in-place
_CACHE["MTYPE_SOFT_DEPS"].update(soft_deps)
_CACHE["SCITYPE_REGISTER"].extend(scitype_register)
_CACHE["SCITYPE_LIST"].extend(scitype_list)

return _CACHE[name]
Comment on lines +82 to +133
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new lazy-generated registries (MTYPE_SOFT_DEPS, SCITYPE_REGISTER, SCITYPE_LIST) and the expanded 4-tuple MTYPE_REGISTER format aren’t covered by a targeted regression test. In particular, there’s no assertion that MTYPE_SOFT_DEPS contains the expected keys/values (e.g., polars mtypes require both polars and pyarrow) and that scitype_to_mtype(..., softdeps='exclude'|'present') filters based on those values. Please add a focused test to lock this behavior down, since subtle registry changes can otherwise silently alter API results.

Copilot uses AI. Check for mistakes.
except Exception:
_CACHE.pop("MTYPE_SOFT_DEPS", None)
_CACHE.pop("SCITYPE_REGISTER", None)
_CACHE.pop("SCITYPE_LIST", None)
raise

raise AttributeError(f"module {__name__} has no attribute {name}")


def __getattr__(name):
return _get_registry(name)


def __dir__():
"""Return module attributes and dynamically generated properties."""
names = list(globals().keys()) + [
"MTYPE_SOFT_DEPS",
"SCITYPE_REGISTER",
"SCITYPE_LIST",
]
return sorted(set(names))


def mtype_to_scitype(mtype: str, return_unique=False, coerce_to_list=False):
Expand Down Expand Up @@ -173,8 +263,8 @@ def scitype_to_mtype(scitype: str, softdeps: str = "exclude"):
if not isinstance(scitype, str):
raise TypeError(msg)

# now we know scitype is a string, check if it is in the register
if scitype not in SCITYPE_LIST:
scitype_list = _get_registry("SCITYPE_LIST")
if scitype not in scitype_list:
raise ValueError(
f'"{scitype}" is not a valid scitype string, see datatypes.SCITYPE_REGISTER'
)
Expand All @@ -188,20 +278,22 @@ def scitype_to_mtype(scitype: str, softdeps: str = "exclude"):
if softdeps not in ["exclude", "present"]:
return mtypes

mtype_soft_deps = _get_registry("MTYPE_SOFT_DEPS")

if softdeps == "exclude":
# subset to mtypes that require no soft deps
mtypes = [m for m in mtypes if m not in MTYPE_SOFT_DEPS.keys()]
mtypes = [m for m in mtypes if m not in mtype_soft_deps.keys()]
return mtypes

if softdeps == "present":
from skbase.utils.dependencies import _check_soft_dependencies

def present(x):
"""Return True if x has satisfied soft dependency or has no soft dep."""
if x not in MTYPE_SOFT_DEPS.keys():
if x not in mtype_soft_deps.keys():
return True
else:
return _check_soft_dependencies(MTYPE_SOFT_DEPS[x], severity="none")
return _check_soft_dependencies(mtype_soft_deps[x], severity="none")

# return only mtypes with soft dependencies present (or requiring none)
mtypes = [m for m in mtypes if present(m)]
Expand Down
14 changes: 7 additions & 7 deletions skpro/datatypes/_table/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@


MTYPE_REGISTER_TABLE = [
("pd_DataFrame_Table", "Table", "pd.DataFrame representation of a data table"),
("numpy1D", "Table", "1D np.narray representation of a univariate table"),
("numpy2D", "Table", "2D np.narray representation of a univariate table"),
("pd_Series_Table", "Table", "pd.Series representation of a data table"),
("list_of_dict", "Table", "list of dictionaries with primitive entries"),
("polars_eager_table", "Table", "polars.DataFrame representation of a data table"),
("polars_lazy_table", "Table", "polars.LazyFrame representation of a data table"),
("pd_DataFrame_Table", "Table", "pd.DataFrame representation of a data table", None),
("numpy1D", "Table", "1D np.narray representation of a univariate table", None),
("numpy2D", "Table", "2D np.narray representation of a univariate table", None),
("pd_Series_Table", "Table", "pd.Series representation of a data table", None),
("list_of_dict", "Table", "list of dictionaries with primitive entries", None),
("polars_eager_table", "Table", "polars.DataFrame representation of a data table", "polars"),
("polars_lazy_table", "Table", "polars.LazyFrame representation of a data table", "polars"),
]

MTYPE_LIST_TABLE = pd.DataFrame(MTYPE_REGISTER_TABLE)[0].values
Loading