Skip to content

Commit 241d681

Browse files
authored
Add tests for index_matrix and index_vector. (#89)
1 parent 0d1958c commit 241d681

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

metalearners/_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
def safe_len(X: Matrix) -> int:
29+
"""Determine the length of a Matrix."""
2930
if scipy.sparse.issparse(X):
3031
return X.shape[0]
3132
return len(X)

tests/test__utils.py

+59
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
from glum import GeneralizedLinearRegressor, GeneralizedLinearRegressorCV
1010
from lightgbm import LGBMClassifier, LGBMRegressor
11+
from scipy.sparse import csr_matrix
1112
from sklearn.ensemble import HistGradientBoostingClassifier
1213
from sklearn.linear_model import LinearRegression
1314
from xgboost import XGBClassifier, XGBRegressor
@@ -20,6 +21,8 @@
2021
convert_treatment,
2122
function_has_argument,
2223
get_linear_dimension,
24+
index_matrix,
25+
index_vector,
2326
supports_categoricals,
2427
validate_all_vectors_same_index,
2528
validate_model_and_predict_method,
@@ -345,3 +348,59 @@ def test_validate_valid_treatment_variant_not_control(
345348
else:
346349
with pytest.raises(ValueError, match="variant"):
347350
validate_valid_treatment_variant_not_control(treatment_variant, n_variants)
351+
352+
353+
@pytest.mark.parametrize("matrix_backend", [np.ndarray, pd.DataFrame, csr_matrix])
354+
@pytest.mark.parametrize("rows_backend", [np.array, pd.Series])
355+
def test_index_matrix(matrix_backend, rows_backend):
356+
n_samples = 10
357+
if matrix_backend == np.ndarray:
358+
matrix = np.array(list(range(n_samples))).reshape((-1, 1))
359+
elif matrix_backend == pd.DataFrame:
360+
# We make sure that the index is not equal to the row number.
361+
matrix = pd.DataFrame(
362+
list(range(n_samples)), index=list(range(20, 20 + n_samples))
363+
)
364+
elif matrix_backend == csr_matrix:
365+
matrix = csr_matrix(np.array(list(range(n_samples))).reshape((-1, 1)))
366+
else:
367+
raise ValueError()
368+
rows = rows_backend([1, 4, 5])
369+
result = index_matrix(matrix=matrix, rows=rows)
370+
371+
assert isinstance(result, matrix_backend)
372+
assert result.shape[1] == matrix.shape[1]
373+
374+
if isinstance(result, pd.DataFrame):
375+
processed_result = result.values[:, 0]
376+
else:
377+
processed_result = result[:, 0]
378+
379+
expected = np.array([1, 4, 5])
380+
assert (processed_result == expected).sum() == len(expected)
381+
382+
383+
@pytest.mark.parametrize("vector_backend", [np.ndarray, pd.Series])
384+
@pytest.mark.parametrize("rows_backend", [np.array, pd.Series])
385+
def test_index_vector(vector_backend, rows_backend):
386+
n_samples = 10
387+
if vector_backend == np.ndarray:
388+
vector = np.array(list(range(n_samples)))
389+
elif vector_backend == pd.Series:
390+
# We make sure that the index is not equal to the row number.
391+
vector = pd.Series(
392+
list(range(n_samples)), index=list(range(20, 20 + n_samples))
393+
)
394+
else:
395+
raise ValueError()
396+
397+
rows = rows_backend([1, 4, 5])
398+
399+
result = index_vector(vector=vector, rows=rows)
400+
assert isinstance(result, vector_backend)
401+
402+
if isinstance(result, pd.Series):
403+
result = result.values
404+
405+
expected = np.array([1, 4, 5])
406+
assert (result == expected).all()

0 commit comments

Comments
 (0)