diff --git a/src/cedalion/models/glm/basis_functions.py b/src/cedalion/models/glm/basis_functions.py index 24ef180..fcafa64 100644 --- a/src/cedalion/models/glm/basis_functions.py +++ b/src/cedalion/models/glm/basis_functions.py @@ -1,4 +1,16 @@ -"""Temporal basis functions for the GLM.""" +"""Temporal basis functions for the GLM. +Modifications for Image Space (Parcel/Vertex) Compatibility +----------------------------------------------------------- + +This script extends Cedalion functions originally designed for channel space, +allowing them to also support image space data such as parcel-level or vertex-level time series. + +Key changes include: +- Added flexible handling of spatial dimensions using: + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + This enables the code to automatically detect and operate over 'parcel', 'vertex', or 'channel' dimensions. +- Replaced hardcoded references to 'channel' with dynamic spatial dimension references (e.g., [spatial_dim].values), + ensuring compatibility with parcel-level and vertex-level data.""" from __future__ import annotations from abc import ABC, abstractmethod @@ -223,7 +235,8 @@ def __call__( self, ts: cdt.NDTimeSeries, ) -> xr.DataArray: - other_dim = xrutils.other_dim(ts, "time", "channel") + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + other_dim = xrutils.other_dim(ts, "time", spatial_dim) other_dim_values = ts[other_dim].values tau = _to_dict(self.tau, other_dim_values) @@ -290,7 +303,8 @@ def __call__( self, ts: cdt.NDTimeSeries, ) -> xr.DataArray: - other_dim = xrutils.other_dim(ts, "time", "channel") + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + other_dim = xrutils.other_dim(ts, "time", spatial_dim) other_dim_values = ts[other_dim].values tau = _to_dict(self.tau, other_dim_values) @@ -361,7 +375,9 @@ def __call__( self, ts: cdt.NDTimeSeries, ) -> xr.DataArray: - other_dim = xrutils.other_dim(ts, "time", "channel") + + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + other_dim = xrutils.other_dim(ts, "time", spatial_dim) other_dim_values = ts[other_dim].values p = _to_dict(self.p, other_dim_values) @@ -414,7 +430,8 @@ def __call__( self, ts: cdt.NDTimeSeries, ) -> xr.DataArray: - other_dim = xrutils.other_dim(ts, "time", "channel") + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + other_dim = xrutils.other_dim(ts, "time", spatial_dim) other_dim_values = ts[other_dim].values n_samples = 2 diff --git a/src/cedalion/models/glm/design_matrix.py b/src/cedalion/models/glm/design_matrix.py index 2183338..5e98180 100644 --- a/src/cedalion/models/glm/design_matrix.py +++ b/src/cedalion/models/glm/design_matrix.py @@ -1,4 +1,30 @@ -"""Functions to create the design matrix for the GLM.""" +"""Functions to create the design matrix for the GLM. +Modifications for Image Space (Parcel/Vertex) Compatibility +----------------------------------------------------------- + +This script extends Cedalion functions originally designed for channel space, +allowing them to also support image space data such as parcel-level or vertex-level time series. + +Key changes include: +- Added flexible handling of spatial dimensions using: + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + This enables the code to automatically detect and operate over "parcel", "vertex", or "channel" dimensions. +- Replaced hardcoded references to "channel" with dynamic spatial dimension references (e.g., [spatial_dim].values), + ensuring compatibility with parcel-level and vertex-level data. +- Added a check for "samples" coordinate in the input time series. + If missing, it assigns a default sample index via: + ts = ts.assign_coords(samples=("time", np.arange(ts.sizes["time"]))) + This ensures compatibility with downstream functions that expect a "samples" coordinate, + particularly during design matrix construction and plotting routines. + Additionally, the "time" coordinate is explicitly labeled with units ("s") for clarity. +- Added assertions in all short-channel regressor functions + ("closest_short_channel_regressor", "max_corr_short_channel_regressor", + and "average_short_channel_regressor") to ensure they are only used in channel space. +- Added a new helper function "global_mean_regressor" that creates a global signal regressor + by averaging over the spatial dimension (works for channel, parcel, or vertex space). +- (Suggested improvement) In __repr__, add a check for "self.common is not None" + to make the representation robust when only channel-wise regressors are present. +""" from __future__ import annotations @@ -82,6 +108,8 @@ def iter_computational_groups( A tuple containing: - dim3 (str): The third dimension name. - group_y (cdt.NDTimeSeries): The grouped time series. + - "unit_values" (np.ndarray): Values of the spatial dimension + (e.g., "channel", "parcel", or "vertex") that belong to the current computational group. - group_design_matrix (xr.DataArray): The grouped design matrix. """ @@ -89,19 +117,21 @@ def iter_computational_groups( dim3_name = xrutils.other_dim(self.common, "time", "regressor") + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + for cwreg in self.channel_wise: assert cwreg.sizes["regressor"] == 1 - assert (ts.channel.values == cwreg.channel.values).all() + assert (ts[spatial_dim].values == cwreg[spatial_dim].values).all() comp_groups = [] for reg in self.channel_wise: if "comp_group" in reg.coords: comp_groups.append(reg["comp_group"].values) else: - comp_groups.append(_hash_channel_wise_regressor(reg)) + comp_groups.append(_hash_channel_wise_regressor(reg,spatial_dim)) if channel_groups is not None: - assert len(channel_groups) == ts.sizes["channel"] + assert len(channel_groups) == ts.sizes[spatial_dim] comp_groups.append(channel_groups) if len(comp_groups) == 0: @@ -110,9 +140,9 @@ def iter_computational_groups( for dim3 in self.common[dim3_name].values: dm = self.common.sel({dim3_name: dim3}) # group_y = ts.sel({dim3_name: dim3}) - channels = ts.channel.values + values = ts[spatial_dim].values # yield dim3, group_y, dm - yield dim3, channels, dm + yield dim3, values, dm return else: @@ -120,45 +150,48 @@ def iter_computational_groups( # the channel-wise regressors are identical, we have to assemble and yield # the design-matrix. - chan_idx_with_same_comp_group = defaultdict(list) + idx_with_same_comp_group = defaultdict(list) - for i_ch, all_comp_groups in enumerate(zip(*comp_groups)): - chan_idx_with_same_comp_group[all_comp_groups].append(i_ch) + for i_unit, all_comp_groups in enumerate(zip(*comp_groups)): + idx_with_same_comp_group[all_comp_groups].append(i_unit) for dim3 in self.common[dim3_name].values: dm = self.common.sel({dim3_name: dim3}) - for chan_indices in chan_idx_with_same_comp_group.values(): - channels = ts.channel[np.asarray(chan_indices)].values + for chan_indices in idx_with_same_comp_group.values(): + unit_values = ts[spatial_dim][np.asarray(chan_indices)].values regs = [] for reg in channel_wise_regressors: regs.append( - reg.sel({"channel": channels, dim3_name: dim3}) - .isel(channel=0) # regs are identical within a group + reg.sel({spatial_dim: unit_values, dim3_name: dim3}) + .isel(spatial_dim=0) # regs are identical within a group .pint.dequantify() ) group_design_matrix = xr.concat([dm] + regs, dim="regressor") # yield dim3, group_y, group_design_matrix - yield dim3, channels, group_design_matrix + yield dim3, unit_values, group_design_matrix def _hash_channel_wise_regressor(regressor: xr.DataArray) -> list[int]: - """Hashes each channel slice of the regressor array. + """Hashes each unit slice of the regressor array along the spatial dimension. Args: - regressor: array of channel-wise regressors. Dims - (channel, regressor, time, chromo|wavelength) + regressor: array of regressors. Dims + (spatial_dim, regressor, time, chromo|wavelength) Returns: - A list of hash values, one hash for each channel. + A list of hash values, one for each element of the spatial dimension. """ + + spatial_dim = xrutils.other_dim(regressor, "time", "chromo") tmp = regressor.pint.dequantify() - n_channel = regressor.sizes["channel"] - return [hash(tmp.isel(channel=i).values.data.tobytes()) for i in range(n_channel)] + n_channel = regressor.sizes[spatial_dim] + + return [hash(tmp.isel({spatial_dim: i}).values.data.tobytes()) for i in range(n_channel)] def hrf_regressors( @@ -180,13 +213,17 @@ def hrf_regressors( # so that users can pass their own individual hrf function trial_types: np.ndarray = stim.trial_type.unique() - + if "samples" not in ts.coords: + ts = ts.assign_coords(samples=("time", np.arange(ts.sizes["time"]))) + ts.time.attrs["units"] = "s" + basis = basis_function(ts) components = basis.component.values + spatial_dim = xrutils.other_dim(ts, "time", "chromo") # could be "chromo" or "wavelength" - other_dim = xrutils.other_dim(ts, "channel", "time") + other_dim = xrutils.other_dim(ts, spatial_dim, "time") n_time = ts.sizes["time"] n_other = ts.sizes[other_dim] @@ -269,7 +306,8 @@ def drift_regressors(ts: cdt.NDTimeSeries, drift_order) -> DesignMatrix: Returns: xr.DataArray: A DataArray containing the drift regressors. """ - dim3 = xrutils.other_dim(ts, "channel", "time") + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + dim3 = xrutils.other_dim(ts, spatial_dim, "time") ndim3 = ts.sizes[dim3] nt = ts.sizes["time"] @@ -306,7 +344,8 @@ def drift_legendre_regressors(ts : cdt.NDTimeSeries, order : int) -> DesignMatri xr.DataArray: A DataArray containing the drift regressors. """ - dim3 = xrutils.other_dim(ts, "channel", "time") + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + dim3 = xrutils.other_dim(ts, spatial_dim, "time") ndim3 = ts.sizes[dim3] nt = ts.sizes["time"] @@ -344,8 +383,8 @@ def drift_cosine_regressors(ts: cdt.NDTimeSeries, fmax: cdt.QFrequency) -> Desig Returns: xr.DataArray: A DataArray containing the drift regressors. """ - - dim3 = xrutils.other_dim(ts, "channel", "time") + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + dim3 = xrutils.other_dim(ts, spatial_dim, "time") ndim3 = ts.sizes[dim3] nt = ts.sizes["time"] @@ -429,6 +468,12 @@ def _regressors_from_selected_short_channels( selected_short_ch_indices: np.ndarray, ) -> xr.DataArray: """Build channel-wise short-channel regressors from a selection.""" + + spatial_dim = xrutils.other_dim(ts_long, "time", "chromo") + assert spatial_dim == "channel", ( + f"Short-channel regressors only make sense in channel space, " + f"but got '{spatial_dim}'." + ) # pick for each long channel from ts_short the selected closest channel # regressors has same dims as ts_long/ts_short and same channels as ts_long @@ -474,6 +519,11 @@ def closest_short_channel_regressor( Returns: regressors (xr.DataArray): Channel-wise regressor """ + spatial_dim = xrutils.other_dim(ts_long, "time", "chromo") + assert spatial_dim == "channel", ( + f"Short-channel regressors only make sense in channel space, " + f"but got '{spatial_dim}'." + ) # calculate midpoints between channel optode pairs. dims: (channel, crs) long_channel_pos = (geo3d.loc[ts_long.source] + geo3d.loc[ts_long.detector]) / 2 short_channel_pos = (geo3d.loc[ts_short.source] + geo3d.loc[ts_short.detector]) / 2 @@ -511,7 +561,11 @@ def max_corr_short_channel_regressor( Returns: xr.DataArray: channel-wise regressors """ - + spatial_dim = xrutils.other_dim(ts_long, "time", "chromo") + assert spatial_dim == "channel", ( + f"Short-channel regressors only make sense in channel space, " + f"but got '{spatial_dim}'." + ) dim3 = xrutils.other_dim(ts_long, "channel", "time") z_long = (ts_long - ts_long.mean("time")) / ts_long.std("time") @@ -550,10 +604,32 @@ def average_short_channel_regressor(ts_short: cdt.NDTimeSeries): Returns: xr.DataArray: regressors """ - + spatial_dim = xrutils.other_dim(ts_short, "time", "chromo") + assert spatial_dim == "channel", ( + f"Average short-channel regressor only makes sense in channel space, " + f"but got '{spatial_dim}'." + ) ts_short = ts_short.pint.dequantify() regressor = ts_short.mean("channel", skipna=True).expand_dims("regressor") regressor = regressor.assign_coords({"regressor": ["short"]}) regressor = regressor.transpose("time", "regressor", ...) return DesignMatrix(common=regressor, channel_wise=[]) +def global_mean_regressor(ts: cdt.NDTimeSeries) -> DesignMatrix: + """Create a global regressor by averaging over the spatial dimension. + + Args: + ts (NDTimeSeries): time series data (time x spatial_dim x chromo) + + Returns: + DesignMatrix: design matrix with one global regressor + """ + # detect the spatial dimension dynamically (channel, parcel, or vertex) + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + + # mean over spatial dimension → global signal + regressor = ts.mean(spatial_dim, skipna=True).expand_dims("regressor") + regressor = regressor.assign_coords({"regressor": ["global"]}) + regressor = regressor.transpose("time", "regressor", ...) + + return DesignMatrix(common=regressor, channel_wise=[]) \ No newline at end of file diff --git a/src/cedalion/models/glm/solve.py b/src/cedalion/models/glm/solve.py index ecc0b11..fe6f60c 100644 --- a/src/cedalion/models/glm/solve.py +++ b/src/cedalion/models/glm/solve.py @@ -1,4 +1,18 @@ -"""Solve the GLM model.""" +"""Solve the GLM model. + +Modifications for Image Space (Parcel/Vertex) Compatibility +----------------------------------------------------------- + +This script extends Cedalion functions originally designed for channel space, +allowing them to also support image space data such as parcel-level or vertex-level time series. + +Key changes include: +- Added flexible handling of spatial dimensions using: + spatial_dim = xrutils.other_dim(ts, "time", "chromo") + This enables the code to automatically detect and operate over 'parcel', 'vertex', or 'channel' dimensions. +- Replaced hardcoded references to 'channel' with dynamic spatial dimension references (e.g., [spatial_dim].values), + ensuring compatibility with parcel-level and vertex-level data.""" + from __future__ import annotations from collections import defaultdict @@ -91,13 +105,15 @@ def fit( # shoud the design matrix be dimensionless? -> thetas will have units ts = ts.pint.dequantify() + spatial_dim = xrutils.other_dim(ts, "time","chromo") + dim3_name = xrutils.other_dim(design_matrix.common, "time", "regressor") reg_results = xr.DataArray( - np.empty((ts.sizes["channel"], ts.sizes[dim3_name]), dtype=object), - dims=("channel", dim3_name), - coords=xrutils.coords_from_other(ts.isel(time=0), dims=("channel", dim3_name)) + np.empty((ts.sizes[spatial_dim], ts.sizes[dim3_name]), dtype=object), + dims=(spatial_dim, dim3_name), + coords=xrutils.coords_from_other(ts.isel(time=0), dims=(spatial_dim, dim3_name)) ) for ( @@ -105,22 +121,21 @@ def fit( group_channels, group_design_matrix, ) in design_matrix.iter_computational_groups(ts): - group_y = ts.sel({"channel": group_channels, dim3_name: dim3}).transpose( - "time", "channel" - ) - # pass x as a DataFrame to statsmodel to make it aware of regressor names + group_y = ts.sel({spatial_dim: group_channels, dim3_name: dim3}).transpose( + "time", spatial_dim + ) x = pd.DataFrame( group_design_matrix.values, columns=group_design_matrix.regressor.values ) - if(max_jobs==1): - for chan in tqdm(group_y.channel.values, disable=not verbose): + if max_jobs == 1: + for chan in tqdm(group_y[spatial_dim].values, disable=not verbose): result = _channel_fit(group_y.loc[:, chan], x, noise_model, ar_order) reg_results.loc[chan, dim3] = result else: - args_list=[] - for chan in group_y.channel.values: + args_list = [] + for chan in group_y[spatial_dim].values: args_list.append([group_y.loc[:, chan], x, noise_model, ar_order]) with parallel_config(backend='threading', n_jobs=max_jobs): @@ -131,9 +146,10 @@ def fit( total=len(args_list) ) - for chan, result in zip(group_y.channel.values, batch_results): + for chan, result in zip(group_y[spatial_dim].values, batch_results): reg_results.loc[chan, dim3] = result + #try: # coloring_matrix=np.linalg.cholesky(np.corrcoef(np.array(resid))) # coloring_matrix=xr.DataArray(data=coloring_matrix,dims=['channel','type'], @@ -178,6 +194,8 @@ def predict( """ dim3_name = xrutils.other_dim(design_matrix.common, "time", "regressor") + + spatial_dim = xrutils.other_dim(ts, "time","chromo") prediction = defaultdict(list) @@ -187,11 +205,11 @@ def predict( group_design_matrix, ) in design_matrix.iter_computational_groups(ts): # (dim3, channel, regressor) - t = thetas.sel({"channel": group_channels, dim3_name: [dim3]}) + t = thetas.sel({spatial_dim: group_channels, dim3_name: [dim3]}) prediction[dim3].append(xr.dot(group_design_matrix, t, dim="regressor")) # concatenate channels - prediction = [xr.concat(v, dim="channel") for v in prediction.values()] + prediction = [xr.concat(v, dim=spatial_dim) for v in prediction.values()] # concatenate dim3 prediction = xr.concat(prediction, dim=dim3_name)