Skip to content
34 changes: 20 additions & 14 deletions skore/src/skore/sklearn/train_test_split/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,30 @@ class labels.

new_arrays = list(arrays)
keys = []

if X is not None:
new_arrays.append(X)
keys += ["X"]

if y is not None:
new_arrays.append(y)
keys += ["y"]

if as_dict and arrays:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you remove this? If both are true, it will cause a conflict.

raise ValueError(
"When as_dict=True, arrays must be passed as keyword arguments.\n"
"Example: train_test_split(X=X, y=y, sw=sample_weight, as_dict=True)"
)
if as_dict and X is None and y is None:
if keyword_arrays:
new_arrays = list(keyword_arrays.values())
else:
X, y = (arrays[0], arrays[1]) if len(arrays) >= 2 else (arrays[0], None)
new_arrays = [X, y]
keys = ["X", "y"]

if keyword_arrays:
if X is None and y is None:
arrays = tuple(
keyword_arrays.values()
) # if X and y is not passed but other variables
keys += list(keyword_arrays.keys())
new_arrays += list(keyword_arrays.values())
keys += list(keyword_arrays.keys())
new_arrays += list(keyword_arrays.values())

if not new_arrays:
raise ValueError("At least one array must be provided")

# Perform the train-test split using sklearn
output = sklearn.model_selection.train_test_split(
*new_arrays,
test_size=test_size,
Expand All @@ -168,7 +171,10 @@ class labels.
)

if X is None:
X = arrays[0] if len(arrays) == 1 else arrays[-2]
if arrays:
X = arrays[0] if len(arrays) == 1 else arrays[-2]
elif keyword_arrays and "X" in keyword_arrays:
X = keyword_arrays["X"]

if y is None and len(arrays) >= 2:
y = arrays[-1]
Expand All @@ -183,7 +189,6 @@ class labels.
y_test = None

ml_task = _find_ml_task(y)

kwargs = dict(
arrays=new_arrays,
test_size=test_size,
Expand All @@ -198,6 +203,7 @@ class labels.
ml_task=ml_task,
)

# Display any warnings related to train-test split
from skore import console # avoid circular import

for warning_class in TRAIN_TEST_SPLIT_WARNINGS:
Expand Down
63 changes: 55 additions & 8 deletions skore/tests/unit/sklearn/train_test_split/test_train_test_split.py
Copy link
Contributor

@auguste-probabl auguste-probabl Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with

arr1 = [[1]] * 20
arr2 = [0] * 10 + [1] * 10
train_test_split(arr2, z=arr1, as_dict=True)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is tested through two functions named( test_train_test_split_check_dict()) and test_train_test_split_dict_kwargs().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit different: right now we test either all arguments passed by keyword, or all arguments passed by position. I'd like to also test the combination of both (one array passed by position, one array passed by keyword).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test what happens with

train_test_split(X, X=X)

I think there should be an error like

X cannot be passed both by position and by keyword.

Same for y.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from datetime import datetime

import numpy as np
import pandas
import polars
import pytest
Expand Down Expand Up @@ -183,17 +184,27 @@ def test_train_test_split_kwargs():


def test_train_test_split_dict_kwargs():
"""Passing data without keyword arguments with return_dict=True
should raise ValueError."""

"""When passing data with positional arguments and as_dict=True,
the first argument will be interpreted as `X` and the second one as `y`."""
X = [[1]] * 20
y = [0] * 10 + [1] * 10

with pytest.raises(
ValueError,
match="When as_dict=True, arrays must be passed as keyword arguments",
):
train_test_split(X, y, random_state=0, as_dict=True)
result = train_test_split(X, y, test_size=0.2, as_dict=True, random_state=0)

assert "X_train" in result
assert "X_test" in result
assert "y_train" in result
assert "y_test" in result
Comment on lines +194 to +197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this


X_train = np.array(result["X_train"])
X_test = np.array(result["X_test"])
y_train = np.array(result["y_train"])
y_test = np.array(result["y_test"])

assert X_train.ndim == 2
assert X_test.ndim == 2
assert y_train.ndim == 1
assert y_test.ndim == 1


def test_train_test_split_check_dict():
Expand Down Expand Up @@ -221,3 +232,39 @@ def test_train_test_split_check_dict_no_X_no_y():
output = train_test_split(z=z, random_state=0, as_dict=True)
keys = output.keys()
assert list(keys) == ["z_train", "z_test"]


def test_train_test_split_as_dict_with_all_keyword_args():
"""Ensure result is a dict with correct keys when as_dict=True
and all arrays are keyword args."""
X = np.arange(10).reshape(10, 1)
y = np.arange(10)
weights = np.ones(10)

result = train_test_split(
X=X,
y=y,
sample_weights=weights,
test_size=0.2,
as_dict=True,
random_state=0,
)
Copy link
Contributor

@auguste-probabl auguste-probabl Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test feels a bit redundant, but

train_test_split(
    X,
    y,
    sample_weights=weights,
    ...
)

(i.e. a mix of positional and keyword arguments) would be interesting. See also my other comment


assert set(result.keys()) == {
"X_train",
"X_test",
"y_train",
"y_test",
"sample_weights_train",
"sample_weights_test",
}
assert result["X_train"].shape[0] == 8
assert result["X_test"].shape[0] == 2


def test_empty_input():
"""Tests that passing empty lists for X and y raises a ValueError."""
X = []
y = []
with pytest.raises(ValueError):
train_test_split(X, y)
Comment on lines +232 to +237
Copy link
Contributor

@auguste-probabl auguste-probabl Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test is needed; this behaviour is not specific to our function, but rather to sklearn's.