Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
89c7f8a
fix: handle missing package metadata for conda packages
Marius1311 Jan 15, 2026
e6f6ec2
fix: prevent stale knn state when neighbor computation fails
Marius1311 Jan 15, 2026
14d55cb
feat: add batch_size parameter to rapids backend for GPU OOM handling
Marius1311 Jan 15, 2026
a27e9ae
refactor: make batch_size backend-agnostic in Kernel.compute_neighbors
Marius1311 Jan 15, 2026
3882057
feat: allow empty prediction_postfix to use key without underscore
Marius1311 Jan 15, 2026
9394bd6
refactor: include underscore in postfix defaults
Marius1311 Jan 15, 2026
b61d650
feat: add subset parameter to plot_confusion_matrix
Marius1311 Jan 15, 2026
a4867db
fix: handle NaN values in both y_true and y_pred for confusion matrix
Marius1311 Jan 15, 2026
7114bc3
fix: handle mismatched category sets in confusion matrix
Marius1311 Jan 15, 2026
c0be4ca
fix: convert float categories to strings for confusion matrix
Marius1311 Jan 15, 2026
a918cd3
fix: add random_state to pynndescent test for cross-platform reproduc…
Marius1311 Jan 16, 2026
89eda50
ci: allow pre-release tests to fail without blocking PR
Marius1311 Jan 16, 2026
b46dcf1
docs: pin docutils<0.22 for sphinx-tabs compatibility
Marius1311 Jan 16, 2026
8089d37
test: skip pynndescent connectivity test (platform-dependent results)
Marius1311 Jan 16, 2026
af4fb35
test: use correlation-based comparison for pynndescent instead of skip
Marius1311 Jan 16, 2026
501190b
docs: emphasize hatch for testing in copilot instructions
Marius1311 Jan 16, 2026
44abfbd
test: relax pynndescent correlation threshold to 0.95
Marius1311 Jan 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/copilot-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- Avoid drafting summary documents or endless markdown files. Just summarize in chat what you did, why, and any open questions.
- Don't update Jupyter notebooks - those are managed manually.
- When running terminal commands, use `uv run` to execute commands within the project's virtual environment (e.g., `uv run python script.py`).
- **Testing: ALWAYS use `hatch test`, NEVER `uv run pytest` or standalone pytest.** Hatch manages the test matrix (Python versions, dependencies) that CI uses. See "Testing Strategy" section for details.
- Rather than making assumptions, ask for clarification when uncertain.
- **GitHub workflows**: Use GitHub CLI (`gh`) when possible. For GitHub MCP server tools, ensure Docker Desktop is running first (`open -a "Docker Desktop"`).

Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ jobs:

