Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/squidpy/experimental/im/_stain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from squidpy.experimental.im._stain._constants import (
RUDERMAN_LAB_TO_LMS,
RUDERMAN_LMS_TO_LAB,
RUDERMAN_LMS_TO_RGB,
RUDERMAN_RGB_TO_LMS,
RUIFROK_HE,
SDA_SCALE,
)
from squidpy.experimental.im._stain._conversion import (
lab_ruderman_to_rgb,
rgb_to_lab_ruderman,
rgb_to_sda,
sda_to_rgb,
)
from squidpy.experimental.im._stain._reference import StainMethod, StainReference

__all__ = [
"RUDERMAN_LAB_TO_LMS",
"RUDERMAN_LMS_TO_LAB",
"RUDERMAN_LMS_TO_RGB",
"RUDERMAN_RGB_TO_LMS",
"RUIFROK_HE",
"SDA_SCALE",
"StainMethod",
"StainReference",
"lab_ruderman_to_rgb",
"rgb_to_lab_ruderman",
"rgb_to_sda",
"sda_to_rgb",
]
51 changes: 51 additions & 0 deletions src/squidpy/experimental/im/_stain/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Canonical stain vectors, color-space matrices, and module-wide defaults."""

from __future__ import annotations

import numpy as np
from skimage.color import rgb_from_hed


def _unit(v: np.ndarray) -> np.ndarray:
a = np.asarray(v, dtype=np.float64)
return a / np.linalg.norm(a)


# Ruifrok and Johnston (2001) canonical stain vectors, unit-normalised.
# Sourced from `skimage.color.rgb_from_hed`; skimage stores them
# un-normalised so the `_unit` step is load-bearing.
RUIFROK_HE: dict[str, np.ndarray] = {
"hematoxylin": _unit(rgb_from_hed[0]),
"eosin": _unit(rgb_from_hed[1]),
"dab": _unit(rgb_from_hed[2]),
}

# HistomicsTK-compatible SDA scale so luminosity thresholds taken from
# the H&E literature transfer directly.
SDA_SCALE: float = 255.0 / np.log(256.0)

# Ruderman, Cronin, Chiao (1998) decorrelated colour space, as used by
# Reinhard et al. (2001). NOT equivalent to CIE Lab and NOT interchangeable
# with `skimage.color.rgb2lab`.
RUDERMAN_RGB_TO_LMS: np.ndarray = np.array(
[
[0.3811, 0.5783, 0.0402],
[0.1967, 0.7244, 0.0782],
[0.0241, 0.1288, 0.8444],
],
dtype=np.float64,
)
RUDERMAN_LMS_TO_RGB: np.ndarray = np.linalg.inv(RUDERMAN_RGB_TO_LMS)

# Reinhard 2001 eq. 4.
_DIAG = np.diag([1.0 / np.sqrt(3.0), 1.0 / np.sqrt(6.0), 1.0 / np.sqrt(2.0)])
_MIX = np.array(
[
[1.0, 1.0, 1.0],
[1.0, 1.0, -2.0],
[1.0, -1.0, 0.0],
],
dtype=np.float64,
)
RUDERMAN_LMS_TO_LAB: np.ndarray = _DIAG @ _MIX
RUDERMAN_LAB_TO_LMS: np.ndarray = np.linalg.inv(RUDERMAN_LMS_TO_LAB)
172 changes: 172 additions & 0 deletions src/squidpy/experimental/im/_stain/_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""RGB <-> optical density (SDA) and RGB <-> Ruderman Lab conversions.

All functions operate on :class:`xarray.DataArray` with a channel dimension
named ``"c"`` of length 3. Numpy-backed and dask-backed inputs are both
supported transparently; nothing here forces materialisation of large arrays.
Each public function compiles down to a single :func:`xarray.apply_ufunc`
call so that dask schedules one task per chunk regardless of how many
elementwise and matrix steps the transform contains.
"""

from __future__ import annotations

import numpy as np
import xarray as xr

