Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/cedalion/dataclasses/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, xarray_obj):
"""Initialize the CedalionAccessor.

Args:
xarray_obj (xr.DataArray): The DataArray to which this accessor is attached.
xarray_obj: The DataArray to which this accessor is attached.
"""
self._validate(xarray_obj)
self._obj = xarray_obj
Expand Down Expand Up @@ -65,16 +65,16 @@ def to_epochs(

return to_epochs(self._obj, df_stim, trial_types, before, after)

def freq_filter(self, fmin, fmax, butter_order=4):
def freq_filter(self, fmin: float, fmax: float, butter_order: int =4) -> xr.DataArray:
"""Applys a Butterworth filter.

Args:
fmin (float): The lower cutoff frequency.
fmax (float): The upper cutoff frequency.
butter_order (int): The order of the Butterworth filter.
fmin: The lower cutoff frequency.
fmax: The upper cutoff frequency.
butter_order: The order of the Butterworth filter.

Returns:
result (xarray.DataArray): The filtered time series.
result: The filtered time series.
"""
array = self._obj

Expand Down Expand Up @@ -292,11 +292,11 @@ def _validate(obj):
f"Stimulus DataFame must have column {column_name}."
)

def rename_events(self, rename_dict):
def rename_events(self, rename_dict: Dict[str, str]) -> None:
"""Renames trial types in the DataFrame based on the provided dictionary.

Args:
rename_dict (dict): A dictionary with the old trial type as key and the new
rename_dict: A dictionary with the old trial type as key and the new
trial type as value.
"""
stim = self._obj
Expand Down
38 changes: 28 additions & 10 deletions src/cedalion/dataclasses/geometry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Dataclasses for representing geometric objects."""

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from functools import total_ordering
from typing import Any
from typing import Any, List

import mne
import numpy as np
Expand Down Expand Up @@ -197,7 +198,7 @@ def apply_transform(self, transform: cdt.AffineTransform) -> "TrimeshSurface":
"""Apply an affine transformation to this surface.

Args:
transform (cdt.AffineTransform): The affine transformation to apply.
transform: The affine transformation to apply.

Returns:
TrimeshSurface: The transformed surface.
Expand Down Expand Up @@ -234,7 +235,11 @@ def smooth(self, lamb: float) -> "TrimeshSurface":
smoothed = trimesh.smoothing.filter_taubin(self.mesh, lamb=lamb)
return TrimeshSurface(smoothed, self.crs, self.units)

def get_vertex_normals(self, points: cdt.LabeledPointCloud, normalized=True):
def get_vertex_normals(
self,
points: cdt.LabeledPointCloud,
normalized: bool = True
):
"""Get normals of vertices closest to the provided points."""

assert points.points.crs == self.crs
Expand Down Expand Up @@ -400,7 +405,11 @@ def apply_transform(self, transform: cdt.AffineTransform) -> "PycortexSurface":
def decimate(self, face_count: int) -> "PycortexSurface":
raise NotImplementedError("Decimation not implemented for PycortexSurface")

def get_vertex_normals(self, points: cdt.LabeledPointCloud, normalized=True):
def get_vertex_normals(
self,
points: cdt.LabeledPointCloud,
normalized: bool = True
):
assert points.points.crs == self.crs
assert points.pint.units == self.units
points = points.pint.dequantify()
Expand Down Expand Up @@ -601,7 +610,11 @@ def avg_edge_length(self):
return edgelens.mean()


def surface_gradient(self, scalars, at_verts=True):
def surface_gradient(
self,
scalars: np.ndarray,
at_verts: bool = True
) -> np.ndarray:
"""Gradient of a function with values `scalars` at each vertex on the surface.

If `at_verts`, returns values at each vertex. Otherwise, returns values at each
Expand All @@ -610,9 +623,8 @@ def surface_gradient(self, scalars, at_verts=True):
Args:
scalars : 1D ndarray, shape (total_verts,) a scalar-valued function across
the cortex.
at_verts : bool, optional
If True (default), values will be returned for each vertex. Otherwise,
values will be returned for each face.
at_verts : If True (default), values will be returned for each vertex.
Otherwise, values will be returned for each face.

Returns:
gradu : 2D ndarray, shape (total_verts,3) or (total_polys,3)
Expand Down Expand Up @@ -642,7 +654,7 @@ def _facenorm_cross_edge(self):
return fe12, fe23, fe31


def geodesic_distance(self, verts, m=1.0, fem=False):
def geodesic_distance(self, verts, m: float = 1.0, fem: bool = False) -> np.ndarray:
"""Calcualte the inimum mesh geodesic distance (in mm).

