diff --git a/src/cedalion/imagereco/forward_model.py b/src/cedalion/imagereco/forward_model.py index 0d926b0d..a8aee6df 100644 --- a/src/cedalion/imagereco/forward_model.py +++ b/src/cedalion/imagereco/forward_model.py @@ -627,6 +627,7 @@ def __init__( assert head_model.crs == geo3d.points.crs self.head_model = head_model + self.measurement_list = measurement_list self.optode_pos = geo3d[ geo3d.type.isin([cdc.PointType.SOURCE, cdc.PointType.DETECTOR]) @@ -645,9 +646,10 @@ def __init__( self.optode_pos = self.optode_pos.pint.dequantify() self.optode_dir = self.optode_dir.pint.dequantify() - + self.tissue_properties = get_tissue_properties( - self.head_model.segmentation_masks + self.head_model.segmentation_masks, + self.measurement_list.wavelength.unique() ) self.volume = self.head_model.segmentation_masks.sum("segmentation_type") diff --git a/src/cedalion/imagereco/tissue_properties.py b/src/cedalion/imagereco/tissue_properties.py index 042571c1..59f6b84c 100644 --- a/src/cedalion/imagereco/tissue_properties.py +++ b/src/cedalion/imagereco/tissue_properties.py @@ -82,26 +82,27 @@ class TissueType(Enum): # FIXME allow for wavelength dependencies -def get_tissue_properties(segmentation_masks: xr.DataArray) -> np.ndarray: +def get_tissue_properties(segmentation_masks: xr.DataArray, wavelengths: list) -> np.ndarray: """Return tissue properties for the given segmentation mask.""" ntissues = segmentation_masks.sizes["segmentation_type"] + 1 - tissue_props = np.zeros((ntissues, 4)) - tissue_props[0, :] = [0.0, 0.0, 1.0, 1.0] # background - - for st in segmentation_masks.segmentation_type.values: - m = segmentation_masks.sel(segmentation_type=st).values - int_labels = np.unique(m[m > 0]) - if len(int_labels) == 0: - warn("Segmentation type %s is empty." % st) - continue - int_label = int_labels.item() - - if (tissue_type := TISSUE_LABELS.get(st, None)) is None: - raise ValueError(f"unknown tissue type '{st}'") - - tissue_props[int_label, 0] = TISSUE_PROPS_ABSORPTION[tissue_type] - tissue_props[int_label, 1] = TISSUE_PROPS_SCATTERING[tissue_type] - tissue_props[int_label, 2] = TISSUE_PROPS_ANISOTROPY[tissue_type] - tissue_props[int_label, 3] = TISSUE_PROPS_REFRACTION[tissue_type] + n_wavelength = len(wavelengths) + tissue_props = np.zeros((ntissues, 4, n_wavelength)) #FIXME add dimension for multiple wavelengths + + for i_wl in range(n_wavelength): + tissue_props[0, :,i_wl] = [0.0, 0.0, 1.0, 1.0] # background + + for st in segmentation_masks.segmentation_type.values: + m = segmentation_masks.sel(segmentation_type=st).values + int_label = np.unique(m[m > 0]).item() + + if (tissue_type := TISSUE_LABELS.get(st, None)) is None: + raise ValueError(f"unknown tissue type '{st}'") + + #FIXME made it so the same properties were assigned to each wavelength + tissue_props[int_label, 0, i_wl] = TISSUE_PROPS_ABSORPTION[tissue_type] + tissue_props[int_label, 1, i_wl] = TISSUE_PROPS_SCATTERING[tissue_type] + tissue_props[int_label, 2, i_wl] = TISSUE_PROPS_ANISOTROPY[tissue_type] + tissue_props[int_label, 3, i_wl] = TISSUE_PROPS_REFRACTION[tissue_type] return tissue_props +