@@ -121,11 +121,11 @@ def _cuda_e_step(
121121 means ,
122122 prec_chol ,
123123 log_det_half ,
124- log_prob ,
125- responsibilities ,
126- ll_per_cell ,
127- centered ,
128- e_step_y ,
124+ log_prob = log_prob ,
125+ responsibilities = responsibilities ,
126+ ll_per_cell = ll_per_cell ,
127+ centered = centered ,
128+ e_step_y = e_step_y ,
129129 e_step_route = e_step_route ,
130130 stream = cp .cuda .get_current_stream ().ptr ,
131131 handle = cp .cuda .device .get_cublas_handle (),
@@ -147,11 +147,11 @@ def _cuda_m_step(
147147 weights ,
148148 means ,
149149 covariances ,
150- reg_covar ,
151- cp .ones (X .shape [0 ], dtype = X .dtype ),
152- cp .empty (K , dtype = X .dtype ),
153- cp .empty ((K , X .shape [1 ]), dtype = X .dtype ),
154- cp .empty_like (X ),
150+ reg_covar = reg_covar ,
151+ ones = cp .ones (X .shape [0 ], dtype = X .dtype ),
152+ effective_counts = cp .empty (K , dtype = X .dtype ),
153+ weighted_sums = cp .empty ((K , X .shape [1 ]), dtype = X .dtype ),
154+ centered = cp .empty_like (X ),
155155 stream = cp .cuda .get_current_stream ().ptr ,
156156 handle = cp .cuda .device .get_cublas_handle (),
157157 )
@@ -332,9 +332,9 @@ def test_cuda_matches_reference_steps():
332332 means ,
333333 prec_chol ,
334334 log_det_half ,
335- log_prob ,
336- resp ,
337- ll_per_cell ,
335+ log_prob = log_prob ,
336+ responsibilities = resp ,
337+ ll_per_cell = ll_per_cell ,
338338 stream = cp .cuda .get_current_stream ().ptr ,
339339 )
340340
@@ -380,11 +380,11 @@ def test_cuda_512_e_step_matches_reference_for_cublas_route():
380380 means ,
381381 prec_chol ,
382382 log_det_half ,
383- centered ,
384- e_step_y ,
385- log_prob ,
386- resp ,
387- ll_per_cell ,
383+ centered = centered ,
384+ e_step_y = e_step_y ,
385+ log_prob = log_prob ,
386+ responsibilities = resp ,
387+ ll_per_cell = ll_per_cell ,
388388 stream = cp .cuda .get_current_stream ().ptr ,
389389 handle = cp .cuda .device .get_cublas_handle (),
390390 )
@@ -415,11 +415,11 @@ def test_cuda_768_e_step_uses_cublas_route():
415415 means ,
416416 prec_chol ,
417417 log_det_half ,
418- log_prob ,
419- resp ,
420- ll_per_cell ,
421- centered ,
422- e_step_y ,
418+ log_prob = log_prob ,
419+ responsibilities = resp ,
420+ ll_per_cell = ll_per_cell ,
421+ centered = centered ,
422+ e_step_y = e_step_y ,
423423 e_step_route = "cublas" ,
424424 stream = stream ,
425425 handle = handle ,
@@ -433,11 +433,11 @@ def test_cuda_768_e_step_uses_cublas_route():
433433 means ,
434434 prec_chol ,
435435 log_det_half ,
436- centered_b ,
437- e_step_y_b ,
438- log_prob_b ,
439- resp_b ,
440- ll_per_cell_b ,
436+ centered = centered_b ,
437+ e_step_y = e_step_y_b ,
438+ log_prob = log_prob_b ,
439+ responsibilities = resp_b ,
440+ ll_per_cell = ll_per_cell_b ,
441441 stream = stream ,
442442 handle = handle ,
443443 )
@@ -468,11 +468,11 @@ def test_cuda_float64_wide_e_step_uses_cublas_route():
468468 means ,
469469 prec_chol ,
470470 log_det_half ,
471- log_prob ,
472- resp ,
473- ll_per_cell ,
474- centered ,
475- e_step_y ,
471+ log_prob = log_prob ,
472+ responsibilities = resp ,
473+ ll_per_cell = ll_per_cell ,
474+ centered = centered ,
475+ e_step_y = e_step_y ,
476476 e_step_route = route ,
477477 stream = cp .cuda .get_current_stream ().ptr ,
478478 handle = cp .cuda .device .get_cublas_handle (),
@@ -517,11 +517,11 @@ def test_cuda_fused_e_step_matches_reference_for_50_pc_regime():
517517 means ,
518518 prec_chol ,
519519 log_det_half ,
520- log_prob ,
521- resp ,
522- ll_per_cell ,
523- centered ,
524- e_step_y ,
520+ log_prob = log_prob ,
521+ responsibilities = resp ,
522+ ll_per_cell = ll_per_cell ,
523+ centered = centered ,
524+ e_step_y = e_step_y ,
525525 e_step_route = "fused" ,
526526 stream = cp .cuda .get_current_stream ().ptr ,
527527 handle = cp .cuda .device .get_cublas_handle (),
0 commit comments