@@ -91,26 +91,26 @@ def __aggregate_dask(X_part, mask_part, groupby_part):
9191 X_part .indptr .data .ptr ,
9292 X_part .indices .data .ptr ,
9393 X_part .data .data .ptr ,
94- out .data .ptr ,
95- gb .data .ptr ,
96- mk .data .ptr ,
97- X_part .shape [0 ],
98- X_part .shape [1 ],
99- int ( n_groups ) ,
100- bool ( 0 ) ,
101- int ( X_part .data .dtype .itemsize ) ,
94+ out = out .data .ptr ,
95+ cats = gb .data .ptr ,
96+ mask = mk .data .ptr ,
97+ n_cells = X_part .shape [0 ],
98+ n_genes = X_part .shape [1 ],
99+ n_groups = n_groups ,
100+ is_csc = False ,
101+ dtype_itemsize = X_part .data .dtype .itemsize ,
102102 )
103103 else :
104104 _aggr_cuda .dense_aggr (
105- int ( X_part .data .ptr ) ,
106- int ( out .data .ptr ) ,
107- int ( gb .data .ptr ) ,
108- int ( mk .data .ptr ) ,
109- int ( X_part .shape [0 ]) ,
110- int ( X_part .shape [1 ]) ,
111- int ( n_groups ) ,
112- bool ( 0 if X_part .flags .c_contiguous else 1 ) ,
113- int ( X_part .dtype .itemsize ) ,
105+ X_part .data .ptr ,
106+ out = out .data .ptr ,
107+ cats = gb .data .ptr ,
108+ mask = mk .data .ptr ,
109+ n_cells = X_part .shape [0 ],
110+ n_genes = X_part .shape [1 ],
111+ n_groups = n_groups ,
112+ is_fortran = X_part .flags .f_contiguous ,
113+ dtype_itemsize = X_part .dtype .itemsize ,
114114 )
115115 return out
116116
@@ -170,14 +170,14 @@ def count_mean_var_sparse(self, dof: int = 1):
170170 self .data .indptr .data .ptr ,
171171 self .data .indices .data .ptr ,
172172 self .data .data .data .ptr ,
173- out .data .ptr ,
174- self .groupby .data .ptr ,
175- mask .data .ptr ,
176- int (self .data .shape [0 ]),
177- int (self .data .shape [1 ]),
178- int (self .n_cells .shape [0 ]),
179- self .data .format == "csc" ,
180- int (self .data .data .dtype .itemsize ),
173+ out = out .data .ptr ,
174+ cats = self .groupby .data .ptr ,
175+ mask = mask .data .ptr ,
176+ n_cells = int (self .data .shape [0 ]),
177+ n_genes = int (self .data .shape [1 ]),
178+ n_groups = int (self .n_cells .shape [0 ]),
179+ is_csc = self .data .format == "csc" ,
180+ dtype_itemsize = int (self .data .data .dtype .itemsize ),
181181 )
182182 sums , counts , sq_sums = out [0 , :], out [1 , :], out [2 , :]
183183 sums = sums .reshape (self .n_cells .shape [0 ], self .data .shape [1 ])
@@ -211,13 +211,13 @@ def count_mean_var_sparse_sparse(self, funcs, dof: int = 1):
211211 self .data .indptr .data .ptr ,
212212 self .data .indices .data .ptr ,
213213 self .data .data .data .ptr ,
214- src_row .data .ptr ,
215- src_col .data .ptr ,
216- src_data .data .ptr ,
217- self .groupby .data .ptr ,
218- mask .data .ptr ,
219- int ( self .data .shape [0 ]) ,
220- int ( self .data .data .dtype .itemsize ) ,
214+ out_row = src_row .data .ptr ,
215+ out_col = src_col .data .ptr ,
216+ out_data = src_data .data .ptr ,
217+ cats = self .groupby .data .ptr ,
218+ mask = mask .data .ptr ,
219+ n_cells = self .data .shape [0 ],
220+ dtype_itemsize = self .data .data .dtype .itemsize ,
221221 )
222222
223223 keys = cp .stack ([src_col , src_row ])
@@ -309,10 +309,10 @@ def count_mean_var_sparse_sparse(self, funcs, dof: int = 1):
309309 var .indptr .data .ptr ,
310310 var .indices .data .ptr ,
311311 var .data .data .ptr ,
312- means .data .ptr ,
313- self .n_cells .data .ptr ,
314- int (dof ),
315- int ( var .shape [0 ]) ,
312+ means = means .data .ptr ,
313+ n_cells = self .n_cells .data .ptr ,
314+ dof = int (dof ),
315+ n_groups = var .shape [0 ],
316316 )
317317 results ["var" ] = var
318318 if "count_nonzero" in funcs :
@@ -347,14 +347,14 @@ def count_mean_var_dense(self, dof: int = 1):
347347
348348 _aggr_cuda .dense_aggr (
349349 self .data .data .ptr ,
350- out .data .ptr ,
351- self .groupby .data .ptr ,
352- mask .data .ptr ,
353- self .data .shape [0 ],
354- int ( self .data .shape [1 ]) ,
355- int ( self .n_cells .shape [0 ]) ,
356- bool ( 0 if self .data .flags .c_contiguous else 1 ) ,
357- int ( self .data .dtype .itemsize ) ,
350+ out = out .data .ptr ,
351+ cats = self .groupby .data .ptr ,
352+ mask = mask .data .ptr ,
353+ n_cells = self .data .shape [0 ],
354+ n_genes = self .data .shape [1 ],
355+ n_groups = self .n_cells .shape [0 ],
356+ is_fortran = self .data .flags .f_contiguous ,
357+ dtype_itemsize = self .data .dtype .itemsize ,
358358 )
359359 sums , counts , sq_sums = out [0 ], out [1 ], out [2 ]
360360 sums = sums .reshape (self .n_cells .shape [0 ], self .data .shape [1 ])
0 commit comments