Skip to content

Commit 19af399

Browse files
authored
Add shape check to Dataset initialization (#106)
* Check input array shapes when Dataset is initialized * Add test for nimare.utils._check_inputs_shape * Cover two missing lines by the test * Add test when n or v is None * Remove extra clause
1 parent 0a99840 commit 19af399

File tree

4 files changed

+76
-1
lines changed

4 files changed

+76
-1
lines changed

.zenodo.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
"affiliation": "Florida International University",
1111
"orcid": "0000-0001-9813-3167"
1212
},
13+
{
14+
"name": "Peraza, Julio A.",
15+
"affiliation": "Florida International University",
16+
"orcid": "0000-0003-3816-5903"
17+
},
1318
{
1419
"name": "Nichols, Thomas E.",
1520
"affiliation": "Big Data Institute, University of Oxford",

pymare/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pandas as pd
77

8-
from pymare.utils import _listify
8+
from pymare.utils import _check_inputs_shape, _listify
99

1010
from .estimators import (
1111
DerSimonianLaird,
@@ -94,6 +94,10 @@ def __init__(
9494
self.X = X
9595
self.X_names = names
9696

97+
_check_inputs_shape(self.y, self.X, "y", "X", row=True)
98+
_check_inputs_shape(self.y, self.v, "y", "v", row=True, column=True)
99+
_check_inputs_shape(self.y, self.n, "y", "n", row=True, column=True)
100+
97101
def _get_predictors(self, X, names, add_intercept):
98102
if X is None and not add_intercept:
99103
raise ValueError(

pymare/tests/test_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,40 @@
11
"""Tests for pymare.utils."""
22
import os.path as op
33

4+
import numpy as np
5+
import pytest
6+
47
from pymare import utils
58

69

710
def test_get_resource_path():
811
"""Test nimare.utils.get_resource_path."""
912
print(utils.get_resource_path())
1013
assert op.isdir(utils.get_resource_path())
14+
15+
16+
def test_check_inputs_shape():
17+
"""Test nimare.utils._check_inputs_shape."""
18+
n_rows = 5
19+
n_columns = 4
20+
n_pred = 3
21+
y = np.random.randint(1, 100, size=(n_rows, n_columns))
22+
v = np.random.randint(1, 100, size=(n_rows + 1, n_columns))
23+
n = np.random.randint(1, 100, size=(n_rows, n_columns))
24+
X = np.random.randint(1, 100, size=(n_rows, n_pred))
25+
X_names = [f"X{x}" for x in range(n_pred)]
26+
27+
utils._check_inputs_shape(y, X, "y", "X", row=True)
28+
utils._check_inputs_shape(y, n, "y", "n", row=True, column=True)
29+
utils._check_inputs_shape(X, np.array(X_names)[None, :], "X", "X_names", column=True)
30+
31+
# Raise error if the number of rows and columns of v don't match y
32+
with pytest.raises(ValueError):
33+
utils._check_inputs_shape(y, v, "y", "v", row=True, column=True)
34+
35+
# Raise error if neither row or column is True
36+
with pytest.raises(ValueError):
37+
utils._check_inputs_shape(y, n, "y", "n")
38+
39+
# Dataset may be initialized with n or v as None
40+
utils._check_inputs_shape(y, None, "y", "n", row=True, column=True)

pymare/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,39 @@ def _listify(obj):
1919
This provides a simple way to accept flexible arguments.
2020
"""
2121
return obj if isinstance(obj, (list, tuple, type(None), np.ndarray)) else [obj]
22+
23+
24+
def _check_inputs_shape(param1, param2, param1_name, param2_name, row=False, column=False):
25+
"""Check whether 'param1' and 'param2' have the same shape.
26+
27+
Parameters
28+
----------
29+
param1 : array
30+
param2 : array
31+
param1_name : str
32+
param2_name : str
33+
row : bool, default to False.
34+
column : bool, default to False.
35+
"""
36+
if (param1 is not None) and (param2 is not None):
37+
if row and not column:
38+
shape1 = param1.shape[0]
39+
shape2 = param2.shape[0]
40+
message = "rows"
41+
elif column and not row:
42+
shape1 = param1.shape[1]
43+
shape2 = param2.shape[1]
44+
message = "columns"
45+
elif row and column:
46+
shape1 = param1.shape
47+
shape2 = param2.shape
48+
message = "rows and columns"
49+
else:
50+
raise ValueError("At least one of the two parameters (row or column) should be True.")
51+
52+
if shape1 != shape2:
53+
raise ValueError(
54+
f"{param1_name} and {param2_name} should have the same number of {message}. "
55+
f"You provided {param1_name} with shape {param1.shape} and {param2_name} "
56+
f"with shape {param2.shape}."
57+
)

0 commit comments

Comments
 (0)