Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/cedalion/models/glm/basis_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
"""Temporal basis functions for the GLM."""
"""Temporal basis functions for the GLM.
Modifications for Image Space (Parcel/Vertex) Compatibility
-----------------------------------------------------------

This script extends Cedalion functions originally designed for channel space,
allowing them to also support image space data such as parcel-level or vertex-level time series.

Key changes include:
- Added flexible handling of spatial dimensions using:
spatial_dim = xrutils.other_dim(ts, "time", "chromo")
This enables the code to automatically detect and operate over 'parcel', 'vertex', or 'channel' dimensions.
- Replaced hardcoded references to 'channel' with dynamic spatial dimension references (e.g., [spatial_dim].values),
ensuring compatibility with parcel-level and vertex-level data."""

from __future__ import annotations
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -223,7 +235,8 @@ def __call__(
self,
ts: cdt.NDTimeSeries,
) -> xr.DataArray:
other_dim = xrutils.other_dim(ts, "time", "channel")
spatial_dim = xrutils.other_dim(ts, "time", "chromo")
other_dim = xrutils.other_dim(ts, "time", spatial_dim)
other_dim_values = ts[other_dim].values

tau = _to_dict(self.tau, other_dim_values)
Expand Down Expand Up @@ -290,7 +303,8 @@ def __call__(
self,
ts: cdt.NDTimeSeries,
) -> xr.DataArray:
other_dim = xrutils.other_dim(ts, "time", "channel")
spatial_dim = xrutils.other_dim(ts, "time", "chromo")
other_dim = xrutils.other_dim(ts, "time", spatial_dim)
other_dim_values = ts[other_dim].values

tau = _to_dict(self.tau, other_dim_values)
Expand Down Expand Up @@ -361,7 +375,9 @@ def __call__(
self,
ts: cdt.NDTimeSeries,
) -> xr.DataArray:
other_dim = xrutils.other_dim(ts, "time", "channel")

spatial_dim = xrutils.other_dim(ts, "time", "chromo")
other_dim = xrutils.other_dim(ts, "time", spatial_dim)
other_dim_values = ts[other_dim].values

p = _to_dict(self.p, other_dim_values)
Expand Down Expand Up @@ -414,7 +430,8 @@ def __call__(
self,
ts: cdt.NDTimeSeries,
) -> xr.DataArray:
other_dim = xrutils.other_dim(ts, "time", "channel")
spatial_dim = xrutils.other_dim(ts, "time", "chromo")
other_dim = xrutils.other_dim(ts, "time", spatial_dim)
other_dim_values = ts[other_dim].values

n_samples = 2
Expand Down
132 changes: 104 additions & 28 deletions src/cedalion/models/glm/design_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,30 @@
"""Functions to create the design matrix for the GLM."""
"""Functions to create the design matrix for the GLM.
Modifications for Image Space (Parcel/Vertex) Compatibility
-----------------------------------------------------------

This script extends Cedalion functions originally designed for channel space,
allowing them to also support image space data such as parcel-level or vertex-level time series.

Key changes include:
- Added flexible handling of spatial dimensions using:
spatial_dim = xrutils.other_dim(ts, "time", "chromo")
This enables the code to automatically detect and operate over "parcel", "vertex", or "channel" dimensions.
- Replaced hardcoded references to "channel" with dynamic spatial dimension references (e.g., [spatial_dim].values),
ensuring compatibility with parcel-level and vertex-level data.
- Added a check for "samples" coordinate in the input time series.
If missing, it assigns a default sample index via:
ts = ts.assign_coords(samples=("time", np.arange(ts.sizes["time"])))
This ensures compatibility with downstream functions that expect a "samples" coordinate,
particularly during design matrix construction and plotting routines.
Additionally, the "time" coordinate is explicitly labeled with units ("s") for clarity.
- Added assertions in all short-channel regressor functions
("closest_short_channel_regressor", "max_corr_short_channel_regressor",
and "average_short_channel_regressor") to ensure they are only used in channel space.
- Added a new helper function "global_mean_regressor" that creates a global signal regressor
by averaging over the spatial dimension (works for channel, parcel, or vertex space).
- (Suggested improvement) In __repr__, add a check for "self.common is not None"
to make the representation robust when only channel-wise regressors are present.
"""

from __future__ import annotations

