diff --git a/.mailmap b/.mailmap index b7d87833dcc..662d097f945 100644 --- a/.mailmap +++ b/.mailmap @@ -357,6 +357,7 @@ Yousra Bekhti Yoursa BEKHTI Yoursa BEKHTI Yousra Bekhti Yousra BEKHTI Yousra Bekhti yousrabk +Zahra M. Aghajan Zahra M. Aghajan Zhi Zhang <850734033@qq.com> ZHANG Zhi <850734033@qq.com> Zhi Zhang <850734033@qq.com> ZHANG Zhi Ziyi ZENG ZIYI ZENG diff --git a/doc/changes/devel/11064.newfeature.rst b/doc/changes/devel/11064.newfeature.rst new file mode 100644 index 00000000000..24f3819a85d --- /dev/null +++ b/doc/changes/devel/11064.newfeature.rst @@ -0,0 +1 @@ +Added basic support for TD fNIRS data, by :newcontrib:`Zahra Aghajan`, :newcontrib:`Julien Dubois`, :newcontrib:`John Griffiths`, `Robert Luke`_, and `Eric Larson`_. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 0d5ee6a5c73..206ca2fec8a 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -140,6 +140,7 @@ .. _Joan Massich: https://github.com/massich .. _Johann Benerradi: https://github.com/HanBnrd .. _Johannes Niediek: https://github.com/jniediek +.. _John Griffiths: https://www.grifflab.com .. _John Samuelsson: https://github.com/johnsam7 .. _John Veillette: https://psychology.uchicago.edu/directory/john-veillette .. _Jon Houck: https://www.mrn.org/people/jon-m.-houck/principal-investigators @@ -153,6 +154,7 @@ .. _Judy D Zhu: https://github.com/JD-Zhu .. _Juergen Dammers: https://github.com/jdammers .. _Jukka Nenonen: https://www.linkedin.com/pub/jukka-nenonen/28/b5a/684 +.. _Julien Dubois: https://github.com/julien-dubois-k .. _Jussi Nurminen: https://github.com/jjnurminen .. _Kaisu Lankinen: http://bishoplab.berkeley.edu/Kaisu.html .. _Katarina Slama: https://github.com/katarinaslama @@ -324,6 +326,7 @@ .. _Yiping Zuo: https://github.com/frostime .. _Yousra Bekhti: https://www.linkedin.com/pub/yousra-bekhti/56/886/421 .. _Yu-Han Luo: https://github.com/yh-luo +.. _Zahra Aghajan: https://github.com/Zahra-M-Aghajan .. _Zhi Zhang: https://github.com/tczhangzhi/ .. _Ziyi ZENG: https://github.com/ZiyiTsang .. _Zvi Baratz: https://github.com/ZviBaratz diff --git a/environment.yml b/environment.yml index 78c773e56bf..4284e02591d 100644 --- a/environment.yml +++ b/environment.yml @@ -23,6 +23,7 @@ dependencies: - joblib - jupyter - lazy_loader >=0.3 + - libxml2 !=2.14.0 - mamba - matplotlib >=3.7 - mffpy >=0.5.7 diff --git a/mne/_fiff/constants.py b/mne/_fiff/constants.py index cf604db530c..fc7212a779f 100644 --- a/mne/_fiff/constants.py +++ b/mne/_fiff/constants.py @@ -1044,7 +1044,9 @@ FIFF.FIFFV_COIL_FNIRS_FD_PHASE = 305 # fNIRS frequency domain phase FIFF.FIFFV_COIL_FNIRS_RAW = FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE # old alias FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE = 306 # fNIRS time-domain gated amplitude -FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE = 307 # fNIRS time-domain moments amplitude +FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_INTENSITY = 307 # fNIRS time-domain moments intensity +FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_MEAN = 308 # fNIRS time-domain moments mean +FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_VARIANCE = 309 # fNIRS time-domain moments variance FIFF.FIFFV_COIL_EYETRACK_POS = 400 # Eye-tracking gaze position FIFF.FIFFV_COIL_EYETRACK_PUPIL = 401 # Eye-tracking pupil size @@ -1145,7 +1147,9 @@ FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE, FIFF.FIFFV_COIL_FNIRS_FD_PHASE, FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE, - FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE, + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_INTENSITY, + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_MEAN, + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_VARIANCE, FIFF.FIFFV_COIL_MCG_42, FIFF.FIFFV_COIL_EYETRACK_POS, FIFF.FIFFV_COIL_EYETRACK_PUPIL, diff --git a/mne/_fiff/pick.py b/mne/_fiff/pick.py index ec3479f8a26..a00adac5a95 100644 --- a/mne/_fiff/pick.py +++ b/mne/_fiff/pick.py @@ -104,6 +104,26 @@ def get_channel_type_constants(include_defaults=False): unit=FIFF.FIFF_UNIT_RAD, coil_type=FIFF.FIFFV_COIL_FNIRS_FD_PHASE, ), + fnirs_td_gated_amplitude=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_V, + coil_type=FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE, + ), + fnirs_td_moments_intensity=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_UNITLESS, + coil_type=FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_INTENSITY, + ), + fnirs_td_moments_mean=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_S, + coil_type=FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_MEAN, + ), + fnirs_td_moments_variance=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_NONE, # TODO: Maybe someday add s^2 + coil_type=FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_VARIANCE, + ), fnirs_od=dict(kind=FIFF.FIFFV_FNIRS_CH, coil_type=FIFF.FIFFV_COIL_FNIRS_OD), hbo=dict( kind=FIFF.FIFFV_FNIRS_CH, @@ -197,6 +217,10 @@ def get_channel_type_constants(include_defaults=False): FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE: "fnirs_fd_ac_amplitude", FIFF.FIFFV_COIL_FNIRS_FD_PHASE: "fnirs_fd_phase", FIFF.FIFFV_COIL_FNIRS_OD: "fnirs_od", + FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE: "fnirs_td_gated_amplitude", + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_INTENSITY: "fnirs_td_moments_intensity", + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_MEAN: "fnirs_td_moments_mean", + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_VARIANCE: "fnirs_td_moments_variance", }, ), "eeg": ( @@ -385,6 +409,26 @@ def _triage_fnirs_pick(ch, fnirs, warned): return True elif ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_OD and "fnirs_od" in fnirs: return True + elif ( + ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE + and "fnirs_td_gated_amplitude" in fnirs + ): + return True + elif ( + ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_INTENSITY + and "fnirs_td_moments_intensity" in fnirs + ): + return True + elif ( + ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_MEAN + and "fnirs_td_moments_mean" in fnirs + ): + return True + elif ( + ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_VARIANCE + and "fnirs_td_moments_variance" in fnirs + ): + return True return False @@ -569,7 +613,7 @@ def pick_types( pick[k] = _triage_meg_pick(info["chs"][k], ref_meg) elif ch_type in ("eyegaze", "pupil"): pick[k] = _triage_eyetrack_pick(info["chs"][k], eyetrack) - else: # ch_type in ('hbo', 'hbr') + else: # ch_type in ('hbo', 'hbr', ...) pick[k] = _triage_fnirs_pick(info["chs"][k], fnirs, warned) # restrict channels to selection if provided @@ -862,6 +906,10 @@ def channel_indices_by_type(info, picks=None): fnirs_fd_ac_amplitude=list(), fnirs_fd_phase=list(), fnirs_od=list(), + fnirs_td_gated_amplitude=list(), + fnirs_td_moments_intensity=list(), + fnirs_td_moments_mean=list(), + fnirs_td_moments_variance=list(), eyegaze=list(), pupil=list(), ) @@ -1099,6 +1147,10 @@ def _check_excludes_includes(chs, info=None, allow_bads=False): "fnirs_fd_ac_amplitude", "fnirs_fd_phase", "fnirs_od", + "fnirs_td_gated_amplitude", + "fnirs_td_moments_intensity", + "fnirs_td_moments_mean", + "fnirs_td_moments_variance", ) _EYETRACK_CH_TYPES_SPLIT = ("eyegaze", "pupil") _DATA_CH_TYPES_ORDER_DEFAULT = ( diff --git a/mne/_fiff/tests/test_constants.py b/mne/_fiff/tests/test_constants.py index c857375c46e..2bf742808c0 100644 --- a/mne/_fiff/tests/test_constants.py +++ b/mne/_fiff/tests/test_constants.py @@ -27,8 +27,8 @@ from mne.utils import requires_good_network # https://github.com/mne-tools/fiff-constants/commits/master -REPO = "mne-tools" -COMMIT = "e27f68cbf74dbfc5193ad429cc77900a59475181" +REPO = "larsoner" # TODO: Replace with upstream once merged +COMMIT = "ba2288355b61b00d65d4f1d8a47ef82b83414201" # These are oddities that we won't address: iod_dups = (355, 359) # these are in both MEGIN and MNE files @@ -91,7 +91,9 @@ 304, # fNIRS frequency domain AC amplitude 305, # fNIRS frequency domain phase 306, # fNIRS time domain gated amplitude - 307, # fNIRS time domain moments amplitude + 307, # fNIRS time domain moments intensity + 308, # fNIRS time domain moments mean + 309, # fNIRS time domain moments variance 400, # Eye-tracking gaze position 401, # Eye-tracking pupil size 1000, # For testing the MCG software diff --git a/mne/cov.py b/mne/cov.py index 694c836d0cd..c49a5fdd051 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -1517,6 +1517,7 @@ def __init__( grad=0.1, mag=0.1, eeg=0.1, + *, seeg=0.1, ecog=0.1, hbo=0.1, @@ -1525,6 +1526,10 @@ def __init__( fnirs_fd_ac_amplitude=0.1, fnirs_fd_phase=0.1, fnirs_od=0.1, + fnirs_td_gated_amplitude=0.1, + fnirs_td_moments_intensity=0.1, + fnirs_td_moments_mean=0.1, + fnirs_td_moments_variance=0.1, csd=0.1, dbs=0.1, store_precision=False, @@ -1545,6 +1550,10 @@ def __init__( self.fnirs_fd_ac_amplitude = fnirs_fd_ac_amplitude self.fnirs_fd_phase = fnirs_fd_phase self.fnirs_od = fnirs_od + self.fnirs_td_gated_amplitude = fnirs_td_gated_amplitude + self.fnirs_td_moments_intensity = fnirs_td_moments_intensity + self.fnirs_td_moments_mean = fnirs_td_moments_mean + self.fnirs_td_moments_variance = fnirs_td_moments_variance self.csd = csd self.store_precision = store_precision self.assume_centered = assume_centered @@ -1577,6 +1586,15 @@ def fit(self, X): dbs=self.dbs, hbo=self.hbo, hbr=self.hbr, + fnirs_cw_amplitude=self.fnirs_cw_amplitude, + fnirs_fd_ac_amplitude=self.fnirs_fd_ac_amplitude, + fnirs_fd_phase=self.fnirs_fd_phase, + fnirs_od=self.fnirs_od, + fnirs_td_gated_amplitude=self.fnirs_td_gated_amplitude, + fnirs_td_moments_intensity=self.fnirs_td_moments_intensity, + fnirs_td_moments_mean=self.fnirs_td_moments_mean, + fnirs_td_moments_variance=self.fnirs_td_moments_variance, + csd=self.csd, rank="full", ) self.estimator_.covariance_ = self.covariance_ = cov_.data @@ -1904,6 +1922,7 @@ def regularize( eeg=0.1, exclude="bads", proj=True, + *, seeg=0.1, ecog=0.1, hbo=0.1, @@ -1912,6 +1931,10 @@ def regularize( fnirs_fd_ac_amplitude=0.1, fnirs_fd_phase=0.1, fnirs_od=0.1, + fnirs_td_gated_amplitude=0.1, + fnirs_td_moments_intensity=0.1, + fnirs_td_moments_mean=0.1, + fnirs_td_moments_variance=0.1, csd=0.1, dbs=0.1, rank=None, @@ -1963,6 +1986,14 @@ def regularize( Regularization factor for fNIRS raw phase signals. fnirs_od : float (default 0.1) Regularization factor for fNIRS optical density signals. + fnirs_td_gated_amplitude : float (default 0.1) + Regularization factor for fNIRS time domain gated amplitude signals. + fnirs_td_moments_intensity : float (default 0.1) + Regularization factor for fNIRS time domain moments amplitude signals. + fnirs_td_moments_mean : float (default 0.1) + Regularization factor for fNIRS time domain moments mean signals. + fnirs_td_moments_variance : float (default 0.1) + Regularization factor for fNIRS time domain moments variance signals. csd : float (default 0.1) Regularization factor for EEG-CSD signals. dbs : float (default 0.1) @@ -2003,6 +2034,10 @@ def regularize( fnirs_fd_ac_amplitude=fnirs_fd_ac_amplitude, fnirs_fd_phase=fnirs_fd_phase, fnirs_od=fnirs_od, + fnirs_td_gated_amplitude=fnirs_td_gated_amplitude, + fnirs_td_moments_intensity=fnirs_td_moments_intensity, + fnirs_td_moments_mean=fnirs_td_moments_mean, + fnirs_td_moments_variance=fnirs_td_moments_variance, csd=csd, ) diff --git a/mne/datasets/config.py b/mne/datasets/config.py index 75eff184cd1..16df4e9ff1f 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( - testing="0.156", + testing="0.162", misc="0.27", phantom_kit="0.2", ucl_opm_auditory="0.2", @@ -115,7 +115,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:d94fe9f3abe949a507eaeb865fb84a3f", + hash="md5:34d4f174adbb211ba58a584b6c1d348c", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f"tar.gz/{RELEASES['testing']}" diff --git a/mne/defaults.py b/mne/defaults.py index d5aab1a8d38..bcd183d1ae6 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -32,6 +32,10 @@ fnirs_fd_ac_amplitude="k", fnirs_fd_phase="k", fnirs_od="k", + fnirs_td_gated_amplitude="k", + fnirs_td_moments_intensity="k", + fnirs_td_moments_mean="k", + fnirs_td_moments_variance="k", csd="k", whitened="k", gsr="#666633", @@ -60,6 +64,10 @@ fnirs_fd_ac_amplitude="V", fnirs_fd_phase="rad", fnirs_od="V", + fnirs_td_gated_amplitude="AU", # counts + fnirs_td_moments_intensity="AU", # counts + fnirs_td_moments_mean="S", + fnirs_td_moments_variance="S²", csd="V/m²", whitened="Z", gsr="S", @@ -88,6 +96,10 @@ fnirs_fd_ac_amplitude="V", fnirs_fd_phase="rad", fnirs_od="V", + fnirs_td_gated_amplitude="AU", + fnirs_td_moments_intensity="AU", + fnirs_td_moments_mean="S", + fnirs_td_moments_variance="S²", csd="mV/m²", whitened="Z", gsr="S", @@ -117,6 +129,10 @@ fnirs_fd_ac_amplitude=1.0, fnirs_fd_phase=1.0, fnirs_od=1.0, + fnirs_td_gated_amplitude=1.0, + fnirs_td_moments_intensity=1.0, + fnirs_td_moments_mean=1.0, + fnirs_td_moments_variance=1.0, csd=1e3, whitened=1.0, gsr=1.0, @@ -151,6 +167,10 @@ fnirs_fd_ac_amplitude=2e-2, fnirs_fd_phase=2e-1, fnirs_od=2e-2, + fnirs_td_gated_amplitude=1.0, + fnirs_td_moments_intensity=1.0, + fnirs_td_moments_mean=1.0, + fnirs_td_moments_variance=1.0, csd=200e-4, dipole=1e-7, gof=1e2, @@ -206,6 +226,10 @@ fnirs_fd_phase="fNIRS (FD phase)", fnirs_od="fNIRS (OD)", hbr="Deoxyhemoglobin", + fnirs_td_gated_amplitude="fNIRS (TD amplitude)", + fnirs_td_moments_intensity="fNIRS (TD moment intensity)", + fnirs_td_moments_mean="fNIRS (TD moment mean)", + fnirs_td_moments_variance="fNIRS (TD moment variance)", gof="Goodness of fit", csd="Current source density", stim="Stimulus", diff --git a/mne/io/snirf/_snirf.py b/mne/io/snirf/_snirf.py index c07790b5845..adba61462de 100644 --- a/mne/io/snirf/_snirf.py +++ b/mne/io/snirf/_snirf.py @@ -14,10 +14,58 @@ from ..._freesurfer import get_mni_fiducials from ...annotations import Annotations from ...transforms import _frame_to_str, apply_trans -from ...utils import _check_fname, _import_h5py, fill_doc, logger, verbose, warn +from ...utils import ( + NamedInt, + _check_fname, + _check_option, + _import_h5py, + fill_doc, + logger, + verbose, + warn, +) from ..base import BaseRaw from ..nirx.nirx import _convert_fnirs_to_head +SNIRF_CW_AMPLITUDE = NamedInt("SNIRF_CW_AMPLITUDE", 1) +SNIRF_TD_GATED_AMPLITUDE = NamedInt("SNIRF_TD_GATED_AMPLITUDE", 201) +SNIRF_TD_MOMENTS_AMPLITUDE = NamedInt("SNIRF_TD_MOMENTS_AMPLITUDE", 301) +SNIRF_PROCESSED = NamedInt("SNIRF_PROCESSED", 99999) +_AVAILABLE_SNIRF_DATA_TYPES = ( + SNIRF_CW_AMPLITUDE, + SNIRF_TD_GATED_AMPLITUDE, + SNIRF_TD_MOMENTS_AMPLITUDE, + SNIRF_PROCESSED, +) + + +# SNIRF: Supported measurementList(k).dataTypeLabel values in dataTimeSeries +FNIRS_SNIRF_DATATYPELABELS = { + # These types are specified here: + # https://github.com/fNIRS/snirf/blob/master/snirf_specification.md#supported-measurementlistkdatatypelabel-values-in-datatimeseries # noqa: E501 + "HbO": 1, # Oxygenated hemoglobin (oxyhemoglobin) concentration + "HbR": 2, # Deoxygenated hemoglobin (deoxyhemoglobin) concentration + "HbT": 3, # Total hemoglobin concentration + "dOD": 4, # Change in optical density + "mua": 5, # Absorption coefficient + "musp": 6, # Scattering coefficient + "H2O": 7, # Water content + "Lipid": 8, # Lipid concentration + "BFi": 9, # Blood flow index + "HRF dOD": 10, # HRF for change in optical density + "HRF HbO": 11, # HRF for oxyhemoglobin concentration + "HRF HbR": 12, # HRF for deoxyhemoglobin concentration + "HRF HbT": 13, # HRF for total hemoglobin concentration + "HRF BFi": 14, # HRF for blood flow index +} + +# In each file, the TD moment order maps to these values +_TD_MOMENT_ORDER_MAP = { + 0: "intensity", + 1: "mean", + 2: "variance", +} + @fill_doc def read_raw_snirf( @@ -104,19 +152,14 @@ def __init__(self, fname, optode_frame="unknown", preload=False, verbose=None): if (optode_frame == "unknown") & (manufacturer == "Gowerlabs"): optode_frame = "head" - snirf_data_type = np.array( - dat.get("nirs/data1/measurementList1/dataType") - ).item() - if snirf_data_type not in [1, 99999]: - # 1 = Continuous Wave - # 99999 = Processed - raise RuntimeError( - "MNE only supports reading continuous" - " wave amplitude and processed haemoglobin" - " SNIRF files. Expected type" - " code 1 or 99999 but received type " - f"code {snirf_data_type}" - ) + snirf_data_type = _correct_shape( + np.array(dat.get("nirs/data1/measurementList1/dataType")) + )[0] + _check_option( + "SNIRF data type", + snirf_data_type, + list(_AVAILABLE_SNIRF_DATA_TYPES), + ) last_samps = dat.get("/nirs/data1/dataTimeSeries").shape[0] - 1 @@ -138,6 +181,15 @@ def __init__(self, fname, optode_frame="unknown", preload=False, verbose=None): "with two wavelengths." ) + # Get data type specific probe information + if snirf_data_type == SNIRF_TD_GATED_AMPLITUDE: + fnirs_time_delays = np.array(dat.get("nirs/probe/timeDelays"), float) + fnirs_time_delay_widths = np.array( + dat.get("nirs/probe/timeDelayWidths"), float + ) + elif snirf_data_type == SNIRF_TD_MOMENTS_AMPLITUDE: + fnirs_moment_orders = np.array(dat.get("nirs/probe/momentOrders"), int) + # Extract channels def atoi(text): return int(text) if text.isdigit() else text @@ -162,7 +214,7 @@ def natural_keys(text): sources = np.unique( [ _correct_shape( - np.array(dat.get("nirs/data1/" + c + "/sourceIndex")) + np.array(dat.get(f"nirs/data1/{c}/sourceIndex")) )[0] for c in channels ] @@ -179,7 +231,7 @@ def natural_keys(text): detectors = np.unique( [ _correct_shape( - np.array(dat.get("nirs/data1/" + c + "/detectorIndex")) + np.array(dat.get(f"nirs/data1/{c}/detectorIndex")) )[0] for c in channels ] @@ -225,63 +277,101 @@ def natural_keys(text): chnames = [] ch_types = [] + need_data_scale = False for chan in channels: + ch_root = f"nirs/data1/{chan}" src_idx = int( - _correct_shape( - np.array(dat.get("nirs/data1/" + chan + "/sourceIndex")) - )[0] + _correct_shape(np.array(dat.get(f"{ch_root}/sourceIndex")))[0] ) det_idx = int( - _correct_shape( - np.array(dat.get("nirs/data1/" + chan + "/detectorIndex")) - )[0] + _correct_shape(np.array(dat.get(f"{ch_root}/detectorIndex")))[0] ) + ch_name = f"{sources[src_idx]}_{detectors[det_idx]}" - if snirf_data_type == 1: + if snirf_data_type in ( + SNIRF_CW_AMPLITUDE, + SNIRF_TD_GATED_AMPLITUDE, + SNIRF_TD_MOMENTS_AMPLITUDE, + ): wve_idx = int( - _correct_shape( - np.array(dat.get("nirs/data1/" + chan + "/wavelengthIndex")) - )[0] - ) - ch_name = ( - sources[src_idx] - + "_" - + detectors[det_idx] - + " " - + str(fnirs_wavelengths[wve_idx - 1]) + _correct_shape(np.array(dat.get(f"{ch_root}/wavelengthIndex")))[ + 0 + ] ) - chnames.append(ch_name) - ch_types.append("fnirs_cw_amplitude") + # append wavelength + ch_name = f"{ch_name} {fnirs_wavelengths[wve_idx - 1]}" + if snirf_data_type == SNIRF_CW_AMPLITUDE: + ch_type = "fnirs_cw_amplitude" + elif snirf_data_type == SNIRF_TD_GATED_AMPLITUDE: + bin_idx = int( + _correct_shape( + np.array(dat.get(f"{ch_root}/dataTypeIndex")) + )[0] + ) + # append time delay + ch_name = f"{ch_name} bin{fnirs_time_delays[bin_idx - 1]}" + ch_type = "fnirs_td_gated_amplitude" + need_data_scale = True + else: + assert snirf_data_type == SNIRF_TD_MOMENTS_AMPLITUDE + moment_idx = int( + _correct_shape( + np.array(dat.get(f"{ch_root}/dataTypeIndex")) + )[0] + ) + # append moment order + order = fnirs_moment_orders[moment_idx - 1] + _check_option( + f"SNIRF channel {chan} moment order", + order, + _TD_MOMENT_ORDER_MAP, + ) + ch_name = f"{ch_name} moment{order}" + ch_type = f"fnirs_td_moments_{_TD_MOMENT_ORDER_MAP[order]}" - elif snirf_data_type == 99999: + elif snirf_data_type == SNIRF_PROCESSED: dt_id = _correct_shape( - np.array(dat.get("nirs/data1/" + chan + "/dataTypeLabel")) + np.array(dat.get(f"{ch_root}/dataTypeLabel")) )[0].decode("UTF-8") # Convert between SNIRF processed names and MNE type names dt_id = dt_id.lower().replace("dod", "fnirs_od") - ch_name = sources[src_idx] + "_" + detectors[det_idx] - if dt_id == "fnirs_od": wve_idx = int( _correct_shape( - np.array( - dat.get("nirs/data1/" + chan + "/wavelengthIndex") - ) + np.array(dat.get(f"{ch_root}/wavelengthIndex")) )[0] ) - suffix = " " + str(fnirs_wavelengths[wve_idx - 1]) + suffix = str(fnirs_wavelengths[wve_idx - 1]) else: - suffix = " " + dt_id.lower() - ch_name = ch_name + suffix - - chnames.append(ch_name) - ch_types.append(dt_id) + if dt_id not in ("hbo", "hbr"): + raise RuntimeError( + "read_raw_snirf can only handle processed " + "data in the form of optical density or " + f"HbO/HbR, but got type f{dt_id}" + ) + suffix = dt_id.lower() + need_data_scale = True + ch_name = f"{ch_name} {suffix}" + ch_type = dt_id + chnames.append(ch_name) + ch_types.append(ch_type) + del ch_root, ch_name, ch_type # Create mne structure info = create_info(chnames, sampling_rate, ch_types=ch_types) + if need_data_scale: + snirf_data_unit = np.array( + dat.get("nirs/data1/measurementList1/dataUnit", b"M") + ) + snirf_data_unit = snirf_data_unit.item().decode("utf-8") + scale = _get_dataunit_scaling(snirf_data_unit) # " " or "M") + if scale is not None: + for ch in info["chs"]: + ch["cal"] = scale + subject_info = {} names = np.array(dat.get("nirs/metaDataTags/SubjectID")) names = _correct_shape(names)[0].decode("UTF-8") @@ -335,15 +425,12 @@ def natural_keys(text): coord_frame = FIFF.FIFFV_COORD_UNKNOWN for idx, chan in enumerate(channels): + ch_root = f"nirs/data1/{chan}" src_idx = int( - _correct_shape( - np.array(dat.get("nirs/data1/" + chan + "/sourceIndex")) - )[0] + _correct_shape(np.array(dat.get(f"{ch_root}/sourceIndex")))[0] ) det_idx = int( - _correct_shape( - np.array(dat.get("nirs/data1/" + chan + "/detectorIndex")) - )[0] + _correct_shape(np.array(dat.get(f"{ch_root}/detectorIndex")))[0] ) info["chs"][idx]["loc"][3:6] = srcPos3D[src_idx - 1, :] @@ -355,15 +442,48 @@ def natural_keys(text): info["chs"][idx]["loc"][0:3] = midpoint info["chs"][idx]["coord_frame"] = coord_frame - if (snirf_data_type in [1]) or ( - (snirf_data_type == 99999) and (ch_types[idx] == "fnirs_od") + # get data type specific info: + wve_idx = int( + _correct_shape( + np.array(dat.get(f"{ch_root}/wavelengthIndex", [1])) + )[0] + ) + if snirf_data_type == SNIRF_CW_AMPLITUDE or ( + snirf_data_type == SNIRF_PROCESSED and ch_types[idx] == "fnirs_od" ): - wve_idx = int( - _correct_shape( - np.array(dat.get("nirs/data1/" + chan + "/wavelengthIndex")) - )[0] - ) info["chs"][idx]["loc"][9] = fnirs_wavelengths[wve_idx - 1] + elif snirf_data_type in ( + SNIRF_TD_GATED_AMPLITUDE, + SNIRF_TD_MOMENTS_AMPLITUDE, + ): + info["chs"][idx]["loc"][9] = fnirs_wavelengths[wve_idx - 1] + if snirf_data_type == SNIRF_TD_GATED_AMPLITUDE: + bin_idx = int( + _correct_shape( + np.array(dat.get(f"{ch_root}/dataTypeIndex")) + )[0] + ) + info["chs"][idx]["loc"][10] = ( + fnirs_time_delays[bin_idx - 1] + * fnirs_time_delay_widths[bin_idx - 1] + ) + else: + assert snirf_data_type == SNIRF_TD_MOMENTS_AMPLITUDE + moment_idx = int( + _correct_shape( + np.array(dat.get(f"{ch_root}/dataTypeIndex")) + )[0] + ) + info["chs"][idx]["loc"][10] = fnirs_moment_orders[ + moment_idx - 1 + ] + elif snirf_data_type == SNIRF_PROCESSED: + hb_id = ( + np.array(dat.get(f"{ch_root}/dataTypeLabel")) + .item() + .decode("UTF-8") + ) + info["chs"][idx]["loc"][9] = FNIRS_SNIRF_DATATYPELABELS[hb_id] if "landmarkPos3D" in dat.get("nirs/probe/"): diglocs = np.array(dat.get("/nirs/probe/landmarkPos3D")) @@ -477,11 +597,9 @@ def natural_keys(text): annot = Annotations([], [], []) for key in dat["nirs"]: if "stim" in key: - data = np.atleast_2d(np.array(dat.get("/nirs/" + key + "/data"))) + data = np.atleast_2d(np.array(dat.get(f"/nirs/{key}/data"))) if data.shape[1] >= 3: - desc = _correct_shape( - np.array(dat.get("/nirs/" + key + "/name")) - )[0] + desc = _correct_shape(np.array(dat.get(f"/nirs/{key}/name")))[0] annot.append(data[:, 0], data[:, 1], desc.decode("UTF-8")) self.set_annotations(annot, emit_warning=False) @@ -531,6 +649,19 @@ def _get_lengthunit_scaling(length_unit): ) +def _get_dataunit_scaling(hbx_unit): + """MNE expects hbo/hbr in M, return required scaling.""" + scalings = {"M": None, "uM": 1e-6} + try: + return scalings[hbx_unit] + except KeyError: + raise RuntimeError( + f"The Hb unit {repr(hbx_unit)} is not supported " + "by MNE. Please report this error as a GitHub " + "issue to inform the developers." + ) from None + + def _extract_sampling_rate(dat): """Extract the sample rate from the time field.""" # This is a workaround to provide support for Artinis data. diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index 1f69d9b9df7..7ec12186041 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -57,7 +57,14 @@ ) # Kernel -kernel_hb = testing_path / "SNIRF" / "Kernel" / "Flow50" / "Portal_2021_11" / "hb.snirf" +kernel_flow1_path = testing_path / "SNIRF" / "Kernel" / "Flow50" / "Portal_2021_11" +kernel_hb_old = kernel_flow1_path / "hb.snirf" +kernel_td_moments_old = kernel_flow1_path / "td_moments.snirf" +kernel_flow2_path = testing_path / "SNIRF" / "Kernel" / "Flow2" / "Portal_2024_10_23" +kernel_td_gated = kernel_flow2_path / "c345d04_2.snirf" # Type 201 (TD Gated, 201) +kernel_td_moments = kernel_flow2_path / "c345d04_3.snirf" # Type 202 (TD Moments, 301) +kernel_hb = kernel_flow2_path / "c345d04_5.snirf" # Type 203 (Hb, 99999) + h5py = pytest.importorskip("h5py") # module-level @@ -85,7 +92,13 @@ def _get_loc(raw, ch_name): nirx_nirsport2_103, nirx_nirsport2_103_2, nirx_nirsport2_103_2, - kernel_hb, + pytest.param(kernel_hb_old, id=f"kernel: {kernel_hb_old.stem}"), + pytest.param( + kernel_td_moments_old, id=f"kernel: {kernel_td_moments_old.stem}" + ), + pytest.param(kernel_td_gated, id=f"kernel: {kernel_td_gated.stem}"), + pytest.param(kernel_td_moments, id=f"kernel: {kernel_td_moments.stem}"), + pytest.param(kernel_hb, id=f"kernel: {kernel_hb.stem}"), lumo110, ] ), @@ -94,12 +107,29 @@ def test_basic_reading_and_min_process(fname): """Test reading SNIRF files and minimum typical processing.""" raw = read_raw_snirf(fname, preload=True) # SNIRF data can contain several types, so only apply appropriate functions + kinds = [ + "fnirs_cw_amplitude", + "fnirs_od", + "fnirs_td_gated_amplitude", + "fnirs_td_moments_intensity", + "hbo", + # TODO: add fd_* + ] + ch_types = raw.get_channel_types(unique=True) + got_kinds = [kind for kind in kinds if kind in raw] + assert len(got_kinds) == 1, f"Need one data type, {got_kinds=} and {ch_types=}" if "fnirs_cw_amplitude" in raw: raw = optical_density(raw) - if "fnirs_od" in raw: + elif "fnirs_od" in raw: raw = beer_lambert_law(raw, ppf=6) - assert "hbo" in raw - assert "hbr" in raw + elif "fnirs_td_gated_amplitude" in raw: + pass + elif "fnirs_td_moments_intensity" in raw: + assert "fnirs_td_moments_mean" in raw + assert "fnirs_td_moments_variance" in raw + else: + assert "hbo" in raw + assert "hbr" in raw @requires_testing_data @@ -413,25 +443,80 @@ def test_snirf_fieldtrip_od(): @requires_testing_data -def test_snirf_kernel_hb(): - """Test reading Kernel SNIRF files with haemoglobin data.""" - raw = read_raw_snirf(kernel_hb, preload=True) - - # Test data import - assert raw._data.shape == (180 * 2, 14) - assert raw.copy().pick("hbo")._data.shape == (180, 14) - assert raw.copy().pick("hbr")._data.shape == (180, 14) - - assert_allclose(raw.info["sfreq"], 8.257638) +@pytest.mark.parametrize( + "kind, ver, shape, n_nan, fname", + [ + pytest.param("hb", "new", (4, 38), 0, kernel_hb, id="hb"), + pytest.param("hb", "old", (180 * 2, 14), 20, kernel_hb_old, id="hb old"), + pytest.param( + "td moments", "new", (12, 38), 0, kernel_td_moments, id="td moments" + ), + pytest.param("td gated", "new", (100, 38), 0, kernel_td_gated, id="td gated"), + pytest.param( + "td moments", + "old", + (1080, 14), + 60, + kernel_td_moments_old, + id="td moments old", + ), + ], +) +def test_snirf_kernel_basic(kind, ver, shape, n_nan, fname): + """Test reading Kernel SNIRF files with haemoglobin or TD data.""" + raw = read_raw_snirf(fname, preload=True) + if kind == "hb": + # Test data import + assert raw._data.shape == shape + hbo_data = raw.get_data("hbo") + hbr_data = raw.get_data("hbr") + assert hbo_data.shape == hbr_data.shape == (shape[0] // 2, shape[1]) + hbo_norm = np.nanmedian(np.linalg.norm(hbo_data, axis=-1)) + hbr_norm = np.nanmedian(np.linalg.norm(hbr_data, axis=-1)) + # TODO: Old file vs new file scaling, one is wrong! + if ver == "new": + assert 1e-5 < hbr_norm < hbo_norm < 1e-4 + else: + assert 1 < hbr_norm < 3 + elif kind == "td moments": + assert raw._data.shape == shape + n_ch = 0 + # TODO: Reasonable values here??? + lims = dict(intensity=(1e4, 1e7), mean=(1e3, 1e4), variance=(1e5, 1e7)) + for key, val in lims.items(): + data = raw.get_data(f"fnirs_td_moments_{key}") + assert data.shape[1] == len(raw.times) + norm = np.nanmedian(np.linalg.norm(data, axis=-1)) + min_, max_ = val + assert min_ < norm < max_, key + n_ch += data.shape[0] + assert raw._data.shape[0] == len(raw.ch_names) == n_ch + else: + pass # TODO: add some gated tests + if ver == "old": + sfreq = 8.257638 + n_annot = 2 + else: + sfreq = 3.759398 + n_annot = 8 + + assert_allclose(raw.info["sfreq"], sfreq, atol=1e-5) bad_nans = np.isnan(raw.get_data()).any(axis=1) - assert np.sum(bad_nans) == 20 - - assert len(raw.annotations.description) == 2 - assert raw.annotations.onset[0] == 0.036939 - assert raw.annotations.onset[1] == 0.874633 - assert raw.annotations.description[0] == "StartTrial" - assert raw.annotations.description[1] == "StartIti" + assert np.sum(bad_nans) == n_nan + + if n_annot == 2: + assert len(raw.annotations.description) == n_annot + assert raw.annotations.onset[0] == 0.036939 + assert raw.annotations.onset[1] == 0.874633 + assert raw.annotations.description[0] == "StartTrial" + assert raw.annotations.description[1] == "StartIti" + else: + assert len(raw.annotations.description) == n_annot + assert raw.annotations.onset[0] == 4.988107 + assert raw.annotations.onset[1] == 5.988107 + assert raw.annotations.description[0] == "StartBlock" + assert raw.annotations.description[1] == "StartTrial" @requires_testing_data diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index 94c7c78468c..5de8ef51916 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -104,10 +104,33 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # All chromophore fNIRS data picks_chroma = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True) - if (len(picks_wave) > 0) & (len(picks_chroma) > 0): + # All TD moments + td_moments = [ + "fnirs_td_moments_intensity", + "fnirs_td_moments_mean", + "fnirs_td_moments_variance", + ] + picks_moments = _picks_to_idx(info, td_moments, exclude=[], allow_empty=True) + + # All TD gated + picks_gated = _picks_to_idx( + info, ["fnirs_td_gated_amplitude"], exclude=[], allow_empty=True + ) + + n_found = sum( + len(x) > 0 + for x in ( + picks_wave, + picks_chroma, + picks_moments, + picks_gated, + ) + ) + if n_found != 1: picks = _throw_or_return_empty( - "MNE does not support a combination of amplitude, optical " - "density, and haemoglobin data in the same raw structure.", + "MNE supports exactly one of amplitude, optical density, " + "TD moments, TD gated, and haemoglobin data in a given raw " + f"structure, found {n_found}", throw_errors, ) @@ -116,10 +139,13 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr error_word = "frequencies" use_RE = _S_D_F_RE picks = picks_wave - else: + elif len(picks_chroma): error_word = "chromophore" use_RE = _S_D_H_RE picks = picks_chroma + else: + assert len(picks_moments) or len(picks_gated) + return # nothing to check pair_vals = np.array(pair_vals) if pair_vals.shape != (2,): diff --git a/mne/preprocessing/nirs/tests/test_nirs.py b/mne/preprocessing/nirs/tests/test_nirs.py index 89fa17c0c8d..ef0e47ba838 100644 --- a/mne/preprocessing/nirs/tests/test_nirs.py +++ b/mne/preprocessing/nirs/tests/test_nirs.py @@ -386,7 +386,7 @@ def test_fnirs_channel_naming_and_order_custom_optical_density(): info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) raw2 = RawArray(data, info, verbose=True) raw.add_channels([raw2]) - with pytest.raises(ValueError, match="does not support a combination"): + with pytest.raises(ValueError, match="exactly one of"): _check_channels_ordered(raw.info, [760, 850]) diff --git a/mne/tests/test_defaults.py b/mne/tests/test_defaults.py index ba3a8395fa8..aae295c3cdf 100644 --- a/mne/tests/test_defaults.py +++ b/mne/tests/test_defaults.py @@ -44,7 +44,7 @@ def test_si_units(): want_scale = _get_scaling(key, units[key]) else: want_scale = _get_scaling(key, units[key]) - assert_allclose(scale, want_scale, rtol=1e-12) + assert_allclose(scale, want_scale, rtol=1e-12, err_msg=key) @pytest.mark.parametrize("key", ("si_units", "color", "scalings", "scalings_plot_raw")) diff --git a/mne/utils/__init__.pyi b/mne/utils/__init__.pyi index 46d272e972d..6e43e8a14f6 100644 --- a/mne/utils/__init__.pyi +++ b/mne/utils/__init__.pyi @@ -6,6 +6,7 @@ __all__ = [ "ClosingStringIO", "ExtendedTimeMixin", "GetEpochsMixin", + "NamedInt", "ProgressBar", "SizeMixin", "TimeMixin", @@ -182,7 +183,7 @@ __all__ = [ "warn", "wrapped_stdout", ] -from ._bunch import Bunch, BunchConst, BunchConstNamed +from ._bunch import Bunch, BunchConst, BunchConstNamed, NamedInt from ._logging import ( ClosingStringIO, _get_call_line, diff --git a/tools/hooks/update_environment_file.py b/tools/hooks/update_environment_file.py index 0b5380a16b5..419d5b91a6c 100755 --- a/tools/hooks/update_environment_file.py +++ b/tools/hooks/update_environment_file.py @@ -22,7 +22,12 @@ deps |= set(section_deps) recursive_deps = set(d for d in deps if d.startswith("mne[")) deps -= recursive_deps -deps |= {"pip", "mamba", "nomkl"} +deps |= { # ones we add to environment.yml in addition to those from pyproject.toml + "pip", + "mamba", + "nomkl", + "libxml2 !=2.14.0", # https://github.com/conda-forge/libxml2-feedstock/issues/146 +} def remove_spaces(version_spec):