Skip to content
148 changes: 29 additions & 119 deletions skore/src/skore/sklearn/train_test_split/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,134 +30,44 @@ def train_test_split(
):
"""Perform train-test-split of data.

This is a wrapper over scikit-learn's
:func:`sklearn.model_selection.train_test_split` helper function,
enriching it with various warnings.

The signature is fully compatible with sklearn's ``train_test_split``, and
some keyword arguments are added to make the detection of issues more accurate.
For instance, argument ``y`` has been added to pass the target explicitly, which
makes it easier to detect issues with the target.

See the :ref:`example_train_test_split` example.

Parameters
----------
*arrays : sequence of indexables with same length / shape[0]
Allowed inputs are lists, numpy arrays, scipy-sparse matrices or pandas
dataframes.
X : array-like, optional
If not None, will be appended to the list of arrays passed positionally.
y : array-like, optional
If not None, will be appended to the list of arrays passed positionally, after
``X``. If None, it is assumed that the last array in ``arrays`` is ``y``.
test_size : float or int, optional
If float, should be between 0.0 and 1.0 and represent the proportion of
the dataset to include in the test split. If int, represents the absolute number
of test samples. If None, the value is set to the complement of the train size.
If train_size is also None, it will be set to 0.25.
train_size : float or int, optional
If float, should be between 0.0 and 1.0 and represent the proportion
of the dataset to include in the train split. If int, represents the absolute
number of train samples. If None, the value is automatically set to the
complement of the test size.
random_state : int or numpy RandomState instance, optional
Controls the shuffling applied to the data before applying the split. Pass an
int for reproducible output across multiple function calls.
shuffle : bool, default is True
Whether or not to shuffle the data before splitting. If shuffle=False
then stratify must be None.
stratify : array-like, optional
If not None, data is split in a stratified fashion, using this as the
class labels.
as_dict : bool, default is False
If True, returns a Dictionary with keys values ``X_train``, ``X_test``,
``y_train``, and ``y_test`` instead of a List. Requires data to be
passed as keyword arguments.
**keyword_arrays : array-like, optional
Additional array-like arguments passed by keyword. Used to determine the keys
of the output dict when ``as_dict=True``.

Returns
-------
splitting : list or dict
If ``as_dict=False`` (the default): List containing train-test split of inputs.
The length of the list is twice the number of arrays passed, including
the ``X`` and ``y`` keyword arguments. If arrays are passed positionally as well
as through ``X`` and ``y``, the output arrays are ordered as follows: first the
arrays passed positionally, in the order they were passed, then ``X`` if it
was passed, then ``y`` if it was passed.

If ``as_dict=True``: Dictionary with keys
``X_train``, ``X_test``, ``y_train``, and ``y_test``,
each containing respective split data.

Examples
--------
>>> # xdoctest: +SKIP
>>> import numpy as np
>>> X, y = np.arange(10).reshape((5, 2)), range(5)

>>> # Drop-in replacement for sklearn train_test_split
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
... test_size=0.33, random_state=42)
>>> X_train
array([[4, 5],
[0, 1],
[6, 7]])

>>> # Explicit X and y, makes detection of problems easier
>>> X_train, X_test, y_train, y_test = train_test_split(X=X, y=y,
... test_size=0.33, random_state=42)
>>> X_train
array([[4, 5],
[0, 1],
[6, 7]])

>>> # When passing X and y explicitly, X is returned before y
>>> arr = np.arange(10).reshape((5, 2))
>>> splits = train_test_split(
... arr, y=y, X=X, test_size=0.33, random_state=42)
>>> arr_train, arr_test, X_train, X_test, y_train, y_test = splits
>>> X_train
array([[4, 5],
[0, 1],
[6, 7]])

>>> # Returns dictionary when as_dict is True, inputs must be keyword arguments.
>>> sample_weights = np.arange(10).reshape((5, 2))
>>> split_dict = train_test_split(
... X=X, y=y, sample_weights=sample_weights, as_dict=True, random_state=0)
>>> split_dict
{'X_train': ..., 'X_test': ...,
'y_train': ..., 'y_test': ...,
'sample_weights_train': ..., 'sample_weights_test': ...}
This is a wrapper over scikit-learn's train_test_split helper function,
enriching it with various warnings and additional functionality.
"""
import sklearn.model_selection

new_arrays = list(arrays)
keys = []

if X is not None:
new_arrays.append(X)
keys += ["X"]
keys.append("X")
if y is not None:
new_arrays.append(y)
keys += ["y"]

if as_dict and arrays:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you remove this? If both are true, it will cause a conflict.

raise ValueError(
"When as_dict=True, arrays must be passed as keyword arguments.\n"
"Example: train_test_split(X=X, y=y, sw=sample_weight, as_dict=True)"
)
keys.append("y")

if keyword_arrays:
if as_dict:
if X is None and y is None:
arrays = tuple(
keyword_arrays.values()
) # if X and y is not passed but other variables
if not keyword_arrays:
raise ValueError(
"When as_dict=True, arrays must be passed as keyword arguments"
)

new_arrays = list(keyword_arrays.values())

if X is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this repeated inside and outside the if? It is redundant.

new_arrays.append(X)
keys.append("X")
if y is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this repeated inside and outside the if? It is redundant.

new_arrays.append(y)
keys.append("y")

keys += list(keyword_arrays.keys())
new_arrays += list(keyword_arrays.values())

if not new_arrays:
raise ValueError("At least one array must be provided")

# Perform the train-test split using sklearn
output = sklearn.model_selection.train_test_split(
*new_arrays,
test_size=test_size,
Expand All @@ -167,21 +77,20 @@ class labels.
stratify=stratify,
)

if X is None:
X = arrays[0] if len(arrays) == 1 else arrays[-2]
if X is None and len(arrays) >= 1:
Copy link
Contributor

@nkapila6 nkapila6 Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect value for case when len(arrays)>1?

X = arrays[0]

if y is None and len(arrays) >= 2:
y = arrays[-1]

if y is not None:
y_labels = np.unique(y)
y_test = (
output[3] if keyword_arrays else output[-1]
) # when more kwargs are given
y_test = output[3] if as_dict else output[-1]
else:
y_labels = None
y_test = None

# Determine the ML task based on y
ml_task = _find_ml_task(y)

kwargs = dict(
Expand All @@ -198,6 +107,7 @@ class labels.
ml_task=ml_task,
)

# Display any warnings related to train-test split
from skore import console # avoid circular import

for warning_class in TRAIN_TEST_SPLIT_WARNINGS:
Expand Down