-
Notifications
You must be signed in to change notification settings - Fork 113
chore: fix spatialdata access in squidpy via a helper function #1134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 18 commits
5f61dd5
40dad3e
c81969f
7eef650
8a7cc98
365dad7
2751464
cde62d5
137f2c7
b358304
da420f8
1608430
8c7c9e3
dbdc166
6039e68
f53ae10
53ba958
466124d
6e60795
df6c923
d2439f7
d185f8f
67f34fc
0bdbf8d
d409c76
d99bbdc
c0d696f
cc67e44
0518d51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,7 @@ | |
| _assert_categorical_obs, | ||
| _assert_spatial_basis, | ||
| _save_data, | ||
| extract_adata, | ||
| ) | ||
|
|
||
| __all__ = ["spatial_neighbors"] | ||
|
|
@@ -142,54 +143,14 @@ def spatial_neighbors( | |
| - :attr:`anndata.AnnData.obsp` ``['{{key_added}}_distances']`` - the spatial distances. | ||
| - :attr:`anndata.AnnData.uns` ``['{{key_added}}']`` - :class:`dict` containing parameters. | ||
| """ | ||
| if isinstance(adata, SpatialData): | ||
| assert elements_to_coordinate_systems is not None, ( | ||
| "Since `adata` is a :class:`spatialdata.SpatialData`, `elements_to_coordinate_systems` must not be `None`." | ||
| ) | ||
| assert table_key is not None, ( | ||
| "Since `adata` is a :class:`spatialdata.SpatialData`, `table_key` must not be `None`." | ||
| ) | ||
| elements, table = match_element_to_table(adata, list(elements_to_coordinate_systems), table_key) | ||
| assert table.obs_names.equals(adata.tables[table_key].obs_names), ( | ||
| "The spatialdata table must annotate all elements keys. Some elements are missing, please check the `elements_to_coordinate_systems` dictionary." | ||
| ) | ||
| regions, region_key, instance_key = get_table_keys(adata.tables[table_key]) | ||
| regions = [regions] if isinstance(regions, str) else regions | ||
| ordered_regions_in_table = adata.tables[table_key].obs[region_key].unique() | ||
|
|
||
| # TODO: remove this after https://github.com/scverse/spatialdata/issues/614 | ||
| remove_centroids = {} | ||
| elem_instances = [] | ||
| for e in regions: | ||
| schema = get_model(elements[e]) | ||
| element_instances = get_element_instances(elements[e]).to_series() | ||
| if np.isin(0, element_instances.values) and (schema in (Labels2DModel, Labels3DModel)): | ||
| element_instances = element_instances.drop(index=0) | ||
| remove_centroids[e] = True | ||
| else: | ||
| remove_centroids[e] = False | ||
| elem_instances.append(element_instances) | ||
|
|
||
| element_instances = pd.concat(elem_instances) | ||
| if (not np.all(element_instances.values == adata.tables[table_key].obs[instance_key].values)) or ( | ||
| not np.all(ordered_regions_in_table == regions) | ||
| ): | ||
| raise ValueError( | ||
| "The spatialdata table must annotate all elements keys. Some elements are missing or not ordered correctly, please check the `elements_to_coordinate_systems` dictionary." | ||
| ) | ||
| centroids = [] | ||
| for region_ in ordered_regions_in_table: | ||
| cs = elements_to_coordinate_systems[region_] | ||
| centroid = get_centroids(adata[region_], coordinate_system=cs)[["x", "y"]].compute() | ||
|
|
||
| # TODO: remove this after https://github.com/scverse/spatialdata/issues/614 | ||
| if remove_centroids[region_]: | ||
| centroid = centroid[1:].copy() | ||
| centroids.append(centroid) | ||
|
|
||
| adata.tables[table_key].obsm[spatial_key] = np.concatenate(centroids) | ||
| adata = adata.tables[table_key] | ||
| library_key = region_key | ||
| adata, library_key = _resolve_sdata( | ||
| adata=adata, | ||
| elements_to_coordinate_systems=elements_to_coordinate_systems, | ||
| table_key=table_key, | ||
| spatial_key=spatial_key, | ||
| library_key=library_key, | ||
| ) | ||
|
|
||
| assert_positive(n_rings, name="n_rings") | ||
| assert_positive(n_neighs, name="n_neighs") | ||
|
|
@@ -436,6 +397,62 @@ def _build_connectivity( | |
| return Adj | ||
|
|
||
|
|
||
| def _resolve_sdata( | ||
| adata: AnnData | SpatialData, | ||
| elements_to_coordinate_systems: dict[str, str] | None, | ||
| table_key: str = "table", | ||
| spatial_key: str = Key.obsm.spatial, | ||
| library_key: str | None = None, | ||
| ) -> tuple[AnnData, str | None]: | ||
| if not isinstance(adata, SpatialData): | ||
| return adata, library_key | ||
|
|
||
| assert elements_to_coordinate_systems is not None, ( | ||
| "Since `adata` is a :class:`spatialdata.SpatialData`, `elements_to_coordinate_systems` must not be `None`." | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "adata is sdata" is weird? Maybe it should be
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes but the error message would still have to be the same since the public signature takes adata, otherwise the error message would be confusing
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is done in many places lets make an issue about it and move on for here. Ideally we would call it data since it's positional argument anyway |
||
| ) | ||
| table = extract_adata(adata, table_key=table_key) | ||
| elements, matched_table = match_element_to_table(adata, list(elements_to_coordinate_systems), table_key) | ||
| assert matched_table.obs_names.equals(table.obs_names), ( | ||
| "The spatialdata table must annotate all elements keys. Some elements are missing, please check the `elements_to_coordinate_systems` dictionary." | ||
| ) | ||
| regions, region_key, instance_key = get_table_keys(table) | ||
| regions = [regions] if isinstance(regions, str) else regions | ||
| ordered_regions_in_table = table.obs[region_key].unique() | ||
|
|
||
| # TODO: remove this after https://github.com/scverse/spatialdata/issues/614 | ||
| remove_centroids = {} | ||
| elem_instances = [] | ||
| for e in regions: | ||
| schema = get_model(elements[e]) | ||
| element_instances = get_element_instances(elements[e]).to_series() | ||
| if np.isin(0, element_instances.values) and (schema in (Labels2DModel, Labels3DModel)): | ||
| element_instances = element_instances.drop(index=0) | ||
| remove_centroids[e] = True | ||
| else: | ||
| remove_centroids[e] = False | ||
| elem_instances.append(element_instances) | ||
|
|
||
| element_instances = pd.concat(elem_instances) | ||
| if (not np.all(element_instances.values == table.obs[instance_key].values)) or ( | ||
| not np.all(ordered_regions_in_table == regions) | ||
| ): | ||
| raise ValueError( | ||
| "The spatialdata table must annotate all elements keys. Some elements are missing or not ordered correctly, please check the `elements_to_coordinate_systems` dictionary." | ||
| ) | ||
| centroids = [] | ||
| for region_ in ordered_regions_in_table: | ||
| cs = elements_to_coordinate_systems[region_] | ||
| centroid = get_centroids(adata[region_], coordinate_system=cs)[["x", "y"]].compute() | ||
|
|
||
| # TODO: remove this after https://github.com/scverse/spatialdata/issues/614 | ||
| if remove_centroids[region_]: | ||
| centroid = centroid[1:].copy() | ||
| centroids.append(centroid) | ||
|
|
||
| table.obsm[spatial_key] = np.concatenate(centroids) | ||
|
selmanozleyen marked this conversation as resolved.
|
||
| return table, region_key | ||
|
|
||
|
|
||
| @njit | ||
| def _csr_bilateral_diag_scale_helper( | ||
| mat: csr_array | csr_matrix, | ||
|
|
@@ -559,8 +576,7 @@ def mask_graph( | |
| if not isinstance(polygon_mask, Polygon | MultiPolygon): | ||
| raise ValueError(f"`polygon_mask` should be of type `Polygon` or `MultiPolygon`, got {type(polygon_mask)}") | ||
|
|
||
| # get elements | ||
| table = sdata.tables[table_key] | ||
| table = extract_adata(sdata, table_key=table_key) | ||
| coords = table.obsm[spatial_key] | ||
| Adj = table.obsp[conns_key] | ||
| Dst = table.obsp[dists_key] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.