Skip to content

Commit 3ae6bcb

Browse files
committed
chore: fix pandas schema flaky test
1 parent b146aab commit 3ae6bcb

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

tests/pandas/test_pandas.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def generate_pandas_dataframe(
2424
n_features: int = 1,
2525
index_name: Optional[str] = None,
2626
indexes: Optional[Union[int, List]] = None,
27+
str_values: Optional[List[str]] = None,
2728
index_position: int = 0,
2829
include_nan: bool = True,
2930
float_min: float = -10.0,
@@ -46,6 +47,7 @@ def generate_pandas_dataframe(
4647
index_name (Optional[str]): The index's name. Default to None ("index").
4748
indexes (Optional[Union[int, List]]): Custom indexes to consider. Default to None (5 rows,
4849
indexed from 1 to 5).
50+
str_values (Optional[List[str]]): The list of string values to consider. Default to None.
4951
index_position (int): The index's column position in the data-frame. Default to 0.
5052
include_nan (bool): If NaN values should be put in the data-frame. If True, they are
5153
inserted in the first row. Default to True.
@@ -75,6 +77,10 @@ def generate_pandas_dataframe(
7577
if isinstance(indexes, int):
7678
indexes = list(range(1, indexes + 1))
7779

80+
assert str_values is None or len(str_values) == len(
81+
indexes
82+
), "Parameter 'str_values' must either be None or a list of length equal to 'indexes'."
83+
7884
if index_name is None:
7985
index_name = "index"
8086

@@ -102,11 +108,16 @@ def generate_pandas_dataframe(
102108

103109
# Add a column with string values (including NaN or not)
104110
if dtype in ["str", "mixed"]:
105-
str_values = ["apple", "orange", "watermelon", "cherry", "banana"]
111+
str_values_default = ["apple", "orange", "watermelon", "cherry", "banana"]
106112

107113
for i in range(1, n_features + 1):
108114
column_name = f"{feat_name}_str_{i}"
109-
columns[column_name] = list(numpy.random.choice(str_values, size=(len(indexes),)))
115+
rand_str_values = (
116+
list(numpy.random.choice(str_values_default, size=(len(indexes),)))
117+
if str_values is None
118+
else str_values
119+
)
120+
columns[column_name] = rand_str_values
110121

111122
if include_nan:
112123
columns[column_name][0] = numpy.nan
@@ -694,8 +705,17 @@ def check_invalid_schema_values():
694705

695706
client = ClientEngine(keys_path=keys_path)
696707

708+
# Fix the string values to consider in order to avoid flaky tests (one of the checks requires to
709+
# have a data-frame with at least 2 unique string values)
710+
str_values = ["apple", "orange", "watermelon", "cherry", "banana"]
711+
697712
pandas_df = generate_pandas_dataframe(
698-
feat_name=feat_name, index_name=selected_column, float_min=float_min, float_max=float_max
713+
feat_name=feat_name,
714+
index_name=selected_column,
715+
indexes=len(str_values),
716+
str_values=str_values,
717+
float_min=float_min,
718+
float_max=float_max,
699719
)
700720

701721
schema_int_column = {f"{feat_name}_int_1": {None: None}}

0 commit comments

Comments
 (0)