@@ -223,6 +223,42 @@ def test_mean_var(
223
223
np .testing .assert_array_almost_equal (var , var_expected ) # type: ignore[arg-type]
224
224
225
225
226
+ @pytest .mark .skipif (not find_spec ("sklearn" ), reason = "sklearn not installed" )
227
+ @pytest .mark .array_type (Flags .Sparse , skip = Flags .Matrix | Flags .Dask | Flags .Disk | Flags .Gpu )
228
+ @pytest .mark .parametrize ("axis" , [0 , 1 ])
229
+ def test_mean_var_sparse_64 (array_type : ArrayType [types .CSArray ], axis : Literal [0 , 1 ]) -> None :
230
+ """Test that we’re equivalent for 64 bit."""
231
+ from sklearn .utils .sparsefuncs import mean_variance_axis
232
+
233
+ mtx = array_type .random ((10000 , 1000 ), dtype = np .float64 )
234
+
235
+ mean_fau , var_fau = stats .mean_var (mtx , axis = axis )
236
+ mean_skl , var_skl = mean_variance_axis (mtx , axis )
237
+
238
+ np .testing .assert_allclose (mean_fau , mean_skl , rtol = 1.0e-5 , atol = 1.0e-8 )
239
+ np .testing .assert_allclose (var_fau , var_skl , rtol = 1.0e-5 , atol = 1.0e-8 )
240
+
241
+
242
+ @pytest .mark .skipif (not find_spec ("sklearn" ), reason = "sklearn not installed" )
243
+ @pytest .mark .array_type (Flags .Sparse , skip = Flags .Matrix | Flags .Dask | Flags .Disk | Flags .Gpu )
244
+ def test_mean_var_sparse_32 (array_type : ArrayType [types .CSArray ]) -> None :
245
+ """Test whether we are more accurate for 32 bit."""
246
+ from sklearn .utils .sparsefuncs import mean_variance_axis
247
+
248
+ mtx64 = array_type .random ((10000 , 1000 ), dtype = np .float64 )
249
+ mtx32 = mtx64 .astype (np .float32 )
250
+
251
+ fau , skl = {}, {}
252
+ for n_bit , mtx in [(32 , mtx32 ), (64 , mtx64 )]:
253
+ fau [n_bit ] = stats .mean_var (mtx , axis = 0 )
254
+ skl [n_bit ] = mean_variance_axis (mtx , 0 )
255
+
256
+ for stat , _ in enumerate (["mean" , "var" ]):
257
+ resid_fau = np .mean (np .abs (fau [64 ][stat ] - fau [32 ][stat ]))
258
+ resid_skl = np .mean (np .abs (skl [64 ][stat ] - skl [32 ][stat ]))
259
+ assert resid_fau < resid_skl
260
+
261
+
226
262
@pytest .mark .array_type (skip = {Flags .Disk , * ATS_CUPY_SPARSE })
227
263
@pytest .mark .parametrize (
228
264
("axis" , "expected" ),
0 commit comments