name: ${{ matrix.env.label }}
runs-on: ${{ matrix.os }}
# Allow pre-release tests to fail without blocking the PR (dependency compatibility issues)
continue-on-error: ${{ contains(matrix.env.name, 'pre') }}

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ optional-dependencies.dev = [
"twine>=4.0.2",
]
optional-dependencies.doc = [
"docutils>=0.8,!=0.18.*,!=0.19.*",
"docutils>=0.8,!=0.18.*,!=0.19.*,<0.22", # sphinx-tabs incompatible with 0.22, see https://github.com/executablebooks/sphinx-tabs/issues/206
"ipykernel",
"ipython",
"myst-nb>=1.1",
Expand Down
2 changes: 1 addition & 1 deletion src/cellmapper/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

_prediction_postfix = """\
prediction_postfix
Postfix to add to mapped variables to identify them as predictions."""
Postfix to append to mapped variable names (including any separator, e.g. "_pred"). Use "" for no postfix."""

_symmetrize = """\
symmetrize
Expand Down
22 changes: 19 additions & 3 deletions src/cellmapper/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import importlib
import types
from importlib.metadata import PackageNotFoundError, version

from packaging.version import parse

from . import version
from .logging import logger


class Checker:
Expand Down Expand Up @@ -42,8 +43,23 @@ def check(self) -> None:
importlib.import_module(self.name)
except ModuleNotFoundError as e:
raise RuntimeError(" ".join(filter(None, [self.vreq_hint, self.install_hint]))) from e
v = parse(version(self.package_name))
if self.vmin and v < self.vmin:

# Skip version check if no minimum version is required
if not self.vmin:
return

# Try to get version from package metadata (may fail for conda packages)
try:
v = parse(version(self.package_name))
except PackageNotFoundError:
logger.debug(
"Could not find package metadata for %s. Skipping version check. "
"This can happen with conda-installed packages.",
self.package_name,
)
return

if v < self.vmin:
raise RuntimeError(
" ".join(
[
Expand Down
45 changes: 45 additions & 0 deletions src/cellmapper/model/_knn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,48 @@ def query(self, points: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
return distances, indices


def _batched_query(
backend: "_KNNBackend",
points: np.ndarray,
k: int,
batch_size: int | None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Query a k-NN backend in batches to avoid memory issues.

Parameters
----------
backend
The k-NN backend to query.
points
Query points.
k
Number of neighbors to query.
batch_size
Number of query points per batch. If None, no batching is applied.

Returns
-------
Tuple of (distances, indices) arrays.
"""
n_points = points.shape[0]

if batch_size is None or n_points <= batch_size:
return backend.query(points, k)

all_distances = []
all_indices = []

for start in range(0, n_points, batch_size):
end = min(start + batch_size, n_points)
batch = points[start:end]
dist, idx = backend.query(batch, k)
all_distances.append(dist)
all_indices.append(idx)

return np.vstack(all_distances), np.vstack(all_indices)


class _RapidsBackend(_KNNBackend):
def __init__(
self,
Expand Down Expand Up @@ -143,6 +185,9 @@ def fit(self, data: np.ndarray) -> None:
def query(self, points: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
points_gpu = self.cp.asarray(points)
distances, indices = self._nn.kneighbors(points_gpu)
# Free GPU memory
del points_gpu
self.cp.get_default_memory_pool().free_all_blocks()
return distances, indices


Expand Down
40 changes: 25 additions & 15 deletions src/cellmapper/model/cellmapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,22 @@ def compute_neighbors(
xrep = xrep[:, :n_comps]
yrep = yrep[:, :n_comps]

self.knn = Kernel(
# Create kernel and compute neighbors. Only assign to self.knn after
# successful completion to avoid stale state if neighbor computation fails.
knn = Kernel(
np.ascontiguousarray(xrep),
None if self._is_self_mapping else np.ascontiguousarray(yrep),
is_self_mapping=self._is_self_mapping,
)
self.knn.compute_neighbors(
knn.compute_neighbors(
n_neighbors=n_neighbors,
knn_method=knn_method,
knn_dist_metric=knn_dist_metric,
only_yx=self.only_yx,
random_state=random_state,
**(neighbors_kwargs or {}),
)
self.knn = knn

@d.dedent
def compute_mapping_matrix(
Expand Down Expand Up @@ -316,7 +319,7 @@ def map_obsm(
key: str,
t: int | None = None,
diffusion_method: Literal["iterative", "spectral"] = "iterative",
prediction_postfix: str = "pred",
prediction_postfix: str = "_pred",
) -> None:
"""
Map embeddings with optional multi-step diffusion.
Expand Down Expand Up @@ -388,7 +391,7 @@ def map_obsm(
)

# Store the transferred embeddings in query.obsm with descriptive key
output_key = f"{key}_{prediction_postfix}"
output_key = f"{key}{prediction_postfix}"
self.query.obsm[output_key] = query_data
logger.info("Embeddings mapped and stored in query.obsm['%s']", output_key)

Expand Down Expand Up @@ -519,6 +522,7 @@ def map(
knn_method: Literal["sklearn", "pynndescent", "rapids"] = "sklearn",
knn_dist_metric: str = "euclidean",
only_yx: bool = False,
neighbors_kwargs: dict[str, Any] | None = None,
kernel_method: Literal[
"jaccard",
"gauss",
Expand All @@ -532,7 +536,7 @@ def map(
| None = None,
symmetrize: bool | None = None,
self_edges: bool | None = None,
prediction_postfix: str = "pred",
prediction_postfix: str = "_pred",
subset_categories: None | list[str] | str = None,
) -> "CellMapper":
"""
Expand All @@ -554,6 +558,10 @@ def map(
%(knn_method)s
%(knn_dist_metric)s
%(only_yx)s
neighbors_kwargs
Additional keyword arguments to pass to the neighbors computation method.
For rapids backend, you can pass ``batch_size`` to process queries in batches
to avoid GPU OOM errors (e.g., ``neighbors_kwargs={"batch_size": 50000}``).
%(kernel_method)s
%(symmetrize)s
%(self_edges)s
Expand All @@ -567,6 +575,7 @@ def map(
knn_method=knn_method,
knn_dist_metric=knn_dist_metric,
only_yx=only_yx,
neighbors_kwargs=neighbors_kwargs,
)
if self._mapping_operator is None:
self.compute_mapping_matrix(kernel_method=kernel_method, symmetrize=symmetrize, self_edges=self_edges)
Expand Down Expand Up @@ -662,8 +671,8 @@ def map_obs(
key: str,
t: int | None = None,
diffusion_method: Literal["iterative", "spectral"] = "iterative",
prediction_postfix: str = "pred",
confidence_postfix: str = "conf",
prediction_postfix: str = "_pred",
confidence_postfix: str = "_conf",
return_probabilities: bool = False,
subset_categories: None | list[str] | str = None,
) -> pd.DataFrame | None:
Expand Down Expand Up @@ -855,19 +864,19 @@ def _map_obs_categorical(
conf_vals = np.max(ytab, axis=1).ravel()
conf = pd.Series(conf_vals, index=self.query.obs_names)

self.query.obs[f"{key}_{prediction_postfix}"] = pred
self.query.obs[f"{key}_{confidence_postfix}"] = conf
pred_key = f"{key}{prediction_postfix}"
conf_key = f"{key}{confidence_postfix}"
self.query.obs[pred_key] = pred
self.query.obs[conf_key] = conf

# Add colors if available
if f"{key}_colors" in self.reference.uns:
color_lookup = dict(
zip(self.reference.obs[key].cat.categories, self.reference.uns[f"{key}_colors"], strict=True)
)
self.query.uns[f"{key}_{prediction_postfix}_colors"] = [
color_lookup.get(cat, "#383838") for cat in pred.cat.categories
]
self.query.uns[f"{pred_key}_colors"] = [color_lookup.get(cat, "#383838") for cat in pred.cat.categories]

logger.info("Categorical data mapped and stored in query.obs['%s'].", f"{key}_{prediction_postfix}")
logger.info("Categorical data mapped and stored in query.obs['%s'].", pred_key)

# Return probabilities as a sparse pandas DataFrame if requested (never densify)
if return_probabilities:
Expand Down Expand Up @@ -900,6 +909,7 @@ def _map_obs_numerical(
index=self.query.obs_names,
)

self.query.obs[f"{key}_{prediction_postfix}"] = pred
pred_key = f"{key}{prediction_postfix}"
self.query.obs[pred_key] = pred

logger.info("Numerical data mapped and stored in query.obs['%s'].", f"{key}_{prediction_postfix}")
logger.info("Numerical data mapped and stored in query.obs['%s'].", pred_key)
59 changes: 47 additions & 12 deletions src/cellmapper/model/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class EvaluationMixin:
"""Mixin class for evaluation-related methods for CellMapper."""

def register_external_predictions(
self, label_key: str, prediction_postfix: str = "pred", confidence_postfix: str = "conf"
self, label_key: str, prediction_postfix: str = "_pred", confidence_postfix: str = "_conf"
) -> None:
"""
Register externally computed predictions for evaluation.
Expand Down Expand Up @@ -99,8 +99,8 @@ def register_external_predictions(
- ``confidence_postfix``: Postfix for confidence column.
"""
# Verify that the expected columns exist
pred_col = f"{label_key}_{prediction_postfix}"
conf_col = f"{label_key}_{confidence_postfix}"
pred_col = f"{label_key}{prediction_postfix}"
conf_col = f"{label_key}{confidence_postfix}"

if pred_col not in self.query.obs.columns:
raise ValueError(f"Prediction column '{pred_col}' not found in query.obs")
Expand Down Expand Up @@ -163,8 +163,8 @@ def evaluate_label_transfer(

# Extract ground-truth and predicted labels
y_true = self.query.obs[label_key].dropna()
y_pred = self.query.obs.loc[y_true.index, f"{label_key}_{self.prediction_postfix}"]
confidence = self.query.obs.loc[y_true.index, f"{label_key}_{self.confidence_postfix}"]
y_pred = self.query.obs.loc[y_true.index, f"{label_key}{pred_postfix}"]
confidence = self.query.obs.loc[y_true.index, f"{label_key}{conf_postfix}"]

# Apply confidence cutoff
valid_indices = confidence >= confidence_cutoff
Expand Down Expand Up @@ -203,32 +203,67 @@ def evaluate_label_transfer(
self.label_transfer_report = pd.DataFrame(report).transpose()

def plot_confusion_matrix(
self, label_key: str, figsize=(10, 8), cmap="viridis", save: str | Path | None = None, **kwargs
self,
label_key: str,
subset: np.ndarray | pd.Series | None = None,
figsize: tuple[int, int] = (10, 8),
cmap: str = "viridis",
save: str | Path | None = None,
**kwargs,
) -> None:
"""
Plot the confusion matrix as a heatmap using sklearn's ConfusionMatrixDisplay.

Parameters
----------
label_key
Key in .obs storing ground-truth cell type annotations.
subset
Boolean mask to select a subset of cells for the confusion matrix.
Must have the same length as query.obs or be a pandas Series indexed by obs_names.
figsize
Size of the figure (width, height). Default is (10, 8).
cmap
Colormap to use for the heatmap. Default is "viridis".
label_key
Key in .obs storing ground-truth cell type annotations.
save
Path to save the figure. If None, the figure is not saved.
**kwargs
Additional keyword arguments to pass to ConfusionMatrixDisplay.
"""
if self.prediction_postfix is None or self.confidence_postfix is None:
raise ValueError("Label transfer has not been performed. Call map_obs() first.")

# Extract true and predicted labels
y_true = self.query.obs[label_key].dropna()
y_pred = self.query.obs.loc[y_true.index, f"{label_key}_pred"]
# Extract true and predicted labels, dropping NaNs from both
y_true = self.query.obs[label_key]
y_pred = self.query.obs[f"{label_key}{self.prediction_postfix}"]
valid_mask = y_true.notna() & y_pred.notna()
y_true = y_true[valid_mask]
y_pred = y_pred[valid_mask]

# Apply subset filter if provided
if subset is not None:
if isinstance(subset, pd.Series):
subset = subset.loc[y_true.index]
else:
# Assume boolean array aligned with query.obs, reindex to y_true
subset = pd.Series(subset, index=self.query.obs_names).loc[y_true.index]
y_true = y_true[subset]
y_pred = y_pred[subset]

# Get union of categories if categorical, to handle mismatched category sets
# Also convert to string to avoid sklearn interpreting float categories as continuous
labels = None
if hasattr(y_true, "cat") and hasattr(y_pred, "cat"):
all_categories = y_true.cat.categories.union(y_pred.cat.categories)
labels = [str(c) for c in sorted(all_categories)]
y_true = y_true.astype(str)
y_pred = y_pred.astype(str)

# Plot confusion matrix using sklearn's ConfusionMatrixDisplay
_, ax = plt.subplots(1, 1, figsize=figsize)
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, xticks_rotation="vertical", ax=ax, **kwargs)
ConfusionMatrixDisplay.from_predictions(
y_true, y_pred, labels=labels, cmap=cmap, xticks_rotation="vertical", ax=ax, **kwargs
)
plt.title("Confusion Matrix")

if save:
Expand Down
Loading
Loading