The geodesic distance is calculated from each vertex in surface to any vertex in
Expand Down Expand Up @@ -745,7 +757,13 @@ def geodesic_distance(self, verts, m=1.0, fem=False):
return phi


def geodesic_path(self, a, b, max_len=1000, d=None, **kwargs):
def geodesic_path(
self, a: int,
b: int,
max_len: int = 1000,
d: np.ndarray = None,
**kwargs
) -> List:
"""Finds the shortest path between two points `a` and `b`.

This shortest path is based on geodesic distances across the surface.
Expand Down
12 changes: 6 additions & 6 deletions src/cedalion/dataclasses/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_timeseries(self, key: Optional[str] = None) -> NDTimeSeries:
"""Get a timeseries object by key.

Args:
key (Optional[str]): The key of the timeseries to retrieve. If None, the
key: The key of the timeseries to retrieve. If None, the
last timeseries is returned.

Returns:
Expand Down Expand Up @@ -104,7 +104,7 @@ def get_mask(self, key: Optional[str] = None) -> xr.DataArray:
"""Get a mask by key.

Args:
key (Optional[str]): The key of the mask to retrieve. If None, the last
key: The key of the mask to retrieve. If None, the last
mask is returned.

Returns:
Expand All @@ -124,9 +124,9 @@ def set_mask(self, key: str, value: xr.DataArray, overwrite: bool = False):
"""Set a mask.

Args:
key (str): The key of the mask to set.
value (xr.DataArray): The mask to set.
overwrite (bool): Whether to overwrite an existing mask with the same key.
key: The key of the mask to set.
value: The mask to set.
overwrite: Whether to overwrite an existing mask with the same key.
Defaults to False.
"""
if (overwrite is False) and (key in self.masks):
Expand All @@ -138,7 +138,7 @@ def get_timeseries_type(self, key):
"""Get the type of a timeseries.

Args:
key (str): The key of the timeseries.
key: The key of the timeseries.

Returns:
str: The type of the timeseries.
Expand Down
19 changes: 10 additions & 9 deletions src/cedalion/dataclasses/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Data array schemas and utilities to build labeled data arrays."""

from __future__ import annotations
import functools
import inspect
import typing
Expand Down Expand Up @@ -99,20 +100,20 @@ def build_timeseries(
value_units: str,
time_units: str,
other_coords: dict[str, ArrayLike] = {},
):
) -> xr.DataArray:
"""Build a labeled time series data array.

Args:
data (ArrayLike): The data values.
dims (List[str]): The dimension names.
time (ArrayLike): The time values.
channel (List[str]): The channel names.
value_units (str): The units of the data values.
time_units (str): The units of the time values.
other_coords (dict[str, ArrayLike]): Additional coordinates.
data: The data values.
dims: The dimension names.
time: The time values.
channel: The channel names.
value_units: The units of the data values.
time_units: The units of the time values.
other_coords: Additional coordinates.

Returns:
da (xr.DataArray): The labeled time series data array.
da: The labeled time series data array.
"""
assert len(dims) == data.ndim
assert "time" in dims
Expand Down
15 changes: 8 additions & 7 deletions src/cedalion/geometry/landmarks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module for constructing the 10-10-system on the scalp surface."""

from __future__ import annotations
import warnings
from typing import List, Optional

