Skip to content

Commit 1900f49

Browse files
authored
embedding density update (#464)
* embedding density * add release note * fix
1 parent 9450efa commit 1900f49

2 files changed

Lines changed: 5 additions & 6 deletions

File tree

docs/release-notes/0.13.3.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
```{rubric} Bug fixes
1414
```
15+
* Updates `tl.embedding_density` to work with `rapids-25.10` {pr}`464` {smaller}`S Dicks`
1516

1617
```{rubric} Misc
1718
```

src/rapids_singlecell/tools/_embedding_density.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,7 @@ def embedding_density(
129129
embed_x = adata.obsm[f"X_{basis}"][:, components[0]]
130130
embed_y = adata.obsm[f"X_{basis}"][:, components[1]]
131131

132-
adata.obs[density_covariate] = _calc_density(
133-
cp.array(embed_x), cp.array(embed_y), batchsize
134-
)
132+
adata.obs[density_covariate] = _calc_density(embed_x, embed_y, batchsize)
135133

136134
# Reduce diffmap components for labeling
137135
# Note: plot_scatter takes care of correcting diffmap components
@@ -152,15 +150,15 @@ def _calc_density(x: cp.ndarray, y: cp.ndarray, batchsize: int):
152150
from cuml.neighbors import KernelDensity
153151

154152
# Calculate the point density
155-
xy = cp.vstack([x, y]).T
156-
bandwidth = cp.power(xy.shape[0], (-1.0 / (xy.shape[1] + 4)))
153+
xy = np.vstack([x, y]).T
154+
bandwidth = np.power(xy.shape[0], (-1.0 / (xy.shape[1] + 4)))
157155
kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth).fit(xy)
158156
z = cp.zeros(xy.shape[0])
159157
n_batches = math.ceil(xy.shape[0] / batchsize)
160158
for batch in range(n_batches):
161159
start_idx = batch * batchsize
162160
stop_idx = min(batch * batchsize + batchsize, xy.shape[0])
163-
z[start_idx:stop_idx] = kde.score_samples(xy[start_idx:stop_idx, :])
161+
z[start_idx:stop_idx] = cp.array(kde.score_samples(xy[start_idx:stop_idx, :]))
164162
min_z = cp.min(z)
165163
max_z = cp.max(z)
166164

0 commit comments

Comments
 (0)