Skip to content
34 changes: 20 additions & 14 deletions skore/src/skore/sklearn/train_test_split/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,30 @@ class labels.

new_arrays = list(arrays)
keys = []

if X is not None:
new_arrays.append(X)
keys += ["X"]

if y is not None:
new_arrays.append(y)
keys += ["y"]

if as_dict and arrays:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you remove this? If both are true, it will cause a conflict.

raise ValueError(
"When as_dict=True, arrays must be passed as keyword arguments.\n"
"Example: train_test_split(X=X, y=y, sw=sample_weight, as_dict=True)"
)
if as_dict and X is None and y is None:
if keyword_arrays:
new_arrays = list(keyword_arrays.values())
else:
X, y = (arrays[0], arrays[1]) if len(arrays) >= 2 else (arrays[0], None)
new_arrays = [X, y]
keys = ["X", "y"]

if keyword_arrays:
if X is None and y is None:
arrays = tuple(
keyword_arrays.values()
) # if X and y is not passed but other variables
keys += list(keyword_arrays.keys())
new_arrays += list(keyword_arrays.values())
keys += list(keyword_arrays.keys())
new_arrays += list(keyword_arrays.values())

if not new_arrays:
raise ValueError("At least one array must be provided")

# Perform the train-test split using sklearn
output = sklearn.model_selection.train_test_split(
*new_arrays,
test_size=test_size,
Expand All @@ -168,7 +171,10 @@ class labels.
)

if X is None:
X = arrays[0] if len(arrays) == 1 else arrays[-2]
if arrays:
X = arrays[0] if len(arrays) == 1 else arrays[-2]
elif keyword_arrays and "X" in keyword_arrays:
X = keyword_arrays["X"]

if y is None and len(arrays) >= 2:
y = arrays[-1]
Expand All @@ -183,7 +189,6 @@ class labels.
y_test = None

ml_task = _find_ml_task(y)

kwargs = dict(
arrays=new_arrays,
test_size=test_size,
Expand All @@ -198,6 +203,7 @@ class labels.
ml_task=ml_task,
)

# Display any warnings related to train-test split
from skore import console # avoid circular import

for warning_class in TRAIN_TEST_SPLIT_WARNINGS:
Expand Down
103 changes: 96 additions & 7 deletions skore/tests/unit/sklearn/train_test_split/test_train_test_split.py
Copy link
Contributor

@auguste-probabl auguste-probabl Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with

arr1 = [[1]] * 20
arr2 = [0] * 10 + [1] * 10
train_test_split(arr2, z=arr1, as_dict=True)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is tested through two functions named( test_train_test_split_check_dict()) and test_train_test_split_dict_kwargs().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit different: right now we test either all arguments passed by keyword, or all arguments passed by position. I'd like to also test the combination of both (one array passed by position, one array passed by keyword).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test what happens with

train_test_split(X, X=X)

I think there should be an error like

X cannot be passed both by position and by keyword.

Same for y.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from datetime import datetime

import numpy as np
import pandas
import polars
import pytest
Expand Down Expand Up @@ -183,17 +184,17 @@ def test_train_test_split_kwargs():


def test_train_test_split_dict_kwargs():
"""Passing data without keyword arguments with return_dict=True
should raise ValueError."""
"""Passing data with positional arguments and as_dict=True should work."""

X = [[1]] * 20
y = [0] * 10 + [1] * 10

with pytest.raises(
ValueError,
match="When as_dict=True, arrays must be passed as keyword arguments",
):
train_test_split(X, y, random_state=0, as_dict=True)
result = train_test_split(X, y, random_state=0, as_dict=True)

assert "X_train" in result
assert "X_test" in result
assert "y_train" in result
assert "y_test" in result
Comment on lines +194 to +197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this



def test_train_test_split_check_dict():
Expand Down Expand Up @@ -221,3 +222,91 @@ def test_train_test_split_check_dict_no_X_no_y():
output = train_test_split(z=z, random_state=0, as_dict=True)
keys = output.keys()
assert list(keys) == ["z_train", "z_test"]


def test_train_test_split_as_dict_with_all_keyword_args():
"""Ensure result is a dict with correct keys when as_dict=True
and all arrays are keyword args."""
X = np.arange(10).reshape(10, 1)
y = np.arange(10)
weights = np.ones(10)

result = train_test_split(
X=X,
y=y,
sample_weights=weights,
test_size=0.2,
as_dict=True,
random_state=0,
)
Copy link
Contributor

@auguste-probabl auguste-probabl Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test feels a bit redundant, but

train_test_split(
    X,
    y,
    sample_weights=weights,
    ...
)

(i.e. a mix of positional and keyword arguments) would be interesting. See also my other comment


assert set(result.keys()) == {
"X_train",
"X_test",
"y_train",
"y_test",
"sample_weights_train",
"sample_weights_test",
}
assert result["X_train"].shape[0] == 8
assert result["X_test"].shape[0] == 2


def test_train_test_split_as_dict_with_multiple_named_inputs():
"""Ensure train_test_split works with multiple inputs when using as_dict=True."""
X = np.arange(10).reshape(10, 1)
y = np.arange(10)
z = np.arange(10, 20)

result = train_test_split(
X=X,
y=y,
z=z,
test_size=0.5,
as_dict=True,
random_state=42,
)

expected_keys = {"X_train", "X_test", "y_train", "y_test", "z_train", "z_test"}

assert all(key in result for key in expected_keys)
assert result["z_train"].shape[0] == 5
assert result["z_test"].shape[0] == 5


def test_train_test_split_as_dict_with_mixed_input_types():
"""Ensure train_test_split handles a mix of array-like types with as_dict=True."""
X = [[i] for i in range(10)]
y = np.arange(10)

result = train_test_split(X=X, y=y, test_size=0.3, as_dict=True, random_state=1)

assert set(result.keys()) == {"X_train", "X_test", "y_train", "y_test"}
assert len(result["X_train"]) == 7
assert len(result["X_test"]) == 3


def test_train_test_split_only_X():
X = [[1]] * 20
result = train_test_split(X=X, random_state=0, as_dict=True)
assert "X_train" in result
assert "X_test" in result


def test_empty_input():
X = []
y = []
with pytest.raises(ValueError):
train_test_split(X, y)


def test_train_test_split_as_dict_with_positioned_args():
X = [[1]] * 20
y = [0] * 10 + [1] * 10

result = train_test_split(X, y, test_size=0.2, as_dict=True, random_state=0)

assert "X_train" in result
assert "X_test" in result
assert "y_train" in result
assert "y_test" in result