Expand Down Expand Up @@ -125,8 +126,8 @@ def __init__(self, scalp_surface: Surface, landmarks: LabeledPointCloud):
"""Initialize the LandmarksBuilder1010.

Args:
scalp_surface (Surface): a triangle-mesh representing the scalp
landmarks (LabeledPointCloud): positions of "Nz", "Iz", "LPA", "RPA"
scalp_surface: a triangle-mesh representing the scalp
landmarks: positions of "Nz", "Iz", "LPA", "RPA"
"""
if isinstance(scalp_surface, TrimeshSurface):
scalp_surface = VTKSurface.from_trimeshsurface(scalp_surface)
Expand Down Expand Up @@ -197,9 +198,9 @@ def _add_landmarks_along_line(
"""Add landmarks along a line defined by three landmarks.

Args:
triangle_labels (List[str]): Labels of the three landmarks defining the line
labels (List[str]): Labels for the new landmarks
dists (List[float]): Distances along the line where the new landmarks should
triangle_labels: Labels of the three landmarks defining the line
labels: Labels for the new landmarks
dists: Distances along the line where the new landmarks should
be placed.
"""
assert len(triangle_labels) == 3
Expand Down Expand Up @@ -319,8 +320,8 @@ def order_ref_points_6(landmarks: xr.DataArray, twoPoints: str) -> xr.DataArray:
"""Reorder a set of six landmarks based on spatial relationships and give labels.

Args:
landmarks (xr.DataArray): coordinates for six landmark points
twoPoints (str): two reference points ('Nz' or 'Iz') for orientation.
landmarks: coordinates for six landmark points
twoPoints: two reference points ('Nz' or 'Iz') for orientation.

Returns:
xr.DataArray: the landmarks ordered as "Nz", "Iz", "RPA", "LPA", "Cz"
Expand Down
1 change: 1 addition & 0 deletions src/cedalion/geometry/photogrammetry/processors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Vertex classifiers."""

from __future__ import annotations
import colorsys
from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand Down
34 changes: 18 additions & 16 deletions src/cedalion/geometry/registration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Registrating optodes to scalp surfaces."""

from __future__ import annotations
import numpy as np
from numpy.linalg import pinv
from scipy.optimize import linear_sum_assignment, minimize
from scipy.spatial import KDTree
import xarray as xr
from typing import Tuple

import cedalion
import cedalion.dataclasses as cdc
Expand Down Expand Up @@ -38,8 +40,8 @@ def register_trans_rot(
between the two point clouds.

Args:
coords_target (LabeledPointCloud): Target point cloud.
coords_trafo (LabeledPointCloud): Source point cloud.
coords_target: Target point cloud.
coords_trafo: Source point cloud.

Returns:
cdt.AffineTransform: Affine transformation between the two point clouds.
Expand Down Expand Up @@ -133,8 +135,8 @@ def register_trans_rot_isoscale(
between the two point clouds.

Args:
coords_target (LabeledPointCloud): Target point cloud.
coords_trafo (LabeledPointCloud): Source point cloud.
coords_target: Target point cloud.
coords_trafo: Source point cloud.

Returns:
cdt.AffineTransform: Affine transformation between the two point clouds.
Expand Down Expand Up @@ -195,9 +197,9 @@ def gen_xform_from_pts(p1: np.ndarray, p2: np.ndarray) -> np.ndarray:
"""Calculate the affine transformation matrix T that transforms p1 to p2.

Args:
p1 (np.ndarray): Source points (p x m) where p is the number of points and m is
p1: Source points (p x m) where p is the number of points and m is
the number of dimensions.
p2 (np.ndarray): Target points (p x m) where p is the number of points and m is
p2: Target points (p x m) where p is the number of points and m is
the number of dimensions.

Returns:
Expand Down Expand Up @@ -237,21 +239,21 @@ def register_icp(
surface: cdc.Surface,
landmarks: cdt.LabeledPointCloud,
geo3d: cdt.LabeledPointCloud,
niterations=1000,
random_sample_fraction=0.5,
):
niterations: int =1000,
random_sample_fraction: float =0.5,
) -> Tuple[np.ndarray, np.ndarray]:
"""Iterative Closest Point algorithm for registration.

Args:
surface (Surface): Surface mesh to which to register the points.
landmarks (LabeledPointCloud): Landmarks to use for registration.
geo3d (LabeledPointCloud): Points to register to the surface.
niterations (int): Number of iterations for the ICP algorithm (default 1000).
random_sample_fraction (float): Fraction of points to use in each iteration
surface: Surface mesh to which to register the points.
landmarks: Landmarks to use for registration.
geo3d: Points to register to the surface.
niterations: Number of iterations for the ICP algorithm (default 1000).
random_sample_fraction: Fraction of points to use in each iteration
(default 0.5).

Returns:
Tuple[np.ndarray, np.ndarray]: Tuple containing the losses and transformations
Tuple containing the losses and transformations
"""
units = "mm"
landmarks_mm = landmarks.pint.to(units).points.to_homogeneous().pint.dequantify()
Expand Down Expand Up @@ -496,7 +498,7 @@ def simple_scalp_projection(geo3d: cdt.LabeledPointCloud) -> cdt.LabeledPointClo
"""Projects 3D coordinates onto a 2D plane using a simple scalp projection.

Args:
geo3d (LabeledPointCloud): 3D coordinates of points to project. Requires the
geo3d: 3D coordinates of points to project. Requires the
landmarks Nz, LPA, and RPA.

Returns:
Expand Down
Loading
Loading