from squidpy.experimental.im._stain._constants import (
RUDERMAN_LAB_TO_LMS,
RUDERMAN_LMS_TO_LAB,
RUDERMAN_LMS_TO_RGB,
RUDERMAN_RGB_TO_LMS,
SDA_SCALE,
)

_CHANNEL_DIM = "c"


def _check_channel_dim(arr: xr.DataArray) -> None:
if _CHANNEL_DIM not in arr.dims:
raise ValueError(f"Input must have a dimension named {_CHANNEL_DIM!r}; got dims {arr.dims}.")
if arr.sizes[_CHANNEL_DIM] != 3:
raise ValueError(f"Channel dimension {_CHANNEL_DIM!r} must have length 3; got {arr.sizes[_CHANNEL_DIM]}.")


def _working_dtype(arr: xr.DataArray) -> np.dtype:
# Integer/uint inputs are promoted to float32 to keep dask graphs cheap
# on whole-slide images; already-float inputs preserve caller dtype.
return arr.dtype if np.issubdtype(arr.dtype, np.floating) else np.dtype(np.float32)


def _apply_along_channel(
arr: xr.DataArray,
kernel,
*,
out_dtype: np.dtype,
**kwargs,
) -> xr.DataArray:
"""Run a per-chunk kernel that consumes and returns the channel axis.

``apply_ufunc`` moves the ``c`` core dim to the end of the output; we
transpose back to the caller's original dim order so downstream
consumers see a stable layout.
"""
original_dims = arr.dims
out = xr.apply_ufunc(
kernel,
arr,
input_core_dims=[[_CHANNEL_DIM]],
output_core_dims=[[_CHANNEL_DIM]],
kwargs=kwargs,
dask="parallelized",
output_dtypes=[out_dtype],
)
return out.transpose(*original_dims)


def _rgb_to_sda_kernel(x: np.ndarray, *, bg: np.ndarray, dtype: np.dtype) -> np.ndarray:
x = x.astype(dtype, copy=False)
return (-np.log((x + 1.0) / (bg + 1.0)) * SDA_SCALE).astype(dtype, copy=False)


def _sda_to_rgb_kernel(x: np.ndarray, *, bg: np.ndarray, dtype: np.dtype) -> np.ndarray:
rgb = (bg + 1.0) * np.exp(-x.astype(dtype, copy=False) / SDA_SCALE) - 1.0
np.clip(rgb, 0.0, 255.0, out=rgb)
return rgb.astype(dtype, copy=False)


def _rgb_to_lab_kernel(x: np.ndarray, *, dtype: np.dtype) -> np.ndarray:
x = x.astype(dtype, copy=False)
lms = x @ RUDERMAN_RGB_TO_LMS.T.astype(dtype, copy=False)
np.log(lms + 1.0, out=lms)
return (lms @ RUDERMAN_LMS_TO_LAB.T.astype(dtype, copy=False)).astype(dtype, copy=False)


def _lab_to_rgb_kernel(x: np.ndarray, *, dtype: np.dtype) -> np.ndarray:
x = x.astype(dtype, copy=False)
log_lms = x @ RUDERMAN_LAB_TO_LMS.T.astype(dtype, copy=False)
# The +1.0 / -1.0 pair is paired with the matching offset in
# `_rgb_to_lab_kernel` so the round trip remains exact for valid RGB.
lms = np.exp(log_lms) - 1.0
rgb = lms @ RUDERMAN_LMS_TO_RGB.T.astype(dtype, copy=False)
np.clip(rgb, 0.0, 255.0, out=rgb)
return rgb.astype(dtype, copy=False)