Expand Down Expand Up @@ -82,26 +108,30 @@ def iter_computational_groups(
A tuple containing:
- dim3 (str): The third dimension name.
- group_y (cdt.NDTimeSeries): The grouped time series.
- "unit_values" (np.ndarray): Values of the spatial dimension
(e.g., "channel", "parcel", or "vertex") that belong to the current computational group.
- group_design_matrix (xr.DataArray): The grouped design matrix.
"""

channel_wise_regressors = self.channel_wise

dim3_name = xrutils.other_dim(self.common, "time", "regressor")

spatial_dim = xrutils.other_dim(ts, "time", "chromo")

for cwreg in self.channel_wise:
assert cwreg.sizes["regressor"] == 1
assert (ts.channel.values == cwreg.channel.values).all()
assert (ts[spatial_dim].values == cwreg[spatial_dim].values).all()

comp_groups = []
for reg in self.channel_wise:
if "comp_group" in reg.coords:
comp_groups.append(reg["comp_group"].values)
else:
comp_groups.append(_hash_channel_wise_regressor(reg))
comp_groups.append(_hash_channel_wise_regressor(reg,spatial_dim))

if channel_groups is not None:
assert len(channel_groups) == ts.sizes["channel"]
assert len(channel_groups) == ts.sizes[spatial_dim]
comp_groups.append(channel_groups)

if len(comp_groups) == 0:
Expand All @@ -110,55 +140,58 @@ def iter_computational_groups(
for dim3 in self.common[dim3_name].values:
dm = self.common.sel({dim3_name: dim3})
# group_y = ts.sel({dim3_name: dim3})
channels = ts.channel.values
values = ts[spatial_dim].values
# yield dim3, group_y, dm
yield dim3, channels, dm
yield dim3, values, dm

return
else:
# there are channel-wise regressors. For each computational group, in which
# the channel-wise regressors are identical, we have to assemble and yield
# the design-matrix.

chan_idx_with_same_comp_group = defaultdict(list)
idx_with_same_comp_group = defaultdict(list)

for i_ch, all_comp_groups in enumerate(zip(*comp_groups)):
chan_idx_with_same_comp_group[all_comp_groups].append(i_ch)
for i_unit, all_comp_groups in enumerate(zip(*comp_groups)):
idx_with_same_comp_group[all_comp_groups].append(i_unit)

for dim3 in self.common[dim3_name].values:
dm = self.common.sel({dim3_name: dim3})

for chan_indices in chan_idx_with_same_comp_group.values():
channels = ts.channel[np.asarray(chan_indices)].values
for chan_indices in idx_with_same_comp_group.values():
unit_values = ts[spatial_dim][np.asarray(chan_indices)].values

regs = []
for reg in channel_wise_regressors:
regs.append(
reg.sel({"channel": channels, dim3_name: dim3})
.isel(channel=0) # regs are identical within a group
reg.sel({spatial_dim: unit_values, dim3_name: dim3})
.isel(spatial_dim=0) # regs are identical within a group
.pint.dequantify()
)

group_design_matrix = xr.concat([dm] + regs, dim="regressor")

# yield dim3, group_y, group_design_matrix
yield dim3, channels, group_design_matrix
yield dim3, unit_values, group_design_matrix


def _hash_channel_wise_regressor(regressor: xr.DataArray) -> list[int]:
"""Hashes each channel slice of the regressor array.
"""Hashes each unit slice of the regressor array along the spatial dimension.

Args:
regressor: array of channel-wise regressors. Dims
(channel, regressor, time, chromo|wavelength)
regressor: array of regressors. Dims
(spatial_dim, regressor, time, chromo|wavelength)

Returns:
A list of hash values, one hash for each channel.
A list of hash values, one for each element of the spatial dimension.
"""

spatial_dim = xrutils.other_dim(regressor, "time", "chromo")

tmp = regressor.pint.dequantify()
n_channel = regressor.sizes["channel"]
return [hash(tmp.isel(channel=i).values.data.tobytes()) for i in range(n_channel)]
n_channel = regressor.sizes[spatial_dim]

return [hash(tmp.isel({spatial_dim: i}).values.data.tobytes()) for i in range(n_channel)]


def hrf_regressors(
Expand All @@ -180,13 +213,17 @@ def hrf_regressors(
# so that users can pass their own individual hrf function

trial_types: np.ndarray = stim.trial_type.unique()

if "samples" not in ts.coords:
ts = ts.assign_coords(samples=("time", np.arange(ts.sizes["time"])))
ts.time.attrs["units"] = "s"

basis = basis_function(ts)

components = basis.component.values
spatial_dim = xrutils.other_dim(ts, "time", "chromo")

# could be "chromo" or "wavelength"
other_dim = xrutils.other_dim(ts, "channel", "time")
other_dim = xrutils.other_dim(ts, spatial_dim, "time")

n_time = ts.sizes["time"]
n_other = ts.sizes[other_dim]
Expand Down Expand Up @@ -269,7 +306,8 @@ def drift_regressors(ts: cdt.NDTimeSeries, drift_order) -> DesignMatrix:
Returns:
xr.DataArray: A DataArray containing the drift regressors.
"""
dim3 = xrutils.other_dim(ts, "channel", "time")
spatial_dim = xrutils.other_dim(ts, "time", "chromo")
dim3 = xrutils.other_dim(ts, spatial_dim, "time")
ndim3 = ts.sizes[dim3]

nt = ts.sizes["time"]
Expand Down Expand Up @@ -306,7 +344,8 @@ def drift_legendre_regressors(ts : cdt.NDTimeSeries, order : int) -> DesignMatri
xr.DataArray: A DataArray containing the drift regressors.
"""

dim3 = xrutils.other_dim(ts, "channel", "time")
spatial_dim = xrutils.other_dim(ts, "time", "chromo")
dim3 = xrutils.other_dim(ts, spatial_dim, "time")
ndim3 = ts.sizes[dim3]

nt = ts.sizes["time"]
Expand Down Expand Up @@ -344,8 +383,8 @@ def drift_cosine_regressors(ts: cdt.NDTimeSeries, fmax: cdt.QFrequency) -> Desig
Returns:
xr.DataArray: A DataArray containing the drift regressors.
"""

dim3 = xrutils.other_dim(ts, "channel", "time")
spatial_dim = xrutils.other_dim(ts, "time", "chromo")
dim3 = xrutils.other_dim(ts, spatial_dim, "time")
ndim3 = ts.sizes[dim3]

nt = ts.sizes["time"]
Expand Down Expand Up @@ -429,6 +468,12 @@ def _regressors_from_selected_short_channels(
selected_short_ch_indices: np.ndarray,
) -> xr.DataArray:
"""Build channel-wise short-channel regressors from a selection."""

spatial_dim = xrutils.other_dim(ts_long, "time", "chromo")
assert spatial_dim == "channel", (
f"Short-channel regressors only make sense in channel space, "
f"but got '{spatial_dim}'."
)

# pick for each long channel from ts_short the selected closest channel
# regressors has same dims as ts_long/ts_short and same channels as ts_long
Expand Down Expand Up @@ -474,6 +519,11 @@ def closest_short_channel_regressor(
Returns:
regressors (xr.DataArray): Channel-wise regressor
"""
spatial_dim = xrutils.other_dim(ts_long, "time", "chromo")
assert spatial_dim == "channel", (
f"Short-channel regressors only make sense in channel space, "
f"but got '{spatial_dim}'."
)
# calculate midpoints between channel optode pairs. dims: (channel, crs)
long_channel_pos = (geo3d.loc[ts_long.source] + geo3d.loc[ts_long.detector]) / 2
short_channel_pos = (geo3d.loc[ts_short.source] + geo3d.loc[ts_short.detector]) / 2
Expand Down Expand Up @@ -511,7 +561,11 @@ def max_corr_short_channel_regressor(
Returns:
xr.DataArray: channel-wise regressors
"""

spatial_dim = xrutils.other_dim(ts_long, "time", "chromo")
assert spatial_dim == "channel", (
f"Short-channel regressors only make sense in channel space, "
f"but got '{spatial_dim}'."
)
dim3 = xrutils.other_dim(ts_long, "channel", "time")

z_long = (ts_long - ts_long.mean("time")) / ts_long.std("time")
Expand Down Expand Up @@ -550,10 +604,32 @@ def average_short_channel_regressor(ts_short: cdt.NDTimeSeries):
Returns:
xr.DataArray: regressors
"""

spatial_dim = xrutils.other_dim(ts_short, "time", "chromo")
assert spatial_dim == "channel", (
f"Average short-channel regressor only makes sense in channel space, "
f"but got '{spatial_dim}'."
)
ts_short = ts_short.pint.dequantify()
regressor = ts_short.mean("channel", skipna=True).expand_dims("regressor")
regressor = regressor.assign_coords({"regressor": ["short"]})
regressor = regressor.transpose("time", "regressor", ...)

return DesignMatrix(common=regressor, channel_wise=[])
def global_mean_regressor(ts: cdt.NDTimeSeries) -> DesignMatrix:
"""Create a global regressor by averaging over the spatial dimension.

Args:
ts (NDTimeSeries): time series data (time x spatial_dim x chromo)

Returns:
DesignMatrix: design matrix with one global regressor
"""
# detect the spatial dimension dynamically (channel, parcel, or vertex)
spatial_dim = xrutils.other_dim(ts, "time", "chromo")

# mean over spatial dimension → global signal
regressor = ts.mean(spatial_dim, skipna=True).expand_dims("regressor")
regressor = regressor.assign_coords({"regressor": ["global"]})
regressor = regressor.transpose("time", "regressor", ...)

return DesignMatrix(common=regressor, channel_wise=[])
Loading