diff --git a/src/cedalion/dataclasses/accessors.py b/src/cedalion/dataclasses/accessors.py index 36b12db4..b62010e3 100644 --- a/src/cedalion/dataclasses/accessors.py +++ b/src/cedalion/dataclasses/accessors.py @@ -23,7 +23,7 @@ def __init__(self, xarray_obj): """Initialize the CedalionAccessor. Args: - xarray_obj (xr.DataArray): The DataArray to which this accessor is attached. + xarray_obj: The DataArray to which this accessor is attached. """ self._validate(xarray_obj) self._obj = xarray_obj @@ -65,16 +65,16 @@ def to_epochs( return to_epochs(self._obj, df_stim, trial_types, before, after) - def freq_filter(self, fmin, fmax, butter_order=4): + def freq_filter(self, fmin: float, fmax: float, butter_order: int =4) -> xr.DataArray: """Applys a Butterworth filter. Args: - fmin (float): The lower cutoff frequency. - fmax (float): The upper cutoff frequency. - butter_order (int): The order of the Butterworth filter. + fmin: The lower cutoff frequency. + fmax: The upper cutoff frequency. + butter_order: The order of the Butterworth filter. Returns: - result (xarray.DataArray): The filtered time series. + result: The filtered time series. """ array = self._obj @@ -292,11 +292,11 @@ def _validate(obj): f"Stimulus DataFame must have column {column_name}." ) - def rename_events(self, rename_dict): + def rename_events(self, rename_dict: Dict[str, str]) -> None: """Renames trial types in the DataFrame based on the provided dictionary. Args: - rename_dict (dict): A dictionary with the old trial type as key and the new + rename_dict: A dictionary with the old trial type as key and the new trial type as value. """ stim = self._obj diff --git a/src/cedalion/dataclasses/geometry.py b/src/cedalion/dataclasses/geometry.py index 99620908..405b338f 100644 --- a/src/cedalion/dataclasses/geometry.py +++ b/src/cedalion/dataclasses/geometry.py @@ -1,10 +1,11 @@ """Dataclasses for representing geometric objects.""" +from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from functools import total_ordering -from typing import Any +from typing import Any, List import mne import numpy as np @@ -197,7 +198,7 @@ def apply_transform(self, transform: cdt.AffineTransform) -> "TrimeshSurface": """Apply an affine transformation to this surface. Args: - transform (cdt.AffineTransform): The affine transformation to apply. + transform: The affine transformation to apply. Returns: TrimeshSurface: The transformed surface. @@ -234,7 +235,11 @@ def smooth(self, lamb: float) -> "TrimeshSurface": smoothed = trimesh.smoothing.filter_taubin(self.mesh, lamb=lamb) return TrimeshSurface(smoothed, self.crs, self.units) - def get_vertex_normals(self, points: cdt.LabeledPointCloud, normalized=True): + def get_vertex_normals( + self, + points: cdt.LabeledPointCloud, + normalized: bool = True + ): """Get normals of vertices closest to the provided points.""" assert points.points.crs == self.crs @@ -400,7 +405,11 @@ def apply_transform(self, transform: cdt.AffineTransform) -> "PycortexSurface": def decimate(self, face_count: int) -> "PycortexSurface": raise NotImplementedError("Decimation not implemented for PycortexSurface") - def get_vertex_normals(self, points: cdt.LabeledPointCloud, normalized=True): + def get_vertex_normals( + self, + points: cdt.LabeledPointCloud, + normalized: bool = True + ): assert points.points.crs == self.crs assert points.pint.units == self.units points = points.pint.dequantify() @@ -601,7 +610,11 @@ def avg_edge_length(self): return edgelens.mean() - def surface_gradient(self, scalars, at_verts=True): + def surface_gradient( + self, + scalars: np.ndarray, + at_verts: bool = True + ) -> np.ndarray: """Gradient of a function with values `scalars` at each vertex on the surface. If `at_verts`, returns values at each vertex. Otherwise, returns values at each @@ -610,9 +623,8 @@ def surface_gradient(self, scalars, at_verts=True): Args: scalars : 1D ndarray, shape (total_verts,) a scalar-valued function across the cortex. - at_verts : bool, optional - If True (default), values will be returned for each vertex. Otherwise, - values will be returned for each face. + at_verts : If True (default), values will be returned for each vertex. + Otherwise, values will be returned for each face. Returns: gradu : 2D ndarray, shape (total_verts,3) or (total_polys,3) @@ -642,7 +654,7 @@ def _facenorm_cross_edge(self): return fe12, fe23, fe31 - def geodesic_distance(self, verts, m=1.0, fem=False): + def geodesic_distance(self, verts, m: float = 1.0, fem: bool = False) -> np.ndarray: """Calcualte the inimum mesh geodesic distance (in mm). The geodesic distance is calculated from each vertex in surface to any vertex in @@ -745,7 +757,13 @@ def geodesic_distance(self, verts, m=1.0, fem=False): return phi - def geodesic_path(self, a, b, max_len=1000, d=None, **kwargs): + def geodesic_path( + self, a: int, + b: int, + max_len: int = 1000, + d: np.ndarray = None, + **kwargs + ) -> List: """Finds the shortest path between two points `a` and `b`. This shortest path is based on geodesic distances across the surface. diff --git a/src/cedalion/dataclasses/recording.py b/src/cedalion/dataclasses/recording.py index c9057fb3..c9dbda11 100644 --- a/src/cedalion/dataclasses/recording.py +++ b/src/cedalion/dataclasses/recording.py @@ -70,7 +70,7 @@ def get_timeseries(self, key: Optional[str] = None) -> NDTimeSeries: """Get a timeseries object by key. Args: - key (Optional[str]): The key of the timeseries to retrieve. If None, the + key: The key of the timeseries to retrieve. If None, the last timeseries is returned. Returns: @@ -104,7 +104,7 @@ def get_mask(self, key: Optional[str] = None) -> xr.DataArray: """Get a mask by key. Args: - key (Optional[str]): The key of the mask to retrieve. If None, the last + key: The key of the mask to retrieve. If None, the last mask is returned. Returns: @@ -124,9 +124,9 @@ def set_mask(self, key: str, value: xr.DataArray, overwrite: bool = False): """Set a mask. Args: - key (str): The key of the mask to set. - value (xr.DataArray): The mask to set. - overwrite (bool): Whether to overwrite an existing mask with the same key. + key: The key of the mask to set. + value: The mask to set. + overwrite: Whether to overwrite an existing mask with the same key. Defaults to False. """ if (overwrite is False) and (key in self.masks): @@ -138,7 +138,7 @@ def get_timeseries_type(self, key): """Get the type of a timeseries. Args: - key (str): The key of the timeseries. + key: The key of the timeseries. Returns: str: The type of the timeseries. diff --git a/src/cedalion/dataclasses/schemas.py b/src/cedalion/dataclasses/schemas.py index 1510ba52..d5c5b0eb 100644 --- a/src/cedalion/dataclasses/schemas.py +++ b/src/cedalion/dataclasses/schemas.py @@ -1,5 +1,6 @@ """Data array schemas and utilities to build labeled data arrays.""" +from __future__ import annotations import functools import inspect import typing @@ -99,20 +100,20 @@ def build_timeseries( value_units: str, time_units: str, other_coords: dict[str, ArrayLike] = {}, -): +) -> xr.DataArray: """Build a labeled time series data array. Args: - data (ArrayLike): The data values. - dims (List[str]): The dimension names. - time (ArrayLike): The time values. - channel (List[str]): The channel names. - value_units (str): The units of the data values. - time_units (str): The units of the time values. - other_coords (dict[str, ArrayLike]): Additional coordinates. + data: The data values. + dims: The dimension names. + time: The time values. + channel: The channel names. + value_units: The units of the data values. + time_units: The units of the time values. + other_coords: Additional coordinates. Returns: - da (xr.DataArray): The labeled time series data array. + da: The labeled time series data array. """ assert len(dims) == data.ndim assert "time" in dims diff --git a/src/cedalion/geometry/landmarks.py b/src/cedalion/geometry/landmarks.py index 3952b468..6bc333e4 100644 --- a/src/cedalion/geometry/landmarks.py +++ b/src/cedalion/geometry/landmarks.py @@ -1,5 +1,6 @@ """Module for constructing the 10-10-system on the scalp surface.""" +from __future__ import annotations import warnings from typing import List, Optional @@ -125,8 +126,8 @@ def __init__(self, scalp_surface: Surface, landmarks: LabeledPointCloud): """Initialize the LandmarksBuilder1010. Args: - scalp_surface (Surface): a triangle-mesh representing the scalp - landmarks (LabeledPointCloud): positions of "Nz", "Iz", "LPA", "RPA" + scalp_surface: a triangle-mesh representing the scalp + landmarks: positions of "Nz", "Iz", "LPA", "RPA" """ if isinstance(scalp_surface, TrimeshSurface): scalp_surface = VTKSurface.from_trimeshsurface(scalp_surface) @@ -197,9 +198,9 @@ def _add_landmarks_along_line( """Add landmarks along a line defined by three landmarks. Args: - triangle_labels (List[str]): Labels of the three landmarks defining the line - labels (List[str]): Labels for the new landmarks - dists (List[float]): Distances along the line where the new landmarks should + triangle_labels: Labels of the three landmarks defining the line + labels: Labels for the new landmarks + dists: Distances along the line where the new landmarks should be placed. """ assert len(triangle_labels) == 3 @@ -319,8 +320,8 @@ def order_ref_points_6(landmarks: xr.DataArray, twoPoints: str) -> xr.DataArray: """Reorder a set of six landmarks based on spatial relationships and give labels. Args: - landmarks (xr.DataArray): coordinates for six landmark points - twoPoints (str): two reference points ('Nz' or 'Iz') for orientation. + landmarks: coordinates for six landmark points + twoPoints: two reference points ('Nz' or 'Iz') for orientation. Returns: xr.DataArray: the landmarks ordered as "Nz", "Iz", "RPA", "LPA", "Cz" diff --git a/src/cedalion/geometry/photogrammetry/processors.py b/src/cedalion/geometry/photogrammetry/processors.py index 32be2340..55a9310f 100644 --- a/src/cedalion/geometry/photogrammetry/processors.py +++ b/src/cedalion/geometry/photogrammetry/processors.py @@ -1,5 +1,6 @@ """Vertex classifiers.""" +from __future__ import annotations import colorsys from abc import ABC, abstractmethod from dataclasses import dataclass diff --git a/src/cedalion/geometry/registration.py b/src/cedalion/geometry/registration.py index 93ae2d37..3d50adea 100644 --- a/src/cedalion/geometry/registration.py +++ b/src/cedalion/geometry/registration.py @@ -1,10 +1,12 @@ """Registrating optodes to scalp surfaces.""" +from __future__ import annotations import numpy as np from numpy.linalg import pinv from scipy.optimize import linear_sum_assignment, minimize from scipy.spatial import KDTree import xarray as xr +from typing import Tuple import cedalion import cedalion.dataclasses as cdc @@ -38,8 +40,8 @@ def register_trans_rot( between the two point clouds. Args: - coords_target (LabeledPointCloud): Target point cloud. - coords_trafo (LabeledPointCloud): Source point cloud. + coords_target: Target point cloud. + coords_trafo: Source point cloud. Returns: cdt.AffineTransform: Affine transformation between the two point clouds. @@ -133,8 +135,8 @@ def register_trans_rot_isoscale( between the two point clouds. Args: - coords_target (LabeledPointCloud): Target point cloud. - coords_trafo (LabeledPointCloud): Source point cloud. + coords_target: Target point cloud. + coords_trafo: Source point cloud. Returns: cdt.AffineTransform: Affine transformation between the two point clouds. @@ -195,9 +197,9 @@ def gen_xform_from_pts(p1: np.ndarray, p2: np.ndarray) -> np.ndarray: """Calculate the affine transformation matrix T that transforms p1 to p2. Args: - p1 (np.ndarray): Source points (p x m) where p is the number of points and m is + p1: Source points (p x m) where p is the number of points and m is the number of dimensions. - p2 (np.ndarray): Target points (p x m) where p is the number of points and m is + p2: Target points (p x m) where p is the number of points and m is the number of dimensions. Returns: @@ -237,21 +239,21 @@ def register_icp( surface: cdc.Surface, landmarks: cdt.LabeledPointCloud, geo3d: cdt.LabeledPointCloud, - niterations=1000, - random_sample_fraction=0.5, -): + niterations: int =1000, + random_sample_fraction: float =0.5, +) -> Tuple[np.ndarray, np.ndarray]: """Iterative Closest Point algorithm for registration. Args: - surface (Surface): Surface mesh to which to register the points. - landmarks (LabeledPointCloud): Landmarks to use for registration. - geo3d (LabeledPointCloud): Points to register to the surface. - niterations (int): Number of iterations for the ICP algorithm (default 1000). - random_sample_fraction (float): Fraction of points to use in each iteration + surface: Surface mesh to which to register the points. + landmarks: Landmarks to use for registration. + geo3d: Points to register to the surface. + niterations: Number of iterations for the ICP algorithm (default 1000). + random_sample_fraction: Fraction of points to use in each iteration (default 0.5). Returns: - Tuple[np.ndarray, np.ndarray]: Tuple containing the losses and transformations + Tuple containing the losses and transformations """ units = "mm" landmarks_mm = landmarks.pint.to(units).points.to_homogeneous().pint.dequantify() @@ -496,7 +498,7 @@ def simple_scalp_projection(geo3d: cdt.LabeledPointCloud) -> cdt.LabeledPointClo """Projects 3D coordinates onto a 2D plane using a simple scalp projection. Args: - geo3d (LabeledPointCloud): 3D coordinates of points to project. Requires the + geo3d: 3D coordinates of points to project. Requires the landmarks Nz, LPA, and RPA. Returns: diff --git a/src/cedalion/geometry/segmentation.py b/src/cedalion/geometry/segmentation.py index 6ec5b1bd..b4b0af87 100644 --- a/src/cedalion/geometry/segmentation.py +++ b/src/cedalion/geometry/segmentation.py @@ -21,24 +21,19 @@ def voxels_from_segmentation( segmentation_mask: xr.DataArray, segmentation_types: List[str], - isovalue=0.9, + isovalue: float = 0.9, fill_holes_in_mask=False, ) -> cdc.Voxels: """Generate voxels from a segmentation mask. Args: - segmentation_mask : xr.DataArray - Segmentation mask. - segmentation_types : List[str] - List of segmentation types. - isovalue : float, optional - Isovalue for marching cubes, by default 0.9. - fill_holes_in_mask : bool, optional - Fill holes in the mask, by default False. + segmentation_mask: Segmentation mask. + segmentation_types: List of segmentation types. + isovalue: Isovalue for marching cubes, by default 0.9. + fill_holes_in_mask: Fill holes in the mask, by default False. Returns: - cdc.Voxels - Voxels in voxel space. + Voxels in voxel space. """ combined_mask = ( segmentation_mask.sel(segmentation_type=segmentation_types) @@ -65,16 +60,16 @@ def surface_from_segmentation( """Create a surface from a segmentation mask. Args: - segmentation_mask (xr.DataArray): Segmentation mask with dimensions segmentation + segmentation_mask: Segmentation mask with dimensions segmentation type, i, j, k. - segmentation_types (List[str]): A list of segmentation types to include in the + segmentation_types: A list of segmentation types to include in the surface. - isovalue (Float): The isovalue to use for the marching cubes algorithm. - fill_holes_in_mask (Bool): Whether to fill holes in the mask before creating the + isovalue: The isovalue to use for the marching cubes algorithm. + fill_holes_in_mask: Whether to fill holes in the mask before creating the surface. Returns: - A cedalion.Surface object. + Surface in voxel space """ combined_mask = ( @@ -82,25 +77,7 @@ def surface_from_segmentation( .any("segmentation_type") .values ) - """ Generate a surface from a segmentation mask. - - Parameters - ---------- - segmentation_mask : xr.DataArray - Segmentation mask. - segmentation_types : List[str] - List of segmentation types. - isovalue : float, optional - Isovalue for marching cubes, by default 0.9. - fill_holes_in_mask : bool, optional - Fill holes in the mask, by default False. - - Returns - ------- - cdc.TrimeshSurface - Surface in voxel space. - """ if fill_holes_in_mask: combined_mask = scipy.ndimage.binary_fill_holes(combined_mask).astype( combined_mask.dtype @@ -114,19 +91,14 @@ def surface_from_segmentation( return cdc.TrimeshSurface(mesh, "ijk", cedalion.units.Unit("1")) -def cell_coordinates(volume, flat: bool = False): +def cell_coordinates(volume: np.ndarray, flat: bool = False) -> xr.DataArray: """Generate cell coordinates from a 3D volume. - Parameters - ---------- - volume : np.ndarray - 3D volume. - flat : bool, optional - If True, return coordinates as a flat array, by default False. + Args: + volume: 3D volume. + flat : If True, return coordinates as a flat array, by default False. Returns: - ------- - xr.DataArray Cell coordinates in voxel space. """ @@ -190,37 +162,23 @@ def segmentation_postprocessing( ) -> dict: """Postprocessing of the segmented SPM12 MRI segmentation files. - Parameters - ---------- - segmentation_dir : str - Directory where the segmented files are stored. - mask_files : dict[str, str], optional - Dictionary containing the filenames of the segmented tissues. - isSmooth : bool, optional - Smooth the segmented tissues using Gaussian filter. - fixCSF : bool, optional - Fix the CSF continuity. - removeDisconnected : bool, optional - Remove disconnected voxels. - labelUnassigned : bool, optional - Label empty voxels to the nearest tissue type. - removeAir : bool, optional - Remove air cavities. - subtractTissues : bool, optional - Subtract tissues from each others - + Args: + segmentation_dir: Directory where the segmented files are stored. + mask_files: Dictionary containing the filenames of the segmented tissues. + isSmooth: Smooth the segmented tissues using Gaussian filter. + fixCSF: Fix the CSF continuity. + removeDisconnected: Remove disconnected voxels. + labelUnassigned: Label empty voxels to the nearest tissue type. + removeAir: Remove air cavities. + subtractTissues: Subtract tissues from each others. Returns: - ------- - mask_files : dict Dictionary containing the filenames of the postprocessed masks. - References: - ---------- - This whole postprocessing is based on the following references: - :cite:t:`Huang2013` - :cite:t:`Harmening2022` + This whole postprocessing is based on the following references: + :cite:t:`Huang2013` + :cite:t:`Harmening2022` """ # Load segmented spm output files diff --git a/src/cedalion/imagereco/forward_model.py b/src/cedalion/imagereco/forward_model.py index f21400f4..c645b9bb 100644 --- a/src/cedalion/imagereco/forward_model.py +++ b/src/cedalion/imagereco/forward_model.py @@ -40,38 +40,22 @@ class TwoSurfaceHeadModel: surfaces. Attributes: - segmentation_masks : xr.DataArray - Segmentation masks of the head for each tissue type. - brain : cdc.Surface - Surface of the brain. - scalp : cdc.Surface - Surface of the scalp. - landmarks : cdt.LabeledPointCloud - Anatomical landmarks in RAS space. - t_ijk2ras : cdt.AffineTransform - Affine transformation from ijk to RAS space. - t_ras2ijk : cdt.AffineTransform - Affine transformation from RAS to ijk space. - voxel_to_vertex_brain : scipy.sparse.spmatrix - Mapping from voxel to brain vertices. - voxel_to_vertex_scalp : scipy.sparse.spmatrix - Mapping from voxel to scalp vertices. - crs : str - Coordinate reference system of the head model. + segmentation_masks: Segmentation masks of the head for each tissue type. + brain: Surface of the brain. + scalp: Surface of the scalp. + landmarks: Anatomical landmarks in RAS space. + t_ijk2ras: Affine transformation from ijk to RAS space. + t_ras2ijk: Affine transformation from RAS to ijk space. + voxel_to_vertex_brain: Mapping from voxel to brain vertices. + voxel_to_vertex_scalp: Mapping from voxel to scalp vertices. + crs: Coordinate reference system of the head model. Methods: - from_segmentation(cls, segmentation_dir, mask_files, landmarks_ras_file, - brain_seg_types, scalp_seg_types, smoothing, brain_face_count, - scalp_face_count): Construct instance from segmentation masks in NIfTI - format. - apply_transform(transform) - Apply a coordinate transformation to the head model. - save(foldername) - Save the head model to a folder. - load(foldername) - Load the head model from a folder. - align_and_snap_to_scalp(points) - Align and snap optodes or points to the scalp surface. + from_segmentation: Construct instance from segmentation masks in NIfTI format. + apply_transform: Apply a coordinate transformation to the head model. + save: Save the head model to a folder. + load: Load the head model from a folder. + align_and_snap_to_scalp: Align and snap optodes or points to the scalp surface. """ segmentation_masks: xr.DataArray @@ -107,19 +91,15 @@ def from_segmentation( """Constructor from binary masks as gained from segmented MRI scans. Args: - segmentation_dir (str): Folder containing the segmentation masks in NIFTI - format. - mask_files (Dict[str, str]): Dictionary mapping segmentation types to NIFTI - filenames. - landmarks_ras_file (Optional[str]): Filename of the landmarks in RAS space. - brain_seg_types (list[str]): List of segmentation types to be included in - the brain surface. - scalp_seg_types (list[str]): List of segmentation types to be included in - the scalp surface. - smoothing(float): Smoothing factor for the brain and scalp surfaces. - brain_face_count (Optional[int]): Number of faces for the brain surface. - scalp_face_count (Optional[int]): Number of faces for the scalp surface. - fill_holes (bool): Whether to fill holes in the segmentation masks. + segmentation_dir: Folder containing the segmentation masks in NIFTI format. + mask_files: Dictionary mapping segmentation types to NIFTI filenames. + landmarks_ras_file: Filename of the landmarks in RAS space. + brain_seg_types: List of segmentation types to be included in the brain surface. + scalp_seg_types: List of segmentation types to be included in the scalp surface. + smoothing: Smoothing factor for the brain and scalp surfaces. + brain_face_count: Number of faces for the brain surface. + scalp_face_count: Number of faces for the scalp surface. + fill_holes: Whether to fill holes in the segmentation masks. """ # load segmentation mask @@ -231,24 +211,20 @@ def from_surfaces( """Constructor from seg.masks, brain and head surfaces as gained from MRI scans. Args: - segmentation_dir (str): Folder containing the segmentation masks in NIFTI - format. - mask_files (dict[str, str]): Dictionary mapping segmentation types to NIFTI - filenames. - brain_surface_file (str): Path to the brain surface. - scalp_surface_file (str): Path to the scalp surface. - landmarks_ras_file (Optional[str]): Filename of the landmarks in RAS space. - brain_seg_types (list[str]): List of segmentation types to be included in - the brain surface. - scalp_seg_types (list[str]): List of segmentation types to be included in - the scalp surface. - smoothing (float): Smoothing factor for the brain and scalp surfaces. - brain_face_count (Optional[int]): Number of faces for the brain surface. - scalp_face_count (Optional[int]): Number of faces for the scalp surface. - fill_holes (bool): Whether to fill holes in the segmentation masks. + segmentation_dir: Folder containing the segmentation masks in NIFTI format. + mask_files: Dictionary mapping segmentation types to NIFTI filenames. + brain_surface_file: Path to the brain surface. + scalp_surface_file: Path to the scalp surface. + landmarks_ras_file: Filename of the landmarks in RAS space. + brain_seg_types: List of segmentation types to be included in the brain surface. + scalp_seg_types: List of segmentation types to be included in the scalp surface. + smoothing: Smoothing factor for the brain and scalp surfaces. + brain_face_count: Number of faces for the brain surface. + scalp_face_count: Number of faces for the scalp surface. + fill_holes: Whether to fill holes in the segmentation masks. Returns: - TwoSurfaceHeadModel: An instance of the TwoSurfaceHeadModel class. + An instance of the TwoSurfaceHeadModel class. """ # load segmentation mask @@ -382,7 +358,7 @@ def save(self, foldername: str): """Save the head model to a folder. Args: - foldername (str): Folder to save the head model into. + foldername: Folder to save the head model into. Returns: None @@ -417,7 +393,7 @@ def load(cls, foldername: str): """Load the head model from a folder. Args: - foldername (str): Folder to load the head model from. + foldername: Folder to load the head model from. Returns: TwoSurfaceHeadModel: Loaded head model. @@ -486,11 +462,11 @@ def align_and_snap_to_scalp( """Align and snap optodes or points to the scalp surface. Args: - points (cdt.LabeledPointCloud): Points to be aligned and snapped to the + points: Points to be aligned and snapped to the scalp surface. Returns: - cdt.LabeledPointCloud: Points aligned and snapped to the scalp surface. + Points aligned and snapped to the scalp surface. """ assert self.landmarks is not None, "Please add landmarks in RAS to head \ @@ -509,11 +485,11 @@ def snap_to_scalp_voxels( """Snap optodes or points to the closest scalp voxel. Args: - points (cdt.LabeledPointCloud): Points to be snapped to the closest scalp + points: Points to be snapped to the closest scalp voxel. Returns: - cdt.LabeledPointCloud: Points aligned and snapped to the closest scalp + Points aligned and snapped to the closest scalp voxel. """ # Align to scalp surface @@ -611,11 +587,9 @@ def __init__( """Constructor for the forward model. Args: - head_model (TwoSurfaceHeadModel): Head model containing voxel projections to - brain and scalp surfaces. - geo3d (cdt.LabeledPointCloud): Optode positions and directions. - measurement_list (pd.DataFrame): List of measurements of experiment with - source, detector, channel and wavelength. + head_model: Head model containing voxel projections to brain and scalp surfaces. + geo3d: Optode positions and directions. + measurement_list: List of measurements of experiment with source, detector, channel, and wavelength. """ assert head_model.crs == "ijk" # FIXME @@ -713,11 +687,11 @@ def _fluence_at_optodes(self, fluence, emitting_opt): """Fluence caused by one optode at the positions of all other optodes. Args: - fluence (np.ndarray): Fluence in each voxel. - emitting_opt (int): Index of the emitting optode. + fluence: Fluence in each voxel. + emitting_opt: Index of the emitting optode. Returns: - np.ndarray: Fluence at all optode positions. + Fluence at all optode positions. """ n_optodes = len(self.optode_pos) @@ -842,7 +816,7 @@ def compute_fluence_nirfaster(self, meshingparam=None): mesher. Note: they should all be double Returns: - xr.DataArray: Fluence in each voxel for each channel and wavelength. + Fluence in each voxel for each channel and wavelength. References: (:cite:t:`Dehghani2009`) Dehghani, Hamid, et al. "Near infrared optical @@ -963,16 +937,15 @@ def compute_fluence_nirfaster(self, meshingparam=None): return fluence_all, fluence_at_optodes - def compute_sensitivity(self, fluence_all, fluence_at_optodes): + def compute_sensitivity(self, fluence_all: xr.DataArray, fluence_at_optodes: xr.DataArray): """Compute sensitivity matrix from fluence. Args: - fluence_all (xr.DataArray): Fluence in each voxel for each wavelength. - fluence_at_optodes (xr.DataArray): Fluence at all optode positions for each - wavelength. + fluence_all: Fluence in each voxel for each wavelength. + fluence_at_optodes: Fluence at all optode positions for each wavelength. Returns: - xr.DataArray: Sensitivity matrix for each channel, vertex and wavelength. + Sensitivity matrix for each channel, vertex and wavelength. """ channels = self.measurement_list.channel.unique().tolist() @@ -1045,11 +1018,10 @@ def compute_stacked_sensitivity(sensitivity: xr.DataArray): """Compute stacked HbO and HbR sensitivity matrices from fluence. Args: - sensitivity (xr.DataArray): Sensitivity matrix for each vertex and - wavelength. + sensitivity: Sensitivity matrix for each vertex and wavelength. Returns: - xr.DataArray: Stacked sensitivity matrix for each channel and vertex. + Stacked sensitivity matrix for each channel and vertex. """ assert "wavelength" in sensitivity.dims diff --git a/src/cedalion/imagereco/solver.py b/src/cedalion/imagereco/solver.py index f8c9e543..48bc77a4 100644 --- a/src/cedalion/imagereco/solver.py +++ b/src/cedalion/imagereco/solver.py @@ -6,15 +6,15 @@ import cedalion.xrutils as xrutils -def pseudo_inverse_stacked(Adot, alpha=0.01): +def pseudo_inverse_stacked(Adot: xr.DataArray, alpha: float = 0.01): """Pseudo-inverse of a stacked matrix. Args: - Adot (xr.DataArray): Stacked matrix. - alpha (float): Regularization parameter. + Adot: Stacked matrix. + alpha: Regularization parameter. Returns: - xr.DataArray: Pseudo-inverse of the stacked matrix. + Pseudo-inverse of the stacked matrix. """ if "units" in Adot.attrs: diff --git a/src/cedalion/imagereco/utils.py b/src/cedalion/imagereco/utils.py index 4e5529f6..3d09cc8f 100644 --- a/src/cedalion/imagereco/utils.py +++ b/src/cedalion/imagereco/utils.py @@ -1,5 +1,6 @@ """Utility functions for image reconstruction.""" +from __future__ import annotations import xarray as xr import numpy as np import cedalion @@ -18,19 +19,16 @@ def map_segmentation_mask_to_surface( segmentation_mask: xr.DataArray, transform_vox2ras: cdt.AffineTransform, # FIXME surface: cdc.Surface, -): +) -> coo_array: """Find for each voxel the closest vertex on the surface. Args: - segmentation_mask (xr.DataArray): A binary mask of shape (segmentation_type, i, - j, k). - transform_vox2ras (xr.DataArray): The affine transformation from voxel to RAS - space. - surface (cedalion.dataclasses.Surface): The surface to map the voxels to. + segmentation_mask: A binary mask of shape (segmentation_type, i, j, k). + transform_vox2ras: The affine transformation from voxel to RAS space. + surface: The surface to map the voxels to. Returns: - coo_array: A sparse matrix of shape (ncells, nvertices) that maps voxels to - cells. + A sparse matrix of shape (ncells, nvertices) that maps voxels to cells. """ assert surface.crs == transform_vox2ras.dims[0] @@ -61,17 +59,17 @@ def map_segmentation_mask_to_surface( return map_voxel_to_vertex -def normal_hrf(t, t_peak, t_std, vmax): - """Create a normal hrf. +def normal_hrf(t: np.ndarray, t_peak: float, t_std: float, vmax: float) -> np.ndarray: + """Create a normal HRF. Args: - t (np.ndarray): The time points. - t_peak (float): The peak time. - t_std (float): The standard deviation. - vmax (float): The maximum value of the HRF. + t: The time points. + t_peak: The peak time. + t_std: The standard deviation. + vmax: The maximum value of the HRF. Returns: - np.ndarray: The HRF. + The HRF. """ hrf = scipy.stats.norm.pdf(t, loc=t_peak, scale=t_std) hrf *= vmax / hrf.max() @@ -85,20 +83,19 @@ def create_mock_activation_below_point( sampling_rate: units.Quantity, spatial_size: units.Quantity, vmax: units.Quantity, -): +) -> xr.DataArray: """Create a mock activation below a point. Args: - head_model (cedalion.imagereco.forward_model.TwoSurfaceHeadModel): The head - model. - point (cdt.LabeledPointCloud): The point below which to create the activation. - time_length (units.Quantity): The length of the activation. - sampling_rate (units.Quantity): The sampling rate. - spatial_size (units.Quantity): The spatial size of the activation. - vmax (units.Quantity): The maximum value of the activation. + head_model: The head model. + point: The point below which to create the activation. + time_length: The length of the activation. + sampling_rate: The sampling rate. + spatial_size: The spatial size of the activation. + vmax: The maximum value of the activation. Returns: - xr.DataArray: The activation. + The activation. """ # assert head_model.crs == point.points.crs diff --git a/src/cedalion/io/anatomy.py b/src/cedalion/io/anatomy.py index 40fe434d..bb0cc928 100644 --- a/src/cedalion/io/anatomy.py +++ b/src/cedalion/io/anatomy.py @@ -21,10 +21,10 @@ def _get_affine_from_niftii(image: nibabel.nifti1.Nifti1Image): """Get affine transformation matrix from NIFTI image. Args: - image (nibabel.nifti1.Nifti1Image): NIFTI image object + image: NIFTI image object Returns: - xr.DataArray: Affine transformation matrix + Affine transformation matrix """ transform, code = image.get_sform(coded=True) if code != 0: @@ -56,15 +56,14 @@ def read_segmentation_masks( """Read segmentation masks from NIFTI files. Args: - basedir (str): Directory containing the mask files - mask_files (Dict[str, str]): Dictionary mapping segmentation types to filenames + basedir: Directory containing the mask files + mask_files: Dictionary mapping segmentation types to filenames Returns: - Tuple[xr.DataArray, np.ndarray]: - - masks (xr.DataArray): Concatenated segmentation masks with a new - dimension `segmentation_type`. - - affine (np.ndarray): Affine transformation matrix associated with the - NIFTI files. + - masks (xr.DataArray): Concatenated segmentation masks with a new + dimension `segmentation_type`. + - affine (np.ndarray): Affine transformation matrix associated with the + NIFTI files. """ mask_ids = {seg_type: i + 1 for i, seg_type in enumerate(mask_files.keys())} masks = [] @@ -118,16 +117,16 @@ def read_segmentation_masks( return masks, affine -def cell_coordinates(mask, affine, units="mm"): +def cell_coordinates(mask: xr.DataArray, affine: np.ndarray, units: str = "mm"): """Get the coordinates of each voxel in the transformed mask. Args: - mask (xr.DataArray): A binary mask of shape (i, j, k). - affine (np.ndarray): Affine transformation matrix. - units (str): Units of the output coordinates. + mask: A binary mask of shape (i, j, k). + affine: Affine transformation matrix. + units: Units of the output coordinates. Returns: - xr.DataArray: Coordinates of the center of each voxel in the mask. + Coordinates of the center of each voxel in the mask. """ # coordinates in voxel space i = np.arange(mask.shape[0]) diff --git a/src/cedalion/io/forward_model.py b/src/cedalion/io/forward_model.py index 9b5022c5..70a66833 100644 --- a/src/cedalion/io/forward_model.py +++ b/src/cedalion/io/forward_model.py @@ -2,6 +2,7 @@ import h5py import xarray as xr +from typing import Tuple import cedalion.dataclasses as cdc @@ -10,8 +11,8 @@ def save_Adot(fn: str, Adot: xr.DataArray): """Save Adot to a netCDF file. Args: - fn (str): File name to save the data to. - Adot (xr.DataArray): Data to save. + fn: File name to save the data to. + Adot: Data to save. Returns: None @@ -20,14 +21,14 @@ def save_Adot(fn: str, Adot: xr.DataArray): Adot.to_netcdf(fn) return -def load_Adot(fn: str): +def load_Adot(fn: str) -> xr.DataArray: """Load Adot from a netCDF file. Args: - fn (str): File name to load the data from. + fn: File name to load the data from. Returns: - xr.DataArray: Data loaded from the file. + Data loaded from the file. """ Adot = xr.open_dataset(fn) @@ -83,14 +84,14 @@ def save_fluence(fn : str, fluence_all, fluence_at_optodes): f.flush() -def load_fluence(fn : str): +def load_fluence(fn : str) -> Tuple[xr.DataArray, xr.DataArray]: """Load forward model computation results. Args: - fn (str): File name to load the data from. + fn: File name to load the data from. Returns: - Tuple[xr.DataArray, xr.DataArray]: Fluence data loaded from the file. + Fluence data loaded from the file. """ with h5py.File(fn, "r") as f: diff --git a/src/cedalion/io/photogrammetry.py b/src/cedalion/io/photogrammetry.py index 53b811d4..aac4e444 100644 --- a/src/cedalion/io/photogrammetry.py +++ b/src/cedalion/io/photogrammetry.py @@ -1,25 +1,24 @@ """Module for reading photogrammetry output file formats.""" +from __future__ import annotations import cedalion.dataclasses as cdc import numpy as np from collections import OrderedDict -def read_photogrammetry_einstar(fn): +def read_photogrammetry_einstar(fn: str) -> tuple: """Read optodes and fiducials from photogrammetry pipeline. This method reads the output file as returned by the photogrammetry pipeline using an einstar device. Args: - fn (str): The filename of the einstar photogrammetry output file. + fn: The filename of the einstar photogrammetry output file. Returns: - tuple: A tuple containing: - - fiducials (cedalion.LabeledPoints): The fiducials as a cedalion - LabeledPoints object. - - optodes (cedalion.LabeledPoints): The optodes as a cedalion LabeledPoints - object. + A tuple containing: + - fiducials: The fiducials as a cedalion LabeledPoints object. + - optodes: The optodes as a cedalion LabeledPoints object. """ fiducials, optodes = read_einstar(fn) @@ -27,16 +26,16 @@ def read_photogrammetry_einstar(fn): return fiducials, optodes -def read_einstar(fn): +def read_einstar(fn: str) -> tuple: """Read optodes and fiducials from einstar devices. Args: - fn (str): The filename of the einstar photogrammetry output file. + fn: The filename of the einstar photogrammetry output file. Returns: - tuple: A tuple containing: - - fiducials (OrderedDict): The fiducials as an OrderedDict. - - optodes (OrderedDict): The optodes as an OrderedDict. + A tuple containing: + - fiducials: The fiducials as an OrderedDict. + - optodes: The optodes as an OrderedDict. """ with open(fn, "r") as f: @@ -52,19 +51,17 @@ def read_einstar(fn): return fiducials, optodes -def opt_fid_to_xr(fiducials, optodes): +def opt_fid_to_xr(fiducials: OrderedDict, optodes: OrderedDict) -> tuple: """Convert OrderedDicts fiducials and optodes to cedalion LabeledPoints objects. Args: - fiducials (OrderedDict): The fiducials as an OrderedDict. - optodes (OrderedDict): The optodes as an OrderedDict. + fiducials: The fiducials as an OrderedDict. + optodes: The optodes as an OrderedDict. Returns: - tuple: A tuple containing: - - fiducials (cedalion.LabeledPoints): The fiducials as a cedalion - LabeledPoints object. - - optodes (cedalion.LabeledPoints): The optodes as a cedalion LabeledPoints - object. + A tuple containing: + - fiducials: The fiducials as a cedalion LabeledPoints object. + - optodes: The optodes as a cedalion LabeledPoints object. """ # FIXME: this should get a different CRS diff --git a/src/cedalion/io/probe_geometry.py b/src/cedalion/io/probe_geometry.py index 09b659cf..f5fbd439 100644 --- a/src/cedalion/io/probe_geometry.py +++ b/src/cedalion/io/probe_geometry.py @@ -1,5 +1,6 @@ """Module for reading and writing probe geometry files.""" +from __future__ import annotations import numpy as np import xarray as xr import trimesh @@ -13,18 +14,13 @@ def load_tsv(tsv_fname: str, crs: str='digitized', units: str='mm') -> xr.DataArray: """Load a tsv file containing optodes or landmarks. - Parameters - ---------- - tsv_fname : str - Path to the tsv file. - crs : str - Coordinate reference system of the points. - units : str + Args: + tsv_fname: Path to the tsv file. + crs: Coordinate reference system of the points. + units: Units of the points. Returns: - ------- - xr.DataArray - Optodes or landmarks as a Data + Optodes or landmarks as a DataArray. """ with open(tsv_fname, 'r') as f: lines = f.readlines() @@ -68,16 +64,11 @@ def load_tsv(tsv_fname: str, crs: str='digitized', units: str='mm') -> xr.DataAr def read_mrk_json(fname: str, crs: str) -> xr.DataArray: """Read a JSON file containing landmarks. - Parameters - ---------- - fname : str - Path to the JSON file. - crs : str - Coordinate reference system of the landmarks. + Args: + fname: Path to the JSON file. + crs: Coordinate reference system of the landmarks. Returns: - ------- - xr.DataArray Landmarks as a DataArray. """ with open(fname) as fin: @@ -122,14 +113,10 @@ def read_mrk_json(fname: str, crs: str) -> xr.DataArray: def save_mrk_json(fname: str, landmarks: xr.DataArray, crs: str): """Save landmarks to a JSON file. - Parameters - ---------- - fname : str - Path to the output file. - landmarks : xr.DataArray - Landmarks to save. - crs: str - Coordinate system of the landmarks. + Args: + fname: Path to the output file. + landmarks: Landmarks to save. + crs: Coordinate system of the landmarks. """ control_points = [{"id": i, "label": lm.label.item(), @@ -150,16 +137,11 @@ def save_mrk_json(fname: str, landmarks: xr.DataArray, crs: str): def read_digpts(fname: str, units: str="mm") -> xr.DataArray: """Read a file containing digitized points. - Parameters - ---------- - fname : str - Path to the file. - units : str - Units of the points. + Args: + fname: Path to the file. + units: Units of the points. Returns: - ------- - xr.DataArray Digitized points as a DataArray. """ with open(fname) as fin: @@ -188,15 +170,11 @@ def read_digpts(fname: str, units: str="mm") -> xr.DataArray: def read_einstar_obj(fname: str) -> TrimeshSurface: """Read a textured triangle mesh generated by Einstar devices. - Parameters - ---------- - fname : str - Path to the file. + Args: + fname: Path to the file. Returns: - ------- - TrimeshSurface - Triangle + TrimeshSurface: Triangle mesh surface. """ mesh = trimesh.load(fname) return TrimeshSurface(mesh, crs="digitized", units=cedalion.units.mm) diff --git a/src/cedalion/io/snirf.py b/src/cedalion/io/snirf.py index 71dbe9e4..94c4b9a1 100644 --- a/src/cedalion/io/snirf.py +++ b/src/cedalion/io/snirf.py @@ -156,7 +156,7 @@ def reduce_ndim_sourceLabels(sourceLabels: np.ndarray) -> list: to a unique common prefix to obtain only one label per source. Args: - sourceLabels (np.ndarray): The source labels to reduce. + sourceLabels: The source labels to reduce. Returns: list: The reduced source labels. @@ -185,12 +185,12 @@ def reduce_ndim_sourceLabels(sourceLabels: np.ndarray) -> list: return labels -def labels_and_positions(probe, dim: int = 3): +def labels_and_positions(probe, dim: int = 3) -> tuple: """Extract 3D coordinates of optodes and landmarks from a nirs probe variable. Args: probe: Nirs probe geometry variable, see snirf docs (:cite:t:`Tucker2022`). - dim (int): Must be either 2 or 3. + dim: Must be either 2 or 3. Returns: tuple: A tuple containing the source, detector, and landmark labels/positions. @@ -252,19 +252,19 @@ def convert_none(probe, attrname, default): ) def geometry_from_probe(nirs_element: NirsElement, dim: int, crs : str): - """Extract 3D coordinates of optodes and landmarks from probe information. - - Args: - nirs_element (NirsElement): Nirs data element as specified in the snirf - documentation (:cite:t:`Tucker2022`). - dim (int): Must be either 2 or 3. - crs: the name of coordinate reference system - - Returns: - xr.DataArray: A DataArray containing the 3D coordinates of optodes and - landmarks, with dimensions 'label' and 'pos' and coordinates 'label' and - 'type'. - """ + def geometry_from_probe(nirs_element: NirsElement, dim: int, crs: str) -> xr.DataArray: + """Extract 3D coordinates of optodes and landmarks from probe information. + + Args: + nirs_element: Nirs data element as specified in the snirf documentation + (:cite:t:`Tucker2022`). + dim: Must be either 2 or 3. + crs: The name of the coordinate reference system. + + Returns: + A DataArray containing the 3D coordinates of optodes and landmarks, with + dimensions 'label' and 'pos' and coordinates 'label' and 'type'. + """ probe = nirs_element.probe length_unit = nirs_element.metaDataTags.LengthUnit @@ -312,10 +312,10 @@ def measurement_list_to_dataframe( Args: measurement_list: MeasurementList object from the snirf file. - drop_none (bool): If True, drop columns that are None for all rows. + drop_none: If True, drop columns that are None for all rows. Returns: - pd.DataFrame: DataFrame containing the measurement list information. + DataFrame containing the measurement list information. """ fields = [ "sourceIndex", @@ -346,11 +346,11 @@ def meta_data_tags_to_dict(nirs_element: NirsElement) -> OrderedDict[str, Any]: """Converts the metaDataTags of a nirs element to a dictionary. Args: - nirs_element (NirsElement): Nirs data element as specified in the snirf + nirs_element: Nirs data element as specified in the snirf documentation (:cite:t:`Tucker2022`). Returns: - OrderedDict[str, Any]: Dictionary containing the metaDataTags information. + Dictionary containing the metaDataTags information. """ mdt = nirs_element.metaDataTags @@ -405,15 +405,15 @@ def read_aux( """Reads the aux data from a nirs element into a dictionary of DataArrays. Args: - nirs_element (NirsElement): Nirs data element as specified in the snirf + nirs_element: Nirs data element as specified in the snirf documentation (:cite:t:`Tucker2022`). - opts (dict[str, Any]): Options for reading the aux data. The following + opts: Options for reading the aux data. The following options are supported: - squeeze_aux (bool): If True, squeeze the aux data to remove dimensions of size 1. Returns: - result (OrderedDict[str, xr.DataArray]): Dictionary containing the aux data + Dictionary containing the aux data. """ result = OrderedDict() @@ -460,18 +460,18 @@ def read_aux( return result -def add_number_to_name(name, keys): +def add_number_to_name(name: str, keys: list[str]) -> str: """Changes name to name_. Number appended to name is the smallest number that makes the new name unique with respect to the list of keys. Args: - name (str): Name to which a number should be added. - keys (list[str]): List of keys to which the new name should be compared. + name: Name to which a number should be added. + keys: List of keys to which the new name should be compared. Returns: - str: New name with number added. + New name with number added. """ pat = re.compile(rf"{name}(_(\d+))?") @@ -491,14 +491,12 @@ def read_data_elements( """Reads the data elements from a nirs element into a list of DataArrays. Args: - data_element (DataElement): DataElement obj. from the snirf file. - nirs_element (NirsElement): Nirs data element as specified in the snirf - documentation (:cite:t:`Tucker2022`). - stim (pd.DataFrame): DataFrame containing the stimulus information. + data_element: DataElement object from the snirf file. + nirs_element: Nirs data element as specified in the snirf documentation. + stim: DataFrame containing the stimulus information. Returns: - list[tuple[str, NDTimeSeries]]: List of tuples containing the canonical name - of the data element and the DataArray. + List of tuples containing the canonical name of the data element and the DataArray. """ time = data_element.time @@ -644,18 +642,18 @@ def _get_time_coords( nirs_element: NirsElement, data_element: DataElement, df_measurement_list: pd.DataFrame, -) -> dict[str, ArrayLike]: +) -> tuple[None, dict[str, ArrayLike]]: """Get time coordinates for the NIRS data element. Args: - nirs_element (NirsElement): NIRS data element containing metadata. - data_element (DataElement): Data element containing time and dataTimeSeries. - df_measurement_list (pd.DataFrame): DataFrame containing the measurement list. + nirs_element: NIRS data element containing metadata. + data_element: Data element containing time and dataTimeSeries. + df_measurement_list: DataFrame containing the measurement list. Returns: tuple: A tuple containing: - - indices (None): Placeholder for indices. - - coordinates (dict[str, ArrayLike]): Dictionary with time coordinates. + - indices: Placeholder for indices. + - coordinates: Dictionary with time coordinates. """ time = data_element.time time_unit = nirs_element.metaDataTags.TimeUnit @@ -678,13 +676,13 @@ def _get_channel_coords( """Get channel coordinates for the NIRS data element. Args: - nirs_element (NirsElement): NIRS data element containing probe information. - df_measurement_list (pd.DataFrame): DataFrame containing the measurement list. + nirs_element: NIRS data element containing probe information. + df_measurement_list: DataFrame containing the measurement list. Returns: tuple: A tuple containing: - - indices (None): Placeholder for indices. - - coordinates (dict[str, ArrayLike]): Dictionary with channel coordinates. + - indices: Placeholder for indices. + - coordinates: Dictionary with channel coordinates. """ sourceLabels, detectorLabels, landmarkLabels, _, _, _ = labels_and_positions( nirs_element.probe @@ -703,20 +701,18 @@ def _get_channel_coords( return indices, coordinates -def read_nirs_element(nirs_element, opts): +def read_nirs_element(nirs_element: NirsElement, opts: dict[str, Any]) -> cdc.Recording: """Reads a single nirs element from a .snirf file into a Recording object. Args: - nirs_element (NirsElement): Nirs data element as specified in the snirf - documentation (:cite:t:`Tucker2022`). - opts (dict[str, Any]): Options for reading the data element. The following - options are supported: - - squeeze_aux (bool): If True, squeeze the aux data to remove - dimensions of size 1. - - crs (str): name of the geo?d's coordinate reference system. + nirs_element: Nirs data element as specified in the snirf documentation. + opts: Options for reading the data element. The following options are supported: + - squeeze_aux (bool): If True, squeeze the aux data to remove dimensions of + size 1. + - crs (str): Name of the geo3d's coordinate reference system. Returns: - rec (Recording): Recording object containing the data from the nirs element. + Recording object containing the data from the nirs element. """ geo2d = geometry_from_probe(nirs_element, dim=2, crs=opts["crs"]) @@ -767,7 +763,7 @@ def read_snirf( squeeze_aux: If True, squeeze the aux data to remove dimensions of size 1. Returns: - list[Recording]: List of Recording objects containing the data from the nirs + List of Recording objects containing the data from the nirs elements in the .snirf file. """ opts = {"squeeze_aux": squeeze_aux, "crs" : crs} @@ -779,18 +775,19 @@ def read_snirf( return [read_nirs_element(ne, opts) for ne in s.nirs] -def denormalize_measurement_list(df_ml: pd.DataFrame, nirs_element: NirsElement): +def denormalize_measurement_list( + df_ml: pd.DataFrame, + nirs_element: NirsElement +) -> pd.DataFrame: """Enriches measurement list DataFrame with additional information. Args: - df_ml (pd.DataFrame): DataFrame containing the measurement list information. - nirs_element (NirsElement): Nirs data element as specified in the snirf - documentation (:cite:t:`Tucker2022`). + df_ml: DataFrame containing the measurement list information. + nirs_element: Nirs data element as specified in the snirf documentation. Returns: - pd.DataFrame: DataFrame containing the measurement list information with - additional columns for channel, source, detector, wavelength and chromo. - + DataFrame containing the measurement list information with additional columns + for channel, source, detector, wavelength, and chromo. """ sourceLabels, detectorLabels, landmarkLabels, _, _, _ = labels_and_positions( nirs_element.probe @@ -843,27 +840,27 @@ def denormalize_measurement_list(df_ml: pd.DataFrame, nirs_element: NirsElement) def measurement_list_from_stacked( - stacked_array, - data_type, - trial_types, - stacked_channel="snirf_channel", - source_labels=None, - detector_labels=None, - wavelengths=None, -): + stacked_array: xr.DataArray, + data_type: str, + trial_types: list[str], + stacked_channel: str = "snirf_channel", + source_labels: list[str] = None, + detector_labels: list[str] = None, + wavelengths: list[float] = None, +) -> tuple: """Create a measurement list from a stacked array. Args: - stacked_array (xr.DataArray): Stacked array containing the data. - data_type (str): Data type of the data. - trial_types (list[str]): List of trial types. - stacked_channel (str): Name of the channel dimension in the stacked array. - source_labels (list[str]): List of source labels. - detector_labels (list[str]): List of detector labels. - wavelengths (list[float]): List of wavelengths. + stacked_array: Stacked array containing the data. + data_type: Data type of the data. + trial_types: List of trial types. + stacked_channel: Name of the channel dimension in the stacked array. + source_labels: List of source labels. + detector_labels: List of detector labels. + wavelengths: List of wavelengths. Returns: - tuple: A tuple containing the source labels, detector labels, wavelengths, and + A tuple containing the source labels, detector labels, wavelengths, and the measurement list. """ if source_labels is None: @@ -927,11 +924,11 @@ def measurement_list_from_stacked( def _write_recordings(snirf_file: Snirf, rec: cdc.Recording): """Write a recording to a .snirf file. - See snirf specification for details (:cite:t:`Tucker2022`) + See snirf specification for details. Args: - snirf_file (Snirf): Snirf object to write to. - rec (Recording): Recording object to write to the file. + snirf_file: Snirf object to write to. + rec: Recording object to write to the file. """ # create and populate nirs element snirf_file.nirs.appendGroup() @@ -1062,13 +1059,12 @@ def _write_recordings(snirf_file: Snirf, rec: cdc.Recording): def write_snirf( fname: Path | str, recordings: cdc.Recording | list[cdc.Recording], -): +) -> None: """Write one or more recordings to a .snirf file. Args: - fname (Path | str): Path to .snirf file. - recordings (Recording | list[Recording]): Recording object(s) to write to the - file. + fname: Path to .snirf file. + recordings: Recording object(s) to write to the file. """ if isinstance(fname, Path): fname = str(fname) diff --git a/src/cedalion/models/glm/design_matrix.py b/src/cedalion/models/glm/design_matrix.py index d69624a0..66f8b600 100644 --- a/src/cedalion/models/glm/design_matrix.py +++ b/src/cedalion/models/glm/design_matrix.py @@ -24,24 +24,23 @@ def make_design_matrix( """Generate the design matrix for the GLM. Args: - ts_long (cdt.NDTimeSeries): Time series of long distance channels. - ts_short (cdt.NDTimeSeries): Time series of short distance channels. - stim (DataFrame): Stimulus DataFrame - geo3d (cdt.LabeledPointCloud): Probe geometry - basis_function (TemporalBasisFunction): the temporal basis function(s) to model - the HRF. - drift_order (int): If not None specify the highest polynomial order of the drift + ts_long: Time series of long distance channels. + ts_short: Time series of short distance channels. + stim: Stimulus DataFrame. + geo3d: Probe geometry. + basis_function: The temporal basis function(s) to model the HRF. + drift_order: If not None, specify the highest polynomial order of the drift terms. - short_channel_method (str): Specifies the method to add short channel - information to the design matrix + short_channel_method: Specifies the method to add short channel information to + the design matrix. Options: - 'closest': Use the closest short channel - 'max_corr': Use the short channel with the highest correlation + 'closest': Use the closest short channel. + 'max_corr': Use the short channel with the highest correlation. 'mean': Use the average of all short channels. Returns: - A tuple containing the global design_matrix and a list of channel-wise - regressors. + A tuple containing the global design matrix and a list of channel-wise + regressors. """ dm = make_hrf_regressors(ts_long, stim, basis_function) @@ -67,15 +66,15 @@ def make_design_matrix( return dm, channel_wise_regressors -def make_drift_regressors(ts: cdt.NDTimeSeries, drift_order) -> xr.DataArray: +def make_drift_regressors(ts: cdt.NDTimeSeries, drift_order: int) -> xr.DataArray: """Create drift regressors. Args: - ts (cdt.NDTimeSeries): Time series data. - drift_order (int): The highest polynomial order of the drift terms. + ts: Time series data. + drift_order: The highest polynomial order of the drift terms. Returns: - xr.DataArray: A DataArray containing the drift regressors. + A DataArray containing the drift regressors. """ dim3 = xrutils.other_dim(ts, "channel", "time") ndim3 = ts.sizes[dim3] @@ -154,16 +153,15 @@ def build_stim_array( def make_hrf_regressors( ts: cdt.NDTimeSeries, stim: pd.DataFrame, basis_function: TemporalBasisFunction ): - """Create regressors modelling the hemodynamic response to stimuli. + """Create regressors modeling the hemodynamic response to stimuli. Args: - ts (NDTimeSeries): Time series data. - stim (pd.DataFrame): Stimulus DataFrame. - basis_function (TemporalBasisFunction): TemporalBasisFunction object defining - the HRF. + ts: Time series data. + stim: Stimulus DataFrame. + basis_function: TemporalBasisFunction object defining the HRF. Returns: - regressors (xr.DataArray): A DataArray containing the regressors. + A DataArray containing the regressors. """ # FIXME allow basis_function to be an xarray as returned by basis_function() @@ -288,15 +286,15 @@ def _regressors_from_selected_short_channels( def closest_short_channel( ts_long: cdt.NDTimeSeries, ts_short: cdt.NDTimeSeries, geo3d: cdt.LabeledPointCloud ): - """Create channel-wise regressors use closest nearby short channel. + """Create channel-wise regressors using the closest nearby short channel. Args: - ts_long (NDTimeSeries): Time series of long channels - ts_short (NDTimeSeries): Time series of short channels - geo3d (LabeledPointCloud): Probe geometry + ts_long: Time series of long channels. + ts_short: Time series of short channels. + geo3d: Probe geometry. Returns: - regressors (xr.DataArray): Channel-wise regressor + Channel-wise regressors. """ # calculate midpoints between channel optode pairs. dims: (channel, crs) long_channel_pos = (geo3d.loc[ts_long.source] + geo3d.loc[ts_long.detector]) / 2 @@ -323,15 +321,15 @@ def closest_short_channel( def max_corr_short_channel(ts_long: cdt.NDTimeSeries, ts_short: cdt.NDTimeSeries): """Create channel-wise regressors using the most correlated short channels. - For each long channel the short channel is selected that has the highest - correleation coefficient in any wavelength or chromophore. + For each long channel, the short channel is selected that has the highest + correlation coefficient in any wavelength or chromophore. Args: - ts_long (NDTimeSeries): time series of long channels - ts_short (NDTimeSeries): time series of short channels + ts_long: Time series of long channels. + ts_short: Time series of short channels. Returns: - xr.DataArray: channel-wise regressors + Channel-wise regressors. """ dim3 = xrutils.other_dim(ts_long, "channel", "time") @@ -366,10 +364,10 @@ def average_short_channel(ts_short: cdt.NDTimeSeries): """Create a regressor by averaging all short channels. Args: - ts_short (NDTimeSeries): time series of short channels + ts_short: Time series of short channels. Returns: - xr.DataArray: regressors + Regressors. """ ts_short = ts_short.pint.dequantify() diff --git a/src/cedalion/models/glm/solve.py b/src/cedalion/models/glm/solve.py index 50cb46b8..bc1391b6 100644 --- a/src/cedalion/models/glm/solve.py +++ b/src/cedalion/models/glm/solve.py @@ -103,14 +103,14 @@ def predict( """Predict time series from design matrix and thetas. Args: - ts (cdt.NDTimeSeries): The time series to be modeled. - thetas (xr.DataArray): The estimated parameters. - design_matrix (xr.DataArray): DataArray with dims time, regressor, chromo - channel_wise_regressors (list[xr.DataArray]): Optional list of design matrices, - with additional channel dimension. + ts: The time series to be modeled. + thetas: The estimated parameters. + design_matrix: DataArray with dims time, regressor, chromo. + channel_wise_regressors: Optional list of design matrices, with additional + channel dimension. Returns: - prediction (xr.DataArray): The predicted time series. + The predicted time series. """ dim3_name = xrutils.other_dim(design_matrix, "time", "regressor") @@ -154,11 +154,11 @@ def iter_design_matrix( """Iterate over the design matrix and yield the design matrix for each group. Args: - ts (cdt.NDTimeSeries): The time series to be modeled. - design_matrix (xr.DataArray): DataArray with dims time, regressor, chromo. - channel_wise_regressors (list[xr.DataArray] | None, optional): Optional list of + ts: The time series to be modeled. + design_matrix: DataArray with dims time, regressor, chromo. + channel_wise_regressors: Optional list of design matrices, with additional channel dimension. - channel_groups (list[int] | None, optional): Optional list of channel groups. + channel_groups: Optional list of channel groups. Yields: tuple: A tuple containing: diff --git a/src/cedalion/nirs.py b/src/cedalion/nirs.py index c037bded..e5dcc2b6 100644 --- a/src/cedalion/nirs.py +++ b/src/cedalion/nirs.py @@ -14,7 +14,7 @@ import cedalion.data -def get_extinction_coefficients(spectrum: str, wavelengths: ArrayLike): +def get_extinction_coefficients(spectrum: str, wavelengths: ArrayLike) -> xr.DataArray: """Provide a matrix of extinction coefficients from tabulated data. Args: @@ -28,7 +28,7 @@ def get_extinction_coefficients(spectrum: str, wavelengths: ArrayLike): calculate the extinction coefficients. Returns: - xr.DataArray: A matrix of extinction coefficients with dimensions "chromo" + A matrix of extinction coefficients with dimensions "chromo" (chromophore, e.g. HbO/HbR) and "wavelength" (e.g. 750, 850, ...) at which the coefficients for each chromophore are given in units of "mm^-1 / M". @@ -82,17 +82,19 @@ def get_extinction_coefficients(spectrum: str, wavelengths: ArrayLike): raise ValueError(f"unsupported spectrum '{spectrum}'") -def channel_distances(amplitudes: cdt.NDTimeSeries, geo3d: cdt.LabeledPointCloud): +def channel_distances( + amplitudes: cdt.NDTimeSeries, geo3d: cdt.LabeledPointCloud +) -> xr.DataArray: """Calculate distances between channels. Args: amplitudes: A DataArray representing the amplitudes with dimensions (channel, *). - geo3d (xr.DataArray): A DataArray containing the 3D coordinates of the channels + geo3d: A DataArray containing the 3D coordinates of the channels with dimensions (channel, pos). Returns: - dists (xr.DataArray): A DataArray containing the calculated distances between + dists: A DataArray containing the calculated distances between source and detector channels. The resulting DataArray has the dimension 'channel'. """ @@ -107,14 +109,14 @@ def channel_distances(amplitudes: cdt.NDTimeSeries, geo3d: cdt.LabeledPointCloud return dists -def int2od(amplitudes: cdt.NDTimeSeries): +def int2od(amplitudes: cdt.NDTimeSeries) -> cdt.NDTimeSeries: """Calculate optical density from intensity amplitude data. Args: - amplitudes (xr.DataArray, (time, channel, *)): amplitude data. + amplitudes: amplitude data, dims (time, channel, *). Returns: - od: (xr.DataArray, (time, channel,*): The optical density data. + od: The optical density data, dims (time, channel,*). """ # check negative values in amplitudes and issue an error if yes if np.any(amplitudes < 0): @@ -134,19 +136,19 @@ def od2conc( geo3d: cdt.LabeledPointCloud, dpf: xr.DataArray, spectrum: str = "prahl", -): +) -> cdt.NDTimeSeries: """Calculate concentration changes from optical density data. Args: - od (xr.DataArray, (channel, wavelength, *)): The optical density data array - geo3d (xr.DataArray): The 3D coordinates of the optodes. - dpf (xr.DataArray, (wavelength, *)): The differential pathlength factor data - spectrum (str, optional): The type of spectrum to use for calculating extinction + od: The optical density data array, dims (time, channel, wavelength, *) + geo3d: The 3D coordinates of the optodes. + dpf: The differential pathlength factor data, dims (wavelength, *) + spectrum: The type of spectrum to use for calculating extinction coefficients. Defaults to "prahl". Returns: - conc (xr.DataArray, (channel, *)): A data array containing - concentration changes by channel. + conc: A data array containing concentration changes by channel, dims + (channel, *) """ validators.has_channel(od) validators.has_wavelengths(od) @@ -181,15 +183,14 @@ def conc2od( """Calculate optical density data from concentration changes. Args: - conc (xr.DataArray, (channel, *)): The concentration changes by channel. - geo3d (xr.DataArray): The 3D coordinates of the optodes. - dpf (xr.DataArray, (wavelength, *)): The differential pathlength factor data. - spectrum (str, optional): The type of spectrum to use for calculating extinction + conc: The concentration changes by channel, dims (channel, *) + geo3d: The 3D coordinates of the optodes. + dpf: The differential pathlength factor data, dims (wavelength, *) + spectrum: The type of spectrum to use for calculating extinction coefficients. Defaults to "prahl". Returns: - od (xr.DataArray, (channel, wavelength, *)): A data array containing - optical density data. + od: A data array containing optical density data, dims (channel, wavelength, *) """ conc = conc.pint.to("molar") @@ -220,16 +221,16 @@ def beer_lambert( """Calculate concentration changes from amplitude using the modified BL law. Args: - amplitudes (xr.DataArray, (channel, wavelength, *)): The input data array - containing the raw intensities. - geo3d (xr.DataArray): The 3D coordinates of the optodes. - dpf (xr.DataArray, (wavelength,*)): The differential pathlength factors - spectrum (str, optional): The type of spectrum to use for calculating extinction + amplitudes: The input data array containing the raw intensities, dims dims + (channel, wavelength, *) + geo3d: The 3D coordinates of the optodes. + dpf: The differential pathlength factors, dims (wavelength,*) + spectrum: The type of spectrum to use for calculating extinction coefficients. Defaults to "prahl". Returns: - conc (xr.DataArray, (channel, *)): A data array containing - concentration changes according to the mBLL. + conc: A data array containing concentration changes according to the mBLL, dims + (channel, *) """ validators.has_channel(amplitudes) validators.has_wavelengths(amplitudes) @@ -248,17 +249,17 @@ def split_long_short_channels( ts: cdt.NDTimeSeries, geo3d: cdt.LabeledPointCloud, distance_threshold: cdt.QLength = 1.5 * cedalion.units.cm, -): +) -> tuple[cdt.NDTimeSeries, cdt.NDTimeSeries]: """Split a time series into two based on channel distances. Args: - ts (cdt.NDTimeSeries) : Time series to split. - geo3d (cdt.LabeledPointCloud) : 3D coordinates of the channels. - distance_threshold (Quantity) : Distance threshold for splitting the channels. + ts: Time series to split. + geo3d: 3D coordinates of the channels. + distance_threshold: Distance threshold for splitting the channels. Returns: - ts_long : time series with channel distances >= distance_threshold - ts_short : time series with channel distances < distance_threshold + ts_long: Time series with channel distances >= distance_threshold. + ts_short: Time series with channel distances < distance_threshold. """ dists = xrutils.norm( geo3d.loc[ts.source] - geo3d.loc[ts.detector], dim=geo3d.points.crs diff --git a/src/cedalion/plots.py b/src/cedalion/plots.py index 76c70d94..56410b1a 100644 --- a/src/cedalion/plots.py +++ b/src/cedalion/plots.py @@ -28,8 +28,8 @@ def plot_montage3D(amp: xr.DataArray, geo3d: xr.DataArray): """Plots a 3D visualization of a montage. Args: - amp (xr.DataArray): Time series data array. - geo3d (xr.DataArray): Landmark coordinates. + amp: Time series data array. + geo3d: Landmark coordinates. """ geo3d = geo3d.pint.dequantify() @@ -56,25 +56,25 @@ def plot_montage3D(amp: xr.DataArray, geo3d: xr.DataArray): def plot3d( - brain_mesh, - scalp_mesh, - geo3d, - timeseries, - poly_lines=[], + brain_mesh: cdc.TrimeshSurface, + scalp_mesh: cdc.TrimeshSurface, + geo3d: xr.Dataset, + timeseries: xr.DataArray, + poly_lines: list[list] = [], brain_scalars=None, - plotter=None, + plotter: pv.Plotter = None, ): """Plots a 3D visualization of brain and scalp meshes. Args: - brain_mesh (TrimeshSurface): The brain mesh as a TrimeshSurface object. - scalp_mesh (TrimeshSurface): The scalp mesh as a TrimeshSurface object. - geo3d (xarray.Dataset): Dataset containing 3-dimentional point centers. + brain_mesh: The brain mesh as a TrimeshSurface object. + scalp_mesh: The scalp mesh as a TrimeshSurface object. + geo3d: Dataset containing 3-dimensional point centers. timeseries: Time series data array. poly_lines: List of lists of points to be plotted as polylines. brain_scalars: Scalars to be used for coloring the brain mesh. - plotter (pv.Plotter, optional): An existing PyVista plotter instance to use for - plotting. If None, a new PyVista plotter instance is created. Default: None. + plotter: An existing PyVista plotter instance to use for plotting. If None, a + new PyVista plotter instance is created. Default: None. Initial Contributors: - Eike Middell | middell@tu-berlin.de | 2024 @@ -367,12 +367,12 @@ def plot_vector_field( """Plots a vector field on a PyVista plotter. Args: - plotter (pv.Plotter): A PyVista plotter instance used for rendering the vector - field. - points (cdt.LabeledPointCloud): A labeled point cloud data structure containing - point coordinates. - vectors (xr.DataArray): A data array containing the vector field. - ppoints (list, optional): A list to store indices of picked points, enables + plotter: A PyVista plotter instance used for rendering the vector field. + points: A labeled point cloud data structure containing point coordinates. + vectors: A data array containing the vector field. + ppoints: A list to store indices of picked points, enables picking if not None. + vectors: A data array containing the vector field. + ppoints: A list to store indices of picked points, enables picking if not None. Default is None. """ assert len(points) == len(vectors) @@ -935,10 +935,10 @@ def scalp_plot( Args: ts: a NDTimeSeries to provide channel definitions geo3d: a LabeledPointCloud to provide the probe geometry - metric ((:class:`DataArray`, (channel,) | ArrayLike)): the scalar metric to be - plotted for each channel. If provided as a DataArray it needs a channel - dimension. If provided as a plain array or list it must have the same - length as ts.channel and the matching is done by position. + metric: the scalar metric to be plotted for each channel. If provided as a + DataArray it needs a channel dimension. If provided as a plain array or list + it must have the same length as ts.channel and the matching is done by + position. ax: the matplotlib.Axes object into which to draw title: the axes title vmin: the minimum value of the metric diff --git a/src/cedalion/sigdecomp/ERBM.py b/src/cedalion/sigdecomp/ERBM.py index 50764f4f..d29a51b6 100644 --- a/src/cedalion/sigdecomp/ERBM.py +++ b/src/cedalion/sigdecomp/ERBM.py @@ -292,17 +292,17 @@ def ERBM(X: np.ndarray, p: int = None ) -> np.ndarray: def lfc(x: np.ndarray, p: int , choice, a0) -> tuple[np.ndarray, np.ndarray]: - """Helper function for ERBM ICA: computes the linear filtering coefficients (LFC) with length p for entropy rate estimation, and the estimated entropy rate. + """Compute the linear filtering coefficients (LFC) with length p for entropy rate estimation, and the estimated entropy rate. Args: - x (np.ndarray, (Time Points, 1)): the source estimate [T x 1] - p (int): the filter length for the source model - choice : can be 'sub', 'super' or 'unknown'; any other input is handled as 'unknown' - a0 (np.ndarray or empty list): is the intial guess [p x 1] or an empty list [] + x: The source estimate [T x 1]. + p: The filter length for the source model. + choice: Can be 'sub', 'super' or 'unknown'; any other input is handled as 'unknown'. + a0: The initial guess [p x 1] or an empty list []. Returns: - a (np.ndarray, (p, 1)): the filter coefficients [p x 1] - min_cost (np.ndarray, (1, 1)): the entropy rate estimation [1 x 1] + a: The filter coefficients [p x 1]. + min_cost: The entropy rate estimation [1 x 1]. """ global nf1, nf2, nf3, nf4, nf5, nf6, nf7, nf8 @@ -489,14 +489,15 @@ def lfc(x: np.ndarray, p: int , choice, a0) -> tuple[np.ndarray, np.ndarray]: def simplified_ppval(pp: dict, xs: float) -> float: """Helper function for ERBM ICA: simplified version of ppval. - This function evaluates a piecewise polynomial at a specific point. - + + This function evaluates a piecewise polynomial at a specific point. + Args: - pp (dict): a dictionary containing the piecewise polynomial representation of a function - xs (float): the evaluation point + pp: A dictionary containing the piecewise polynomial representation of a function. + xs: The evaluation point. Returns: - v (float): the value of the function at xs + The value of the function at xs. """ b = pp['breaks'][0] @@ -535,13 +536,13 @@ def simplified_ppval(pp: dict, xs: float) -> float: def cnstd_and_gain(a: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Helper function for ERBM ICA: returns constraint direction used for calculating projected gradient and gain of filter a. - + Args: - a (np.ndarray, (p, 1)): the filter coefficients [p x 1] - + a: The filter coefficients [p x 1]. + Returns: - b (np.ndarray, (p, 1)): the constraint direction [p x 1] - G (np.ndarray, (1,)): the gain of the filter a + b: The constraint direction [p x 1]. + G: The gain of the filter a. """ global cosmtx, sinmtx, Simpson_c @@ -581,10 +582,10 @@ def cnstd_and_gain(a: np.ndarray) -> tuple[np.ndarray, np.ndarray]: def calculate_cos_sin_mtx(p: int) -> None : """Helper function for ERBM ICA: calculates the cos and sin matrix for integral calculation in ERBM ICA. - + Args: - p (int): the filter length for the invertible filter source model - + p: the filter length for the invertible filter source model + Returns: None """ @@ -611,14 +612,16 @@ def calculate_cos_sin_mtx(p: int) -> None : def pre_processing(X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """Helper function for ERBM ICA: Preprocessing (removal of mean, patial pre-whitening, temporal pre-filtering) - + """Preprocessing for ERBM ICA. + + Removal of mean, patial pre-whitening, temporal pre-filtering. + Args: - X (np.ndarray, (Channels, Time Points)): the [N x T] input multivariate time series with dimensionality N observations/channels and T time points - + X: the [N x T] input multivariate time series with dimensionality N observations/channels and T time points + Returns: - X (np.ndarray, (Channels, Time Points)): the pre-processed input multivariate time series - P (np.ndarray, (Channels, Channels)): the pre-whitening matrix + X: the pre-processed input multivariate time series [N x T] + P: the pre-whitening matrix [N x N] """ # pre-processing of the data N, T = X.shape @@ -648,12 +651,12 @@ def pre_processing(X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: def inv_sqrtmH(B: np.ndarray) -> np.ndarray: """Helper function for ERBM ICA: computes the inverse square root of a matrix. - + Args: - B (np.ndarray): a square matrix - + B: a square matrix + Returns: - A (np.ndarray): the inverse square root of B + A: the inverse square root of B """ D, V = np.linalg.eig(B) order = np.argsort(D) diff --git a/src/cedalion/sigdecomp/ICA_EBM.py b/src/cedalion/sigdecomp/ICA_EBM.py index 379b9438..3e9aa34d 100644 --- a/src/cedalion/sigdecomp/ICA_EBM.py +++ b/src/cedalion/sigdecomp/ICA_EBM.py @@ -15,19 +15,22 @@ def ICA_EBM(X: np.ndarray) -> np.ndarray: are used for entropy bound calculation Args: - X (np.ndarray, (Channels, Time Points)): the [N x T] input multivariate time series with dimensionality N observations/channels and T time points + X: the [N x T] input multivariate time series with dimensionality N + observations/channels and T time points Returns: - W (np.ndarray, (Channels, Channels)): the [N x N] demixing matrix with weights for N channels/sources. - To obtain the independent components, the demixed signals can be calculated as S = W @ X. + W: the [N x N] demixing matrix with weights for N channels/sources. + To obtain the independent components, the demixed signals can be calculated + as S = W @ X. Initial Contributors: - Jacqueline Behrendt | jacqueline.behrendt@campus.tu-berlin.de | 2024 References: This code is based on the matlab version by Xi-Lin Li (:cite:t:`Li2010A`) - Xi-Lin Li and Tulay Adali, "Independent component analysis by entropy bound minimization," - IEEE Trans. Signal Processing, vol. 58, no. 10, pp. 5151-5164, Oct. 2010. + Xi-Lin Li and Tulay Adali, "Independent component analysis by entropy bound + minimization," IEEE Trans. Signal Processing, vol. 58, no. 10, pp. 5151-5164, + Oct. 2010. The original matlab version is available at https://mlsp.umbc.edu/resources.html under the name "Real-valued ICA by entropy rate bound minimization (ICA-ERBM)" """ @@ -850,14 +853,15 @@ def ICA_EBM(X: np.ndarray) -> np.ndarray: def simplified_ppval(pp: dict, xs: float) -> float: """Helper function for ICA EBM: simplified version of ppval. - This function evaluates a piecewise polynomial at a specific point. - + + This function evaluates a piecewise polynomial at a specific point. + Args: - pp (dict): a dictionary containing the piecewise polynomial representation of a function - xs (float): the evaluation point + pp: A dictionary containing the piecewise polynomial representation of a function. + xs: The evaluation point. Returns: - v (float): the value of the function at xs + The value of the function at xs. """ b = pp['breaks'][0] c = pp['coefs'] @@ -895,12 +899,12 @@ def simplified_ppval(pp: dict, xs: float) -> float: def inv_sqrtmH(B: np.ndarray) -> np.ndarray: """Helper function for ICA EBM: computes the inverse square root of a matrix. - + Args: - B (np.ndarray): a square matrix - + B: a square matrix + Returns: - A (np.ndarray): the inverse square root of B + The inverse square root of B. """ D, V = np.linalg.eig(B) @@ -912,14 +916,16 @@ def inv_sqrtmH(B: np.ndarray) -> np.ndarray: return A def pre_processing(X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """Helper function for ICA EBM: pre-processing (DC removal & spatial pre-whitening). - + """Helper function for ICA EBM pre-processing. + + DC removal & spatial pre-whitening + Args: - X (np.ndarray, (Channels, Time Points) ): the data matrix [N x T] - + X: the data matrix [N channels x T timepoints] + Returns: - X (np.ndarray, (Channels, Time Points)): the pre-processed data matrix [N x T] - P (np.ndarray, (Channels, Channels)): the pre-whitening matrix [N x N] + X: the pre-processed data matrix [N x T] + P: the pre-whitening matrix [N x N] """ # pre-processing program @@ -936,12 +942,12 @@ def pre_processing(X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: def symdecor(M: np.ndarray) -> np.ndarray: """Helper function for ICA EBM: fast symmetric orthogonalization. - + Args: - M (np.ndarray, (Channels, Channels)): the matrix to be orthogonalized [N x N] + M: the matrix to be orthogonalized [N channels x N channels] Returns: - W (np.ndarray, (Channels, Channels)): the orthogonalized matrix [N x N] + W: the orthogonalized matrix [N x N] """ D, V = np.linalg.eig(M.dot(M.T)) diff --git a/src/cedalion/sigproc/epochs.py b/src/cedalion/sigproc/epochs.py index c4bed40a..09c24049 100644 --- a/src/cedalion/sigproc/epochs.py +++ b/src/cedalion/sigproc/epochs.py @@ -23,7 +23,7 @@ def to_epochs( trial_types: list[str], before: cdt.QTime, after: cdt.QTime, -): +) -> xr.DataArray: """Extract epochs from the time series based on stimulus events. Args: @@ -34,7 +34,7 @@ def to_epochs( after: Time after stimulus event to include in epoch. Returns: - xarray.DataArray: Array containing the extracted epochs. + Array containing the extracted epochs. """ if not isinstance(before, Quantity): diff --git a/src/cedalion/sigproc/frequency.py b/src/cedalion/sigproc/frequency.py index 1682b49a..23a330a4 100644 --- a/src/cedalion/sigproc/frequency.py +++ b/src/cedalion/sigproc/frequency.py @@ -1,5 +1,6 @@ """Frequency-related signal processing methods.""" +from __future__ import annotations import numpy as np import scipy.signal import xarray as xr @@ -17,7 +18,7 @@ def sampling_rate(timeseries: cdt.NDTimeSeries) -> Quantity: This functions assumes uniform sampling. Args: - timeseries (:class:`NDTimeSeries`, (time,*)): the input time series + timeseries: the input time series, coords (time,*). Returns: The sampling rate estimated by averaging time differences between samples. @@ -39,9 +40,9 @@ def freq_filter( """Apply a Butterworth bandpass frequency filter. Args: - timeseries (:class:`NDTimeSeries`, (time,*)): the input time series - fmin (:class:`Quantity`, [frequency]): lower threshold of the pass band - fmax (:class:`Quantity`, [frequency]): higher threshold of the pass band + timeseries: the input time series, coords (time,*) + fmin: lower threshold of the pass band + fmax: higher threshold of the pass band butter_order: order of the filter Returns: diff --git a/src/cedalion/sigproc/motion_correct.py b/src/cedalion/sigproc/motion_correct.py index 7a49cad7..ab3bc65f 100644 --- a/src/cedalion/sigproc/motion_correct.py +++ b/src/cedalion/sigproc/motion_correct.py @@ -35,7 +35,7 @@ def motion_correct_spline( p: smoothing factor Returns: - dodSpline (cdt.NDTimeSeries): The motion-corrected fNIRS data. + dodSpline: The motion-corrected fNIRS data. """ dtShort = 0.3 dtLong = 3 @@ -163,19 +163,19 @@ def motion_correct_spline( def compute_window( SegLength: cdt.NDTimeSeries, dtShort: Quantity, dtLong: Quantity, fs: Quantity ): - """Computes the window size. + """Compute the window size. Window size is based on the segment length, short time interval, long time interval, and sampling frequency. Args: - SegLength (cdt.NDTimeSeries): The length of the segment. - dtShort (Quantity): The short time interval. - dtLong (Quantity): The long time interval. - fs (Quantity): The sampling frequency. + SegLength: The length of the segment. + dtShort: The short time interval. + dtLong: The long time interval. + fs: The sampling frequency. Returns: - wind: The computed window size. + The computed window size. """ if SegLength < dtShort * fs: wind = SegLength @@ -197,14 +197,12 @@ def motion_correct_splineSG( """Apply motion correction using spline interpolation and Savitzky-Golay filter. Args: - fNIRSdata (cdt.NDTimeSeries): The fNIRS data to be motion corrected. - frame_size (Quantity): The size of the sliding window in seconds for the - Savitzky-Golay filter. Default is 10 seconds. - p: smoothing factor + fNIRSdata: The fNIRS data to be motion corrected. + p: Smoothing factor. + frame_size: The size of the sliding window in seconds for the Savitzky-Golay filter. Default is 10 seconds. Returns: - dodSplineSG (cdt.NDTimeSeries): The motion-corrected fNIRS data after applying - spline interpolation and Savitzky-Golay filter. + dodSplineSG: The motion-corrected fNIRS data after applying spline interpolation and Savitzky-Golay filter. """ fs = sampling_rate(fNIRSdata) @@ -262,17 +260,17 @@ def motion_correct_PCA( Boston University Neurophotonics Center https://github.com/BUNPC/Homer3 - Inputs: + Args: fNIRSdata: The fNIRS data to be motion corrected. tInc: The time series indicating the presence of motion artifacts. - nSV (Quantity): Specifies the number of prinicpal components to remove from the - data. If nSV < 1 then the filter removes the first n components of the data - that removes a fraction of the variance up to nSV. + nSV: Specifies the number of principal components to remove from the data. + If nSV < 1 then the filter removes the first n components of the data + that removes a fraction of the variance up to nSV. Returns: - fNIRSdata_cleaned (cdt.NDTimeSeries): The motion-corrected fNIRS data. - svs (np.array): the singular values of the PCA. - nSV (Quantity): the number of principal components removed from the data. + fNIRSdata_cleaned: The motion-corrected fNIRS data. + svs: The singular values of the PCA. + nSV: The number of principal components removed from the data. """ # apply mask to get only points with motion @@ -400,7 +398,7 @@ def motion_correct_PCA_recurse( amp_thresh: Quantity = 5, nSV: Quantity = 0.97, maxIter: Quantity = 5, -): +) -> tuple[cdt.NDTimeSeries, np.array, int, cdt.NDTimeSeries]: """Identify motion artefacts in input fNIRSdata. If any active channel exhibits signal change greater than STDEVthresh or AMPthresh, @@ -409,24 +407,26 @@ def motion_correct_PCA_recurse( until maxIter is reached or there are no motion artefacts identified. Args: - fNIRSdata (cdt.NDTimeSeries): The fNIRS data to be motion corrected. - t_motion: check for signal change indicative of a motion artefact over - time range tMotion. (units of seconds) - t_mask (Quantity): mark data +/- tMask seconds aroundthe identified motion + fNIRSdata: The fNIRS data to be motion corrected. + t_motion: Check for signal change indicative of a motion artefact over + time range tMotion (units of seconds). + t_mask: Mark data +/- tMask seconds around the identified motion artefact as a motion artefact. - stdev_thresh (Quantity): if the signal d for any given active channel changes by + stdev_thresh: If the signal d for any given active channel changes by more than stdev_thresh * stdev(d) over the time interval tMotion then this time point is marked as a motion artefact. - amp_thresh (Quantity): if the signal d for any given active channel changes + amp_thresh: If the signal d for any given active channel changes by more than amp_thresh over the time interval tMotion then this time point is marked as a motion artefact. - nSV: FIXME - maxIter: FIXME + nSV: Specifies the number of principal components to remove from the data. + maxIter: Maximum number of iterations for motion correction. Returns: - fNIRSdata_cleaned (cdt.NDTimeSeries): The motion-corrected fNIRS data. - svs (np.array): the singular values of the PCA. - nSV (int): the number of principal components removed from the data. + Tuple with + fNIRSdata_cleaned: The motion-corrected fNIRS data. + svs: The singular values of the PCA. + nSV: The number of principal components removed from the data. + tInc: The time series indicating the presence of motion artifacts. """ tIncCh = id_motion( diff --git a/src/cedalion/sigproc/quality.py b/src/cedalion/sigproc/quality.py index 624abff1..2d01f76c 100644 --- a/src/cedalion/sigproc/quality.py +++ b/src/cedalion/sigproc/quality.py @@ -33,25 +33,21 @@ def prune_ch( operator: str, flag_drop: bool = True, ): - """Prune channels from the the input data array using quality masks. + """Prune channels from the input data array using quality masks. Args: - amplitudes (:class:`NDTimeSeries`): input time series - masks (:class:`list[NDTimeSeries]`) : list of boolean masks with coordinates - comptabile to amplitudes - - operator: operators for combination of masks before pruning data_array - - - "all": logical AND, keeps channel if it is good across all masks - - "any": logical OR, keeps channel if it is good in any mask/metric - - flag_drop: if True, channels are dropped from the data_array, otherwise they are - set to NaN (default: True) + amplitudes: Input time series. + masks: List of boolean masks with coordinates compatible with amplitudes. + operator: Operator for combination of masks before pruning data_array. + - "all": Logical AND, keeps channel if it is good across all masks. + - "any": Logical OR, keeps channel if it is good in any mask/metric. + flag_drop: If True, channels are dropped from the data_array, otherwise they are + set to NaN (default: True). Returns: - A tuple (amplitudes_pruned, prune_list), where amplitudes_pruned is - a the original time series channels pruned (dropped) according to quality masks. - A list of removed channels is returned in prune_list. + A tuple (amplitudes_pruned, prune_list), where amplitudes_pruned is the original + time series with channels pruned (dropped) according to quality masks. A list of + removed channels is returned in prune_list. """ # check if all dimensions in the all the masks are also existing in data_array @@ -101,9 +97,9 @@ def psp( :cite:t:`Pollonini2016`. Args: - amplitudes (:class:`NDTimeSeries`, (channel, wavelength, time)): input time + amplitudes: input time series - window_length (:class:`Quantity`, [time]): size of the computation window + window_length: size of the computation window psp_thresh: if the calculated PSP metric falls below this threshold then the corresponding time window should be excluded. cardiac_fmin : minimm frequency to extract cardiac component @@ -179,13 +175,12 @@ def gvtd(amplitudes: NDTimeSeries, stat_type: str = "default", n_std: int = 10): """Calculate GVTD metric based on :cite:t:`Sherafati2020`. Args: - amplitudes (:class:`NDTimeSeries`, (channel, wavelength, time)): input time - series + amplitudes: input time series, coords (channel, wavelength, time) - stat_type (string): statistic of GVTD time trace to use to set the threshold + stat_type: statistic of GVTD time trace to use to set the threshold (see _get_gvtd_threshold). Default = 'default' - n_std (int): number of standard deviations for consider above the statistic of + n_std: number of standard deviations for consider above the statistic of interest. Returns: @@ -231,13 +226,13 @@ def _get_gvtd_threshold( GVTD: NDTimeSeries, stat_type: str = "default", n_std: int = 10, -): +) -> float: """Calculate GVTD threshold based on :cite:t:`Sherafati2020`. Args: - GVTD (:class:`NDTimeSeries`, (time,)): GVTD timetrace + GVTD: GVTD timetrace, coords (time,) - stat_type (string): statistic of GVTD time trace to use to set the threshold + stat_type: statistic of GVTD time trace to use to set the threshold - *default*: threshold is the mode plus the distance between the smallest GVTD value and the mode. @@ -253,11 +248,11 @@ def _get_gvtd_threshold( - *median*: same as histogram_mode but using the median instead of the mode. - *MAD*: same as histogram_mode but using the MAD instead of the mode. - n_std (int): number of standard deviations for consider above the statistic of + n_std: number of standard deviations for consider above the statistic of interest. Returns: - thresh (float): the threshold above which GVTD is considered motion. + thresh: the threshold above which GVTD is considered motion. """ units = GVTD.pint.units @@ -458,13 +453,13 @@ def sci( :cite:t:`Pollonini2016`. Args: - amplitudes (:class:`NDTimeSeries`, (channel, wavelength, time)): input time + amplitudes: input time, coords (channel, wavelength, time) series - window_length (:class:`Quantity`, [time]): size of the computation window + window_length: size of the computation window sci_thresh: if the calculated SCI metric falls below this threshold then the corresponding time window should be excluded. - cardiac_fmin : minimm frequency to extract cardiac component - cardiac_fmax : maximum frequency to extract cardiac component + cardiac_fmin: minimm frequency to extract cardiac component + cardiac_fmax: maximum frequency to extract cardiac component Returns: A tuple (sci, sci_mask), where sci is a DataArray with coords from the input @@ -523,7 +518,7 @@ def snr(amplitudes: cdt.NDTimeSeries, snr_thresh: float = 2.0): SNR is the ratio of the average signal over time divided by its standard deviation. Args: - amplitudes (:class:`NDTimeSeries`, (time, *)): the input time series + amplitudes: the input time series, coords (time,*) snr_thresh: threshold (unitless) below which a channel should be excluded. Returns: @@ -550,7 +545,7 @@ def mean_amp(amplitudes: cdt.NDTimeSeries, amp_range: tuple[Quantity, Quantity]) """Calculate mean amplitudes and mask channels outside amplitude range. Args: - amplitudes (:class:`NDTimeSeries`, (time, *)): input time series + amplitudes: input time series, coords (time, *) amp_range: if amplitudes.mean("time") < amp_threshs[0] or > amp_threshs[1] then it is excluded as an active channel in amp_mask Returns: @@ -583,8 +578,8 @@ def sd_dist( """Calculate source-detector separations and mask channels outside a distance range. Args: - amplitudes (:class:`NDTimeSeries`, (channel, *)): input time series - geo3D (:class:`LabeledPointCloud`): 3D optode coordinates + amplitudes: input time series, coords (channel, *) + geo3D: 3D optode coordinates sd_range: if source-detector separation < sd_range[0] or > sd_range[1] then it is excluded as an active channelin sd_mask @@ -627,24 +622,19 @@ def id_motion( as a motion artifact. Args: - fNIRSdata (:class:`NDTimeSeries`, (time, channel, *)): input time series - - t_motion (:class:`Quantity`, [time]): time interval for motion artifact + fNIRSdata: input time series, coords (time, channel, *) + t_motion: time interval for motion artifact detection. Checks for signal change indicative of a motion artifact over time range t_motion. - - - t_mask (:class:`Quantity`, [time]): time range to mask around motion artifacts. + t_mask: time range to mask around motion artifacts. Mark data over +/- t_mask around the identified motion artifact as a motion artifact. - stdev_thresh: threshold for std deviation of signal change. If the signal d for any given active channel changes by more than stdev_thresh * stdev(d) over the time interval tMotion, then this time point is marked as a motion artifact. The standard deviation is determined for each channel independently. Typical value ranges from 5 to 20. Use a value of 100 or greater if you wish for this condition to not find motion artifacts. - amp_thresh: threshold for amplitude of signal change. If the signal d for any given active channel changes by more that amp_thresh over the time interval t_motion, then this time point is marked as a motion artifact. Typical value @@ -718,8 +708,8 @@ def id_motion_refine(ma_mask: cdt.NDTimeSeries, operator: str): """Refines motion artifact mask to simplify and quantify motion artifacts. Args: - ma_mask (:class:`NDTimeSeries`, (time, channel, *)): motion artifact mask as - generated by id_motion(). + ma_mask: motion artifact mask as generated by id_motion(), coords + (time, channel, *). operator: operation to apply to the mask. Available operators: @@ -807,8 +797,8 @@ def detect_outliers_std( """Detect outliers in fNIRSdata based on standard deviation of signal. Args: - ts :class:`NDTimeSeries`, (time, channel, *): fNIRS timeseries data - t_window :class:`Quantity`: time window over which to calculate std. deviations + ts: fNIRS timeseries data, coords (time, channel, *) + t_window: time window over which to calculate std. deviations iqr_threshold: interquartile range threshold (detect outlier as any std. deviation outside iqr_threshold * [25th percentile, 75th percentile]) @@ -855,7 +845,7 @@ def detect_outliers_grad(ts: cdt.NDTimeSeries, iqr_threshold: float = 1.5): """Detect outliers in fNIRSdata based on gradient of signal. Args: - ts (:class:`NDTimeSeries`, (time, channel, *)): fNIRS timeseries data + ts: fNIRS timeseries data, coords (time, channel, *) iqr_threshold: interquartile range threshold (detect outlier as any gradient outside iqr_threshold * [25th percentile, 75th percentile]) @@ -912,8 +902,8 @@ def detect_outliers( """Detect outliers in fNIRSdata based on standard deviation and gradient of signal. Args: - ts (:class:`NDTimeSeries`, (time, channel, *)): fNIRS timeseries data - t_window_std (:class:`Quantity`): time window over which to calculate std. devs. + ts: fNIRS timeseries data, coords (time, channel, *) + t_window_std: time window over which to calculate std. devs. iqr_threshold_grad: interquartile range threshold (detect outlier as any gradient outside iqr_threshold * [25th percentile, 75th percentile]) iqr_threshold_std: interquartile range threshold (detect outlier as any standard @@ -937,7 +927,7 @@ def _mask1D_to_segments(mask: ArrayLike): """Find consecutive segments for a boolean mask. Args: - mask (ArrayLike): boolean mask + mask: boolean mask Returns: Given a boolean mask, this function returns an integer array `segments` of @@ -961,16 +951,16 @@ def _mask1D_to_segments(mask: ArrayLike): return segments -def _calculate_snr(ts, fs, segments): +def _calculate_snr(ts: ArrayLike, fs: float, segments: ArrayLike): """Calculate signal to noise ratio for a time series. Args: - ts (ArrayLike): Time series - fs (float): Sampling rate - segments (ArrayLike): Segments of the time series + ts: Time series + fs: Sampling rate + segments: Segments of the time series Returns: - float: Signal to noise ratio + Signal to noise ratio """ # Calculate signal to noise ratio by considering only longer segments. # Only segments longer than 3s are used. Segments may be clean or tainted. @@ -989,16 +979,16 @@ def _calculate_snr(ts, fs, segments): return snr -def _calculate_delta_threshold(ts, segments, threshold_samples): +def _calculate_delta_threshold(ts: ArrayLike, segments: float, threshold_samples: int): """Calculate delta threshold for a time series. Args: - ts (ArrayLike): Time series - segments (ArrayLike): Segments of the time series - threshold_samples (int): Threshold samples + ts: Time series. + segments: Segments of the time series. + threshold_samples: Threshold samples. Returns: - float: Delta threshold + Delta threshold. """ # for long segments (>threshold_samples (0.5s)) that are not marked as artifacts # calculate the absolute differences of samples that are threshold_samples away @@ -1023,8 +1013,8 @@ def detect_baselineshift(ts: cdt.NDTimeSeries, outlier_mask: cdt.NDTimeSeries): """Detect baselineshifts in fNIRSdata. Args: - ts (:class:`NDTimeSeries`, (time, channel, *)): fNIRS timeseries data - outlier_mask (:class:`NDTimeSeries`): mask containing FALSE anytime an outlier + ts: fNIRS timeseries data, coords (time, channel, *) + outlier_mask: mask containing FALSE anytime an outlier is detected in signal Returns: diff --git a/src/cedalion/sigproc/tasks.py b/src/cedalion/sigproc/tasks.py index 3bd455db..892100bd 100644 --- a/src/cedalion/sigproc/tasks.py +++ b/src/cedalion/sigproc/tasks.py @@ -13,13 +13,12 @@ def int2od( ts_input: str | None = None, ts_output: str = "od", ): - """Calculate optical density from intensity amplitude data. + """Calculate optical density from intensity amplitude data. Args: - rec (Recording): container of timeseries data - ts_input (str): name of intensity timeseries. If None, this tasks operates on - the last timeseries in rec.timeseries. - ts_output (str): name of optical density timeseries. + rec: Container of timeseries data. + ts_input: Name of intensity timeseries. If None, this task operates on the last timeseries in rec.timeseries. + ts_output: Name of optical density timeseries. """ ts = rec.get_timeseries(ts_input) @@ -38,12 +37,11 @@ def od2conc( """Calculate hemoglobin concentrations from optical density data. Args: - rec (Recording): container of timeseries data - dpf (dict[float, float]): differential path length factors - spectrum (str): label of the extinction coefficients to use (default: "prahl") - ts_input (str | None): name of intensity timeseries. If None, this tasks operates - on the last timeseries in rec.timeseries. - ts_output (str): name of optical density timeseries (default: "conc"). + rec: Container of timeseries data. + dpf: Differential path length factors. + spectrum: Label of the extinction coefficients to use (default: "prahl"). + ts_input: Name of intensity timeseries. If None, this task operates on the last timeseries in rec.timeseries. + ts_output: Name of optical density timeseries (default: "conc"). """ ts = rec.get_timeseries(ts_input) @@ -70,13 +68,11 @@ def snr( """Calculate signal-to-noise ratio (SNR) of timeseries data. Args: - rec (Recording): The recording object containing the data. - snr_thresh (float): The SNR threshold. - ts_input (str | None, optional): The input time series. Defaults to None. - aux_obj_output (str, optional): The key for storing the SNR in the auxiliary - object. Defaults to "snr". - mask_output (str, optional): The key for storing the mask in the recording - object. Defaults to "snr". + rec: The recording object containing the data. + snr_thresh: The SNR threshold. + ts_input: The input time series. Defaults to None. + aux_obj_output: The key for storing the SNR in the auxiliary object. Defaults to "snr". + mask_output: The key for storing the mask in the recording object. Defaults to "snr". """ ts = rec.get_timeseries(ts_input) @@ -98,14 +94,12 @@ def sd_dist( """Calculate source-detector separations and mask channels outside a range. Args: - rec (Recording): The recording object containing the data. - sd_min (Annotated[Quantity, "[length]"]): The minimum source-detector separation. - sd_max (Annotated[Quantity, "[length]"]): The maximum source-detector separation. - ts_input (str | None, optional): The input time series. Defaults to None. - aux_obj_output (str, optional): The key for storing the source-detector distances - in the auxiliary object. Defaults to "sd_dist". - mask_output (str, optional): The key for storing the mask in the recording object. - Defaults to "sd_dist". + rec: The recording object containing the data. + sd_min: The minimum source-detector separation. + sd_max: The maximum source-detector separation. + ts_input: The input time series. Defaults to None. + aux_obj_output: The key for storing the source-detector distances in the auxiliary object. Defaults to "sd_dist". + mask_output: The key for storing the mask in the recording object. Defaults to "sd_dist". """ ts = rec.get_timeseries(ts_input) diff --git a/src/cedalion/sim/synthetic_hrf.py b/src/cedalion/sim/synthetic_hrf.py index cad1b307..9869d3a4 100644 --- a/src/cedalion/sim/synthetic_hrf.py +++ b/src/cedalion/sim/synthetic_hrf.py @@ -22,8 +22,8 @@ def generate_hrf( time_axis: xr.DataArray, stim_dur: Quantity = 10 * units.seconds, - params_basis: list = [0.1000, 3.0000, 1.8000, 3.0000], - scale: list = [10 * units.micromolar, -4 * units.micromolar], + params_basis: list[float] = [0.1000, 3.0000, 1.8000, 3.0000], + scale: list[float] = [10 * units.micromolar, -4 * units.micromolar], ): """Generates HRF basis functions for different chromophores. @@ -34,14 +34,14 @@ def generate_hrf( Args: time_axis: The time axis for the resulting HRF. stim_dur: Duration of the stimulus. - params_basis (list of float): Parameters for tau and sigma for the modified + params_basis: Parameters for tau and sigma for the modified gamma function for each chromophore. Expected to be a flat list where pairs represent [tau, sigma] for each chromophore. - scale (list of float): Scaling factors for each chromophore, typically + scale: Scaling factors for each chromophore, typically [HbO scale, HbR scale]. Returns: - xarray.DataArray: A DataArray object with dimensions "time" and "chromo", + A DataArray object with dimensions "time" and "chromo", containing the HRF basis functions for each chromophore. Initial Contributors: @@ -119,15 +119,15 @@ def build_blob( The blob is centered at the vertex closest to the seed landmark. Args: - head_model (cfm.TwoSurfaceHeadModel): Head model with brain and scalp surfaces. - landmark (str): Name of the seed landmark. - scale (Quantity): Scale of the blob. - m (float): Geodesic distance parameter. Larger values of m will smooth & + head_model: Head model with brain and scalp surfaces. + landmark: Name of the seed landmark. + scale: Scale of the blob. + m: Geodesic distance parameter. Larger values of m will smooth & regularize the distance computation. Smaller values of m will roughen and will usually increase error in the distance computation. Returns: - xr.DataArray: Blob image with activation values for each vertex. + Blob image with activation values for each vertex. Initial Contributors: - Thomas Fischer | t.fischer.1@campus.tu-berlin.de | 2024 @@ -161,17 +161,16 @@ def hrfs_from_image_reco( """Maps an activation blob on the brain to HRFs in channel space. Args: - blob (xr.DataArray): Activation values for each vertex. - hrf_model (xr.DataArray): HRF model for HbO and HbR. - Adot (xr.DataArray): Sensitivity matrix for the forward model. + blob: Activation values for each vertex. + hrf_model: HRF model for HbO and HbR. + Adot: Sensitivity matrix for the forward model. Returns: - cdt.NDTimeseries: HRFs in channel space. + HRFs in channel space. Initial Contributors: - Laura Carlton | lcarlton@bu.edu | 2024 - Thomas Fischer | t.fischer.1@campus.tu-berlin.de | 2024 - """ hrf_model = hrf_model.pint.to(units.molar) @@ -201,28 +200,20 @@ def hrfs_from_image_reco( def add_hrf_to_vertices( hrf_basis: xr.DataArray, num_vertices: int, scale: np.array = None -): +) -> xr.DataArray: """Adds hemodynamic response functions (HRF) for HbO and HbR to specified vertices. - This function applies temporal HRF profiles to vertices, optionally scaling the - response by a provided amplitude scale. It generates separate images for HbO and HbR - and then combines them. - Args: - hrf_basis (xarray.DataArray): Dataset containing HRF time series for - HbO and HbR. - num_vertices (int): Total number of vertices in the image space. - scale (np.array, optional): Array of scale factors of shape (num_vertices) to - scale the amplitude of HRFs. + hrf_basis: Dataset containing HRF time series for HbO and HbR. + num_vertices: Total number of vertices in the image space. + scale: Array of scale factors of shape (num_vertices) to scale the amplitude of HRFs. Returns: - xr.DataArray: Combined image of HbO and HbR responses across all vertices for - all time points. + Combined image of HbO and HbR responses across all vertices for all time points. Initial Contributors: - Laura Carlton | lcarlton@bu.edu | 2024 - Thomas Fischer | t.fischer.1@campus.tu-berlin.de | 2024 - """ unit = hrf_basis.pint.units @@ -266,22 +257,22 @@ def build_stim_df( min_interval: Quantity = 5 * units.seconds, max_interval: Quantity = 10 * units.seconds, order: str = "alternating", -): +) -> pd.DataFrame: """Generates a DataFrame for stimulus metadata based on provided parameters. Stimuli can be added in an 'alternating' or 'random' order, and the inter-stimulus interval (ISI) is chosen randomly between the minimum and maximum allowed intervals. Args: - num_stims (int): Number of stimuli to be added for each trial type. - stim_dur (int): Duration of the stimulus in seconds. - trial_types (list): List of trial types for the stimuli. - min_interval (int): Minimum inter-stimulus interval in seconds. - max_interval (int): Maximum inter-stimulus interval in seconds. - order (str): Order of adding Stims; 'alternating' or 'random'. + num_stims: Number of stimuli to be added for each trial type. + stim_dur: Duration of the stimulus. + trial_types: List of trial types for the stimuli. + min_interval: Minimum inter-stimulus interval. + max_interval: Maximum inter-stimulus interval. + order: Order of adding stimuli; 'alternating' or 'random'. Returns: - pd.DataFrame: DataFrame containing stimulus metadata. + DataFrame containing stimulus metadata. Initial Contributors: - Laura Carlton | lcarlton@bu.edu | 2024 @@ -339,20 +330,21 @@ def build_stim_df( @cdc.validate_schemas -def add_hrf_to_od(od: cdt.NDTimeSeries, hrfs: cdt.NDTimeSeries, stim_df: pd.DataFrame): +def add_hrf_to_od( + od: cdt.NDTimeSeries, hrfs: cdt.NDTimeSeries, stim_df: pd.DataFrame +) -> cdt.NDTimeSeries: """Adds Hemodynamic Response Functions (HRFs) to optical density (OD) data. The timing of the HRFs is based on the provided stimulus dataframe (stim_df). Args: - od (cdt.NDTimeSeries): OD timeseries data with dimensions - ["channel", "wavelength", "time"]. - hrfs (cdt.NDTimeSeries): HRFs in channel space with dimensions - ["channel", "wavelength", "time"] + maybe ["trial_type"]. - stim_df (pd.DataFrame): DataFrame containing stimulus metadata. + od: OD timeseries data with dimensions ["channel", "wavelength", "time"]. + hrfs: HRFs in channel space with dimensions ["channel", "wavelength", "time"] + + maybe ["trial_type"]. + stim_df: DataFrame containing stimulus metadata. Returns: - cdt.NDTimeSeries: OD data with HRFs added based on the stimulus dataframe. + OD data with HRFs added based on the stimulus dataframe. Initial Contributors: - Laura Carlton | lcarlton@bu.edu | 2024 @@ -402,26 +394,23 @@ def hrf_to_long_channels( y: cdt.NDTimeSeries, geo3d: xr.DataArray, ss_tresh: Quantity = 1.5 * units.cm, -): - """Add HRFs to optical density (OD) data in channel space. +) -> xr.DataArray: + """Adds HRFs to optical density (OD) data in channel space. Broadcasts the HRF model to long channels based on the source-detector distances. - Short channel hrfs are filled with zeros. + Short channel HRFs are filled with zeros. Args: - hrf_model (xr.DataArray): HRF model with dimensions ["time", "wavelength"]. - y (cdt.NDTimeSeries): Raw amp / OD / Chromo timeseries data with dimensions - ["channel", "time"]. - geo3d (xr.DataArray): 3D coordinates of sources and detectors. - ss_tresh (Quantity): Threshold for short/long channels. + hrf_model: HRF model with dimensions ["time", "wavelength"]. + y: Raw amp / OD / Chromo timeseries data with dimensions ["channel", "time"]. + geo3d: 3D coordinates of sources and detectors. + ss_tresh: Threshold for short/long channels. Returns: - xr.DataArray: HRFs in channel space with dimensions - ["channel", "time", "wavelength"]. + HRFs in channel space with dimensions ["channel", "time", "wavelength"]. Initial Contributors: - Thomas Fischer | t.fischer.1@campus.tu-berlin.de | 2024 - """ # Calculate source-detector distances for each channel @@ -450,17 +439,17 @@ def get_colors( vertex_colors: np.array, log_scale: bool = False, max_scale: float = None, -): +) -> np.array: """Maps activations to colors for visualization. Args: - activations (xr.DataArray): Activation values for each vertex. - vertex_colors (np.array): Vertex color array of the brain mesh. - log_scale (bool): Whether to map activations on a logarithmic scale. - max_scale (float): Maximum value to scale the activations. + activations: Activation values for each vertex. + vertex_colors: Vertex color array of the brain mesh. + log_scale: Whether to map activations on a logarithmic scale. + max_scale: Maximum value to scale the activations. Returns: - np.array: New vertex color array with same shape as `vertex_colors`. + New vertex color array with same shape as `vertex_colors`. """ if not isinstance(activations, np.ndarray): diff --git a/src/cedalion/sim/synthetic_utils.py b/src/cedalion/sim/synthetic_utils.py index 9381b7b4..1d699174 100644 --- a/src/cedalion/sim/synthetic_utils.py +++ b/src/cedalion/sim/synthetic_utils.py @@ -69,25 +69,25 @@ def build_event_df( scenarios. Args: - time_axis (xr.DataArray): The time axis of the data. - trial_types (List[str]): List of trial types to draw from. - num_events (int, optional): Number of events to generate. - perc_events (float, optional): Percentage of total time to cover with events. - min_dur (Quantity): Minimum event duration. - max_dur (Quantity): Maximum event duration. - min_interval (Quantity): Minimum interval between events. - min_value (float): Minimum event amplitude. - max_value (float): Maximum event amplitude. - order (str): Order of types ('alternating', 'random', or 'random balanced'). + time_axis: The time axis of the data. + trial_types: List of trial types to draw from. + num_events: Number of events to generate. + perc_events: Percentage of total time to cover with events. + min_dur: Minimum event duration. + max_dur: Maximum event duration. + min_interval: Minimum interval between events. + min_value: Minimum event amplitude. + max_value: Maximum event amplitude. + order: Order of types ('alternating', 'random', or 'random balanced'). Alternating will cycle through trial types. Random will randomly assign trial types. Random balanced will randomly assign trial types, but each type will be assigned the same number of times (if possible). - channels (List[str], optional): List of channel names to add events to. - max_attempts (int): Maximum number of attempts to place events. + channels: List of channel names to add events to. + max_attempts: Maximum number of attempts to place events. Returns: - df (pd.DataFrame): DataFrame containing stimulus metadata. Columns are: + DataFrame containing stimulus metadata. Columns are: - onset: Event onset time. - duration: Event duration. - value: Event amplitude. diff --git a/src/cedalion/vtktutils.py b/src/cedalion/vtktutils.py index 0d47e53f..c6c1e8e7 100644 --- a/src/cedalion/vtktutils.py +++ b/src/cedalion/vtktutils.py @@ -56,10 +56,10 @@ def pyvista_polydata_to_trimesh(polydata: pv.PolyData) -> trimesh.Trimesh: """Convert a PyVista PolyData object to a Trimesh object. Args: - polydata (pv.PolyData): The input PyVista PolyData object. + polydata: The input PyVista PolyData object. Returns: - trimesh.Trimesh: The converted Trimesh object. + The converted Trimesh object. """ vertices = polydata.points faces = polydata.regular_faces diff --git a/src/cedalion/xrutils.py b/src/cedalion/xrutils.py index 6ea3f181..905dcfea 100644 --- a/src/cedalion/xrutils.py +++ b/src/cedalion/xrutils.py @@ -14,10 +14,10 @@ def pinv(array: xr.DataArray) -> xr.DataArray: DataArrays with units in their attrs. Args: - array (xr.DataArray): Input array + array: Input array Returns: - array_inv (xr.DataArray): Pseudoinverse of the input array + array_inv: Pseudoinverse of the input array """ if not array.ndim == 2: raise ValueError("array must have only 2 dimensions") @@ -56,11 +56,11 @@ def norm(array: xr.DataArray, dim: str) -> xr.DataArray: Extends the behavior of numpy.linalg.norm to xarray DataArrays. Args: - array (xr.DataArray): Input array - dim (str): Dimension along which to calculate the norm + array: Input array + dim: Dimension along which to calculate the norm Returns: - normed (xr.DataArray): Array with the norm along the specified dimension + normed: Array with the norm along the specified dimension """ if dim not in array.dims: raise ValueError(f"array does not have dimension '{dim}'") @@ -90,7 +90,7 @@ def apply_mask( """Apply a boolean mask to a DataArray according to the defined "operator". Args: - data_array: NDTimeSeries, input time series data xarray + data_array: input time series data xarray mask: input boolean mask array with a subset of dimensions matching data_array operator: operators to apply to the mask and data_array "nan": inserts NaNs in the data_array where mask is False diff --git a/tests/test_labeled_points.py b/tests/test_labeled_points.py index 5e2678ce..4402fb11 100644 --- a/tests/test_labeled_points.py +++ b/tests/test_labeled_points.py @@ -1,3 +1,4 @@ +from __future__ import annotations import pytest import xarray as xr import numpy as np