def rgb_to_sda(
rgb: xr.DataArray,
background_intensity: np.ndarray,
) -> xr.DataArray:
"""Convert RGB intensities to standard deviation per absorbance (SDA).

Equivalent to optical density with a per-channel background ``I_0``::

sda = -log((rgb + 1) / (I_0 + 1)) * SDA_SCALE

The matched ``+1`` terms avoid ``log(0)`` at fully saturated black
pixels and guarantee that pixels at the white point map exactly to
zero. Scaling matches the HistomicsTK convention so that luminosity
thresholds from the published H&E literature transfer directly.

Parameters
----------
rgb
Image with a ``"c"`` dimension of length 3. May be numpy- or
dask-backed; the operation is purely elementwise and stays lazy.
background_intensity
Per-channel white-point ``I_0`` as a shape-``(3,)`` numpy array.
Required: no scanner produces a pure-white background, so the
caller must supply either an estimate (PR 3 will ship the
estimator) or, knowingly, an explicit
``np.array([255., 255., 255.])``.

Returns
-------
SDA-space DataArray, float dtype. Lazy if and only if ``rgb`` was lazy.
"""
_check_channel_dim(rgb)
dtype = _working_dtype(rgb)
bg = np.asarray(background_intensity, dtype=dtype)
return _apply_along_channel(rgb, _rgb_to_sda_kernel, out_dtype=dtype, bg=bg, dtype=dtype)


def sda_to_rgb(
sda: xr.DataArray,
background_intensity: np.ndarray,
) -> xr.DataArray:
"""Convert SDA back to RGB intensities in ``[0, 255]``.

Inverse of :func:`rgb_to_sda`. Pass the same ``background_intensity``
used at encode time. The result is clipped to ``[0, 255]`` but kept in
float dtype; uint8 conversion is the caller's choice.
"""
_check_channel_dim(sda)
dtype = _working_dtype(sda)
bg = np.asarray(background_intensity, dtype=dtype)
return _apply_along_channel(sda, _sda_to_rgb_kernel, out_dtype=dtype, bg=bg, dtype=dtype)


def rgb_to_lab_ruderman(rgb: xr.DataArray) -> xr.DataArray:
"""Convert RGB to Ruderman et al. (1998) decorrelated Lab space.

This is the Lab variant used by Reinhard et al. (2001) for colour
transfer, not CIE Lab. Results differ from
:func:`skimage.color.rgb2lab`.

The pipeline is RGB -> LMS via :data:`RUDERMAN_RGB_TO_LMS`, then
``log(LMS + 1)``, then LMS -> Lab via :data:`RUDERMAN_LMS_TO_LAB`. All
steps fuse into a single per-chunk numpy kernel.
"""
_check_channel_dim(rgb)
dtype = _working_dtype(rgb)
return _apply_along_channel(rgb, _rgb_to_lab_kernel, out_dtype=dtype, dtype=dtype)


def lab_ruderman_to_rgb(lab: xr.DataArray) -> xr.DataArray:
"""Inverse of :func:`rgb_to_lab_ruderman`.

Returns RGB clipped to ``[0, 255]`` in float dtype; uint8 conversion is
the caller's choice.
"""
_check_channel_dim(lab)
dtype = _working_dtype(lab)
return _apply_along_channel(lab, _lab_to_rgb_kernel, out_dtype=dtype, dtype=dtype)
95 changes: 95 additions & 0 deletions src/squidpy/experimental/im/_stain/_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Slim container for a fitted stain reference.

Holds either a 3x3 stain matrix (Macenko/Vahadane, ships in PR 3) or a
pair of Ruderman Lab channel statistics (Reinhard, ships in PR 2). The
dataclass is intentionally minimal in this PR; cohort fields, persistence,
and provenance metadata land alongside their first consumers.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Literal

import numpy as np

StainMethod = Literal["macenko", "vahadane", "reinhard"]
_DECOMPOSITION_METHODS: frozenset[str] = frozenset({"macenko", "vahadane"})
_VALID_METHODS: frozenset[str] = _DECOMPOSITION_METHODS | {"reinhard"}


def _coerce_finite(arr: Any, *, shape: tuple[int, ...], name: str) -> np.ndarray:
out = np.asarray(arr, dtype=np.float64)
if out.shape != shape:
raise ValueError(f"{name} must have shape {shape}; got {out.shape}.")
if not np.all(np.isfinite(out)):
raise ValueError(f"{name} contains non-finite values.")
return out


