From e3fdbb31c0c1d6a5bd8786f44d7e0b3c0335180d Mon Sep 17 00:00:00 2001 From: Camille Pescatore Date: Thu, 7 Aug 2025 13:57:55 +0200 Subject: [PATCH] Implemented stacked_sensitivity for adaptive number of wavelengths --- src/cedalion/imagereco/forward_model.py | 33 ++++++++++++++----------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/cedalion/imagereco/forward_model.py b/src/cedalion/imagereco/forward_model.py index 7893e60..94a07fa 100644 --- a/src/cedalion/imagereco/forward_model.py +++ b/src/cedalion/imagereco/forward_model.py @@ -1094,9 +1094,9 @@ def compute_sensitivity( # FIXME: better name for Adot * ext. coeffs - # FIXME: hardcoded for 2 chromophores (HbO and HbR) and wavelengths + # FIXME: hardcoded for 2 chromophores (HbO and HbR) @staticmethod - def compute_stacked_sensitivity(sensitivity: xr.DataArray): + def compute_stacked_sensitivity(sensitivity: xr.DataArray, spectrum: str = "prahl"): """Compute stacked HbO and HbR sensitivity matrices from fluence. Args: @@ -1109,14 +1109,13 @@ def compute_stacked_sensitivity(sensitivity: xr.DataArray): assert "wavelength" in sensitivity.dims wavelengths = sensitivity.wavelength.values - assert len(wavelengths) == 2 if "units" in sensitivity.attrs: units_sens = pint.Unit(sensitivity.attrs["units"]) else: units_sens = pint.Unit("mm") - ec = cedalion.nirs.get_extinction_coefficients("prahl", wavelengths) + ec = cedalion.nirs.get_extinction_coefficients(spectrum, wavelengths) units_ec = ec.pint.units ec = ec.pint.dequantify() @@ -1125,25 +1124,31 @@ def compute_stacked_sensitivity(sensitivity: xr.DataArray): nchannel = sensitivity.sizes["channel"] nvertices = sensitivity.sizes["vertex"] - A = np.zeros((2 * nchannel, 2 * nvertices)) + nwavelengths = len(wavelengths) + A = np.zeros((nwavelengths * nchannel, 2 * nvertices)) - wl1, wl2 = wavelengths # fmt: off - A[:nchannel, :nvertices] = ec.sel(chromo="HbO", wavelength=wl1).values * sensitivity.sel(wavelength=wl1) # noqa: E501 - A[:nchannel, nvertices:] = ec.sel(chromo="HbR", wavelength=wl1).values * sensitivity.sel(wavelength=wl1) # noqa: E501 - A[nchannel:, :nvertices] = ec.sel(chromo="HbO", wavelength=wl2).values * sensitivity.sel(wavelength=wl2) # noqa: E501 - A[nchannel:, nvertices:] = ec.sel(chromo="HbR", wavelength=wl2).values * sensitivity.sel(wavelength=wl2) # noqa: E501 + for i, wl in enumerate(wavelengths): + for i, wl in enumerate(wavelengths): + A[i * nchannel : (i + 1) * nchannel, :nvertices] = ( + ec.sel(chromo="HbO", wavelength=wl).values + * sensitivity.sel(wavelength=wl) + ) + A[i * nchannel : (i + 1) * nchannel, nvertices:] = ( + ec.sel(chromo="HbR", wavelength=wl).values + * sensitivity.sel(wavelength=wl) + ) # fmt: on is_brain = np.hstack([sensitivity.is_brain, sensitivity.is_brain]) flat_chromo = ["HbO"] * nvertices + ["HbR"] * nvertices - flat_wavelength = [wl1] * nchannel + [wl2] * nchannel + flat_wavelength = [wl for wl in wavelengths for _ in range(nchannel)] channel = sensitivity.channel.values source = sensitivity.source.values detector = sensitivity.detector.values - flat_channel = np.hstack((channel, channel)) - flat_source = np.hstack((source, source)) - flat_detector = np.hstack((detector, detector)) + flat_channel = np.hstack([channel] * nwavelengths) + flat_source = np.hstack([source] * nwavelengths) + flat_detector = np.hstack([detector] * nwavelengths) vertex = np.hstack([np.arange(nvertices), np.arange(nvertices)]) coords = {