Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/cedalion/imagereco/forward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,13 +1124,16 @@ def compute_stacked_sensitivity(sensitivity: xr.DataArray):


def apply_inv_sensitivity(
od: cdt.NDTimeSeries, inv_sens: xr.DataArray
od: cdt.NDTimeSeries, inv_sens: xr.DataArray, chunk: bool = True,
) -> tuple[xr.DataArray, xr.DataArray]:
"""Apply the inverted sensitivity matrix to optical density data.

Args:
od: time series of optical density data
inv_sens: the inverted sensitivity matrix
chunk: optional piecewise matrix multiplication.
default True, gets active if more than 1000 time samples are to be converted.
False force-skips chunking.

Returns:
Two DataArrays for the brain and scalp with the reconcstructed time series per
Expand All @@ -1142,7 +1145,19 @@ def apply_inv_sensitivity(
od_stacked = od.stack({"flat_channel": ["wavelength", "channel"]})
od_stacked = od_stacked.pint.dequantify()

delta_conc = inv_sens @ od_stacked
# for image recon we have time-series data either with "time" or "reltime" dimension
sample_dim = next((d for d in ["time","reltime"] if d in od_stacked.dims), None)

# if od_stacked has more than 1000 time points, chunk it
if (od_stacked.sizes["time"] > 1000) and chunk:
delta_conc = xrutils.chunked_eff_xr_matmult(
od_stacked,
inv_sens,
contract_dim="flat_channel",
sample_dim = sample_dim,
chunksize=1000)
else:
delta_conc = inv_sens @ od_stacked

# Construct a multiindex for dimension flat_vertex from chromo and vertex.
# Afterwards use this multiindex to unstack flat_vertex. The resulting array
Expand Down
99 changes: 99 additions & 0 deletions src/cedalion/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np
import pint
import xarray as xr
import os
import tempfile
import shutil


def pinv(array: xr.DataArray) -> xr.DataArray:
Expand Down Expand Up @@ -231,3 +234,99 @@ def unit_stripping_is_error(is_error : bool = True):
if f[0] =="error" and f[2] == pint.errors.UnitStrippedWarning:
del warnings.filters[i]
break


def chunked_eff_xr_matmult(
A: xr.DataArray,
B: xr.DataArray,
contract_dim: str,
sample_dim: str,
chunksize: int = 5000,
tmpdir: str | None = None
) -> xr.DataArray:
"""Performs a large matrix multiplication of A and B, chunking A along `sample_dim`; to avoid memory issues, streams each chunk to disk, and then rebuilds a full DataArray.

Args:
A: DataArray to multiply (dims include `contract_dim` and `sample_dim` among others)
B: DataArray defining the mat-mul (dims include `contract_dim` and others)
contract_dim: name of the dimension to contract (e.g. "flat_channel")
sample_dim: name of the dimension along which to chunk (e.g. "time")
chunksize: max size of each chunk along dimension `sample_dim`
tmpdir: optional path to temp directory (auto‐created and removed if None)

Returns:
A new DataArray of containing the result of the matrix multiplication over `contract_dim`,
with coords, dims, and attrs preserved. Should yield the same result as `xr.dot(A, B, dims=[contract_dim])`
but at increased speed and with a much lower memory footprint.

Initial Contributors:
- Alexander von Lühmann | [email protected] | 2025
"""
# Total samples & number of chunks
N = A.sizes[sample_dim]
n_chunks = int(np.ceil(N / chunksize))

# Build a “shell” result for metadata by doing the dot on the first sample
A0 = A.isel({sample_dim: slice(0, 1)})
Xres = xr.dot(B, A0, dims=[contract_dim])

# Prepare for raw numpy multiply
dims_B_not = [d for d in B.dims if d != contract_dim]
dims_A_not = [d for d in A.dims if d != contract_dim]
B_mat = B.transpose(*dims_B_not, contract_dim).values
A2 = A.transpose(contract_dim, *dims_A_not)

# Create Temp directory
cleanup = False
if tmpdir is None:
tmpdir = tempfile.mkdtemp()
cleanup = True
else:
os.makedirs(tmpdir, exist_ok=True)

print(f"Large Matrix Multiplication: Processing {n_chunks} chunks...")

# Stream‐compute each chunk
file_paths = []
for i in range(n_chunks):
start = i * chunksize
stop = min((i + 1) * chunksize, N)
A_chunk = A2.isel({sample_dim: slice(start, stop)})
C_chunk = B_mat.dot(A_chunk.values) # raw (out_dim, chunk_len, ...)
fn = os.path.join(tmpdir, f"chunk_{i:04d}.npy")
np.save(fn, C_chunk)
file_paths.append(fn)
del A_chunk, C_chunk
print(f"Chunk {i+1}/{n_chunks} done.")

# Read back & concatenate along the sample axis
arrs = [np.load(fp) for fp in sorted(file_paths)]
axis = Xres.get_axis_num(sample_dim)
full_arr = np.concatenate(arrs, axis=axis)

if cleanup:
shutil.rmtree(tmpdir)

# create set of coordinates
coords = {
name: coord
for name, coord in Xres.coords.items()
if sample_dim not in coord.dims
}
sample_coords = {
name: coord
for name, coord in A.coords.items()
if sample_dim in coord.dims
}
coords.update(sample_coords)

# rebuild the DataArray using the Xres metadata
result = xr.DataArray(
data = full_arr,
dims = Xres.dims,
coords = coords,
attrs = Xres.attrs
)
# add time coords
result.assign_coords()
return result
Loading