Skip to content

Commit 7b72292

Browse files
committed
Merge branch 'numba-co-occurrence' of github.com:wenjie1991/squidpy into numba-co-occurrence
2 parents c490d45 + 05dd724 commit 7b72292

1 file changed

Lines changed: 6 additions & 14 deletions

File tree

src/squidpy/gr/_ppatterns.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,10 @@ def _score_helper(
267267

268268
return score_perms
269269

270+
270271
@njit(parallel=True, fastmath=True, cache=True)
271272
def _occur_count(
272-
spatial_x: NDArrayA,
273-
spatial_y: NDArrayA,
274-
thresholds: NDArrayA,
275-
label_idx: NDArrayA,
276-
n: int,
277-
k: int,
278-
l_val: int
273+
spatial_x: NDArrayA, spatial_y: NDArrayA, thresholds: NDArrayA, label_idx: NDArrayA, n: int, k: int, l_val: int
279274
) -> NDArrayA:
280275
# Allocate a 2D array to store a flat local result per point.
281276
k2 = k * k
@@ -291,8 +286,8 @@ def _occur_count(
291286
dy = spatial_y[i] - spatial_y[j]
292287
d2 = dx * dx + dy * dy
293288

294-
pair = label_idx[i] * k + label_idx[j] # fixed in r–loop
295-
base = pair * l_val # first cell for that pair
289+
pair = label_idx[i] * k + label_idx[j] # fixed in r–loop
290+
base = pair * l_val # first cell for that pair
296291

297292
for r in range(l_val):
298293
if d2 <= thresholds[r]:
@@ -306,12 +301,9 @@ def _occur_count(
306301

307302
return cast(NDArray[np.int32], result)
308303

304+
309305
@njit(parallel=True, fastmath=True, cache=True)
310-
def _co_occurrence_helper(
311-
v_x: NDArrayA,
312-
v_y: NDArrayA,
313-
v_radium: NDArrayA,
314-
labs: NDArrayA) -> NDArrayA:
306+
def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs: NDArrayA) -> NDArrayA:
315307
"""
316308
Fast co-occurrence probability computation using the new numba-accelerated counting.
317309

0 commit comments

Comments
 (0)