@@ -24,6 +24,7 @@ def generate_pandas_dataframe(
2424 n_features : int = 1 ,
2525 index_name : Optional [str ] = None ,
2626 indexes : Optional [Union [int , List ]] = None ,
27+ str_values : Optional [List [str ]] = None ,
2728 index_position : int = 0 ,
2829 include_nan : bool = True ,
2930 float_min : float = - 10.0 ,
@@ -46,6 +47,7 @@ def generate_pandas_dataframe(
4647 index_name (Optional[str]): The index's name. Default to None ("index").
4748 indexes (Optional[Union[int, List]]): Custom indexes to consider. Default to None (5 rows,
4849 indexed from 1 to 5).
50+ str_values (Optional[List[str]]): The list of string values to consider. Default to None.
4951 index_position (int): The index's column position in the data-frame. Default to 0.
5052 include_nan (bool): If NaN values should be put in the data-frame. If True, they are
5153 inserted in the first row. Default to True.
@@ -75,6 +77,10 @@ def generate_pandas_dataframe(
7577 if isinstance (indexes , int ):
7678 indexes = list (range (1 , indexes + 1 ))
7779
80+ assert str_values is None or len (str_values ) == len (
81+ indexes
82+ ), "Parameter 'str_values' must either be None or a list of length equal to 'indexes'."
83+
7884 if index_name is None :
7985 index_name = "index"
8086
@@ -102,11 +108,16 @@ def generate_pandas_dataframe(
102108
103109 # Add a column with string values (including NaN or not)
104110 if dtype in ["str" , "mixed" ]:
105- str_values = ["apple" , "orange" , "watermelon" , "cherry" , "banana" ]
111+ str_values_default = ["apple" , "orange" , "watermelon" , "cherry" , "banana" ]
106112
107113 for i in range (1 , n_features + 1 ):
108114 column_name = f"{ feat_name } _str_{ i } "
109- columns [column_name ] = list (numpy .random .choice (str_values , size = (len (indexes ),)))
115+ rand_str_values = (
116+ list (numpy .random .choice (str_values_default , size = (len (indexes ),)))
117+ if str_values is None
118+ else str_values
119+ )
120+ columns [column_name ] = rand_str_values
110121
111122 if include_nan :
112123 columns [column_name ][0 ] = numpy .nan
@@ -694,8 +705,17 @@ def check_invalid_schema_values():
694705
695706 client = ClientEngine (keys_path = keys_path )
696707
708+ # Fix the string values to consider in order to avoid flaky tests (one of the checks requires to
709+ # have a data-frame with at least 2 unique string values)
710+ str_values = ["apple" , "orange" , "watermelon" , "cherry" , "banana" ]
711+
697712 pandas_df = generate_pandas_dataframe (
698- feat_name = feat_name , index_name = selected_column , float_min = float_min , float_max = float_max
713+ feat_name = feat_name ,
714+ index_name = selected_column ,
715+ indexes = len (str_values ),
716+ str_values = str_values ,
717+ float_min = float_min ,
718+ float_max = float_max ,
699719 )
700720
701721 schema_int_column = {f"{ feat_name } _int_1" : {None : None }}
0 commit comments