|
8 | 8 | import pytest
|
9 | 9 | from glum import GeneralizedLinearRegressor, GeneralizedLinearRegressorCV
|
10 | 10 | from lightgbm import LGBMClassifier, LGBMRegressor
|
| 11 | +from scipy.sparse import csr_matrix |
11 | 12 | from sklearn.ensemble import HistGradientBoostingClassifier
|
12 | 13 | from sklearn.linear_model import LinearRegression
|
13 | 14 | from xgboost import XGBClassifier, XGBRegressor
|
|
20 | 21 | convert_treatment,
|
21 | 22 | function_has_argument,
|
22 | 23 | get_linear_dimension,
|
| 24 | + index_matrix, |
| 25 | + index_vector, |
23 | 26 | supports_categoricals,
|
24 | 27 | validate_all_vectors_same_index,
|
25 | 28 | validate_model_and_predict_method,
|
@@ -345,3 +348,59 @@ def test_validate_valid_treatment_variant_not_control(
|
345 | 348 | else:
|
346 | 349 | with pytest.raises(ValueError, match="variant"):
|
347 | 350 | validate_valid_treatment_variant_not_control(treatment_variant, n_variants)
|
| 351 | + |
| 352 | + |
| 353 | +@pytest.mark.parametrize("matrix_backend", [np.ndarray, pd.DataFrame, csr_matrix]) |
| 354 | +@pytest.mark.parametrize("rows_backend", [np.array, pd.Series]) |
| 355 | +def test_index_matrix(matrix_backend, rows_backend): |
| 356 | + n_samples = 10 |
| 357 | + if matrix_backend == np.ndarray: |
| 358 | + matrix = np.array(list(range(n_samples))).reshape((-1, 1)) |
| 359 | + elif matrix_backend == pd.DataFrame: |
| 360 | + # We make sure that the index is not equal to the row number. |
| 361 | + matrix = pd.DataFrame( |
| 362 | + list(range(n_samples)), index=list(range(20, 20 + n_samples)) |
| 363 | + ) |
| 364 | + elif matrix_backend == csr_matrix: |
| 365 | + matrix = csr_matrix(np.array(list(range(n_samples))).reshape((-1, 1))) |
| 366 | + else: |
| 367 | + raise ValueError() |
| 368 | + rows = rows_backend([1, 4, 5]) |
| 369 | + result = index_matrix(matrix=matrix, rows=rows) |
| 370 | + |
| 371 | + assert isinstance(result, matrix_backend) |
| 372 | + assert result.shape[1] == matrix.shape[1] |
| 373 | + |
| 374 | + if isinstance(result, pd.DataFrame): |
| 375 | + processed_result = result.values[:, 0] |
| 376 | + else: |
| 377 | + processed_result = result[:, 0] |
| 378 | + |
| 379 | + expected = np.array([1, 4, 5]) |
| 380 | + assert (processed_result == expected).sum() == len(expected) |
| 381 | + |
| 382 | + |
| 383 | +@pytest.mark.parametrize("vector_backend", [np.ndarray, pd.Series]) |
| 384 | +@pytest.mark.parametrize("rows_backend", [np.array, pd.Series]) |
| 385 | +def test_index_vector(vector_backend, rows_backend): |
| 386 | + n_samples = 10 |
| 387 | + if vector_backend == np.ndarray: |
| 388 | + vector = np.array(list(range(n_samples))) |
| 389 | + elif vector_backend == pd.Series: |
| 390 | + # We make sure that the index is not equal to the row number. |
| 391 | + vector = pd.Series( |
| 392 | + list(range(n_samples)), index=list(range(20, 20 + n_samples)) |
| 393 | + ) |
| 394 | + else: |
| 395 | + raise ValueError() |
| 396 | + |
| 397 | + rows = rows_backend([1, 4, 5]) |
| 398 | + |
| 399 | + result = index_vector(vector=vector, rows=rows) |
| 400 | + assert isinstance(result, vector_backend) |
| 401 | + |
| 402 | + if isinstance(result, pd.Series): |
| 403 | + result = result.values |
| 404 | + |
| 405 | + expected = np.array([1, 4, 5]) |
| 406 | + assert (result == expected).all() |
0 commit comments