@dataclass(frozen=True)
class StainReference:
"""Container for a fitted stain reference.

Parameters
----------
method
Fitting method: ``"macenko"``, ``"vahadane"``, or ``"reinhard"``.
stain_matrix
Shape ``(3, 3)`` unit-norm matrix in canonical order
``(H, E, complement)``. Required for decomposition methods.
mu
Shape ``(3,)`` Ruderman Lab channel means. Reinhard only.
sigma
Shape ``(3,)`` Ruderman Lab channel standard deviations. Reinhard
only.
background_intensity
Shape ``(3,)`` per-channel white-point estimate. Required for
decomposition methods (apply consumes it). Forbidden for Reinhard
because Reinhard's color transfer operates in Ruderman Lab and
does not model absorbance. There is no universal default; pass an
estimate from your data (PR 3 ships the estimator).
"""

method: StainMethod
stain_matrix: np.ndarray | None = None
mu: np.ndarray | None = None
sigma: np.ndarray | None = None
background_intensity: np.ndarray | None = None

def __post_init__(self) -> None:
if self.method not in _VALID_METHODS:
raise ValueError(f"Unknown method {self.method!r}; expected one of {sorted(_VALID_METHODS)}.")

if self.method in _DECOMPOSITION_METHODS:
if self.stain_matrix is None:
raise ValueError(f"method={self.method!r} requires stain_matrix.")
if self.mu is not None or self.sigma is not None:
raise ValueError(f"method={self.method!r} forbids mu/sigma; pass them only for Reinhard.")
if self.background_intensity is None:
raise ValueError(f"method={self.method!r} requires background_intensity.")
object.__setattr__(
self,
"stain_matrix",
_coerce_finite(self.stain_matrix, shape=(3, 3), name="stain_matrix"),
)
bg = _coerce_finite(self.background_intensity, shape=(3,), name="background_intensity")
if np.any(bg <= 0):
raise ValueError("background_intensity must be strictly positive.")
object.__setattr__(self, "background_intensity", bg)
else:
if self.mu is None or self.sigma is None:
raise ValueError("method='reinhard' requires both mu and sigma.")
if self.stain_matrix is not None:
raise ValueError("method='reinhard' forbids stain_matrix.")
if self.background_intensity is not None:
raise ValueError(
"method='reinhard' forbids background_intensity; Reinhard's color "
"transfer is in Ruderman Lab and does not use a white point."
)
mu = _coerce_finite(self.mu, shape=(3,), name="mu")
sigma = _coerce_finite(self.sigma, shape=(3,), name="sigma")
if np.any(sigma <= 0):
raise ValueError("sigma must be strictly positive.")
object.__setattr__(self, "mu", mu)
object.__setattr__(self, "sigma", sigma)
32 changes: 32 additions & 0 deletions tests/experimental/test_stain_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

import numpy as np
import pytest

from squidpy.experimental.im._stain._constants import (
RUDERMAN_LAB_TO_LMS,
RUDERMAN_LMS_TO_LAB,
RUDERMAN_LMS_TO_RGB,
RUDERMAN_RGB_TO_LMS,
RUIFROK_HE,
SDA_SCALE,
)


@pytest.mark.parametrize("stain", ["hematoxylin", "eosin", "dab"])
def test_ruifrok_unit_norm(stain: str) -> None:
v = RUIFROK_HE[stain]
assert v.shape == (3,)
np.testing.assert_allclose(np.linalg.norm(v), 1.0, atol=1e-12)


def test_rgb_lms_round_trip() -> None:
np.testing.assert_allclose(RUDERMAN_RGB_TO_LMS @ RUDERMAN_LMS_TO_RGB, np.eye(3), atol=1e-10)


def test_lms_lab_round_trip() -> None:
np.testing.assert_allclose(RUDERMAN_LMS_TO_LAB @ RUDERMAN_LAB_TO_LMS, np.eye(3), atol=1e-10)


def test_sda_scale_positive() -> None:
assert SDA_SCALE > 0.0
Loading
Loading