@@ -33,6 +33,15 @@ def setUp(self):
3333 * repeat
3434 )
3535 self .y = np .array ([0 , 1 , 2 , 3 ] * repeat )
36+ self .multi_y = np .array (
37+ [
38+ [1 , 0 , 0 , 0 ],
39+ [0 , 1 , 0 , 0 ],
40+ [0 , 0 , 1 , 1 ],
41+ [0 , 0 , 1 , 0 ],
42+ ]
43+ * repeat
44+ )
3645
3746 @classmethod
3847 def setUpClass (cls ):
@@ -62,7 +71,7 @@ def testColumnOrdering(self):
6271
6372 assert data .columns .tolist () == cols [:- 1 ]
6473
65- def _testMatrixCreation (self , in_x , in_y , ** kwargs ):
74+ def _testMatrixCreation (self , in_x , in_y , multi_label = False , ** kwargs ):
6675 if "sharding" not in kwargs :
6776 kwargs ["sharding" ] = RayShardingMode .BATCH
6877 mat = RayDMatrix (in_x , in_y , ** kwargs )
@@ -81,7 +90,10 @@ def _load_data(params):
8190 x , y = _load_data (params )
8291
8392 self .assertTrue (np .allclose (self .x , x ))
84- self .assertTrue (np .allclose (self .y , y ))
93+ if multi_label :
94+ self .assertTrue (np .allclose (self .multi_y , y ))
95+ else :
96+ self .assertTrue (np .allclose (self .y , y ))
8597
8698 # Multi actor check
8799 mat = RayDMatrix (in_x , in_y , ** kwargs )
@@ -95,7 +107,10 @@ def _load_data(params):
95107 x2 , y2 = _load_data (params )
96108
97109 self .assertTrue (np .allclose (self .x , concat_dataframes ([x1 , x2 ])))
98- self .assertTrue (np .allclose (self .y , concat_dataframes ([y1 , y2 ])))
110+ if multi_label :
111+ self .assertTrue (np .allclose (self .multi_y , concat_dataframes ([y1 , y2 ])))
112+ else :
113+ self .assertTrue (np .allclose (self .y , concat_dataframes ([y1 , y2 ])))
99114
100115 def testFromNumpy (self ):
101116 in_x = self .x
@@ -276,6 +291,22 @@ def testFromMultiCSVString(self):
276291 [data_file_1 , data_file_2 ], "label" , distributed = True
277292 )
278293
294+ def testFromParquetStringMultiLabel (self ):
295+ with tempfile .TemporaryDirectory () as dir :
296+ data_file = os .path .join (dir , "data.parquet" )
297+
298+ data_df = pd .DataFrame (self .x , columns = ["a" , "b" , "c" , "d" ])
299+ labels = [f"label_{ label } " for label in range (4 )]
300+ data_df [labels ] = self .multi_y
301+ data_df .to_parquet (data_file )
302+
303+ self ._testMatrixCreation (
304+ data_file , labels , multi_label = True , distributed = False
305+ )
306+ self ._testMatrixCreation (
307+ data_file , labels , multi_label = True , distributed = True
308+ )
309+
279310 def testFromParquetString (self ):
280311 with tempfile .TemporaryDirectory () as dir :
281312 data_file = os .path .join (dir , "data.parquet" )
@@ -287,6 +318,28 @@ def testFromParquetString(self):
287318 self ._testMatrixCreation (data_file , "label" , distributed = False )
288319 self ._testMatrixCreation (data_file , "label" , distributed = True )
289320
321+ def testFromMultiParquetStringMultiLabel (self ):
322+ with tempfile .TemporaryDirectory () as dir :
323+ data_file_1 = os .path .join (dir , "data_1.parquet" )
324+ data_file_2 = os .path .join (dir , "data_2.parquet" )
325+
326+ data_df = pd .DataFrame (self .x , columns = ["a" , "b" , "c" , "d" ])
327+ labels = [f"label_{ label } " for label in range (4 )]
328+ data_df [labels ] = self .multi_y
329+
330+ df_1 = data_df [0 : len (data_df ) // 2 ]
331+ df_2 = data_df [len (data_df ) // 2 :]
332+
333+ df_1 .to_parquet (data_file_1 )
334+ df_2 .to_parquet (data_file_2 )
335+
336+ self ._testMatrixCreation (
337+ [data_file_1 , data_file_2 ], labels , multi_label = True , distributed = False
338+ )
339+ self ._testMatrixCreation (
340+ [data_file_1 , data_file_2 ], labels , multi_label = True , distributed = True
341+ )
342+
290343 def testFromMultiParquetString (self ):
291344 with tempfile .TemporaryDirectory () as dir :
292345 data_file_1 = os .path .join (dir , "data_1.parquet" )
0 commit comments