Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions src/cedalion/imagereco/forward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,8 @@ def _get_fluence_from_mcx(self, i_optode: int, **kwargs) -> np.ndarray:

fluence = result["flux"][:, :, :, 0] # there is only one time bin

return fluence
return fluence.astype(np.float32)


def _fluence_at_optodes(self, fluence, emitting_opt):
"""Fluence caused by one optode at the positions of all other optodes.
Expand All @@ -765,14 +766,14 @@ def _fluence_at_optodes(self, fluence, emitting_opt):
MAX_DISTANCE_IN_MM = 50
MAX_STEPS = int(np.ceil(MAX_DISTANCE_IN_MM / self.unitinmm))

result = np.zeros(n_optodes)
result = np.zeros(n_optodes, dtype=np.float32)
for i_opt in range(n_optodes):
for i_step in range(MAX_STEPS):
pos = self.optode_pos[i_opt] + i_step * self.optode_dir[i_opt]
i, j, k = np.floor(pos.values).astype(int)

if fluence[i, j, k] > 0:
result[i_opt] = fluence[i, j, k]
result[i_opt] = fluence[i, j, k].astype(np.float32)
break
else:
l_emit = self.optode_pos.label.values[emitting_opt]
Expand All @@ -781,7 +782,8 @@ def _fluence_at_optodes(self, fluence, emitting_opt):
f"fluence from {l_emit} to optode {l_rcv} "
f"is zero within {MAX_DISTANCE_IN_MM} mm."
)

assert result.dtype == np.float32
assert result.dtype == np.float32
return result

def compute_fluence_mcx(self, fluence_fname : str | Path, **kwargs):
Expand Down Expand Up @@ -845,6 +847,7 @@ def compute_fluence_mcx(self, fluence_fname : str | Path, **kwargs):
# run MCX or MCXCL
# shape: [i,j,k]
fluence = self._get_fluence_from_mcx(i_opt, **kwargs)
assert fluence.dtype == np.float32

# FIXME shortcut:
# currently tissue props are wavelength independent -> copy
Expand Down Expand Up @@ -984,11 +987,11 @@ def compute_fluence_nirfaster(self, fluence_fname : str | Path, meshingparam=Non
data.phi[:, :, :, i_opt], (1, 0, 2)
) # xyz to ijk

fluence_file.set_fluence_by_index(i_opt,i_wl, fluence)
fluence_file.set_fluence_by_index(i_opt,i_wl, fluence.astype(np.float32))

fluence_at_optodes[i_opt, :, i_wl] = amplitude_optode[:,i_opt]

fluence_file.set_fluence_at_optodes(fluence_at_optodes)
fluence_file.set_fluence_at_optodes(fluence_at_optodes.astype(np.float32))

def compute_sensitivity(
self,
Expand Down Expand Up @@ -1016,8 +1019,8 @@ def compute_sensitivity(

n_brain = self.head_model.brain.nvertices
n_scalp = self.head_model.scalp.nvertices
Adot_brain = np.zeros((n_channel, n_brain, n_wavelength))
Adot_scalp = np.zeros((n_channel, n_scalp, n_wavelength))
Adot_brain = np.zeros((n_channel, n_brain, n_wavelength), dtype=np.float32)
Adot_scalp = np.zeros((n_channel, n_scalp, n_wavelength), dtype=np.float32)

# fluence_all: (label, wavelength, i, j, k)
# fluence_at_optodes: (optode1, optode2, wavelength)
Expand Down Expand Up @@ -1045,16 +1048,17 @@ def compute_sensitivity(

Adot_brain[i_ch, :, i_wl] = (
pertubation @ self.head_model.voxel_to_vertex_brain / normfactor
)
).astype(np.float32)
Adot_scalp[i_ch, :, i_wl] = (
pertubation @ self.head_model.voxel_to_vertex_scalp / normfactor
)
).astype(np.float32)

is_brain = np.zeros((n_brain + n_scalp), dtype=bool)
is_brain[:n_brain] = True

# shape [nchannel, nvertices, nwavelength]
Adot = np.concatenate([Adot_brain, Adot_scalp], axis=1)
assert Adot.dtype == np.float32

# Adot calculated from fluence has units 1/mm^2. Multiplied with
# the voxel volume (mm^3) and the change in absorption coefficient (1/mm)
Expand Down Expand Up @@ -1089,6 +1093,8 @@ def compute_sensitivity(
)
)
Adot = Adot.assign_coords(parcel = ("vertex", parcels))

assert Adot.values.dtype == np.float32

save_Adot(sensitivity_fname, Adot)

Expand Down Expand Up @@ -1119,13 +1125,14 @@ def compute_stacked_sensitivity(sensitivity: xr.DataArray):
ec = cedalion.nirs.get_extinction_coefficients("prahl", wavelengths)

units_ec = ec.pint.units
ec = ec.pint.dequantify()
ec = ec.pint.dequantify().astype(np.float32)

units_A = units_sens * units_ec

nchannel = sensitivity.sizes["channel"]
nvertices = sensitivity.sizes["vertex"]
A = np.zeros((2 * nchannel, 2 * nvertices))
sensitivity = sensitivity.astype(np.float32)
A = np.zeros((2 * nchannel, 2 * nvertices), dtype=np.float32)

wl1, wl2 = wavelengths
# fmt: off
Expand All @@ -1134,6 +1141,7 @@ def compute_stacked_sensitivity(sensitivity: xr.DataArray):
A[nchannel:, :nvertices] = ec.sel(chromo="HbO", wavelength=wl2).values * sensitivity.sel(wavelength=wl2) # noqa: E501
A[nchannel:, nvertices:] = ec.sel(chromo="HbR", wavelength=wl2).values * sensitivity.sel(wavelength=wl2) # noqa: E501
# fmt: on
assert A.dtype == np.float32

is_brain = np.hstack([sensitivity.is_brain, sensitivity.is_brain])
flat_chromo = ["HbO"] * nvertices + ["HbR"] * nvertices
Expand Down
Loading