diff --git a/.gitmodules b/.gitmodules index f5a812d9..7257b728 100644 --- a/.gitmodules +++ b/.gitmodules @@ -34,3 +34,6 @@ [submodule "core_libraries/submodules/networkx"] path = core_libraries/submodules/networkx url = https://github.com/networkx/networkx.git +[submodule "scripts/codehub/utils/acquisition/BIDS"] + path = scripts/codehub/utils/acquisition/BIDS + url = https://github.com/penn-cnt/EEG_BIDS.git diff --git a/README.md b/README.md index 86289d28..8f096703 100644 --- a/README.md +++ b/README.md @@ -6,14 +6,16 @@ CNT Code Hub This code is designed to help with the processing of epilepsy datasets commonly used within the Center for Neuroengineering & Therapeutics (CNT) at the University of Pennsylvania. -This code is meant to be researcher driven, allowing new code libraries to be added to modules that represent common research tasks (i.e. Channel Cleaning, Montaging, Preprocessing, etc.). The code can be accessed both as independent libraries that can be called on for a range of tasks, or as part of a large framework meant to ingest, clean, and prepare data for analysis or deep-learning tasks. +This code is meant to be researcher driven, allowing new code libraries to be added to modules that represent common research tasks (i.e. Channel Cleaning, Montaging, Preprocessing, Feature Extraction, etc.). The code can be accessed both as independent libraries that can be called on for a range of tasks, or as part of a large framework meant to ingest, clean, and prepare data for analysis or deep-learning tasks. -For more information on how to use our code, please see the examples folder for specific use-cases and common practices. +We also provide a number of additional scripts to help with common/important tasks. For more information, please refer [here](https://github.com/penn-cnt/CNT-codehub/tree/main/scripts/codehub/utils/) for what scripts are currently available. -# Prerequisites +# Installation + +## Prerequisites In order to use this repository, you must have access to Python 3+. You must also have access to conda 23.+ if building environments from yaml files. -# Installation +## Using Conda An environment file with all the needed packages to run this suite of code can be found at the following location: @@ -39,17 +41,38 @@ The environment is then activated by running: More information about creating conda environments can be found [here](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html). -## Adding the codehub to your python paths +### Adding the codehub to your conda environment paths You will likely need to add this package to your python path to ensure full functionality of utility scripts and the main pipeline. To do so using anaconda, you can run: > conda develop /scripts/codehub/ -For a virtual environment, an easy way to add `/scripts/codehub/` to your path would be to add a text file with a .pth extention (any filename is fine) to the site-packages subfolder in your virtual environment folder. Within the text file you can just copy and paste the absolute path as the only contents. +## Using venv and pip + +To create a virtual environment, you need to create a location for the environment to install to. For this example, we will specify `/demonstration/environment/cnt_codehub` as our environment location. Using the python version of your choice, in this example we will select 3.10, run the following command: + +> python3.10 -m venv /demonstration/environment/cnt_codehub + +to create a new virtual environment. To enter the envrionment, simply run: + +> source /demonstration/environment/cnt_codehub/bin/activate -## Installation using venv +**NOTE:** To streamline the process, we recommend making an alias command to avoid having to navigate the the activate file everytime. + +Once in the environment, a requirements.txt file with all the needed packages to run this suite of code can be found at the following location: + +> [CNT Codehub YAML](core_libraries/python/cnt_codehub/envs/requirements.txt) + +This file can be installed using the following call to pip from the envs subdirectory: + +> pip install -r requirements.txt + +which will install everything to your current virual environment. + +### Adding the codehub to your virtual environment path +For a virtual environment, an easy way to add `/scripts/codehub/` to your path would be to add a text file with a .pth extention (any filename is fine) to the site-packages subfolder in your virtual environment folder. Within the text file you can just copy and paste the absolute path as the only contents. -To be added soon. +Typically, the path your your site-packages can be found at: `/lib/python/site-packages`. # Documentation diff --git a/core_libraries/matlab/README.md b/core_libraries/matlab/README.md deleted file mode 100644 index a863263b..00000000 --- a/core_libraries/matlab/README.md +++ /dev/null @@ -1 +0,0 @@ -:woman_shrugging: diff --git a/core_libraries/python/cnt_codehub/cnt_codehub.yml b/core_libraries/python/cnt_codehub/cnt_codehub.yml new file mode 100644 index 00000000..c809454a --- /dev/null +++ b/core_libraries/python/cnt_codehub/cnt_codehub.yml @@ -0,0 +1,27 @@ +name: codehub +channels: + - conda-forge + - defaults +dependencies: + - fooof + - neurodsp + - nltk + - numpy + - pandas + - prettytable + - pybids + - pytorch + - pyyaml + - scikit-learn + - scipy + - tqdm + - pip: + - ./wheels/ieeg-1.6-py3-none-any.whl + - edflib-python + - edfio + - mne + - mne-bids + - mne-icalabel + - nibabel + - pyEDFlib + - yasa diff --git a/core_libraries/python/cnt_codehub/envs/README.md b/core_libraries/python/cnt_codehub/envs/README.md deleted file mode 100644 index ec8ebaec..00000000 --- a/core_libraries/python/cnt_codehub/envs/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# iEEG Environments - -Home for scalp eeg environment files. diff --git a/core_libraries/python/cnt_codehub/envs/cnt_codehub.yml b/core_libraries/python/cnt_codehub/envs/cnt_codehub.yml deleted file mode 100644 index ff618825..00000000 --- a/core_libraries/python/cnt_codehub/envs/cnt_codehub.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: cnt_codehub -channels: - - conda-forge - - defaults -dependencies: - - python>=3.0 - - numpy - - scipy - - pandas - - requests - - deprecation - - tqdm - - ipython - - pytorch - - scikit-learn - - paramiko - - nltk - - pip: - - pennprov - - mne - - mne-bids - - mne-icalabel - - pyEDFlib - - EDFlib-Python - - pyyaml - - ../wheels/ieeg-1.6-py3-none-any.whl - - dearpygui - - pyperclip - - keyring diff --git a/core_libraries/python/cnt_codehub/envs/requirements.txt b/core_libraries/python/cnt_codehub/envs/requirements.txt deleted file mode 100644 index 581d811e..00000000 --- a/core_libraries/python/cnt_codehub/envs/requirements.txt +++ /dev/null @@ -1,22 +0,0 @@ -numpy -scipy -pandas -requests -deprecation -tqdm -ipython -pytorch -scikit-learn -paramiko -nltk -pennprov -mne -mne-bids -mne-icalabel -pyEDFlib -EDFlib-Python -pyyaml -../wheels/ieeg-1.6-py3-none-any.whl -dearpygui -pyperclip -keyring \ No newline at end of file diff --git a/core_libraries/python/cnt_codehub/requirements.txt b/core_libraries/python/cnt_codehub/requirements.txt new file mode 100644 index 00000000..35f7b453 --- /dev/null +++ b/core_libraries/python/cnt_codehub/requirements.txt @@ -0,0 +1,117 @@ +antropy==0.1.9 +astor +asttokens==3.0.0 +attrs +bids-validator +bidsschematools +cached-property +certifi==2025.1.31 +charset-normalizer==3.4.1 +click +colorama +comm==0.2.2 +contourpy +cycler +decorator==5.2.1 +deprecation==2.1.0 +docopt +edfio==0.4.8 +EDFlib-Python==1.0.8 +executing==2.2.0 +filelock +fonttools +fooof +formulaic +frozendict +fsspec +gmpy2 +graphlib-backport +greenlet +idna==3.10 +./wheels/ieeg-1.6-py3-none-any.whl +importlib_resources +interface_meta +ipython +ipython_pygments_lexers==1.1.1 +ipywidgets==8.1.6 +jedi==0.19.2 +Jinja2 +joblib +jsonschema +jsonschema-specifications +jupyterlab_widgets==3.0.14 +kiwisolver +lazy_loader==0.4 +lightgbm==4.6.0 +llvmlite +lspopt==1.4.0 +markdown-it-py +MarkupSafe +matplotlib +matplotlib-inline +mdurl +mne +mne-bids +mne-icalabel +mpmath +munkres==1.1.4 +networkx +neurodsp +nibabel +nltk +num2words +numba +numpy +optree +packaging +pandas +parso==0.8.4 +pennprov==2.2.4 +pexpect==4.9.0 +pillow +pkgutil_resolve_name +platformdirs==4.3.7 +pooch==1.8.2 +prettytable +prompt_toolkit==3.0.51 +psutil==7.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pybids +pybind11 +pybind11_global +pyEDFlib==0.1.40 +Pygments==2.19.1 +pyparsing +pyriemann==0.8 +python-dateutil +pytz +PyYAML +referencing +regex +requests==2.32.3 +rpds-py +scikit-learn +scipy +seaborn==0.13.2 +setuptools==75.8.2 +six +sleepecg +SQLAlchemy +stack-data==0.6.3 +sympy +tabulate +tensorpac==0.6.5 +threadpoolctl +torch +tqdm +traitlets==5.14.3 +typing_extensions +tzdata +universal_pathlib +urllib3==2.4.0 +wcwidth +widgetsnbextension==4.0.14 +wrapt +yasa==0.6.5 +zipp diff --git a/core_libraries/python/cnt_codehub/wheels/README.md b/core_libraries/python/cnt_codehub/wheels/README.md deleted file mode 100644 index 571d4977..00000000 --- a/core_libraries/python/cnt_codehub/wheels/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# iEEG Environments - -Home for scalp eeg wheel files. diff --git a/scripts/CNT-STUDENT-DIRECTORY-GOES-HERE/README.md b/scripts/CNT-STUDENT-DIRECTORY-GOES-HERE/README.md new file mode 100644 index 00000000..8ab520f5 --- /dev/null +++ b/scripts/CNT-STUDENT-DIRECTORY-GOES-HERE/README.md @@ -0,0 +1,31 @@ +# Sample User Directory + +This is a placeholder folder to help you orient yourself to the repository for the first time. You can either rename this to your eventual project folder, or just delete it. + +We provide some information on how to make a project folder/rename this folder below. + +## Naming conventions + +This is an example user folder. Personal project work takes place in folders saved at the same level as the central `codehub' folder. + +We require that the naming of the folder follow the following design pattern: + +> {device name}\_{project name}\_{optional subdirectory} + +We require this naming structure in order to join different projects within the CNT data ecosystem. Multiple users can work within a single repository, either within the same directory or within their own optional subdirectories. + +### Example + +If I am working on a scalp multi-layer perceptron (MLP) project to predict sleep stages, spikes, and pnes predictions. I might go about making folders like follows: + +- `device_name`: I am working on scalp data, so I will go with `scalp` +- `project_name`: I am using MLP for a a few different tasks, so lets just summarize the project as `MLP`. + +Now I could stop there and just make my folder: `scalp_MLP` and place all of my work within. Or, if I wanted to be careful about the environment for each sub-goal, or maybe I was collaborating and each person was doing their own sub-goal, I could make the following folders: +- scalp_MLP_sleep +- scalp_MLP_sleep-stags +- scalp_MLP_pnes-predictor + +## Updating the codehub libraries + +Any changes to scripts within the [modules](../codehub/modules) subdiretory can be submitted to the main lab repository as its own branch, at which point a pull request will be reviewed before changes are accepted or rejected. diff --git a/scripts/codehub/README.md b/scripts/codehub/README.md index efbf8e2f..407a4828 100644 --- a/scripts/codehub/README.md +++ b/scripts/codehub/README.md @@ -45,4 +45,9 @@ Public components are where lab code can be saved for everyone to use and for ep The utility scripts are not built into the epipy framework, and do not require specific formatting. To add a utility script, simply identify or create a new folder that generally defines the task being done (data acquisition/data validation/etc.) and add your code to the existing folder for that task type, or create a new folder defining the task and add it there. ### Pull Requests -Submit a pull request to share your changes with the lab as a whole. The data team will review the request before merging it, or sending it back to you for more clarity or bug fixes. \ No newline at end of file +Submit a pull request to share your changes with the lab as a whole. The data team will review the request before merging it, or sending it back to you for more clarity or bug fixes. + + +## Remaining Updates/Fixes + +- Remove keys from metadata handler if they do not get promoted to the feature extraction step. \ No newline at end of file diff --git a/scripts/codehub/allowed_arguments.yaml b/scripts/codehub/allowed_arguments.yaml index 94a0195c..cdf243a0 100644 --- a/scripts/codehub/allowed_arguments.yaml +++ b/scripts/codehub/allowed_arguments.yaml @@ -16,6 +16,7 @@ allowed_channel_args: NEUROVISTA: Channels for NeuroVista data. RAW: Use all possible channels. Warning, channels may not match across different datasets. allowed_montage_args: + NONE: No montage. HUP1020: Use a 10-20 montage. NEUROVISTA: Use a custom NeuroVista montage. COMMON_AVERAGE: Use a common average montage. diff --git a/scripts/codehub/components/core/internal/config_loader.py b/scripts/codehub/components/core/internal/config_loader.py index a383a8d9..38744192 100644 --- a/scripts/codehub/components/core/internal/config_loader.py +++ b/scripts/codehub/components/core/internal/config_loader.py @@ -27,7 +27,9 @@ def __init__(self,input_file): self.yaml_step = config else: # Read in and typecast the yaml file - config = yaml.safe_load(open(input_file,'r')) + fp = open(input_file,'r') + config = yaml.safe_load(fp) + fp.close() # Add in any looped steps to the correct yaml input format self.loop_handler(config) diff --git a/scripts/codehub/components/core/internal/dataframe_manager.py b/scripts/codehub/components/core/internal/dataframe_manager.py index d07ed03f..04a2507f 100644 --- a/scripts/codehub/components/core/internal/dataframe_manager.py +++ b/scripts/codehub/components/core/internal/dataframe_manager.py @@ -39,16 +39,4 @@ def column_subsection(self,keep_columns): # Get the columns to drop drop_cols = np.setdiff1d(self.dataframe.columns,keep_columns) - self.dataframe = self.dataframe.drop(drop_cols, axis=1) - - def montaged_dataframe(self,data,columns): - """ - Create a dataframe that stores the montaged data. - DEPRECIATE AFTER BETA PIPELINE RELEASE! - - Args: - data (array): array of montaged data - columns (list): List of column names - """ - - self.montaged_dataframe = PD.DataFrame(data,columns=columns) \ No newline at end of file + self.dataframe = self.dataframe.drop(drop_cols, axis=1) \ No newline at end of file diff --git a/scripts/codehub/components/core/internal/output_manager.py b/scripts/codehub/components/core/internal/output_manager.py index d7a600ea..0928bb3d 100644 --- a/scripts/codehub/components/core/internal/output_manager.py +++ b/scripts/codehub/components/core/internal/output_manager.py @@ -44,10 +44,20 @@ def save_features(self): """ if not self.args.debug and not self.args.no_feature_flag: - #timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") - pickle.dump(self.metadata,open("%s/%s_%s_meta.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb")) - pickle.dump(self.feature_df,open("%s/%s_%s_features.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb")) - pickle.dump(self.feature_commands,open("%s/%s_%s_fconfigs.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb")) + + # Pickled objects + fp1 = open("%s/%s_%s_meta.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb") + fp2 = open("%s/%s_%s_fconfigs.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb") + fp3 = open("%s/%s_%s_features.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb") + pickle.dump(self.metadata,fp1) + pickle.dump(self.feature_commands,fp2) + pickle.dump(self.feature_df,fp3) + fp1.close() + fp2.close() + fp3.close() + + # CSV object + #self.feature_df.to_csv("%s/%s_%s_features.pickle" %(self.args.outdir,self.timestamp,self.unique_id),index=False) def save_output_list(self): """ @@ -55,6 +65,9 @@ def save_output_list(self): """ if not self.args.debug: - #timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") - pickle.dump(self.output_list,open("%s/%s_%s_data.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb")) - pickle.dump(self.metadata,open("%s/%s_%s_meta.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb")) \ No newline at end of file + fp1 = open("%s/%s_%s_data.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb") + fp2 = open("%s/%s_%s_meta.pickle" %(self.args.outdir,self.timestamp,self.unique_id),"wb") + pickle.dump(self.output_list,fp1) + pickle.dump(self.metadata,fp2) + fp1.close() + fp2.close() \ No newline at end of file diff --git a/scripts/codehub/components/core/internal/target_loader.py b/scripts/codehub/components/core/internal/target_loader.py index 19339a73..41179f7d 100644 --- a/scripts/codehub/components/core/internal/target_loader.py +++ b/scripts/codehub/components/core/internal/target_loader.py @@ -43,7 +43,9 @@ def load_targets(self,current_edf,datatype,target_substring): if self.target_file != None: # Load the data - raw_targets = pickle.load(open(self.target_file,"rb")) + fp = open(self.target_file,"rb") + raw_targets = pickle.load(fp) + fp.close() # Apply logic to known target types self.target_logic(raw_targets,current_edf) diff --git a/scripts/codehub/components/curation/internal/data_curation.py b/scripts/codehub/components/curation/internal/data_curation.py index aa479b36..64ac049b 100644 --- a/scripts/codehub/components/curation/internal/data_curation.py +++ b/scripts/codehub/components/curation/internal/data_curation.py @@ -50,8 +50,8 @@ def test_input_data(self): for idx,ifile in enumerate(self.files): # Get the load type - ftype = self.args.datatype - if ftype.lower() == 'mix': + ftype = self.args.datatype.lower() + if ftype == 'mix': ftype = ifile.split('.')[-1] # Use the load type and perform a load test @@ -149,7 +149,7 @@ def create_time_windows(self): for ifile in self.files: # Read in just the header to get duration - t_end = self.args.t_end + t_end = self.args.t_end.copy() for idx,ival in enumerate(t_end): if ival == -1: @@ -159,7 +159,8 @@ def create_time_windows(self): dtype = ifile.split('.')[-1].lower() if dtype == 'edf': - t_end[idx] = read_edf_header(ifile)['Duration'] + header = read_edf_header(ifile) + t_end[idx] = header['Duration'] elif dtype == 'pickle': idict = pickle.load(open(ifile,'rb')) t_end[idx] = idict['data'].shape[0]/idict['samp_freq'] @@ -249,6 +250,7 @@ def stratifier_BIDS_subject_count(self): for ifile in self.files: regex_match = re.match(r"(\D+)(\d+)", ifile) self.stratification_array.append(int(regex_match.group(2))) + subcnt = np.unique(self.stratification_array).size if not self.args.silent: print(f"Assuming BIDS data, approximately {subcnt:04d} subjects loaded.") diff --git a/scripts/codehub/components/curation/public/data_loader.py b/scripts/codehub/components/curation/public/data_loader.py index 82da7ad8..67f83447 100644 --- a/scripts/codehub/components/curation/public/data_loader.py +++ b/scripts/codehub/components/curation/public/data_loader.py @@ -9,9 +9,6 @@ from mne.io import read_raw_edf from pyedflib.highlevel import read_edf_header -# CNT/EEG Specific -from ieeg.auth import Session - # Component imports from components.metadata.public.metadata_handler import * @@ -75,9 +72,15 @@ def pipeline(self): self.ssh_username = self.args.ssh_username # Logic gate for filetyping, returns if load succeeded - flag = self.data_loader_logic(self.args.datatype) + readflag = self.data_loader_logic(self.args.datatype) - if flag: + # Get the data slice now so we can do a quick quality check + if readflag: + sample_frequency = np.array([self.sfreq for ichannel in self.channel_metadata]) + setflag = self.raw_dataslice(sample_frequency,majoraxis=self.args.orientation) + + # Set the information to our metadata object if it passes all tests + if setflag: # Create the metadata handler metadata_handler.highlevel_info(self) @@ -86,12 +89,8 @@ def pipeline(self): metadata_handler.set_channels(self,self.channels) # Calculate the sample frequencies to save the information and make time cuts - sample_frequency = np.array([self.sfreq for ichannel in self.channel_metadata]) metadata_handler.set_sampling_frequency(self,sample_frequency) - # Get the rawdata - self.raw_dataslice(sample_frequency,majoraxis=self.args.orientation) - # Set the clip duration referenced to the whole file metadata_handler.set_ref_window(self) @@ -159,6 +158,7 @@ def raw_dataslice(self,sample_frequency,majoraxis='column'): else: samp_end = int(isamp*self.t_end) + # Get the dataslice if majoraxis.lower() == 'column': self.raw_data.append(self.indata[samp_start:samp_end,ii]) elif majoraxis.lower() == 'row': @@ -174,6 +174,12 @@ def raw_dataslice(self,sample_frequency,majoraxis='column'): elif majoraxis.lower() == 'row': self.duration = (samp_end-samp_start)/self.indata.shape[1] + # Check for monotonic or zero data + if (np.ptp(self.raw_data,axis=1)==0).all(): + return False + else: + return True + ################################### #### User Provided Logic Below #### ################################### @@ -210,6 +216,8 @@ def load_edf(self): # Read in the data via mne backend raw = read_raw_edf(self.infile,verbose=False) self.indata = raw.get_data().T + + # Make the MNE objects self.channels = raw.ch_names self.sfreq = raw.info.get('sfreq') diff --git a/scripts/codehub/components/features/public/features.py b/scripts/codehub/components/features/public/features.py index 0eb1fc6b..5d235f3f 100644 --- a/scripts/codehub/components/features/public/features.py +++ b/scripts/codehub/components/features/public/features.py @@ -1,13 +1,18 @@ import os import ast import sys +import mne +import yasa import inspect import warnings import numpy as np import pandas as PD from tqdm import tqdm from fooof import FOOOF +from functools import wraps +from scipy.stats import mode from scipy.integrate import simpson +from sklearn.linear_model import LinearRegression from scipy.signal import welch, find_peaks, detrend from neurodsp.spectral import compute_spectrum_welch @@ -20,17 +25,296 @@ persistance_dict = {} # Ignore FutureWarnings. Pandas is giving a warning for concat. But the data is not zero. Might be due to a single channel of all NaNs. +from sklearn.exceptions import InconsistentVersionWarning warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action='ignore', category=InconsistentVersionWarning) + +def channel_wrapper(method): + """Decorator to apply a method to each column unless the user directly passes the column to analyze.""" + @wraps(method) + def wrapper(self, channel=None, *args, **kwargs): + if channel is not None: + # Run method on a specific column + return [method(self, channel, *args, **kwargs)] + else: + # Run method on all columns and return results as a dictionary + return [method(self, col, *args, **kwargs) for col in self.channels] + return wrapper + +class channel_wise_metrics: + + def __init__ (self, data, fs, file, fidx, channels=None, trace=False): + # Manage data typing and form a dataframe as needed + if isinstance(data, np.ndarray): + if isinstance(channels,list) or isinstance(channels,np.ndarray): + if isinstance(channels,np.ndarray): + if channels.ndim == 1: + channels = list(channels) + else: + raise ValueError("If passing a numpy arra as channel names, channels must be a 1-d array.") + self.data = data + self.channels = channels + else: + raise ValueError("Channels must not be None and be a list or a 1-d array if passing data as a numpy array.") + elif isinstance(data,PD.DataFrame): + self.data = data.values + self.channels = list(data.columns) + + # Save remaining keywords + self.fs = fs + self.file = file + self.fidx = fidx + self.trace = trace + + # Because we are using the unmontaged data to infer this feature, we want to map the results to the output channel mapping + self.outchannel = channels + + def check_persistance(self): + + self.channelwise_key = f"channelwise_{self.file}_{self.fidx}_{self.window_length}" + if self.channelwise_key not in persistance_dict.keys(): + persistance_dict[self.channelwise_key] = {} + self.fit_ar_model() + self.calculate_source_sink_space() + else: + self.Avg_A = persistance_dict[self.channelwise_key]['Avg_A'] + self.rr = persistance_dict[self.channelwise_key]['rr'] + self.cr = persistance_dict[self.channelwise_key]['cr'] + + def source_index(self,window_length=0.5): + """ + Calculate the source index for each channel. + + Args: + window_length (float, optional): Window length for auto regression in seconds. Defaults to 0.5. + """ + + # Save the window length to the class instance + self.window_length = window_length + + # Check for any pre-calculated metrics + self.check_persistance() + + # Get the sink index + source_index = self.source_fnc() + + # Make the optional tag + optional_str = f"windowlength_{self.window_length}" + + # Make the output results + results = [(i_index,optional_str) for i_index in source_index] + + return results + + def sink_index(self,window_length=0.5): + """ + Calculate the sink index for each channel. + + Args: + window_length (float, optional): Window length for auto regression in seconds. Defaults to 0.5. + """ + + # Save the window length to the class instance + self.window_length = window_length + + # Check for any pre-calculated metrics + self.check_persistance() + + # Get the sink index + sink_index = self.sink_fnc() + + # Make the optional tag + optional_str = f"windowlength_{self.window_length}" + + # Make the output results + results = [(i_index,optional_str) for i_index in sink_index] + + return results + + def fit_ar_model(self): + + # Get window properties + window_size = self.window_length*self.fs + n_windows = self.data.shape[0] // window_size + + # Initialize the model insance + model = LinearRegression() + + # Get the auto-regression across windows + A_all = [] + for i in range(int(n_windows)): + + # Get the current window data + window_data = self.data[int(i * window_size):int((i + 1) * window_size), :] + + # Get the input vectors + X = window_data[:-1] + + # Get the output vectors + y = window_data[1:] + + # Fit the model + model.fit(X, y) + + # Get the slope and store the running list + A = model.coef_ + A_all.append(A) + + # Get the absolute mean of all linear slopes + A_all = np.array(A_all) + Avg_A = np.mean(A_all, axis=0) + Avg_A = np.abs(Avg_A) + + # Store result to class instance + self.Avg_A = Avg_A + + # Store to persistance dict + persistance_dict[self.channelwise_key]['Avg_A'] = self.Avg_A + + def calculate_source_sink_space(self): + + # Update the main diagonal to avoid self-same calculatuon + np.fill_diagonal(self.Avg_A, 0) + + # Get the node strength for column and row-wise + i_node_strength = np.sum(self.Avg_A, axis=0) + j_node_strength = np.sum(self.Avg_A, axis=1) + + # rank node strengths to get row rank (rr) and column rank (cr) + self.rr = self.calculate_rank(i_node_strength) + self.cr = self.calculate_rank(j_node_strength) + + # Store to persistance dict + persistance_dict[self.channelwise_key]['rr'] = self.rr + persistance_dict[self.channelwise_key]['cr'] = self.cr + + def calculate_rank(self,arr): + + # Get the indices that would sort the array in descending order + sorted_indices = np.argsort(arr)[::-1] + + # Create a ranking array + ranks = np.zeros_like(arr, dtype=float) + ranks[sorted_indices] = np.arange(1, len(arr) + 1) / len(ranks) + + return ranks + + def source_fnc(self): + + # Calculate the source index + N = len(self.rr) + x = self.rr - (1/N) + y = self.cr - 1 + vector_length = np.sqrt(x**2 + y**2) + source_index = np.sqrt(2) - vector_length + return source_index + + def sink_fnc(self): + + # Calculate source sink + N = len(self.rr) + x = self.rr - 1 + y = self.cr - (1/N) + vector_length = np.sqrt(x**2 + y**2) + sink_index = np.sqrt(2) - vector_length + return sink_index class YASA_processing: + """ + Yasa sleep staging feature extraction. + """ - def __init__ (self,data,channels): - self.data = data - self.channels = channels - - def get_sleep_stage(self): + def __init__ (self, data, fs, channels=None, trace=False): + # Manage data typing and form a dataframe as needed + if isinstance(data, np.ndarray): + if isinstance(channels,list) or isinstance(channels,np.ndarray): + if isinstance(channels,np.ndarray): + if channels.ndim == 1: + channels = list(channels) + else: + raise ValueError("If passing a numpy arra as channel names, channels must be a 1-d array.") + self.data = data + self.channels = channels + else: + raise ValueError("Channels must not be None and be a list or a 1-d array if passing data as a numpy array.") + elif isinstance(data,PD.DataFrame): + self.data = data.values + self.channels = list(data.columns) + + # Save remaining keywords + self.fs = fs + self.trace = trace + + # Because we are using the unmontaged data to infer this feature, we want to map the results to the output channel mapping + self.outchannel = channels + + def make_montage_object(self,config_path): + + # Create the mne channel types + fp = open(config_path,'r') + mapping = yaml.safe_load(fp) + fp.close() + persistance_dict['mne_mapping'] = mapping + + def make_raw_object(self,config_path): + + # Get the channel mappings in mne compliant form + if 'mne_mapping' not in persistance_dict.keys(): + self.make_montage_object(config_path) + mapping = persistance_dict['mne_mapping'] + mapping_keys = list(mapping.keys()) + + # Assign the mapping to each channel + ch_types = [] + for ichannel in self.channels: + if ichannel in mapping_keys: + ch_types.append(mapping[ichannel]) + else: + ch_types.append('eeg') + + # Create the raw mne object and set the reference + info = mne.create_info(self.channels, self.fs, ch_types=ch_types,verbose=False) + self.raw = mne.io.RawArray(self.data.T, info, verbose=False) + + def yasa_sleep_stage(self,config_path,consensus_channels=['CZ','C03','C04']): + + # Check for a long enough duration + if (self.data.shape[0]/self.fs/60 >=5): + # Make the raw object for YASA to work with + self.make_raw_object(config_path) + + # Set the right reference for eyeblink removal (CAR by default) + self.raw = self.raw.set_eeg_reference('average',verbose=False) + + # Apply the minimum needed filter for eyeblink removal + self.raw = self.raw.filter(0.5,30,verbose=False) + + # Resample down to 100 HZ + self.raw = self.raw.resample(100) + + # Get the yasa prediction + results = [] + for ichannel in consensus_channels: + sls = yasa.SleepStaging(self.raw, eeg_name=ichannel) + results.append(list(sls.predict())) + results = np.array(results) + + # Get the epipy formatted output + output = '' + for irow in results.T: + output += ','.join(irow) + output += '|' + output = output[:-1] + else: + output = None + + # Make the optional string. In this case, the consensus channel list + optional_str = ','.join(consensus_channels) + + # Reformat the output to match the output structure + results = [(output,optional_str) for ichannel in self.outchannel] - raw = mne.io(self.data,self.channels) # Fix this + return results class FOOOF_processing: @@ -196,14 +480,42 @@ def fooof_bandpower(self,lo_freq,hi_freq, win_size=2., win_stride=1.): class signal_processing: """ Class devoted to basic signal processing tasks. (Band-power/peak-finder/etc.) + + Uses only one channel at a time. """ - def __init__(self, data, fs, trace=False): - self.data = data - self.fs = fs - self.trace = trace + def __init__(self, data, fs, channels=None, trace=False): + """ + Store the dataframe object to the signal processing class for use in different methods. + + Args: + data (array or dataframe): Array/DataFrame of timeseries data. Row=Sample, Column=Channel. + fs (float): Sampling Frequency + channels (list): List of channel names. Same order as columns in data. + trace (bool, optional): _description_. Defaults to False. + """ + + # Manage data typing and form a dataframe as needed + if isinstance(data, np.ndarray): + if isinstance(channels,list) or isinstance(channels,np.ndarray): + if isinstance(channels,np.ndarray): + if channels.ndim == 1: + channels = list(channels) + else: + raise ValueError("If passing a numpy arra as channel names, channels must be a 1-d array.") + self.data = PD.DataFrame(data,columns=channels) + self.channels = channels + else: + raise ValueError("Channels must not be None and be a list or a 1-d array if passing data as a numpy array.") + elif isinstance(data,PD.DataFrame): + self.channels = list(data.columns) + + # Save remaining keywords + self.fs = fs + self.trace = trace - def spectral_energy_welch(self, low_freq=-np.inf, hi_freq=np.inf, win_size=2., win_stride=1.): + @channel_wrapper + def spectral_energy_welch(self, channel, low_freq=-np.inf, hi_freq=np.inf, win_size=2., win_stride=1.): """ Returns the spectral energy using the Welch method. @@ -229,7 +541,8 @@ def spectral_energy_welch(self, low_freq=-np.inf, hi_freq=np.inf, win_size=2., w noverlap = int(float(win_stride) * self.fs) # Calculate the welch periodogram - frequencies, initial_power_spectrum = welch(x=self.data.reshape((-1,1)), fs=self.fs, nperseg=nperseg, noverlap=noverlap, axis=0) + idata = self.data[channel].values + frequencies, initial_power_spectrum = welch(x=idata.reshape((-1,1)), fs=self.fs, nperseg=nperseg, noverlap=noverlap, axis=0) initial_power_spectrum = initial_power_spectrum.flatten() inds = (frequencies>=0.5)&np.isfinite(initial_power_spectrum)&(initial_power_spectrum>0) freqs = frequencies[inds] @@ -244,8 +557,53 @@ def spectral_energy_welch(self, low_freq=-np.inf, hi_freq=np.inf, win_size=2., w return spectral_energy,self.optional_tag else: return spectral_energy,self.optional_tag,(['freqs','psd_welch'],frequencies.astype('float16'),psd.astype('float32')) - - def topographic_peaks(self,prominence_height,min_width,height_unit='zscore',width_unit='seconds',detrend_flag=False): + + @channel_wrapper + def normalized_spectral_energy_welch(self, channel, low_freq=-np.inf, hi_freq=np.inf, win_size=2., win_stride=1.): + """ + Returns the spectral energy using the Welch method. + + Args: + low_freq (float, optional): Low frequency cutoff. Defaults to -np.inf. + hi_freq (float, optional): High frequency cutoff. Defaults to np.inf. + win_size (float, optional): Window size in units of sampling frequency. Defaults to 2. + win_stride (float, optional): Window overlap in units of sampling frequency. Defaults to 1. + + Returns: + spectral_energy (float): Spectral energy within the frequency band. + optional_tag (string): Unique identifier that is added to the output dataframe to show the frequency window for which a welch spectral energy was calculated. + (frequencies,psd): If trace is enabled for this pipeline, return the frequencies and psd for this channel for testing. + """ + + # Add in the optional tagging to denote frequency range of this step + low_freq_str = f"{low_freq:.2f}" + hi_freq_str = f"{hi_freq:.2f}" + self.optional_tag = '['+low_freq_str+','+hi_freq_str+']' + + # Get the number of samples in each window for welch average and the overlap + nperseg = int(float(win_size) * self.fs) + noverlap = int(float(win_stride) * self.fs) + + # Calculate the welch periodogram + idata = self.data[channel].values + frequencies, initial_power_spectrum = welch(x=idata.reshape((-1,1)), fs=self.fs, nperseg=nperseg, noverlap=noverlap, axis=0) + initial_power_spectrum = initial_power_spectrum.flatten() + inds = (frequencies>=0.5)&np.isfinite(initial_power_spectrum)&(initial_power_spectrum>0) + freqs = frequencies[inds] + initial_power_spectrum = initial_power_spectrum[inds] + psd = np.interp(frequencies,freqs,initial_power_spectrum) + + # Calculate the spectral energy + mask = (frequencies >= low_freq) & (frequencies <= hi_freq) + spectral_energy = np.trapz(psd[mask], frequencies[mask])/np.trapz(psd, frequencies) + + if not self.trace: + return spectral_energy,self.optional_tag + else: + return spectral_energy,self.optional_tag,(['freqs','psd_welch'],frequencies.astype('float16'),psd.astype('float32')) + + @channel_wrapper + def topographic_peaks(self,channel,prominence_height,min_width,height_unit='zscore',width_unit='seconds',detrend_flag=False): """ Find the topographic peaks in channel data. This is a naive/fast way of finding spikes or slowing. @@ -266,9 +624,9 @@ def topographic_peaks(self,prominence_height,min_width,height_unit='zscore',widt # Detrend as needed if detrend_flag: - data = detrend(self.data) + data = detrend(self.data[channel].values) else: - data = np.copy(self.data) + data = np.copy(self.data[channel].values) # Recast height into a pure number as needed if height_unit == 'zscore': @@ -301,7 +659,8 @@ def topographic_peaks(self,prominence_height,min_width,height_unit='zscore',widt # Return a tuple of (peak, left width, right width) to store all of the peak info return out,self.optional_tag - def line_length(self): + @channel_wrapper + def line_length(self,channel): """ Return the line length along the given channel. @@ -310,18 +669,48 @@ def line_length(self): optional_tag (string): Optional tag """ - LL = np.sum(np.abs(np.ediff1d(self.data))) + LL = np.sum(np.abs(np.ediff1d(self.data[channel].values))) optional_tag = '' return LL,optional_tag class basic_statistics: + """ + Basic features that can be extracted from the raw time series data. + """ - def __init__(self, data, fs, trace=False): - self.data = data - self.fs = fs - self.trace = trace + def __init__(self, data, fs, channels=None, trace=False): + """ + Store the dataframe object to the signal processing class for use in different methods. - def mean(self): + Args: + data (array or dataframe): Array/DataFrame of timeseries data. Row=Sample, Column=Channel. + fs (float): Sampling Frequency + channels (list): List of channel names. Same order as columns in data. + trace (bool, optional): _description_. Defaults to False. + """ + + # Manage data typing and form a dataframe as needed + if isinstance(data, np.ndarray): + if isinstance(channels,list) or isinstance(channels,np.ndarray): + if isinstance(channels,np.ndarray): + if channels.ndim == 1: + channels = list(channels) + else: + raise ValueError("If passing a numpy arra as channel names, channels must be a 1-d array.") + self.data = PD.DataFrame(data,columns=channels) + self.channels = channels + else: + raise ValueError("Channels must not be None and be a list or a 1-d array if passing data as a numpy array.") + elif isinstance(data,PD.DataFrame): + self.data = data.values + self.channels = list(data.columns) + + # Save remaining keywords + self.fs = fs + self.trace = trace + + @channel_wrapper + def mean(self,channel): """ Returns the mean value in a channel. @@ -329,9 +718,10 @@ def mean(self): float: Mean channel intensity. """ - return np.mean(self.data),'mean' + return np.mean(self.data[channel].values),'mean' - def median(self): + @channel_wrapper + def median(self,channel): """ Returns the median value in a channel. @@ -339,9 +729,10 @@ def median(self): float: Median channel intensity. """ - return np.median(self.data),'median' + return np.median(self.data[channel].values),'median' - def stdev(self): + @channel_wrapper + def stdev(self,channel): """ Returns the standard deviation in a channel. @@ -349,9 +740,10 @@ def stdev(self): float: Standard deviation in a channel. """ - return np.std(self.data),'stdev' + return np.std(self.data[channel].values),'stdev' - def quantile(self,q,method='median_unbiased'): + @channel_wrapper + def quantile(self,channel,q,method='median_unbiased'): """ Returns the q-th quantile of the data. @@ -361,14 +753,15 @@ def quantile(self,q,method='median_unbiased'): """ optional_tag = f"quantile_{q:.2f}" - return np.quantile(self.data,q=q,method=method),optional_tag + return np.quantile(self.data[channel].values,q=q,method=method),optional_tag - def rms(self): + @channel_wrapper + def rms(self,channel): """ Returns the mean root mean square of the channel. """ - val = np.sum(self.data**2)/self.data.size + val = np.sum(self.data[channel].values**2)/self.data[channel].values.size return np.sqrt(val),'rms' class features: @@ -425,52 +818,45 @@ def __init__(self): for idx,dataset in enumerate(self.output_list): # Grab the current meta data object - imeta = self.metadata[idx] + meta_idx = self.output_meta[idx] + imeta = self.metadata[meta_idx] # Get the input frequencies - fs = imeta['fs'] - - # Loop over the channels and get the updated values - output = [] - for ichannel in range(dataset.shape[1]): + fs = imeta['fs'][0] + # Obtain the features + output = [] + try: + # Get the input arguments for the current step for key, value in method_args.items(): try: method_args[key] = ast.literal_eval(value) except: pass - # Perform preprocessing step - try: - - # Grab the data and give it a first pass check for all zeros - idata = dataset[:,ichannel] - if not np.any(idata): - raise ValueError(f"Channel {channels[ichannel]} contains all zeros for file {imeta['file']}.") - - ################################# - ###### CLASS INITILIZATION ###### - ################################# - # Create namespaces for each class. Then choose which style of initilization is used by logic gate. - if cls.__name__ == 'FOOOF_processing': - namespace = cls(idata,fs[ichannel],[0.5,32], imeta['file'], idx, ichannel, self.args.trace) - elif cls.__name__ == 'YASA_processing': - namespace = cls(dataset,channels) - else: - namespace = cls(idata,fs[ichannel],self.args.trace) - - # Get the method name and return results from the method - method_call = getattr(namespace,method_name) - results = method_call(**method_args) - result_a = results[0] - result_b = results[1] - - # If the user wants to trace some values (see the results as they are processed), they can return result_c - if len(results) == 3: - + # Create namespaces for each class. Then choose which style of initilization is used by logic gate. + if cls.__name__ == 'FOOOF_processing': + # DEPRECIATED FORMAT! Will not work. Should mirror channel_wise metrics going forward. + namespace = cls(idata,fs,[0.5,32], imeta['file'], idx, ichannel, self.args.trace) + elif cls.__name__ == 'channel_wise_metrics': + namespace = cls(dataset,fs,imeta['file'], idx, channels, self.args.trace) + elif cls.__name__ == 'YASA_processing': + namespace = cls(imeta['unmontaged_data'],fs,channels,self.args.trace) + else: + namespace = cls(dataset,fs,channels,self.args.trace) + + # Get the method name and return results from the method + method_call = getattr(namespace,method_name) + results = method_call(**method_args) + result_a = [iresult[0] for iresult in results] + result_b = results[0][1] + + # If the user wants to trace some values (see the results as they are processed), they can return result_c + if len(results[0]) == 3: + for ii,ichannel in enumerate(channels): # Get the lower level column labels - cols = results[2][0] - vals = results[2][1:] + cols = results[ii][2][0] + vals = results[ii][2][1:] # Make the dictionary to nest into metadata inner_dict = dict(zip(cols,vals)) @@ -479,37 +865,35 @@ def __init__(self): # Add the trace to the metadata metadata_handler.add_metadata(self,idx,method_name,tracemeta) - # Check if we have a multivalue output - if type(result_a) == list: - metadata_handler.add_metadata(self,idx,method_name,result_a) - result_a = result_a[0] - - # Add the results to the output object - output.append(result_a) - - except Exception as e: + # Extend the output with results + output.extend(result_a) + except IndexError:#Exception as e: - # Add the ability to see the error if debugging - if self.args.debug and not self.args.silent: - print(f"Error {e} in step {istep} in {imeta['file']}.") + # Add the ability to see the error if debugging + if self.args.debug: + fname = os.path.split(sys.exc_info()[2].tb_frame.f_code.co_filename)[1] + error_type = sys.exc_info()[0] + line_number = sys.exc_info()[2].tb_lineno + print(f"Error type {error_type} in line {line_number} for {method_name}. Error message: {e}") + exit() - # We need a flexible solution to errors, so just populating a nan value - output.append(None) - try: - result_b = getattr(namespace,'optional_tag') - except: - result_b = "None" - - # Save the error for this step - if not error_flag and not self.args.debug: - error_dir = f"{self.args.outdir}errors/" - if not os.path.exists(error_dir): - os.system(f"mkdir -p {error_dir}") - - fp = open(f"{error_dir}{self.worker_number}_features.error","a") - fp.write(f"Step {istep:02}/{method_name}: Error {e}\n") - fp.close() - error_flag = True + # We need a flexible solution to errors, so just populating a nan value + output.extend([None for ii in range(len(channels))]) + try: + result_b = getattr(namespace,'optional_tag') + except: + result_b = "None" + + # Save the error for this step + if not error_flag and not self.args.debug: + error_dir = f"{self.args.outdir}errors/" + if not os.path.exists(error_dir): + os.system(f"mkdir -p {error_dir}") + + fp = open(f"{error_dir}{self.worker_number}_features.error","a") + fp.write(f"Step {istep:02}/{method_name}: Error {e}\n") + fp.close() + error_flag = True # Use metadata to allow proper feature grouping meta_arr = [imeta['file'].split('/')[-1],imeta['t_start'],imeta['t_end'],imeta['t_window'],method_name,result_b] @@ -545,4 +929,4 @@ def __init__(self): pass # The stagger condition seems to add duplicates. Need to fix eventually. - self.feature_df = self.feature_df.drop_duplicates(ignore_index=True) + self.feature_df = self.feature_df.drop_duplicates(ignore_index=True) \ No newline at end of file diff --git a/scripts/codehub/components/metadata/public/metadata_handler.py b/scripts/codehub/components/metadata/public/metadata_handler.py index e30c1943..ce4a4f04 100644 --- a/scripts/codehub/components/metadata/public/metadata_handler.py +++ b/scripts/codehub/components/metadata/public/metadata_handler.py @@ -23,13 +23,12 @@ def highlevel_info(self): self.metadata[self.file_cntr] = {} # Other high level dataset info - self.metadata[self.file_cntr]['file'] = self.infile - self.metadata[self.file_cntr]['t_start'] = self.t_start - self.metadata[self.file_cntr]['t_end'] = self.t_end - self.metadata[self.file_cntr]['history'] = self.args + self.metadata[self.file_cntr]['file'] = str(self.infile) + self.metadata[self.file_cntr]['t_start'] = self.t_start.astype(np.float16) + self.metadata[self.file_cntr]['t_end'] = self.t_end.astype(np.float16) def set_ref_window(self): - self.metadata[self.file_cntr]['t_window'] = self.t_window + self.metadata[self.file_cntr]['t_window'] = self.t_window.astype(np.float16) def set_channels(self,inputs): @@ -45,7 +44,7 @@ def set_sampling_frequency(self,inputs): def set_target_file(self,inputs): - self.metadata[self.file_cntr]['target_file'] = inputs + self.metadata[self.file_cntr]['target_file'] = str(inputs) def add_metadata(self,file_cntr,key,values): """ @@ -56,4 +55,10 @@ def add_metadata(self,file_cntr,key,values): if type(values) == dict and type(self.metadata[file_cntr][key]) == dict: olddict = self.metadata[file_cntr][key] values = {**olddict,**values} - self.metadata[file_cntr][key] = values \ No newline at end of file + self.metadata[file_cntr][key] = values + + def drop_metadata(self,key): + + for icntr in self.metadata.keys(): + if key in self.metadata[icntr]: + del self.metadata[icntr][key] \ No newline at end of file diff --git a/scripts/codehub/components/posthoc/public/marsh_filter.py b/scripts/codehub/components/posthoc/public/marsh_filter.py new file mode 100644 index 00000000..18672a97 --- /dev/null +++ b/scripts/codehub/components/posthoc/public/marsh_filter.py @@ -0,0 +1,160 @@ +import numpy as np +import pandas as PD +from tqdm import tqdm +import multiprocessing +from sys import argv,exit + +class marsh_rejection: + """ + Applies a marsh rejection mask to a dataframe. + Looks for dt=-1 from the pipeline manager to reference against the full file. + """ + + def __init__(self,DF,channels,multithread,ncpu): + + # Save the input data to class instance + self.DF = DF + self.channels = channels + self.multithread = multithread + self.ncpu = ncpu + self.bar_frmt = '{l_bar}{bar}| {n_fmt}/{total_fmt}|' + + # Find the channel labels + self.ref_cols = np.setdiff1d(self.DF.columns, self.channels) + self.merge_labels = np.concatenate((['file', 'method', 'tag'],self.channels)) + + def workflow(self): + + # Make the keys to break up + marsh_lookup_dict = self.DF.groupby(['file']).indices + marsh_lookup_keys = list(marsh_lookup_dict.keys()) + + # Make the initial subset proposal size. If 0, just use single core + subset_size = len(marsh_lookup_keys) // self.ncpu + + if self.multithread and subset_size>0: + print("A") + + # get the subset list + list_subsets = [marsh_lookup_keys[i:i + subset_size] for i in range(0, subset_size*self.ncpu, subset_size)] + + # Handle leftovers + remainder = list_subsets[self.ncpu*subset_size:] + for idx,ival in enumerate(remainder): + list_subsets[idx] = np.concatenate((list_subsets[idx],np.array([ival]))) + + # Convert to indices + list_subsets_indices = [[] for idx in range(len(list_subsets))] + for idx,subset in enumerate(list_subsets): + for ifile in subset: + list_subsets_indices[idx].extend(marsh_lookup_dict[ifile]) + + # Create processes and start workers + processes = [] + manager = multiprocessing.Manager() + return_dict = manager.dict() + for worker_id, data_chunk in enumerate(list_subsets_indices): + process = multiprocessing.Process(target=self.calculate_marsh, args=(worker_id,data_chunk,return_dict)) + processes.append(process) + process.start() + + # Wait for all processes to complete + for process in processes: + process.join() + else: + self.multithread = False + marsh_keys = list(self.DF.index) + return_dict = self.calculate_marsh(0,marsh_keys,{}) + + # Reformat the output + self.DF = PD.concat(return_dict.values()).reset_index(drop=True) + + return self.DF + + def calculate_marsh(self,worker_num, DF_inds, return_dict): + + try: + # Get the data slice to work on + current_DF = self.DF.loc[DF_inds] + + # Make a dataslice just for rms and just for ll + DF_rms = current_DF.loc[current_DF.method=='rms'] + DF_ll = current_DF.loc[current_DF.method=='line_length'] + + # Convert the data types to numeric + for ichannel in self.channels: + DF_rms.loc[:,ichannel] = DF_rms[ichannel].astype('float32') + DF_ll.loc[:,ichannel] = DF_ll[ichannel].astype('float32') + + # Get the group level values + rms_obj = DF_rms.groupby(['file'])[self.channels] + ll_obj = DF_ll.groupby(['file'])[self.channels] + DF_rms_mean = rms_obj.mean() + DF_rms_stdev = rms_obj.std() + DF_ll_mean = ll_obj.mean() + DF_ll_stdev = ll_obj.std() + + # Make output lists + rms_output = [] + ll_output = [] + + # Apply the filter + DF_rms.set_index(['file'],inplace=True) + DF_ll.set_index(['file'],inplace=True) + DF_rms = DF_rms.sort_values(by=['t_start','t_end','t_window']) + DF_ll = DF_ll.sort_values(by=['t_start','t_end','t_window']) + + # Apply the filter for each group + for ifile in tqdm(DF_rms_mean.index, desc='Applying Marsh Filter', total=len(DF_rms_mean.index),bar_format=self.bar_frmt, position=worker_num, leave=False, dynamic_ncols=True): + + # Get the reference values + ref_rms_mean = DF_rms_mean.loc[ifile] + ref_rms_stdev = DF_rms_stdev.loc[ifile] + ref_ll_mean = DF_ll_mean.loc[ifile] + ref_ll_stdev = DF_ll_stdev.loc[ifile] + + # Get the rms mask + DF_rms_slice = DF_rms.loc[[ifile]] + channel_rms_marsh = DF_rms_slice[self.channels]/(ref_rms_mean+2*ref_rms_stdev).values + DF_rms_slice.loc[:,self.channels] = channel_rms_marsh[self.channels].values + DF_rms_slice.loc[:,['method']] = 'marsh_filter' + DF_rms_slice.loc[:,['tag']] = 'rms' + rms_output.append(DF_rms_slice) + + # Get the line length mask + DF_ll_slice = DF_ll.loc[[ifile]] + channel_ll_marsh = DF_ll_slice[self.channels]/(ref_ll_mean+2*ref_ll_stdev).values + DF_ll_slice.loc[:,self.channels] = channel_ll_marsh[self.channels].values + DF_ll_slice.loc[:,['method']] = 'marsh_filter' + DF_ll_slice.loc[:,['tag']] = 'line_length' + ll_output.append(DF_ll_slice) + + # make the output dataframes + DF_rms = PD.concat(rms_output) + DF_ll = PD.concat(ll_output) + + # Clean up the outputs + DF_rms['file'] = DF_rms.index + DF_ll['file'] = DF_ll.index + DF_rms = DF_rms.reset_index(drop=True) + DF_ll = DF_ll.reset_index(drop=True) + + # Append the results to input + current_DF = PD.concat((current_DF,DF_rms)).reset_index(drop=True) + current_DF = PD.concat((current_DF,DF_ll)).reset_index(drop=True) + + # Save the results to the output object + return_dict[worker_num] = current_DF + + if not self.multithread: + return return_dict + except Exception as e: + + print(DF_rms[self.channels].dtypes) + + import os,sys + fname = os.path.split(sys.exc_info()[2].tb_frame.f_code.co_filename)[1] + error_type = sys.exc_info()[0] + line_number = sys.exc_info()[2].tb_lineno + print(f"Error {error_type} in line {line_number}.") + exit() diff --git a/scripts/codehub/components/posthoc/public/misc/create_yasa_lookup.py b/scripts/codehub/components/posthoc/public/misc/create_yasa_lookup.py new file mode 100644 index 00000000..62eed8a7 --- /dev/null +++ b/scripts/codehub/components/posthoc/public/misc/create_yasa_lookup.py @@ -0,0 +1,61 @@ +import numpy as np +import pandas as PD +from sys import argv + +if __name__ == '__main__': + + # Read in the dataframe + raw_DF = PD.read_pickle(argv[1]) + + # Only allow clips that had five minutes of data + inds = (raw_DF['t_end'].values-raw_DF['t_start'].values)>=300 + raw_DF = raw_DF.iloc[inds] + + # Make the lookup column list and get channel names + lookup_cols = ['file', 't_start', 't_end', 't_window', 'method', 'tag'] + channels = np.setdiff1d(raw_DF.columns,lookup_cols) + + # get the column headers for predictions + cols = np.array(raw_DF.tag.values[0].split(',')) + + # Get the YASA prediction. Which should be the same for all channels as we use a consensus across channels + predictions = [] + for idx,ival in enumerate(raw_DF[channels[0]].values): + try: + formatted_pred = ival.replace('|',',') + formatted_pred = np.array(formatted_pred.split(',')).reshape((-1,cols.size)) + predictions.append(formatted_pred) + except: + predictions.append(np.nan*np.ones((10,3))) + + # Get the start time and filename for each row + files = raw_DF.file.values + t_start = raw_DF.t_start.values + + # Make the final lookup tables + outfile = [] + outstart = [] + outend = [] + outstage = [] + for idx,ifile in enumerate(files): + istart = t_start[idx] + ipred = predictions[idx] + for jdx,sleep_stage in enumerate(ipred): + outfile.append(ifile) + outstart.append(istart+(jdx*30)) + outend.append(istart+((jdx+1)*30)) + outstage.append(sleep_stage) + outstage = np.array(outstage) + + # Make the lookup dataframe + outDF = PD.DataFrame(outfile,columns=['file']) + outDF['t_start'] = outstart + outDF['t_end'] = outend + for idx,icol in enumerate(cols): + outDF[f"yasa_{icol}"] = outstage[:,idx] + + # Sort the results + outDF = outDF.sort_values(by=['file','t_start']) + + # Save the results + outDF.to_csv(argv[2],index=False) diff --git a/scripts/codehub/components/posthoc/public/marsh_rejection.py b/scripts/codehub/components/posthoc/public/misc/marsh_rejection.py similarity index 100% rename from scripts/codehub/components/posthoc/public/marsh_rejection.py rename to scripts/codehub/components/posthoc/public/misc/marsh_rejection.py diff --git a/scripts/codehub/components/posthoc/public/misc/merge_yasa_lookup.py b/scripts/codehub/components/posthoc/public/misc/merge_yasa_lookup.py new file mode 100644 index 00000000..543b5da0 --- /dev/null +++ b/scripts/codehub/components/posthoc/public/misc/merge_yasa_lookup.py @@ -0,0 +1,72 @@ +import numpy as np +import pandas as PD +from sys import argv +from tqdm import tqdm + +if __name__ == '__main__': + + # Read in the raw data + YASA_DF = PD.read_csv(argv[1]) + FEATURE_DF = PD.read_pickle(argv[2]) + + # Clean up the labels to just be sleep or wake + new_map = {'N1':'S','N2':'S','N3':'S','R':'S','W':'W'} + consensus_cols = [icol for icol in YASA_DF if 'yasa' in icol] + for icol in consensus_cols: + YASA_DF[icol] = YASA_DF[icol].apply(lambda x: new_map[x] if x in new_map.keys() else 'U') + + # Get the consensus prediction + preds = YASA_DF[consensus_cols].mode(axis=1).values + YASA_DF['yasa_consensus'] = preds.flatten() + + # Drop the original columns + YASA_DF = YASA_DF.drop(consensus_cols,axis=1) + + # Create the yasa lookup arrays + yasa_files = YASA_DF.file.values + yasa_tstart = YASA_DF.t_start.values + yasa_tend = YASA_DF.t_end.values + unique_files = np.unique(yasa_files) + + # Create the feature dataframe lookup arrays + feature_files = FEATURE_DF.file.values + feature_tstart = FEATURE_DF.t_start.values + + # Populate the YASA feature column with unknowns that we can replace by index with the correct value + YASA_FEATURE = np.array(['U' for ii in range(FEATURE_DF.shape[0])]) + YASA_LOOKUP = YASA_DF['yasa_consensus'].values + + # Step through the unique files + for ifile in tqdm(unique_files,total=unique_files.size): + + # Get the file indices + yasa_file_inds = (yasa_files==ifile) + feature_file_inds = (feature_files==ifile) + + # The yasa lookup was made for more than just the PNES project. So we can cull for files in the feature df + if feature_file_inds.sum() > 0: + + # Step through the time values + unique_tstart = np.unique(yasa_tstart[yasa_file_inds]) + for istart in unique_tstart: + + # Get the time indices + yasa_time_inds = (yasa_tstart==istart) + feature_time_inds = (feature_tstart>=istart)&(feature_tstart<(istart+30)) + + # Get the current prediction, if available + YASA_slice = YASA_LOOKUP[yasa_file_inds&yasa_time_inds] + + # Step through the possible outcomes for the yasa slice size + combined_inds = feature_file_inds&feature_time_inds + if combined_inds.sum() > 0: + if YASA_slice.size == 1: + YASA_FEATURE[combined_inds] = YASA_slice[0] + elif YASA_slice.size > 1: + raise Exception("Too many YASA values map to this feature. Check YASA generation.") + else: + pass + + # Add the prediction to the feature dataframe and save + FEATURE_DF['yasa_prediction'] = YASA_FEATURE + FEATURE_DF.to_csv(argv[3],index=False) diff --git a/scripts/codehub/components/posthoc/public/misc/yasa_join.py b/scripts/codehub/components/posthoc/public/misc/yasa_join.py new file mode 100644 index 00000000..6594cccc --- /dev/null +++ b/scripts/codehub/components/posthoc/public/misc/yasa_join.py @@ -0,0 +1,197 @@ +import argparse +import numpy as np +import pandas as PD +from sys import argv +from tqdm import tqdm + +class clean_yasa: + """ + Clean up the yasa feature output to display results on thirty second windows. + """ + + def __init__(self,yasa_path,yasa_window_size): + + self.raw_yasa = PD.read_pickle(yasa_path) + self.yasa_window_size = yasa_window_size + + def pipeline(self): + + self.data_prep() + self.format_predictions() + self.make_dataframe() + return self.outDF + + def data_prep(self): + """ + Only grab segments with the minimum time requirement and get the metadata out. + """ + + # Only allow clips that had five minutes of data + inds = (self.raw_yasa['t_end'].values-self.raw_yasa['t_start'].values)>=300 + self.raw_yasa = self.raw_yasa.iloc[inds] + + # Make the lookup column list and get channel names + lookup_cols = ['file', 't_start', 't_end', 't_window', 'method', 'tag'] + self.channels = np.setdiff1d(self.raw_yasa.columns,lookup_cols) + + # get the column headers for predictions + self.yasa_cols = np.array(self.raw_yasa.tag.values[0].split(',')) + + def format_predictions(self): + """ + Format the predictions to be a time by yasa channel array + """ + + # Get the YASA prediction. Which should be the same for all channels as we use a consensus across channels + self.predictions = [] + for idx,ival in enumerate(self.raw_yasa[self.channels[0]].values): + + # Get the expected prediction shape + try: + formatted_pred = ival.replace('|',',') + formatted_pred = np.array(formatted_pred.split(',')).reshape((-1,self.yasa_cols.size)) + self.predictions.append(formatted_pred) + except: + nrow = np.floor(self.yasa_window_size/30).astype('int') + self.predictions.append(np.nan*np.ones((nrow,self.yasa_cols.size))) + + def make_dataframe(self): + """ + Make the predictions into a similar format as the feature dataframe for easier merging. + """ + + # Get the start time and filename for each row + files = self.raw_yasa.file.values + t_start = self.raw_yasa.t_start.values + + # Make the final lookup tables + outfile = [] + outstart = [] + outend = [] + outstage = [] + for idx,ifile in tqdm(enumerate(files),total=len(files)): + istart = t_start[idx] + ipred = self.predictions[idx] + for jdx,sleep_stage in enumerate(ipred): + outfile.append(ifile) + outstart.append(istart+(jdx*30)) + outend.append(istart+((jdx+1)*30)) + outstage.append(sleep_stage) + outstage = np.array(outstage) + + # Make the lookup dataframe + outDF = PD.DataFrame(outfile,columns=['file']) + outDF['t_start'] = outstart + outDF['t_end'] = outend + for idx,icol in enumerate(self.yasa_cols): + outDF[f"yasa_{icol}"] = outstage[:,idx] + + # Sort the results + self.outDF = outDF.sort_values(by=['file','t_start']) + +class merge_yasa: + + def __init__(self,feature_path,yasa_df,outpath): + + self.yasa = yasa_df + self.features = PD.read_pickle(feature_path) + self.outpath = outpath + + def pipeline(self): + + self.yasa_mapping() + self.joint_prediction() + self.merge_results() + self.save_results() + + def yasa_mapping(self): + """ + Apply project specific mapping to the yasa labels. + """ + + # Clean up the labels to just be sleep or wake + new_map = {'N1':'S','N2':'S','N3':'S','R':'S','W':'W'} + self.consensus_cols = [icol for icol in self.yasa if 'yasa' in icol] + for icol in self.consensus_cols: + self.yasa[icol] = self.yasa[icol].apply(lambda x: new_map[x] if x in new_map.keys() else 'U') + + def joint_prediction(self): + """ + Get the joint prediction for yasa + """ + + # Get the consensus prediction + preds = self.yasa[self.consensus_cols].mode(axis=1).values + self.yasa['yasa_consensus'] = preds.flatten() + self.yasa = self.yasa.drop(self.consensus_cols,axis=1) + + def merge_results(self): + + # Create the yasa lookup arrays + yasa_files = self.yasa.file.values + yasa_tstart = self.yasa.t_start.values + yasa_tend = self.yasa.t_end.values + unique_files = np.unique(yasa_files) + + # Create the feature dataframe lookup arrays + feature_files = self.features.file.values + feature_tstart = self.features.t_start.values + + # Populate the YASA feature column with unknowns that we can replace by index with the correct value + YASA_FEATURE = np.array(['U' for ii in range(self.features.shape[0])]) + YASA_LOOKUP = self.yasa['yasa_consensus'].values + + # Step through the unique files + for ifile in tqdm(unique_files,total=unique_files.size): + + # Get the file indices + yasa_file_inds = (yasa_files==ifile) + feature_file_inds = (feature_files==ifile) + + # The yasa lookup was made for more than just the PNES project. So we can cull for files in the feature df + if feature_file_inds.sum() > 0: + + # Step through the time values + unique_tstart = np.unique(yasa_tstart[yasa_file_inds]) + for istart in unique_tstart: + + # Get the time indices + yasa_time_inds = (yasa_tstart==istart) + feature_time_inds = (feature_tstart>=istart)&(feature_tstart<(istart+30)) + + # Get the current prediction, if available + YASA_slice = YASA_LOOKUP[yasa_file_inds&yasa_time_inds] + + # Step through the possible outcomes for the yasa slice size + combined_inds = feature_file_inds&feature_time_inds + if combined_inds.sum() > 0: + if YASA_slice.size == 1: + YASA_FEATURE[combined_inds] = YASA_slice[0] + elif YASA_slice.size > 1: + raise Exception("Too many YASA values map to this feature. Check YASA generation.") + else: + pass + self.YASA_FEATURE = YASA_FEATURE + + def save_results(self): + + self.features['yasa_prediction'] = self.YASA_FEATURE + self.features.to_pickle(args.outpath) + +if __name__ == '__main__': + + # Command line options needed to obtain data. + parser = argparse.ArgumentParser() + parser.add_argument('--feature_path', type=str, help='Input path to the feature dataframe.') + parser.add_argument('--yasa_path', type=str, help='Input path to the yasa dataframe.') + parser.add_argument('--outpath', type=str, help='Output path for the feature dataframe with yasa added.') + parser.add_argument('--yasa_window_size', type=int, default=300, help='Input path to the yasa dataframe.') + args = parser.parse_args() + + # Clean up the yasa data + CLN = clean_yasa(args.yasa_path,args.yasa_window_size) + yasaDF = CLN.pipeline() + + # Merge the results + MRG = merge_yasa(args.feature_path,yasaDF,args.outpath) + MRG.pipeline() \ No newline at end of file diff --git a/scripts/codehub/components/posthoc/public/yasa_reformat.py b/scripts/codehub/components/posthoc/public/yasa_reformat.py new file mode 100644 index 00000000..bb7c9d71 --- /dev/null +++ b/scripts/codehub/components/posthoc/public/yasa_reformat.py @@ -0,0 +1,154 @@ +import numpy as np +import pandas as PD +from tqdm import tqdm +import multiprocessing +from sys import argv,exit +from collections import Counter + +class yasa_reformat: + + def __init__(self,DF,channels,multithread,ncpu): + + # Save input data to instance + self.channels = channels + self.multithread = multithread + self.ncpu = ncpu + self.bar_frmt = '{l_bar}{bar}| {n_fmt}/{total_fmt}|' + + # Make a yasa lookup df slice + self.YASA_DF = DF.loc[(DF.method == 'yasa_sleep_stage')&(DF.t_window==300)] + + # Save the data slice to update + self.DF = DF.drop(self.YASA_DF.index) + + def workflow(self): + + # Clean the lookup data + self.cleanup() + + # Get the indices for the lookup groups + self.YASA_lookup_dict = self.YASA_DF.groupby(['file','t_start','t_end']).indices + YASA_keys = list(self.YASA_lookup_dict.keys()) + + if self.multithread: + + # Make the initial subset proposal + subset_size = len(YASA_keys) // self.ncpu + list_subsets = [YASA_keys[i:i + subset_size] for i in range(0, subset_size*self.ncpu, subset_size)] + + # Handle leftovers + remainder = list_subsets[self.ncpu*subset_size:] + for idx,ival in enumerate(remainder): + list_subsets[idx] = np.concatenate((list_subsets[idx],np.array([ival]))) + + # Create processes and start workers + processes = [] + manager = multiprocessing.Manager() + return_dict = manager.dict() + for worker_id, data_chunk in enumerate(list_subsets): + process = multiprocessing.Process(target=self.reformat, args=(worker_id,data_chunk,return_dict)) + processes.append(process) + process.start() + + # Wait for all processes to complete + for process in processes: + process.join() + else: + return_dict = self.reformat(0,YASA_keys,{}) + + # Reformat the output + self.DF = PD.concat(return_dict.values()).reset_index(drop=True) + + return self.DF + + def cleanup(self): + + # Create the mapping for cleanup + new_map = {'N1':'S','N2':'S','N3':'S','R':'S','W':'W'} + + # Mapping function + def replace_stages(x): + for ikey in new_map.keys(): + x=x.replace(ikey,new_map[ikey]) + return x + + def consensus_stage(x): + time_list = x.split('|') + for idx,itime in enumerate(time_list): + vals = itime.split(',') + count = Counter(vals) + time_list[idx] = max(count, key=count.get) + return time_list + + # Loop over the lookup table for cleanup + cleaned_output = [] + for i_index in self.YASA_DF.index: + + # Get the full row slice so we can build the new output with modified times + original_slice = self.YASA_DF.loc[i_index] + + for ichannel in self.channels[:1]: + + # grab the entry to modify + input_yasa = original_slice[ichannel] + + # Update the mapping labels in the lookup table + mapped_yasa = replace_stages(input_yasa) + + # get the consensus + output_yasa = consensus_stage(mapped_yasa) + + # Get the time offsets for the new rows + dt = 30*np.arange(len(output_yasa)) + + # Make the new outputs + for idx,time_offset in enumerate(dt): + + # Update the entries with the new timing and sleep stage + new_slice = original_slice.copy() + new_slice['t_start'] += time_offset + new_slice['t_end'] = new_slice['t_start']+30 + new_slice[self.channels] = output_yasa[idx] + + # Store the results to the output array + cleaned_output.append(new_slice.values) + + # YASA dataframe creation and cleanup + self.YASA_DF = PD.DataFrame(cleaned_output,columns=self.YASA_DF.columns) + self.YASA_DF['t_start'] = self.YASA_DF['t_start'].astype('int16') + self.YASA_DF['t_end'] = self.YASA_DF['t_end'].astype('int16') + self.YASA_DF = self.YASA_DF.drop_duplicates(subset=['file','t_start'],keep='first').reset_index(drop=True) + + def reformat(self,worker_num,inkeys,return_dict): + + # Loop over the keys + output = [] + for ikey in tqdm(inkeys, desc='Applying YASA Restructure', total=len(inkeys),bar_format=self.bar_frmt, position=worker_num, leave=False, dynamic_ncols=True): + + # Get the value to propagate + newval = self.YASA_DF.loc[self.YASA_lookup_dict[ikey]][self.channels[0]].values[0] + + # Get the indices to update + base_slice = (self.DF.file==ikey[0])&(self.DF.t_start>=ikey[1])&(self.DF.t_start 0: + # Remove unwanted metadata + for dropkey in self.args.dropkeys: + metadata_handler.drop_metadata(self,dropkey) + # Save the results output_manager.save_features(self) if self.args.clean_save: output_manager.save_output_list(self) - + def feature_manager(self): """ Kick off function for feature extraction if requested by the user. @@ -111,17 +120,6 @@ def feature_manager(self): """ if not self.args.no_feature_flag: - if self.args.multithread: - self.barrier.wait() - - # Add a wait for proper progress bars - time.sleep(self.worker_number) - - # Clean up the screen - if self.worker_number == 0: - sys.stdout.write("\033[H") - sys.stdout.flush() - # In the case that all of the data is removed, skip the feature step if len(self.metadata.keys()) > 0: features.__init__(self) @@ -179,12 +177,12 @@ def parse_list(input_str): values = input_str.replace(',', ' ').split() return [float(value) for value in values] -def start_analysis(data_chunk,args,timestamp,worker_id,barrier): +def start_analysis(data_chunk,args,timestamp,worker_id,barrier,active_workers): """ Helper function to allow for easy multiprocessing initialization. """ - DM = data_manager(data_chunk,args,timestamp,worker_id,barrier) + DM = data_manager(data_chunk,args,timestamp,worker_id,barrier,active_workers) def merge_outputs(args,timestamp): """ @@ -200,28 +198,31 @@ def merge_outputs(args,timestamp): # Make a merged downcasted feature file if len(feature_files) > 0: for idx,ifile in enumerate(feature_files): - - # Read in the dataframe - iDF = PD.read_pickle(ifile) - - # Attempt downcasting as much as possible - for icol in iDF.columns: - itype = iDF[icol].dtype - try: - iDF[icol] = PD.to_numeric(iDF[icol],downcast='integer') - if iDF[icol].dtype == itype: - iDF[icol] = PD.to_numeric(iDF[icol],downcast='float') - except ValueError: - pass - - # Merge the outputs to one final file - if idx == 0: - output_DF = iDF.copy() - else: - output_DF = PD.concat((output_DF,iDF)) + + # In case the feature file was empty due to malformed data inputs. + if os.path.getsize(ifile) > 0: + # Read in the dataframe + iDF = PD.read_pickle(ifile) + + # Attempt downcasting as much as possible + for icol in iDF.columns: + itype = iDF[icol].dtype + try: + iDF[icol] = PD.to_numeric(iDF[icol],downcast='integer') + if iDF[icol].dtype == itype: + iDF[icol] = PD.to_numeric(iDF[icol],downcast='float') + except ValueError: + pass + + # Merge the outputs to one final file + if idx == 0: + output_DF = iDF.copy() + else: + output_DF = PD.concat((output_DF,iDF)) # Make the new output and only remove files once things were confirmed to work - output_DF.to_pickle(f"{args.outdir}/{timestamp}_features.pickle") + base_path = f"{args.outdir}/{timestamp}_features_" + output_DF.to_csv(f"{args.outdir}/{timestamp}_features.csv", index=False) for ifile in feature_files:os.remove(ifile) # Clean up the feature config files (if present) @@ -232,7 +233,9 @@ def merge_outputs(args,timestamp): # Clean up the meta files as needed if len(metadata_files) > 0: for idx,ifile in enumerate(metadata_files): - imeta = pickle.load(open(ifile,"rb")) + fp = open(ifile,"rb") + imeta = pickle.load(fp) + fp.close() if idx == 0: metadata = imeta.copy() else: @@ -242,7 +245,9 @@ def merge_outputs(args,timestamp): for ikey in newkeys: imeta[ikey+offset] = imeta.pop(ikey) metadata = {**metadata,**imeta} - pickle.dump(metadata,open(f"{args.outdir}/{timestamp}_meta.pickle","wb")) + fp = open(f"{args.outdir}/{timestamp}_meta.pickle","wb") + pickle.dump(metadata,fp) + fp.close() for ifile in metadata_files:os.remove(ifile) # Clean up the raw data files as needed @@ -253,16 +258,28 @@ def merge_outputs(args,timestamp): data = idata.copy() else: data.extend(idata) - pickle.dump(data,open(f"{args.outdir}/{timestamp}_data.pickle","wb")) + fp = open(f"{args.outdir}/{timestamp}_data.pickle","wb") + pickle.dump(data,fp) + fp.close() for ifile in data_list:os.remove(ifile) + # Find the first key with montage channels + ### SHOULD BE DEPRECIATED ONCE KEY DELETION IS QAed. + for ikey in metadata.keys(): + if 'montage_channels' in metadata[ikey].keys(): + mkey=ikey + break + + return output_DF,metadata[mkey]['montage_channels'],base_path def argument_handler(argument_dir='./',require_flag=True): # Read in the allowed arguments - raw_args = yaml.safe_load(open(f"{argument_dir}allowed_arguments.yaml","r")) + fp = open(f"{argument_dir}allowed_arguments.yaml","r") + raw_args = yaml.safe_load(fp) for key, inner_dict in raw_args.items(): globals()[key] = inner_dict + fp.close() # Make a useful help string for each keyword allowed_project_help = make_help_str(allowed_project_args) @@ -332,12 +349,21 @@ def argument_handler(argument_dir='./',require_flag=True): Also allows for skipping on subsequent loads. Default=outdir+excluded.txt (In Dev. Just gets initial load fails.)") output_group.add_argument("--nomerge", action='store_true', default=False, help="Do not merge the outputs from multiprocessing into one final set of files.") output_group.add_argument("--clean_save", action='store_true', default=False, help="Save cleaned up raw data. Mostly useful if you need time series and not just features.") + output_group.add_argument("--dropkeys", type=parse_list, default=['unmontaged_data'], help="Drop these keys from the output metadata. Useful if your workflow has some unique entries for processing") + + posthoc_group = parser.add_argument_group('Posthoc analysis Options') + posthoc_group.add_argument("--postprocess_only", action='store_true', default=False, help="Perform post processing only.") + posthoc_group.add_argument("--postprocess_feature_file", type=str, help="Path to feature dataframe to postprocess. Only for --postprocess_only use.") + posthoc_group.add_argument("--postprocess_meta_file", type=str, help="Path to feature metadata to postprocess. Only for --postprocess_only use.") + posthoc_group.add_argument("--yasa_cleanup", action='store_true', default=False, help="Don't restructure YASA data to have staging at other time part levels. Requires t=300 run.") + posthoc_group.add_argument("--nomarsh", action='store_false', default=True, help="Do not run posthoc analysis.") + misc_group = parser.add_argument_group('Misc Options') - misc_group.add_argument("--nfreq_window", type=int, default=8, help="Optional. Minimum number of samples required to send to preprocessing and feature extraction.") + misc_group.add_argument("--nfreq_window", type=int, default=2, help="Optional. Minimum number of samples required to send to preprocessing and feature extraction.") misc_group.add_argument("--input_str", type=str, help="Optional. If glob input, wildcard path. If csv/manual, filepath to input csv/raw data.") misc_group.add_argument("--silent", action='store_true', default=False, help="Silent mode.") - misc_group.add_argument("--debug", action='store_true', default=False, help="Debug mode. If set, does not save results. Useful for testing code.") + misc_group.add_argument("--debug", action='store_true', default=False, help="Debug mode. If set, does not save results. Useful for testing code. Cannot perform posthoc with this option on.") misc_group.add_argument("--trace", action='store_true', default=False, help="Trace data through the code. If selected, any user function that looks for trace can return extra information (i.e. intermediate calculations) to the metadata object.") args = parser.parse_args() @@ -375,11 +401,56 @@ def argument_handler(argument_dir='./',require_flag=True): type_info[action.dest] = str return args,(help_info,type_info,default_info,raw_args) +def postprocessing(args,feature_df,channels,base_path): + + # Perform post-hoc yasa cleanup + if args.yasa_cleanup: + YR = yasa_reformat(feature_df,channels,args.multithread,args.ncpu) + feature_df = YR.workflow() + newsaveflag = True + + # Perform post-hoc marsh analysis + if args.nomarsh: + MR = marsh_rejection(feature_df,channels,args.multithread,args.ncpu) + feature_df = MR.workflow() + newsaveflag = True + + outpath = base_path+'clean.csv' + feature_df.to_csv(outpath,index=False) + if __name__ == "__main__": # Get the argument handler args,_ = argument_handler() + # If performing post processing only, do so here and exit + if args.postprocess_only: + + # Get the features out + print("Reading in features...") + feature_df = PD.read_csv(args.postprocess_feature_file) + + # Get the channel info out + print("Reading in metadata....") + fp = open(args.postprocess_meta_file,'rb') + metadata = pickle.load(fp) + + # Find the first key with montage channels + ### SHOULD BE DEPRECIATED ONCE KEY DELETION IS QAed. + for ikey in metadata.keys(): + if 'montage_channels' in metadata[ikey].keys(): + mkey=ikey + break + channels = metadata[mkey]['montage_channels'] + fp.close() + + # Get the basepath + base_path = args.postprocess_feature_file.strip('.csv')+'_' + + # Run post processing + postprocessing(args,feature_df,channels,base_path) + exit() + # Make the output directory as needed if not os.path.exists(args.outdir) and not args.debug: print("Output directory does not exist. Make directory at %s (Y/y)?" %(args.outdir)) @@ -494,12 +565,14 @@ def argument_handler(argument_dir='./',require_flag=True): list_subsets[idx] = np.concatenate((list_subsets[idx],np.array([ival]))) # Create a barrier for synchronization - barrier = multiprocessing.Barrier(args.ncpu) + manager = multiprocessing.Manager() + active_workers = manager.Value("i", args.ncpu) + barrier = multiprocessing.Barrier(active_workers) # Create processes and start workers processes = [] for worker_id, data_chunk in enumerate(list_subsets): - process = multiprocessing.Process(target=start_analysis, args=(data_chunk,args,timestamp,worker_id,barrier)) + process = multiprocessing.Process(target=start_analysis, args=(data_chunk,args,timestamp,worker_id,barrier,active_workers)) processes.append(process) process.start() @@ -508,8 +581,12 @@ def argument_handler(argument_dir='./',require_flag=True): process.join() else: # Run a non parallel version. - start_analysis(input_parameters, args, timestamp, 0, None) + start_analysis(input_parameters, args, timestamp, 0, None, None) # Perform merge if requested - if not args.nomerge: - merge_outputs(args,timestamp) \ No newline at end of file + if not args.nomerge and not args.debug: + newsaveflag = False + feature_df,channels,base_path = merge_outputs(args,timestamp) + + postprocessing(args,feature_df,channels,base_path) + \ No newline at end of file diff --git a/scripts/codehub/utils/acquisition/BIDS b/scripts/codehub/utils/acquisition/BIDS new file mode 160000 index 00000000..fb6a0500 --- /dev/null +++ b/scripts/codehub/utils/acquisition/BIDS @@ -0,0 +1 @@ +Subproject commit fb6a0500c627836862edbd47bca73c4828c47e69 diff --git a/scripts/codehub/utils/acquisition/BIDS/EEG_BIDS.py b/scripts/codehub/utils/acquisition/BIDS/EEG_BIDS.py deleted file mode 100644 index 7d909502..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/EEG_BIDS.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import argparse -import pandas as PD -from sys import exit -from prettytable import PrettyTable,ALL - -# Locale import -from components.internal.BIDS_handler import * -from components.public.edf_handler import edf_handler -from components.public.iEEG_handler import ieeg_handler -from components.public.jar_handler import jar_handler - -# MNE is very chatty. Turn off some warnings. -import warnings -warnings.simplefilter(action='ignore', category=FutureWarning) -warnings.simplefilter(action='ignore', category=RuntimeWarning) - -def print_examples(): - - # Read in the sample time csv - script_dir = '/'.join(os.path.abspath(__file__).split('/')[:-1]) - example_csv = PD.read_csv(f"{script_dir}/samples/sample_times.csv") - - # Initialize a pretty table for easy reading - table = PrettyTable(hrules=ALL) - table.field_names = example_csv.columns - for irow in example_csv.index: - iDF = example_csv.loc[irow] - formatted_row = [iDF[icol] for icol in example_csv.columns] - table.add_row(formatted_row) - table.align['path'] = 'l' - print("Sample inputs that explicitly set the download times.") - print(table) - - # Read in the sample annotation csv - script_dir = '/'.join(os.path.abspath(__file__).split('/')[:-1]) - example_csv = PD.read_csv(f"{script_dir}/samples/sample_annot.csv") - - # Initialize a pretty table for easy reading - table = PrettyTable(hrules=ALL) - table.field_names = example_csv.columns - for irow in example_csv.index: - iDF = example_csv.loc[irow] - formatted_row = [iDF[icol] for icol in example_csv.columns] - table.add_row(formatted_row) - table.align['path'] = 'l' - print("Sample inputs that use annotations.") - print(table) - -def ieeg(args): - IH = ieeg_handler(args) - IH.workflow() - -def raw_edf(args): - EH = edf_handler(args) - EH.workflow() - -def read_jar(args): - JH = jar_handler(args) - JH.workflow() - -if __name__ == '__main__': - - # Command line options needed to obtain data. - parser = argparse.ArgumentParser(description="Make an EEG BIDS dataset from various sources. Also manages helper scripts for the CNT.") - - source_group = parser.add_mutually_exclusive_group() - source_group.add_argument("--ieeg", action='store_true', default=False, help="iEEG data pull.") - source_group.add_argument("--edf", action='store_true', default=False, help="Raw edf data pull.") - source_group.add_argument("--jar", action='store_true', default=False, help="Convert jar file to EDF Bids.") - - data_group = parser.add_argument_group('Data configuration options') - data_group.add_argument("--bids_root", type=str, required=True, default=None, help="Output directory to store BIDS data.") - data_group.add_argument("--data_record", type=str, default='subject_map.csv', help="Filename for data record. Outputs to bids_root.") - - ieeg_group = parser.add_argument_group('iEEG connection options') - ieeg_group.add_argument("--username", type=str, help="Username for iEEG.org.") - ieeg_group.add_argument("--input_csv", type=str, help="CSV file with the relevant filenames, start times, durations, and keywords. For an example, use the --example_input flag.") - ieeg_group.add_argument("--dataset", type=str, help="iEEG.org Dataset name. Useful if downloading just one dataset,") - ieeg_group.add_argument("--start", type=float, help="Start time of clip in usec. Useful if downloading just one dataset,") - ieeg_group.add_argument("--duration", type=float, help="Duration of clip in usec. Useful if downloading just one dataset,") - ieeg_group.add_argument("--failure_file", default='./failed_ieeg_calls.csv', type=str, help="CSV containing failed iEEG calls.") - ieeg_group.add_argument("--annotations", action='store_true', default=False, help="Download by annotation layers. Defaults to scalp layer names.") - ieeg_group.add_argument("--time_layer", type=str, default='EEG clip times', help="Annotation layer name for clip times.") - ieeg_group.add_argument("--annot_layer", type=str, default='Imported Natus ENT annotations', help="Annotation layer name for annotation strings.") - ieeg_group.add_argument("--timeout", type=int, default=60, help="Timeout interval for ieeg.org calls") - - bids_group = parser.add_argument_group('BIDS keyword options') - bids_group.add_argument("--uid_number", type=str, help="Unique identifier string to use when not referencing a input_csv file. Only used for single data pulls. Can be used to map the same patient across different datasets to something like an MRN behind clinical firewalls.") - bids_group.add_argument("--subject_number", type=str, help="Subject string to use when not referencing a input_csv file. Only used for single data pulls.") - bids_group.add_argument("--session", type=str, help="Session string to use when not referencing a input_csv file. Only used for single data pulls.") - bids_group.add_argument("--run", type=str, help="Run string to use when not referencing a input_csv file. Only used for single data pulls.") - bids_group.add_argument("--task", type=str, default='rest', help="Task string to use when not referencing a input_csv file value. Used to populate all entries if not explicitly set.") - - multithread_group = parser.add_argument_group('Multithreading Options') - multithread_group.add_argument("--multithread", action='store_true', default=False, help="Multithreaded download.") - multithread_group.add_argument("--ncpu", default=1, type=int, help="Number of CPUs to use when downloading.") - - misc_group = parser.add_argument_group('Misc options') - misc_group.add_argument("--include_annotation", action='store_true', default=False, help="If downloading by time, include annotations/events file. Defaults to scalp layer names.") - misc_group.add_argument("--target", type=str, help="Target to associate with the data. (i.e. PNES/EPILEPSY/etc.)") - misc_group.add_argument("--example_input", action='store_true', default=False, help="Show example input file structure.") - misc_group.add_argument("--backend", type=str, default='MNE', help="Backend data handler.") - misc_group.add_argument("--ch_type", default=None, type=str, help="Manual set of channel type if not matched by known patterns. (i.e. 'seeg' for intracranial data)") - misc_group.add_argument("--debug", action='store_true', default=False, help="Debug tools. Mainly removes files after generation.") - args = parser.parse_args() - - # If the user wants an example input file, print it then close application - if args.example_input: - print_examples() - exit() - - # Basic clean-up - if args.bids_root[-1] != '/': args.bids_root+='/' - - # Main Logic - if args.ieeg: - ieeg(args) - elif args.edf: - raw_edf(args) - elif args.jar: - read_jar(args) - else: - print("Please select at least one source from the source group. (--help for all options.)") diff --git a/scripts/codehub/utils/acquisition/BIDS/README.md b/scripts/codehub/utils/acquisition/BIDS/README.md deleted file mode 100644 index 42b8a9e8..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# EEG BIDS Creation - -EEG Bids is a package designed to convert timeseries data into BIDS-compliant datasets. As the push for standardized datasets grows, harmonizing how we collect and store data has become increasingly important. - -## Features - -Currently, the package supports: - -- Pulling data from iEEG.org -- Converting raw EDF files to BIDS format - -We aim to make it easy to add new data pull methods by using an observer coding style, allowing new code to integrate with just a few lines. For more details, refer to the contribution section. - -Additionally, the package generates various sidecar files used by other components of the CNT codehub for a range of tasks. - -## Files - -### `EEG_BIDS.py` -This is the user-interface portion of the code. You can access detailed usage instructions by running: -```bash -python EEG_BIDS.py --help -``` - -## Folders - -### `modules` -This folder contains the backend code that makes up EEG BIDS, providing functionality to convert and handle timeseries data. - -### `samples` -Includes sample CLI calls and input files to help you get started using the package. - -## Installation - -EEG_BIDS uses a number of specific packages, and it can be time consuming to build an environment just for the purposes of this script. We recommend starting with the directions for installing the cnt-codehub python environment found [here](https://github.com/penn-cnt/CNT-codehub/blob/main/README.md). You can then modify the cnt_codehub.yaml file as needed to match your needs. - -## Usage Examples - -For a few example use cases, see [here](https://github.com/penn-cnt/CNT-codehub/blob/main/scripts/codehub/utils/acquisition/BIDS/samples/sample_cmds.txt) - -## Contributing -(In Progress) - -If adding support for new data inputs, you can make a new object in components.public that reads in your raw data and generates the proper bids keywords. - -Once you have read in your data and generated keywords, you just need to alert the observers to generate the actual backend data. You can do this by diff --git a/scripts/codehub/utils/acquisition/BIDS/components/internal/BIDS_handler.py b/scripts/codehub/utils/acquisition/BIDS/components/internal/BIDS_handler.py deleted file mode 100644 index 06e2a10f..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/components/internal/BIDS_handler.py +++ /dev/null @@ -1,167 +0,0 @@ -import time -import getpass -import pickle -import numpy as np -import pandas as PD -from mne_bids import BIDSPath,write_raw_bids - -# Local Imports -from components.internal.observer_handler import * - -class BIDS_observer(Observer): - - def listen_metadata(self): - - def clean_value(key,value): - - # Track the typing - dtype = 'str' - - # Try a float conversion first - try: - newval = float(value) - dtype = 'float' - except: - newval = value - - # Try an integer conversion - if dtype == 'float': - if newval % 1 == 0: - newval = int(newval) - dtype = 'int' - - # Clean up the value as much as possible - if dtype == 'float': - newval = f"{newval:06.1f}" - elif dtype == 'int': - newval = f"{newval:04d}" - else: - if key in ['start','duration']: - if value == None: - newval = "None" - return newval - - # Define the required BIDS keywords - BIDS_keys = ['root','datatype','session','subject','run','task','filename','start','duration','uid'] - - # Populate the bids dictionary with the new values - for ikey,ivalue in self.keywords.items(): - if ikey in BIDS_keys: - self.BIDS_keywords[ikey]=clean_value(ikey,ivalue) - - # If all keywords are set, send information to the BIDS handler. - if all(self.BIDS_keywords.values()): - - # Update the bids path - self.BH.update_path(self.BIDS_keywords) - - if self.args.include_annotation or self.args.annotations: - # Update the events - self.BH.create_events(self.keywords['filename'],int(self.keywords['run']), - self.keywords['fs'],self.annotations) - else: - print(f"Unable to create BIDS keywords for file: {self.keywords['filename']}.") - print(f"{self.BIDS_keywords}") - - -class BIDS_handler: - - def __init__(self): - pass - - def update_path(self,keywords): - """ - Update the bidspath. - """ - - self.current_keywords = keywords - self.bids_path = BIDSPath(root=keywords['root'], - datatype=keywords['datatype'], - session=keywords['session'], - subject=keywords['subject'], - run=keywords['run'], - task=keywords['task']) - - self.target_path = str(self.bids_path.copy()).rstrip('.edf')+'_targets.pickle' - - def create_events(self,ifile,run,fs,annotations): - - # Make the events file and save the results - events = [] - self.alldesc = [] - self.event_mapping = {} - for ii,iannot in enumerate(annotations[ifile][run].keys()): - - # Get the raw annotation and the index - desc = annotations[ifile][run][iannot] - index = (1e-6*iannot)*fs - - # Make the required mne event mapper - self.event_mapping[str(iannot)] = ii - - # Store the results - events.append([index,0,ii]) - self.alldesc.append(desc) - self.events = np.array(events) - - def save_targets(self,target): - - # Store the targets - target_dict = {'target':target,'annotation':'||'.join(self.alldesc)} - pickle.dump(target_dict,open(self.target_path,"wb")) - - def save_data_w_events(self, raw, debug=False): - """ - Save EDF data into a BIDS structure. With events. - - Args: - raw (_type_): MNE Raw objext. - debug (bool, optional): Debug flag. Acts for verbosity. - - Returns: - _type_: _description_ - """ - - # Save the bids data - try: - write_raw_bids(bids_path=self.bids_path, raw=raw, events_data=self.events,event_id=self.event_mapping, allow_preload=True, format='EDF', overwrite=True, verbose=False) - return True - except Exception as e: - if debug: - print(f"Write error: {e}") - return False - - def save_data_wo_events(self, raw, debug=False): - """ - Save EDF data into a BIDS structure. - - Args: - raw (_type_): MNE Raw objext. - debug (bool, optional): Debug flag. Acts for verbosity. - - Returns: - _type_: _description_ - """ - - # Save the bids data - try: - write_raw_bids(bids_path=self.bids_path, raw=raw, allow_preload=True, format='EDF',verbose=False) - return True - except Exception as e: - if debug: - print(f"Write error: {e}") - return False - - def make_records(self,source): - - self.current_record = PD.DataFrame([self.current_keywords['filename']],columns=['orig_filename']) - self.current_record['source'] = source - self.current_record['creator'] = getpass.getuser() - self.current_record['gendate'] = time.strftime('%d-%m-%y', time.localtime()) - self.current_record['uid'] = self.current_keywords['uid'] - self.current_record['subject_number'] = self.current_keywords['subject'] - self.current_record['session_number'] = self.current_keywords['session'] - self.current_record['run_number'] = self.current_keywords['run'] - self.current_record['start_sec'] = self.current_keywords['start'] - self.current_record['duration_sec'] = self.current_keywords['duration'] - return self.current_record \ No newline at end of file diff --git a/scripts/codehub/utils/acquisition/BIDS/components/internal/data_backends.py b/scripts/codehub/utils/acquisition/BIDS/components/internal/data_backends.py deleted file mode 100644 index c5aa8e9e..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/components/internal/data_backends.py +++ /dev/null @@ -1,128 +0,0 @@ -import re -import mne -import numpy as np -import pandas as PD -from mne.io.constants import FIFF - -# Local Imports -from components.internal.observer_handler import * - -def return_backend(user_request='MNE'): - if user_request == 'MNE': - return MNE_handler() - -class backend_observer(Observer): - - def listen_data(self): - idata,itype = self.backend.workflow(self.args,self.data,self.channels,self.fs) - self.data_list.append(idata) - self.type_list.append(itype) - -class MNE_handler: - - def __init__(self): - pass - - def workflow(self,args,data,channels,fs): - - # Save the inputs to class instance - self.args = args - self.indata = data - self.channels = channels - self.fs = fs - - # Prepare the data according to the backend - try: - passflag = self.get_channel_type() - if passflag: - self.make_info() - self.make_raw() - else: - self.irow = None - self.bids_datatype = None - except Exception as e: - if self.args.debug: - print(f"Load error {e}") - - # Return raw to the list of raws being tracked by the Subject class - return self.iraw,self.bids_datatype - - def make_raw(self): - self.iraw = mne.io.RawArray(self.indata.T, self.data_info, verbose=False) - self.iraw.set_channel_types(self.channel_types.type) - - def make_info(self): - self.data_info = mne.create_info(ch_names=list(self.channels), sfreq=self.fs, verbose=False) - for idx,ichannel in enumerate(self.channels): - if self.channel_types.loc[ichannel]['type'] in ['seeg','eeg']: - self.data_info['chs'][idx]['unit'] = FIFF.FIFF_UNIT_V - - def get_channel_type(self, threshold=15): - - # Define the expression that gets lead info - regex = re.compile(r"(\D+)(\d+)") - - # Get the outputs of each channel - try: - channel_expressions = [regex.match(ichannel) for ichannel in self.channels] - - # Make the channel types - self.channel_types = [] - for (i, iexpression), channel in zip(enumerate(channel_expressions), self.channels): - if iexpression == None: - if channel.lower() in ['fz','cz']: - self.channel_types.append('eeg') - else: - self.channel_types.append('misc') - else: - lead = iexpression.group(1) - contact = int(iexpression.group(2)) - if lead.lower() in ["ecg", "ekg"]: - self.channel_types.append('ecg') - elif lead.lower() in ['c', 'cz', 'cz', 'f', 'fp', 'fp', 'fz', 'fz', 'o', 'p', 'pz', 'pz', 't']: - self.channel_types.append('eeg') - elif "NVC" in iexpression.group(0): # NeuroVista data - self.channel_types.append('eeg') - self.channels[i] = f"{channel[-2:]}" - elif lead.lower() in ['a']: - self.channel_types.append('misc') - else: - self.channel_types.append(1) - - # Do some final clean ups based on number of leads - lead_sum = 0 - for ival in self.channel_types: - if isinstance(ival,int):lead_sum+=1 - if self.args.ch_type == None: - if lead_sum > threshold: - remaining_leads = 'ecog' - else: - remaining_leads = 'seeg' - else: - remaining_leads = self.args.ch_type - for idx,ival in enumerate(self.channel_types): - if isinstance(ival,int):self.channel_types[idx] = remaining_leads - self.channel_types = np.array(self.channel_types) - except: - if self.args.ch_type != None: - self.channel_types = np.array([self.args.ch_type for ichannel in self.channels]) - else: - return False - - # Make the dictionary for mne - self.channel_types = PD.DataFrame(self.channel_types.reshape((-1,1)),index=self.channels,columns=["type"]) - - # Get the best guess datatype to send to bids writer - raw_datatype = self.channel_types['type'].mode().values[0] - - # perform some common mappings to the bids keywords - if raw_datatype == 'ecog': - datatype = 'ieeg' - elif raw_datatype == 'seeg': - datatype = 'ieeg' - else: - datatype = raw_datatype - - # Store the data type to use for write out - self.bids_datatype = datatype - return True diff --git a/scripts/codehub/utils/acquisition/BIDS/components/internal/exception_handler.py b/scripts/codehub/utils/acquisition/BIDS/components/internal/exception_handler.py deleted file mode 100644 index de6e9303..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/components/internal/exception_handler.py +++ /dev/null @@ -1,71 +0,0 @@ -# API timeout class -import signal -class TimeoutException(Exception): - pass - -class Timeout: - """ - Manage timeouts to the iEEG.org API call. It can go stale, and sit for long periods of time otherwise. - """ - - def __init__(self, seconds=1, multiflag=False, error_message='Function call timed out'): - self.seconds = seconds - self.error_message = error_message - self.multiflag = multiflag - - def handle_timeout(self, signum, frame): - raise TimeoutException(self.error_message) - - def __enter__(self): - if not self.multiflag: - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - else: - pass - - def __exit__(self, exc_type, exc_value, traceback): - if not self.multiflag: - signal.alarm(0) - else: - pass - -class DataExists: - """ - Checks data records for existing data. - """ - - def __init__(self,data_record): - self.data_record = data_record - self.record_checkfile = '' - self.record_start = -1 - self.record_duration = -1 - - def check_default_records(self,checkfile,checkstart,checkduration): - """ - Check the data record for data that matched the current query. - - Args: - checkfile (_type_): Current ieeg.org filename. - checkstart (_type_): Current iEEG.org start time in seconds. - checkduration (_type_): Current iEEG.org duration in seconds. - - Returns: - bool: True if no data found in record. False is found. - """ - - # Update file mask as needed - if checkfile != self.record_checkfile: - self.record_checkfile = checkfile - self.record_file_mask = (self.data_record['orig_filename'].values==checkfile) - if checkstart != self.record_start: - self.record_start = checkstart - self.record_start_mask = (self.data_record['start_sec'].values==checkstart) - if checkduration != self.record_duration: - self.record_duration = checkduration - self.record_duration_mask = (self.data_record['duration_sec'].values==checkduration) - - # Get the combined mask - mask = self.record_file_mask*self.record_start_mask*self.record_duration_mask - - # Check for any existing records - return not(any(mask)) diff --git a/scripts/codehub/utils/acquisition/BIDS/components/internal/observer_handler.py b/scripts/codehub/utils/acquisition/BIDS/components/internal/observer_handler.py deleted file mode 100644 index d40ff361..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/components/internal/observer_handler.py +++ /dev/null @@ -1,43 +0,0 @@ -from abc import ABC, abstractmethod - -class Subject: - """ - Subject class to allow the BIDS handler to listen for new keywords. - """ - def add_meta_observer(self, observer): - if observer not in self._meta_observers: - self._meta_observers.append(observer) - - def add_data_observer(self, observer): - if observer not in self._data_observers: - self._data_observers.append(observer) - - def notify_metadata_observers(self): - for observer in self._meta_observers: - observer.listen_metadata(self) - - def notify_data_observers(self): - for observer in self._data_observers: - observer.listen_data(self) - -class Observer(ABC): - """ - Observer class to allow the BIDS handler to listen for new keywords. - - Args: - ABC (object): Abstract Base Class object. Enforces the use of abstractmethod to prevent accidental access to listen_keyword without matching - class in the observer. - - Raises: - NotImplementedError: Error if the observing class doesn't have the right class object. - """ - - # Listener for BIDS keyword generation to create the correct pathing. - @abstractmethod - def listen_metadata(self): - raise NotImplementedError("Subclass must implement abstract method") - - # Listener for backend data work - @abstractmethod - def listen_data(self): - raise NotImplementedError("Subclass must implement abstract method") \ No newline at end of file diff --git a/scripts/codehub/utils/acquisition/BIDS/components/public/edf_handler.py b/scripts/codehub/utils/acquisition/BIDS/components/public/edf_handler.py deleted file mode 100644 index eb78039b..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/components/public/edf_handler.py +++ /dev/null @@ -1,278 +0,0 @@ -import os -import time -import getpass -from mne.io import read_raw_edf - -# Local import -from components.internal.BIDS_handler import * -from components.internal.observer_handler import * -from components.internal.exception_handler import * -from components.internal.data_backends import * - -class edf_handler(Subject): - - def __init__(self,args): - - # Save the input objects - self.args = self.input_exceptions(args) - - # Create the object pointers - self.BH = BIDS_handler() - self.backend = return_backend(args.backend) - - # Get the data record - self.get_data_record() - - # Create objects that interact with observers - self.data_list = [] - self.type_list = [] - self.BIDS_keywords = {'root':self.args.bids_root,'datatype':None,'session':None,'subject':None,'run':None,'task':None} - - def workflow(self): - """ - Run a workflow that downloads data from iEEG.org, creates the correct objects in memory, and saves it to BIDS format. - """ - - # Attach observers - self.attach_objects() - - # Determine how to save the data - self.get_inputs() - - # Begin downloading the data - self.load_data_manager() - - # Save the data - self.save_data() - - # Save the data record - self.new_data_record = self.new_data_record.sort_values(by=['subject_number','session_number','run_number']) - self.new_data_record.to_csv(self.data_record_path,index=False) - - # Remove if debugging - if self.args.debug: - os.system(f"rm -r {self.args.bids_root}*") - - def attach_objects(self): - """ - Attach observers here so we can have each multiprocessor see the pointers correctly. - """ - - # Create the observer objects - self._meta_observers = [] - self._data_observers = [] - - # Attach observers - self.add_meta_observer(BIDS_observer) - self.add_data_observer(backend_observer) - - def get_inputs(self, multiflag=False, multiinds=None): - """ - Create the input objects that track what files and times to download, and any relevant keywords for the BIDS process. - For single core pulls, has more flexibility to set parameters. For multicore, we restrict it to a pre-built input_args. - """ - - # Make sure we have some required inputs - - - # Check for an input csv to manually set entries - if self.args.input_csv != None: - - # Read in the input data - input_args = PD.read_csv(self.args.input_csv) - - # Pull out the relevant data pointers for required columns. - self.edf_files = list(input_args['orig_filename'].values) - - # Get the unique identifier if provided - if 'start' in input_args.columns: - self.start_times=list(input_args['start'].values) - else: - self.start_times=[self.args.start for idx in range(input_args.shape[0])] - - # Get the unique identifier if provided - if 'duration' in input_args.columns: - self.durations=list(input_args['duration'].values) - else: - self.durations=[self.args.duration for idx in range(input_args.shape[0])] - - # Get the unique identifier if provided - if 'uid' in input_args.columns: - self.uid_list=list(input_args['uid'].values) - else: - self.uid_list=[self.args.uid for idx in range(input_args.shape[0])] - - # Get the subejct number if provided - if 'subject_number' in input_args.columns: - self.subject_list=list(input_args['subject_number'].values) - else: - self.subject_list=[self.args.subject_number for idx in range(input_args.shape[0])] - - # Get the session number if provided - if 'session_number' in input_args.columns: - self.session_list=list(input_args['session_number'].values) - else: - self.session_list=[self.args.session for idx in range(input_args.shape[0])] - - # Get the run number if provided - if 'run_number' in input_args.columns: - self.run_list=list(input_args['run_number'].values) - else: - self.run_list=[self.args.run for idx in range(input_args.shape[0])] - - # Get the task if provided - if 'task' in input_args.columns: - self.task_list=list(input_args['task'].values) - - # Get the target if provided - if 'target' in input_args.columns: - self.target_list = list(input_args['target'].values) - else: - # Get the required information if we don't have an input csv - self.edf_files = [self.args.dataset] - self.start_times = [self.args.start] - self.durations = [self.args.duration] - - # Get the information that can be inferred - if self.args.uid_number != None: - self.uid_list = [self.args.uid_number] - - if self.args.subject_number != None: - self.subject_list = [self.args.subject_number] - - if self.args.session != None: - self.session_list = [self.args.session] - - if self.args.run != None: - self.run_list = [self.args.run] - - if self.args.task != None: - self.task_list = [self.args.task] - - if self.args.target != None: - self.target_list = [self.args.target] - - def get_data_record(self): - """ - Get the data record. This is typically 'subject_map.csv' and is used to locate data and prevent duplicate downloads. - """ - - # Get the proposed data record - self.data_record_path = self.args.bids_root+self.args.data_record - - # Check if the file exists - if os.path.exists(self.data_record_path): - self.data_record = PD.read_csv(self.data_record_path) - else: - self.data_record = PD.DataFrame(columns=['orig_filename','source','creator','gendate','uid','subject_number','session_number','run_number','start_sec','duration_sec']) - - def load_data_manager(self): - """ - Loop over the ieeg file list and download data. If annotations, does a first pass to get annotation layers and times, then downloads. - """ - - # Load the data exists exception handler so we can avoid already downloaded data. - DE = DataExists(self.data_record) - - # Loop over the requested data - for idx in range(len(self.edf_files)): - - # Check if we have a specific set of times for this file - try: - istart = self.start_times[idx] - iduration = self.durations[idx] - except TypeError: - istart = None - iduration = None - - if DE.check_default_records(self.edf_files[idx],istart,iduration): - self.load_data(self.edf_files[idx]) - - # If successful, notify data observer. Else, add a skip - if self.success_flag: - self.notify_data_observers() - else: - self.data_list.append(None) - else: - print(f"Skipping {self.edf_files[idx]}.") - self.data_list.append(None) - - def load_data(self,infile): - try: - raw = read_raw_edf(infile,verbose=False) - self.data = raw.get_data().T - self.channels = raw.ch_names - self.fs = raw.info.get('sfreq') - self.success_flag = True - except Exception as e: - self.success_flag = False - if self.args.debug: - print(f"Load error {e}") - - def save_data(self): - """ - Notify the BIDS code about data updates and save the results when possible. - """ - - # Loop over the data, assign keys, and save - self.new_data_record = self.data_record.copy() - for idx,iraw in enumerate(self.data_list): - if iraw != None: - - # Define start time and duration. Can differ for different filetypes - # May not exist for a raw edf transfer, so add a None outcome. - try: - istart = self.start_times[idx] - iduration = self.durations[idx] - except TypeError: - istart = None - iduration = None - - # Update keywords - self.keywords = {'filename':self.edf_files[idx],'root':self.args.bids_root,'datatype':self.type_list[idx], - 'session':self.session_list[idx],'subject':self.subject_list[idx],'run':self.run_list[idx], - 'task':'rest','fs':iraw.info["sfreq"],'start':istart,'duration':iduration,'uid':self.uid_list[idx]} - self.notify_metadata_observers() - - # Save the data without events until a future release - print(f"Converting {self.edf_files[idx]} to BIDS...") - success_flag = self.BH.save_data_wo_events(iraw, debug=self.args.debug) - - # If the data wrote out correctly, update the data record - if success_flag: - # Save the target info - try: - self.BH.save_targets(self.target_list[idx]) - except: - pass - - # Add the datarow to the records - self.current_record = self.BH.make_records('edf_file') - self.new_data_record = PD.concat((self.new_data_record,self.current_record)) - - ############################### - ###### Custom exceptions ###### - ############################### - - def input_exceptions(self,args): - - # Input csv exceptions - if args.input_csv: - input_cols = PD.read_csv(args.input_csv, index_col=0, nrows=0).columns.tolist() - if 'subject_number' not in input_cols: - raise Exception("Please provide a --subject_number to the input csv.") - if 'session_number' not in input_cols: - raise Exception("Please provide a --session_number to the input csv.") - if 'run_number' not in input_cols: - raise Exception("Please provide a --run_number to the input csv.") - if 'uid_number' not in input_cols: - raise Exception("Please provide a --uid_number to the input csv.") - else: - if args.subject_number == None: - raise Exception("Please provide a --subject_number input to the command line.") - if args.uid_number == None: - raise Exception("Please provide a --uid_number input to the command line.") - if args.session == None: args.session=1 - if args.run == None: args.run=1 - - return args \ No newline at end of file diff --git a/scripts/codehub/utils/acquisition/BIDS/components/public/iEEG_handler.py b/scripts/codehub/utils/acquisition/BIDS/components/public/iEEG_handler.py deleted file mode 100644 index b96c042c..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/components/public/iEEG_handler.py +++ /dev/null @@ -1,634 +0,0 @@ -import re -import os -import time -import uuid -import getpass -import numpy as np -import multiprocessing -from time import sleep -from typing import List -import ieeg.ieeg_api as IIA -from ieeg.auth import Session -from requests.exceptions import ReadTimeout as RTIMEOUT - -# Local import -from components.internal.BIDS_handler import * -from components.internal.observer_handler import * -from components.internal.exception_handler import * -from components.internal.data_backends import * - -class ieeg_handler(Subject): - - def __init__(self,args): - - # Save the input objects - self.args = args - - # Create the object pointers - self.BH = BIDS_handler() - self.backend = return_backend(args.backend) - - # Get the data record - self.get_data_record() - - # Create objects that interact with observers - self.data_list = [] - self.type_list = [] - self.BIDS_keywords = {'root':self.args.bids_root,'datatype':None,'session':None,'subject':None,'run':None,'task':None} - - def workflow(self): - """ - Run a workflow that downloads data from iEEG.org, creates the correct objects in memory, and saves it to BIDS format. - """ - - # Get credentials - self.get_password() - - # Manage mutltithreading requests - if not self.args.multithread: - - # Make a unique id for this core - self.unique_id = uuid.uuid4() - - # Attach observers - self.attach_objects() - - # Determine what files to download and to where - self.get_inputs() - - # Begin downloading the data - self.download_data_manager() - - # Save the data - self.save_data() - - # Save the data record - self.new_data_record = self.new_data_record.sort_values(by=['subject_number','session_number','run_number']) - self.new_data_record.to_csv(self.data_record_path,index=False) - else: - self.multipull_manager() - - # Remove if debugging - if self.args.debug: - os.system(f"rm -r {self.args.bids_root}*") - - def attach_objects(self): - """ - Attach observers here so we can have each multiprocessor see the pointers correctly. - """ - - # Create the observer objects - self._meta_observers = [] - self._data_observers = [] - - # Attach observers - self.add_meta_observer(BIDS_observer) - self.add_data_observer(backend_observer) - - ######################################### - ####### Multiprocessing functions ####### - ######################################### - - def multipull_manager(self): - - # Make sure we have an input csv for multithreading. By default, this should be used for large data pulls. - if self.args.input_csv == None: - raise Exception("Please provide an input_csv with multiple files if using multithreading. For single files, you can just turn off --multithread.") - - # Read in the input csv - input_args = PD.read_csv(self.args.input_csv) - if input_args.shape[0] == 1: - error_msg = "--multithread requires the number of files to be greater than the requested cores." - error_msg += " For single files, you can just turn off --multithread. Otherwise adjust --ncpu." - raise Exception(error_msg) - - # Read in the load data for us to figure out best load strategy - input_args = PD.read_csv(self.args.input_csv) - - # Add a sempahore to allow orderly file access - semaphore = multiprocessing.Semaphore(1) - - # Create a load list for each cpu - all_inds = np.arange(input_args.shape[0]) - split_arrays = np.array_split(all_inds, self.args.ncpu) - - # Start the multipull processing - processes = [] - for data_chunk in split_arrays: - process = multiprocessing.Process(target=self.multipull, args=(data_chunk,semaphore)) - processes.append(process) - process.start() - - # Wait for all processes to complete - for process in processes: - process.join() - - def multipull(self,multiind,semaphore,writeout_freq=10): - """ - Handles a multithread data pull. - - Args: - multiind (_type_): _description_ - semaphore (_type_): _description_ - writeout_freq (int, optional): How many ieeg calls to make before saving out to disk. Defaults to 10. - """ - - # Make a unique id for this core - self.unique_id = uuid.uuid4() - - # Attach observers - self.attach_objects() - - # Loop over the writeout frequency - niter = np.ceil(multiind.size/writeout_freq).astype('int') - for iwrite in range(niter): - - # Get the current indice slice - index_slice = multiind[iwrite*writeout_freq:(iwrite+1)*writeout_freq] - - # Determine what files to download and to where - self.get_inputs(multiflag=True,multiinds=index_slice) - - # Begin downloading the data - self.download_data_manager() - - # Save the data - self.save_data() - - with semaphore: - self.get_data_record() - self.new_data_record = PD.concat((self.data_record,self.new_data_record)) - self.new_data_record = self.new_data_record.drop_duplicates() - self.new_data_record = self.new_data_record.sort_values(by=['subject_number','session_number','run_number']) - self.new_data_record.to_csv(self.data_record_path,index=False) - - ############################## - ####### iEEG functions ####### - ############################## - - def get_password(self): - """ - Get password for iEEG.org via Keyring or user input. - """ - - # Determine the method to get passwords. Not all systems can use a keyring easily. - try: - import keyring - - # Get the password from the user or the keyring. If needed, add to keyring. - self.password = keyring.get_password("eeg_bids_ieeg_pass", self.args.username) - if self.password == None: - self.password = getpass.getpass("Enter your password. (This will be stored to your keyring): ") - keyring.set_password("eeg_bids_ieeg_pass", self.args.username, self.password) - except: - self.password = getpass.getpass("Enter your password: ") - - def get_data_record(self): - """ - Get the data record. This is typically 'subject_map.csv' and is used to locate data and prevent duplicate downloads. - """ - - # Get the proposed data record - self.data_record_path = self.args.bids_root+self.args.data_record - - # Check if the file exists - if os.path.exists(self.data_record_path): - self.data_record = PD.read_csv(self.data_record_path) - else: - self.data_record = PD.DataFrame(columns=['orig_filename','source','creator','gendate','uid','subject_number','session_number','run_number','start_sec','duration_sec']) - - def get_inputs(self, multiflag=False, multiinds=None): - """ - Create the input objects that track what files and times to download, and any relevant keywords for the BIDS process. - For single core pulls, has more flexibility to set parameters. For multicore, we restrict it to a pre-built input_args. - """ - - # Check for an input csv to manually set entries - if self.args.input_csv != None: - - # Read in the input data - input_args = PD.read_csv(self.args.input_csv) - - # Check for any exceptions in the inputs - input_args = self.input_exceptions(input_args) - - # Grab the relevant indices if using multithreading - if multiflag: - input_args = input_args.iloc[multiinds].reset_index(drop=True) - - # Pull out the relevant data pointers for required columns. - self.ieeg_files = list(input_args['orig_filename'].values) - if not self.args.annotations: - self.start_times = list(input_args['start'].values) - self.durations = list(input_args['duration'].values) - - # Get candidate keywords for missing columns - self.ieegfile_to_keys() - - # Get the unique identifier if provided - if 'uid' in input_args.columns: - self.uid_list=list(input_args['uid'].values) - - # Get the subejct number if provided - if 'subject_number' in input_args.columns: - self.subject_list=list(input_args['subject_number'].values) - - # Get the session number if provided - if 'session_number' in input_args.columns: - self.session_list=list(input_args['session_number'].values) - - # Get the run number if provided - if 'run_number' in input_args.columns: - self.run_list=list(input_args['run_number'].values) - - # Get the task if provided - if 'task' in input_args.columns: - self.task_list=list(input_args['task'].values) - - # Get the target if provided - if 'target' in input_args.columns: - self.target_list = list(input_args['target'].values) - - # Conditions for no input csv file - else: - # Get the required information if we don't have an input csv - self.ieeg_files = [self.args.dataset] - self.start_times = [self.args.start] - self.durations = [self.args.duration] - - # Infer input information from filename - self.ieegfile_to_keys() - - # Get the information that can be inferred - if self.args.uid_number != None: - self.uid_list = [self.args.uid_number] - - if self.args.subject_number != None: - self.subject_list = [self.args.subject_number] - - if self.args.session != None: - self.session_list = [self.args.session] - - if self.args.run != None: - self.run_list = [self.args.run] - - if self.args.task != None: - self.task_list = [self.args.task] - - if self.args.target != None: - self.target_list = [self.args.target] - - # Add an object to store information via annotation downloads - if self.args.annotations: - self.annot_files = [] - self.start_times = [] - self.durations = [] - self.annotation_uid = [] - self.annotation_sub = [] - self.annotation_ses = [] - self.run_list = [] - self.annotation_flats = [] - - # Make the annotation object - self.annotations = {} - - def annotation_cleanup(self,ifile,iuid,isub,ises,itarget): - """ - Restructure annotation information to be used as new inputs. - """ - - # Remove start clip time if it is just the machine starting up - if self.clips[0].type.lower() == 'clip end' and self.clips[0].end_time_offset_usec == 2000: - self.clips = self.clips[1:] - - # Manage edge cases - if self.clips[0].type.lower() == 'clip end': - self.clips = list(np.concatenate(([0],self.clips), axis=0)) - if self.clips[-1].type.lower() == 'clip start': - self.clips = list(np.concatenate((self.clips,[self.ieeg_end_time-self.ieeg_start_time]), axis=0)) - - clip_vals = [] - for iclip in self.clips: - try: - clip_vals.append(iclip.start_time_offset_usec) - except AttributeError: - clip_vals.append(iclip) - - # Turn the clip times into start and end arrays - clip_start_times = np.array([iclip for iclip in clip_vals[::2]]) - clip_end_times = np.array([iclip for iclip in clip_vals[1::2]]) - clip_durations = clip_end_times-clip_start_times - - # Match the annotations to the clips - self.annotations[ifile] = {ival:{} for ival in range(clip_start_times.size)} - annotation_flats = [] - for annot in self.raw_annotations: - time = annot.start_time_offset_usec - desc = annot.description - for idx, istart in enumerate(clip_start_times): - if (time >= istart) and (time <= clip_end_times[idx]): - event_time_shift = (time-istart) - self.annotations[ifile][idx][event_time_shift] = desc - annotation_flats.append(desc) - - # Update the instance wide values - self.annot_files.extend([ifile for idx in range(len(clip_start_times))]) - self.annotation_uid.extend([iuid for idx in range(len(clip_start_times))]) - self.annotation_sub.extend([isub for idx in range(len(clip_start_times))]) - self.annotation_ses.extend([ises for idx in range(len(clip_start_times))]) - self.target_list.extend([itarget for idx in range(len(clip_start_times))]) - self.start_times.extend(clip_start_times) - self.durations.extend(clip_durations) - self.run_list.extend(np.arange(len(clip_start_times))) - self.annotation_flats.extend(annotation_flats) - - def annotation_cleanup_set_time(self,idx): - - # Get just the current annotation block from ieeg - self.download_data(self.ieeg_files[idx],self.start_times[idx],self.durations[idx],True) - - # Make the annotation object - if self.ieeg_files[idx] not in self.annotations.keys(): - self.annotations[self.ieeg_files[idx]] = {} - self.annotations[self.ieeg_files[idx]][self.run_list[idx]] = {} - - for annot in self.raw_annotations: - - # get the information out of the annotation layer - time = annot.start_time_offset_usec - desc = annot.description - - # figure out its time relative to the download start - event_time_shift = (time-self.start_times[idx]) - - # Store the results - self.annotations[self.ieeg_files[idx]][self.run_list[idx]][event_time_shift] = desc - - - def ieegfile_to_keys(self): - """ - Use the iEEG.org filename to determine keywords. - """ - - # Extract possible keywords from the ieeg filename - self.uid_list = [] - self.subject_list = [] - self.session_list = [] - self.run_list = [] - for ifile in self.ieeg_files: - - # Create a match object to search for relevant subject and session data - match = re.search(r'\D+(\d+)_\D+(\d+)', ifile) - - # Get numerical portions of filename that correspond to subject and session - if match: - candidate_uid = int(match.group(1)) - candidate_sub = match.group(1) - candidate_ses = match.group(2) - else: - candidate_uid = None - candidate_sub = None - candidate_ses = None - - # Look for this informaion in the records - iDF = self.data_record.loc[self.data_record.orig_filename==ifile] - - # If the data already exists, get its previous information - if iDF.shape[0] > 0: - - # Get the existing information - candidate_uid = iDF.iloc[0].uid - candidate_sub = str(iDF.iloc[0].subject_number) - candidate_ses = str(iDF.iloc[0].session_number) - - # Create the subject and session lists - self.uid_list.append(candidate_uid) - self.subject_list.append(candidate_sub) - self.session_list.append(candidate_ses) - self.run_list.append(1) - - def download_data_manager(self): - """ - Loop over the ieeg file list and download data. If annotations, does a first pass to get annotation layers and times, then downloads. - """ - - # Load the data exists exception handler so we can avoid already downloaded data. - DE = DataExists(self.data_record) - - # Loop over the requested data - for idx in range(len(self.ieeg_files)): - - # Download the data - if self.args.annotations: - self.download_data(self.ieeg_files[idx],0,0,True) - self.annotation_cleanup(self.ieeg_files[idx],self.uid_list[idx],self.subject_list[idx],self.session_list[idx],self.target_list[idx]) - else: - # If-else around if the data already exists in our records. Add a skip to the data list if found to maintain run order. - if DE.check_default_records(self.ieeg_files[idx],1e-6*self.start_times[idx],1e-6*self.durations[idx]): - - # Get the annotations for just this download if requested - if self.args.include_annotation: - self.annotation_cleanup_set_time(idx) - - # Download the data - self.download_data(self.ieeg_files[idx],self.start_times[idx],self.durations[idx],False) - - # If successful, notify data observer. Else, add a skip - if self.success_flag: - self.notify_data_observers() - else: - self.data_list.append(None) - else: - print(f"Skipping {self.ieeg_files[idx]} starting at {1e-6*self.start_times[idx]:011.2f} seconds for {1e-6*self.durations[idx]:08.2f} seconds.") - self.data_list.append(None) - - # If downloading by annotations, now loop over the clip level info and save - if self.args.annotations: - # Update the object pointers for subject, session, etc. info - self.ieeg_files = self.annot_files - self.uid_list = self.annotation_uid - self.subject_list = self.annotation_sub - self.session_list = self.annotation_ses - - # Loop over the file list that is expanded by all the annotations - for idx in range(len(self.ieeg_files)): - - # If-else around if the data already exists in our records. Add a skip to the data list if found to maintain run order. - if DE.check_default_records(self.ieeg_files[idx],1e-6*self.start_times[idx],1e-6*self.durations[idx]): - - # Download the data - self.download_data(self.ieeg_files[idx],self.start_times[idx],self.durations[idx],False) - - # If successful, notify data observer. Else, add a skip - if self.success_flag: - self.notify_data_observers() - else: - self.data_list.append(None) - else: - print(f"Skipping {self.ieeg_files[idx]} starting at {1e-6*self.start_times[idx]:011.2f} seconds for {1e-6*self.durations[idx]:08.2f} seconds.") - self.data_list.append(None) - - def save_data(self): - """ - Notify the BIDS code about data updates and save the results when possible. - """ - - # Loop over the data, assign keys, and save - self.new_data_record = self.data_record.copy() - for idx,iraw in enumerate(self.data_list): - if iraw != None: - - # Define start time and duration. Can differ for different filetypes - # iEEG.org uses microseconds. So we convert here to seconds for output. - istart = 1e-6*self.start_times[idx] - iduration = 1e-6*self.durations[idx] - - # Update keywords - self.keywords = {'filename':self.ieeg_files[idx],'root':self.args.bids_root,'datatype':self.type_list[idx], - 'session':self.session_list[idx],'subject':self.subject_list[idx],'run':self.run_list[idx], - 'task':'rest','fs':iraw.info["sfreq"],'start':istart,'duration':iduration,'uid':self.uid_list[idx]} - self.notify_metadata_observers() - - # Save the data - if self.args.include_annotation or self.args.annotations: - success_flag = self.BH.save_data_w_events(iraw, debug=self.args.debug) - else: - success_flag = self.BH.save_data_wo_events(iraw, debug=self.args.debug) - - # If the data wrote out correctly, update the data record - if success_flag: - # Save the target info - try: - self.BH.save_targets(self.target_list[idx]) - except: - pass - - # Add the datarow to the records - self.current_record = self.BH.make_records('ieeg.org') - self.new_data_record = PD.concat((self.new_data_record,self.current_record)) - - ############################################### - ###### IEEG Connection related functions ###### - ############################################### - - def download_data(self,ieegfile,start,duration,annotation_flag,n_retry=5): - - # Attempt connection to iEEG.org up to the retry limit - self.global_timeout = self.args.timeout - n_attempts = 0 - self.success_flag = False - while True: - with Timeout(self.global_timeout,False): - try: - self.ieeg_session(ieegfile,start,duration,annotation_flag) - self.success_flag = True - break - except (IIA.IeegConnectionError,IIA.IeegServiceError,TimeoutException,RTIMEOUT,TypeError) as e: - if n_attempts= end_time: - chunks.append([ival,end_time-ival]) - else: - chunks.append([ival,time_cutoff]) - ival += time_cutoff - - # Call data and concatenate calls if greater than 10 min - self.data = [] - for ival in chunks: - self.data.append(dataset.get_data(ival[0],ival[1],channel_cntr)) - if len(self.data) > 1: - self.data = np.concatenate(self.data) - else: - self.data = self.data[0] - - # Apply the voltage factors - self.data = 1e-6*self.data - - # Get the channel labels - self.channels = dataset.ch_labels - - # Get the samping frequencies - self.fs = [dataset.get_time_series_details(ichannel).sample_rate for ichannel in self.channels] - - # Data quality checks before saving - if np.unique(self.fs).size == 1: - self.fs = self.fs[0] - else: - raise Exception("Too many unique values for sampling frequency.") - else: - self.clips = dataset.get_annotations(self.args.time_layer) - self.raw_annotations = dataset.get_annotations(self.args.annot_layer) - self.ieeg_start_time = dataset.start_time - self.ieeg_end_time = dataset.end_time - session.close() - - ############################### - ###### Custom exceptions ###### - ############################### - - def input_exceptions(self,input_args): - - # Raise some exceptions if we find data we can't work with - if 'orig_filename' not in input_args.columns: - raise Exception("Please provide 'orig_filename' in the input csv file.") - elif 'orig_filename' in input_args.columns: - if 'start' not in input_args.columns and not self.args.annotations: - raise Exception("A 'start' column is required in the input csv if not using the --annotations flag.") - elif 'duration' not in input_args.columns and not self.args.annotations: - raise Exception("A 'duration' column is required in the input csv if not using the --annotations flag.") - - # Handle situations where the user requested annotations but also provided times - if self.args.annotations: - if 'start' in input_args.columns or 'duration' in input_args.columns: - userinput = '' - while userinput.lower() not in ['y','n']: - userinput = input("--annotations flag set to True, but start times and durations were provided in the input. Override these times with annotations clips (Yy/Nn)? ") - if userinput.lower() == 'n': - print("Ignoring --annotation flag. Using user provided times.") - self.args.annotations = False - if userinput.lower() == 'y': - print("Ignoring user provided times in favor of annotation layer times.") - if 'start' in input_args.columns: input_args.drop(['start'],axis=1,inplace=True) - if 'duration' in input_args.columns: input_args.drop(['duration'],axis=1,inplace=True) - - return input_args \ No newline at end of file diff --git a/scripts/codehub/utils/acquisition/BIDS/components/public/jar_handler.py b/scripts/codehub/utils/acquisition/BIDS/components/public/jar_handler.py deleted file mode 100644 index cccef782..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/components/public/jar_handler.py +++ /dev/null @@ -1,260 +0,0 @@ -import os -import time -import getpass - -# Local import -from components.internal.BIDS_handler import * -from components.internal.observer_handler import * -from components.internal.exception_handler import * -from components.internal.data_backends import * - -class jar_handler(Subject): - - def __init__(self,args): - - # Save the input objects - self.args = args - - # Create the object pointers - self.BH = BIDS_handler() - self.backend = return_backend(args.backend) - - # Get the data record - self.get_data_record() - - # Create objects that interact with observers - self.data_list = [] - self.type_list = [] - self.BIDS_keywords = {'root':self.args.bids_root,'datatype':None,'session':None,'subject':None,'run':None,'task':None} - - def workflow(self): - """ - Run a workflow that downloads data from iEEG.org, creates the correct objects in memory, and saves it to BIDS format. - """ - - # Attach observers - self.attach_objects() - - # Determine what files to download and to where - self.get_inputs() - - # Begin downloading the data - self.convert_data_manager() - - # Save the data - self.save_data() - - # Save the data record - self.new_data_record = self.new_data_record.sort_values(by=['subject_number','session_number','run_number']) - self.new_data_record.to_csv(self.data_record_path,index=False) - - # Remove if debugging - if self.args.debug: - os.system(f"rm -r {self.args.bids_root}*") - - def attach_objects(self): - """ - Attach observers here so we can have each multiprocessor see the pointers correctly. - """ - - # Create the observer objects - self._meta_observers = [] - self._data_observers = [] - - # Attach observers - self.add_meta_observer(BIDS_observer) - self.add_data_observer(backend_observer) - - def get_data_record(self): - """ - Get the data record. This is typically 'subject_map.csv' and is used to locate data and prevent duplicate downloads. - """ - - # Get the proposed data record - self.data_record_path = self.args.bids_root+self.args.data_record - - # Check if the file exists - if os.path.exists(self.data_record_path): - self.data_record = PD.read_csv(self.data_record_path) - else: - self.data_record = PD.DataFrame(columns=['orig_filename','source','creator','gendate','uid','subject_number','session_number','run_number','start_sec','duration_sec']) - - def get_inputs(self, multiflag=False, multiinds=None): - """ - Create the input objects that track what files and times to download, and any relevant keywords for the BIDS process. - For single core pulls, has more flexibility to set parameters. For multicore, we restrict it to a pre-built input_args. - """ - - # Check for an input csv to manually set entries - if self.args.input_csv != None: - - # Read in the input data - input_args = PD.read_csv(self.args.input_csv) - - # Pull out the relevant data pointers for required columns. - self.jar_files = list(input_args['orig_filename'].values) - - # Get the unique identifier if provided - if 'start' in input_args.columns: - self.start_times=list(input_args['start'].values) - else: - self.start_times=[self.args.start for idx in range(input_args.shape[0])] - - # Get the unique identifier if provided - if 'duration' in input_args.columns: - self.durations=list(input_args['duration'].values) - else: - self.durations=[self.args.duration for idx in range(input_args.shape[0])] - - # Get the unique identifier if provided - if 'uid' in input_args.columns: - self.uid_list=list(input_args['uid'].values) - else: - self.uid_list=[self.args.uid for idx in range(input_args.shape[0])] - - # Get the subejct number if provided - if 'subject_number' in input_args.columns: - self.subject_list=list(input_args['subject_number'].values) - else: - self.subject_list=[self.args.subject_number for idx in range(input_args.shape[0])] - - # Get the session number if provided - if 'session_number' in input_args.columns: - self.session_list=list(input_args['session_number'].values) - else: - self.session_list=[self.args.session for idx in range(input_args.shape[0])] - - # Get the run number if provided - if 'run_number' in input_args.columns: - self.run_list=list(input_args['run_number'].values) - else: - self.run_list=[self.args.run for idx in range(input_args.shape[0])] - - # Get the task if provided - if 'task' in input_args.columns: - self.task_list=list(input_args['task'].values) - - # Get the target if provided - if 'target' in input_args.columns: - self.target_list = list(input_args['target'].values) - else: - # Get the required information if we don't have an input csv - self.jar_files = [self.args.dataset] - self.start_times = [self.args.start] - self.durations = [self.args.duration] - - # Get the information that can be inferred - if self.args.uid_number != None: - self.uid_list = [self.args.uid_number] - - if self.args.subject_number != None: - self.subject_list = [self.args.subject_number] - - if self.args.session != None: - self.session_list = [self.args.session] - - if self.args.run != None: - self.run_list = [self.args.run] - - if self.args.task != None: - self.task_list = [self.args.task] - - if self.args.target != None: - self.target_list = [self.args.target] - - def convert_data_manager(self): - - # Load the data exists exception handler so we can avoid already downloaded data. - DE = DataExists(self.data_record) - - # Loop over the requested data - for idx in range(len(self.jar_files)): - - # Check if we have a specific set of times for this file - try: - istart = self.start_times[idx] - iduration = self.durations[idx] - except TypeError: - istart = None - iduration = None - - if DE.check_default_records(self.jar_files[idx],istart,iduration): - - # Run the java script here - # Reference the orig_filename to the mef folder - # java.run() - - # Look for data quality flags from java - #java.pass_fail() - # if pass then continue - # if fail: self.success_flag=False - - # Loop over channel files - - self.read_jar_data(self.jar_files[idx]) - - # If successful, notify data observer. Else, add a skip - if self.success_flag: - self.notify_data_observers() - else: - self.data_list.append(None) - else: - print(f"Skipping {self.jar_files[idx]}.") - self.data_list.append(None) - - def read_jar_data(self,data_file): - - try: - - header_file = data_file.split('values_data')[0]+"header_info.csv" - self.data = PD.read_csv(data_file).values - #self.channels = PD.read_csv(header_file)['Channel Name'].values - #self.fs = PD.read_csv(header_file)['Sampling Frequency'].values - self.channels = ['Sin 10Hz'] - self.fs = 800 - self.success_flag = True - except Exception as e: - self.success_flag = False - if self.args.debug: - print(f"Load error {e}") - - def save_data(self): - """ - Notify the BIDS code about data updates and save the results when possible. - """ - - # Loop over the data, assign keys, and save - self.new_data_record = self.data_record.copy() - for idx,iraw in enumerate(self.data_list): - if iraw != None: - - # Define start time and duration. Can differ for different filetypes - # May not exist for a raw file transfer, so add a None outcome. - try: - istart = self.start_times[idx] - iduration = self.durations[idx] - except TypeError: - istart = None - iduration = None - - # Update keywords - self.keywords = {'filename':self.jar_files[idx],'root':self.args.bids_root,'datatype':self.type_list[idx], - 'session':self.session_list[idx],'subject':self.subject_list[idx],'run':self.run_list[idx], - 'task':'rest','fs':iraw.info["sfreq"],'start':istart,'duration':iduration,'uid':self.uid_list[idx]} - self.notify_metadata_observers() - - # Save the data without events until a future release - print(f"Converting {self.jar_files[idx]} to BIDS...") - success_flag = self.BH.save_data_wo_events(iraw, debug=self.args.debug) - - # If the data wrote out correctly, update the data record - if success_flag: - # Save the target info - try: - self.BH.save_targets(self.target_list[idx]) - except: - pass - - # Add the datarow to the records - self.current_record = self.BH.make_records('jar_file') - self.new_data_record = PD.concat((self.new_data_record,self.current_record)) \ No newline at end of file diff --git a/scripts/codehub/utils/acquisition/BIDS/samples/sample_annot.csv b/scripts/codehub/utils/acquisition/BIDS/samples/sample_annot.csv deleted file mode 100644 index 695416a7..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/samples/sample_annot.csv +++ /dev/null @@ -1,5 +0,0 @@ -orig_filename,uid,subject_number,session_number,task,target -EMU0562_Day01_1,1,1,1.0,rest,target_A -EMU0562_Day02_1,1,1,2.0,rest,target_B -EMU0562_Day03_1,1,1,3.0,rest,target_B -EMU0562_Day04_1,1,1,4.0,rest,target_B \ No newline at end of file diff --git a/scripts/codehub/utils/acquisition/BIDS/samples/sample_cmds.txt b/scripts/codehub/utils/acquisition/BIDS/samples/sample_cmds.txt deleted file mode 100644 index 9044bb03..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/samples/sample_cmds.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Single download without inputs. Should create subject 562, session 1, run 1 -python utils/acquisition/BIDS/EEG_BIDS.py --ieeg --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --dataset EMU0562_Day01_1 --start 2925000000 --duration 10000000 - -# Single download without inputs. Should create subject HUP001, session 1, run 1 -python utils/acquisition/BIDS/EEG_BIDS.py --ieeg --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --dataset EMU0562_Day01_1 --start 2925000000 --duration 10000000 --subject HUP001 - -# Single download without inputs. Should create subject 001, session HUP001, run 1 -python utils/acquisition/BIDS/EEG_BIDS.py --ieeg --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --dataset EMU0562_Day01_1 --start 2925000000 --duration 10000000 --session HUP001 --subject 001 --run 1 - -# Single download without inputs. Should create subject 001, session HUP001, run 1. Prevents all output -python utils/acquisition/BIDS/EEG_BIDS.py --ieeg --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --dataset EMU0562_Day01_1 --start 2925000000 --duration 10000000 --session HUP001 --subject 001 --run 1 --debug - -# Download with inputs, specific times -python utils/acquisition/BIDS/EEG_BIDS.py --ieeg --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --input_csv utils/acquisition/BIDS/samples/sample_times.csv - -# Download with inputs, specific times, include an annotation/events file -python utils/acquisition/BIDS/EEG_BIDS.py --ieeg --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --input_csv utils/acquisition/BIDS/samples/sample_times.csv --include_annotation - -# Download with inputs, annotations -python utils/acquisition/BIDS/EEG_BIDS.py --ieeg --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --input_csv utils/acquisition/BIDS/samples/sample_annot.csv --annotations - -# Download with inputs, annotations, using multithreading -python utils/acquisition/BIDS/EEG_BIDS.py --ieeg --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --input_csv utils/acquisition/BIDS/samples/sample_annot.csv --annotations --multithread --ncpu 2 - -# Download with inputs, specific times, setting annotation layer to test exception handling -python utils/acquisition/BIDS/EEG_BIDS.py --ieeg --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --input_csv utils/acquisition/BIDS/samples/sample_times.csv --annotations - -# Single raw edf file conversion without inputs. Should create subject HUP001, session 1, run 1 -python utils/acquisition/BIDS/EEG_BIDS.py --edf --username BJPrager --bids_root /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/tests/single/ --dataset /Users/bjprager/Documents/GitHub/CNT-codehub/user_data/BIDS/sub-00001/ses-preimplant001/eeg/sub-00001_ses-preimplant001_task-task_run-01_eeg.edf --subject HUP001 --uid_number 1 diff --git a/scripts/codehub/utils/acquisition/BIDS/samples/sample_jarfiles.csv b/scripts/codehub/utils/acquisition/BIDS/samples/sample_jarfiles.csv deleted file mode 100644 index 82b7d2d3..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/samples/sample_jarfiles.csv +++ /dev/null @@ -1,2 +0,0 @@ -orig_filename,uid,subject_number,session_number,run_number,task -/Users/bjprager/Documents/GitHub/CNT-codehub/user_data/examples/ieeg_migration/jar_data/Sin_10Hz_values_data.csv,0,HUP0001,001,1,rest \ No newline at end of file diff --git a/scripts/codehub/utils/acquisition/BIDS/samples/sample_tereza.csv b/scripts/codehub/utils/acquisition/BIDS/samples/sample_tereza.csv deleted file mode 100644 index 9c03d47a..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/samples/sample_tereza.csv +++ /dev/null @@ -1,2 +0,0 @@ -orig_filename,uid,start,duration,subject_number,session_number,run_number -HUP223_phaseII,1,1132280000,10000000,HUP001,1,1 \ No newline at end of file diff --git a/scripts/codehub/utils/acquisition/BIDS/samples/sample_times.csv b/scripts/codehub/utils/acquisition/BIDS/samples/sample_times.csv deleted file mode 100644 index dadb38c6..00000000 --- a/scripts/codehub/utils/acquisition/BIDS/samples/sample_times.csv +++ /dev/null @@ -1,3 +0,0 @@ -orig_filename,uid,start,duration,subject_number,session_number,run_number,target -EMU0562_Day01_1,0,4000000000.0,10000000,HUP0562,001,3,target_D -EMU0562_Day01_1,0,5740000000.0,10000000,HUP0562,001,4,target_E \ No newline at end of file diff --git a/scripts/codehub/utils/annotations/README.md b/scripts/codehub/utils/annotations/README.md new file mode 100644 index 00000000..e34df357 --- /dev/null +++ b/scripts/codehub/utils/annotations/README.md @@ -0,0 +1,3 @@ +# Data Annotations + +This folder contains scripts to associate annotations/predictors/etc. with data files. \ No newline at end of file diff --git a/scripts/codehub/utils/annotations/find_data.py b/scripts/codehub/utils/annotations/find_data.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/codehub/utils/association/add_events/add_event.py b/scripts/codehub/utils/association/add_events/add_event.py new file mode 100644 index 00000000..ebece30f --- /dev/null +++ b/scripts/codehub/utils/association/add_events/add_event.py @@ -0,0 +1,43 @@ +import os +import re +import argparse +import pandas as PD + + +if __name__ == '__main__': + + # Command line options needed to obtain data. + parser = argparse.ArgumentParser(description="Merge EDF files together given a manifest document.") + + data_group = parser.add_argument_group('Data configuration options') + data_group.add_argument("--edfpath", type=str, required=True, default=None, help="Output directory to store merged files.") + data_group.add_argument("--sampfreq", type=int, help="Sampling frequency.") + data_group.add_argument("--time", type=float, help="Time to add annotation for.") + data_group.add_argument("--annot", type=str, help="Annotation to add.") + args = parser.parse_args() + + # Determine the annotation filepath + pattern = r"(.+)_\w+\.edf$" + match = re.match(pattern, args.edfpath) + basename = match.group(1) + + # Make the events path + event_path = basename+'_events.tsv' + + # Try to find an existing annotation file + if os.path.exists(event_path): + event_DF = PD.read_csv(event_path,delimiter='\t') + makeflag = False + else: + event_DF = PD.DataFrame(columns=['onset','duration','trial_type','value','sample']) + makeflag = True + + # Add annotation as needed + if args.annot != None: + iDF = PD.DataFrame([[args.time, 0.0, args.annot, 0, args.time*args.sampfreq]],columns=event_DF.columns) + event_DF = PD.concat((event_DF,iDF)) + makeflag = True + + # Write out the results + if makeflag: + event_DF.to_csv(event_path,sep='\t') \ No newline at end of file diff --git a/scripts/codehub/utils/association/data_conversion/README.md b/scripts/codehub/utils/association/data_conversion/README.md new file mode 100644 index 00000000..7a341f5c --- /dev/null +++ b/scripts/codehub/utils/association/data_conversion/README.md @@ -0,0 +1,7 @@ +# File converters + +A few important tools for converting between file formats. + +## Lay to EDF + +Convert a Persyst .lay format dataset into an EDF file. diff --git a/scripts/codehub/utils/association/data_conversion/lay_to_edf/README.md b/scripts/codehub/utils/association/data_conversion/lay_to_edf/README.md new file mode 100644 index 00000000..c094a601 --- /dev/null +++ b/scripts/codehub/utils/association/data_conversion/lay_to_edf/README.md @@ -0,0 +1,3 @@ +# Lay to EDF + +Convert a Persyst .lay format dataset into an EDF file. \ No newline at end of file diff --git a/scripts/codehub/utils/association/data_conversion/lay_to_edf/convert_lay_edf.py b/scripts/codehub/utils/association/data_conversion/lay_to_edf/convert_lay_edf.py index db36708e..44684e4a 100644 --- a/scripts/codehub/utils/association/data_conversion/lay_to_edf/convert_lay_edf.py +++ b/scripts/codehub/utils/association/data_conversion/lay_to_edf/convert_lay_edf.py @@ -7,6 +7,12 @@ from mne_bids import BIDSPath, write_raw_bids def DateException(inpath): + """ + LAY data can have a float type for its date, which is not allowed for MNE's read raw PERSYST. So we need to clean it up and remove the float. + + Args: + inpath (filepath,string): Path to the lay file. + """ # Read in the lay file DF = PD.read_csv(inpath,delimiter='=',names=['key','value']) diff --git a/scripts/codehub/utils/association/data_merge/README.md b/scripts/codehub/utils/association/data_merge/README.md deleted file mode 100644 index 1d4b6c85..00000000 --- a/scripts/codehub/utils/association/data_merge/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Data Merging Scripts - -This folder contains scripts aimed at associating and merging different data sources. - -## merge_features.py -This script merges the pipeline outputs into one file with downcasted and mappings to reduce data volume. - -### Example command - -``` -python merge_features.py --indir XX/YY/ZZ/ -``` - -where XX/YY/ZZ is the output directory path for the EPIPY pipeline. \ No newline at end of file diff --git a/scripts/codehub/utils/association/data_merge/merge_and_map.py b/scripts/codehub/utils/association/data_merge/merge_and_map.py deleted file mode 100644 index d6a53934..00000000 --- a/scripts/codehub/utils/association/data_merge/merge_and_map.py +++ /dev/null @@ -1,117 +0,0 @@ -import glob -import pickle -import argparse -import numpy as np -import pandas as PD - -def merge_data(searchpath,map_dict): - - # Get the columns to map - map_keys = list(map_dict.keys()) - - # Read in the files in order, and apply mapping, storing to the output - for idx,ifile in enumerate(searchpath): - - # Special condition for first time load, otherwise append - if idx == 0: - - # Read in data to final variable name - new_df = PD.read_pickle(ifile) - - # Apply any needed mapping - if len(map_keys) > 0: - for ikey in map_keys: - new_df[ikey] = new_df[ikey].apply(lambda x: map_dict[ikey][x]) - else: - # Read in data to temporary namespace - idf = PD.read_pickle(ifile) - - # Apply any needed mapping - if len(map_keys) > 0: - for ikey in map_keys: - idf[ikey] = idf[ikey].apply(lambda x: map_dict[ikey][x]) - - # Append results to final variable name - new_df = PD.concat((new_df,idf)) - return new_df - -def create_map(searchpath,cols): - - # Make the data dictionary to store info for each column that needs to be mapped - data_dict = {} - for icol in cols: - data_dict[icol] = [] - - # Read in the files in order, and grab just the mapping columns to reduce memory usage - print("Generating a unique mapping for each column. This may take awhile.") - for ifile in searchpath: - iDF = PD.read_pickle(ifile)[cols] - for icol in cols: - vals = list(iDF[icol].unique()) - data_dict[icol].extend(vals) - - # Create the mapping dictionary - map_dict = {} - for icol in cols: - - # Create the mapping dictionary - uvals = np.unique(data_dict[icol]) - newvals = np.arange(uvals.size) - udict = dict(zip(uvals.ravel(),newvals.ravel())) - - # Save the mapping - map_dict[icol] = udict - - return map_dict - - -def parse_list(input_str): - """ - Helper function to allow list inputs to argparse using a space or comma - - Args: - input_str (str): Users inputted string - - Returns: - list: Input argument list as python list - """ - - # Split the input using either spaces or commas as separators - values = input_str.replace(',', ' ').split() - return [value for value in values] - -if __name__ == '__main__': - """ - Create a mapping file from EPIPY output for given columns, downcast to mapped varaibles, and merge the files together. - Reduces size in memory of the dataframe. - """ - - # Argument parsing - parser = argparse.ArgumentParser(description="Simplified data merging tool.") - parser.add_argument("--searchpath", type=str, required=True, help='Search path for files to map and downcast. Wildcard enabled.') - parser.add_argument("--cols", type=parse_list, help="Comma separated list of columsn to map and downcast.") - parser.add_argument("--outfile_data", default="merged_data.pickle", type=str, help='Output filename for the new merged and downcasted') - parser.add_argument("--outfile_map", default="merged_map.pickle", type=str, help='Output filename for the mapped data column dictionaries') - parser.add_argument("--mapping_file", default=None, type=str, help="Optional filepath to an exisitng mapping file. Useful if you are doing sensitivity analysis or reprocessing the same dataset.") - args = parser.parse_args() - - # Read in the searchpath and create a filelist - filelist = glob.glob(args.searchpath) - - # Get the list of columns to map, if any - if args.mapping_file == None: - if len(args.cols) > 0: - map_dict = create_map(filelist,args.cols) - pickle.dump(map_dict,open(args.outfile_map,"wb")) - else: - map_dict = {} - else: - print("Using existing mapping file.") - map_dict = pickle.load(open(args.mapping_file,"rb")) - - # Merge the files - out_DF = merge_data(filelist,map_dict) - - # Save the results - out_DF.to_pickle(args.outfile_data) - diff --git a/scripts/codehub/utils/association/data_merge/merge_features.py b/scripts/codehub/utils/association/data_merge/merge_features.py deleted file mode 100644 index 029d3318..00000000 --- a/scripts/codehub/utils/association/data_merge/merge_features.py +++ /dev/null @@ -1,90 +0,0 @@ -import yaml -import glob -import pickle -import argparse -import pandas as PD -from sys import argv,exit - - -if __name__ == '__main__': - """ - Creates a merged output for the EPIPY feature dataframes. - - DEPRECIATE: 07/08/24: New format from EPIPY makes this defunct. - """ - - # Argument parsing - parser = argparse.ArgumentParser(description="Simplified data merging tool.") - parser.add_argument("--indir", type=str, help='Input directory') - parser.add_argument("--col_config", type=str, help="Optional path to a yaml with drop_col and obj_col definitions.") - parser.add_argument("--outfile_model", default="merged_model.pickle", type=str, help='Output filename for model data') - parser.add_argument("--outfile_meta", default="merged_meta.pickle", type=str, help='Output filename for metadata') - parser.add_argument("--outfile_map", default="merged_map.pickle", type=str, help='Output filename for any mapped data column dictionaries') - args = parser.parse_args() - - # get the files to merge - files = glob.glob(f"{args.indir}*feature*.pickle") - files = [ifile for ifile in files if ifile != args.outfile_model and ifile != args.outfile_meta] - - if len(files) > 0: - - # Object columns - if args.col_config == None: - drop_cols = ['t_end','method'] - obj_cols = ['dt','annotation'] - map_cols = ['file','uid','tag'] - else: - col_info = yaml.safe_load(open(args.col_config,'r')) - for key, inner_dict in col_info.items(): - globals()[key] = inner_dict - - # Loop over the files and save the outputs - meta_obj = [] - model_obj = [] - for ifile in files: - - # Read in data and clean up - print(f"Working on {ifile}.") - iDF = PD.read_pickle(ifile) - iDF['tag'] = iDF['method']+'_'+iDF['tag'] - iDF = iDF.drop(drop_cols,axis=1) - - # Get the model columns - model_cols = [icol for icol in iDF.columns if icol not in obj_cols] - - # Store results in serialized object - meta_obj.append(iDF[obj_cols]) - model_obj.append(iDF[model_cols]) - - # Meta generation - print("Making the meta file") - iDF = PD.concat(meta_obj) - pickle.dump(iDF,open(f"{args.indir}{args.outfile_meta}","wb")) - - # Make the cleaned up model view - output_dict = {} - iDF = PD.concat(model_obj) - for imap in map_cols: - iDF[imap], output_dict[imap] = PD.factorize(iDF[imap]) - if 'file' in iDF.columns: - iDF['file'], file_mapping_dict = PD.factorize(iDF['file']) - - # Final downcasting attempt - for icol in iDF: - itype = iDF[icol].dtype - try: - iDF[icol] = PD.to_numeric(iDF[icol],downcast='integer') - if iDF[icol].dtype == itype: - iDF[icol] = PD.to_numeric(iDF[icol],downcast='float') - except ValueError: - pass - - # Make the mapping dictionary - if 'file' in iDF.columns: - output_dict['file'] = file_mapping_dict - - print("Making the model file") - pickle.dump(iDF,open(f"{args.indir}{args.outfile_model}","wb")) - pickle.dump(output_dict,open(f"{args.indir}{args.outfile_map}","wb")) - else: - print("No files found.") \ No newline at end of file diff --git a/scripts/codehub/utils/association/edf_merge/edf_merge.py b/scripts/codehub/utils/association/edf_merge/edf_merge.py new file mode 100644 index 00000000..c045c0cf --- /dev/null +++ b/scripts/codehub/utils/association/edf_merge/edf_merge.py @@ -0,0 +1,57 @@ +import argparse +import numpy as np +import pandas as PD +from tqdm import tqdm +from mne.io import read_raw_edf +from mne import concatenate_raws +from mne.export import export_raw +from sys import exit + +if __name__ == '__main__': + + # Command line options needed to obtain data. + parser = argparse.ArgumentParser(description="Merge EDF files together given a manifest document.") + + data_group = parser.add_argument_group('Data configuration options') + data_group.add_argument("--outdir", type=str, required=True, default=None, help="Output directory to store merged files.") + data_group.add_argument("--manifest", type=str, required=True, help="Filepath to the manifest document.") + data_group.add_argument("--blocksize", type=int, default=-1, help="Number of files to combine together. -1 means merge all.") + args = parser.parse_args() + + # Filepath cleanup + if args.outdir[-1] != '/':args.outdir+='/' + + # Load the manifest + manifest_DF = PD.read_csv(args.manifest) + + # Get the filepaths as an array + filepaths = manifest_DF.filepath.values + + # Get the fileblocks + if args.blocksize != -1: + + # Get the total number of blocks with remainder + number_blocks = filepaths.size/args.blocksize + + # Get the number of blocks with the exact right sizing + nitr = int(np.floor(number_blocks)) + fileblocks = [] + for itr in range(nitr): + fileblocks.append(filepaths[itr*args.blocksize:(itr+1)*args.blocksize]) + + # Add the remainder if needed + if number_blocks > nitr: + fileblocks.append(filepaths[args.blocksize*nitr:]) + else: + fileblocks = [filepaths] + + # Begin the merging + for idx,iblock in enumerate(fileblocks): + outraw = read_raw_edf(iblock[0]) + for ifile in tqdm(iblock[1:], total=len(iblock)-1, desc=f"Merging Block {idx:02d}"): + newraw = read_raw_edf(ifile) + outraw = concatenate_raws([outraw,newraw],on_mismatch='ignore') + try: + export_raw(f"{args.outdir}merged_{idx:03d}.edf",outraw,fmt='edf') + except ValueError: + export_raw(f"{args.outdir}merged_{idx:03d}.edf",outraw,fmt='edf',physical_range=(0,1)) \ No newline at end of file diff --git a/scripts/codehub/utils/association/enrich_targets/README.md b/scripts/codehub/utils/association/enrich_targets/README.md deleted file mode 100644 index 9a72e237..00000000 --- a/scripts/codehub/utils/association/enrich_targets/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Enrich Targets - -This folder contains scripts aimed at associating user inputed information sources (cli/file/etc.) and adding it to exisitng target files. - -## enrich_targets.py - -This script takes user provided information and appending it to a target file. diff --git a/scripts/codehub/utils/association/enrich_targets/enrich_targets.py b/scripts/codehub/utils/association/enrich_targets/enrich_targets.py deleted file mode 100644 index f53d41dc..00000000 --- a/scripts/codehub/utils/association/enrich_targets/enrich_targets.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -import pickle -import argparse -import numpy as np -import pandas as PD -from sys import exit - -class CustomFormatter(argparse.HelpFormatter): - """ - Custom formatting class to get a better argument parser help output. - """ - - def _split_lines(self, text, width): - if text.startswith("R|"): - return text[2:].splitlines() - return super()._split_lines(text, width) - -class data_reader: - - def __init__(self,infile): - - # Create class-wide variables - self.infile = infile - - def enrichment_keypair(self): - - # Read in the key-pair - fp = open(self.infile,'r') - data = fp.readline() - fp.close() - - # Clean up the string and break it up - data = data.replace('\n','') - data_array = data.split(',') - - # Make the new output target dict - output = {} - output[data_array[0]] = data_array[1] - - return output - - def TUEG_dt(self): - - # Read in the tsv using pandas so we can just skip rows and assign column headers - DF = PD.read_csv(self.infile,skiprows=2,delimiter=' ', names=['t0','t1','tag','prob']) - - # Create an output dictionary that will be merged with the current targets - output = {} - output['TUEG_dt_t0'] = '_'.join([f"{ival:.1f}" for ival in DF['t0'].values]) - output['TUEG_dt_t1'] = '_'.join([f"{ival:.1f}" for ival in DF['t1'].values]) - output['TUEG_dt_tag'] = '_'.join(DF['tag'].values) - - return output - -def make_help_str(idict): - """ - Make a well-formated help string for the possible keyword mappings - - Args: - idict (dict): Dictionary containing the allowed keywords values and their explanation. - - Returns: - str: Formatted help string - """ - - return "\n".join([f"{key:15}: {value}" for key, value in idict.items()]) - -if __name__ == '__main__': - - # Define the allowed enrichment information - allowed_enrichment_types = {} - allowed_enrichment_types['TUEG_TSV_dt'] = 'Read in a TUEG tsv file and assign the target variable to a time window.' - allowed_enrichment_types['keypair'] = 'Add a keypair to the targets' - allowed_enrichment_help = make_help_str(allowed_enrichment_types) - - # Command line options needed to obtain data. - parser = argparse.ArgumentParser(description="iEEG to bids conversion tool.", formatter_class=CustomFormatter) - parser.add_argument("--target_file", type=str, help="Path to target file to enrich.") - parser.add_argument("--enrichment_type", type=str, choices=list(allowed_enrichment_types.keys()), default="TUEG_TSV_dt", help=f"R|Choose an option:\n{allowed_enrichment_help}") - - input_group = parser.add_mutually_exclusive_group() - input_group.add_argument("--enrichment_map", type=str, help="Csv file with enrichment info. Columns:[path_to_datafile_that_enriches_target,path_to_target_file,enrichment_type].") - input_group.add_argument("--enrichment_file", type=str, help="Path to datafile that will enrich target.") - args = parser.parse_args() - - # Check which type of input format we are working with, and create relevant work list - if args.enrichment_map != None and args.enrichment_file == None: - enrichment_df = PD.read_csv(args.enrichment_map) - enrichment_files = enrichment_df['enrichment_files'].values - target_files = enrichment_df['target_files'].values - enrichment_types = enrichment_df['enrichment_types'].values - elif args.enrichment_map == None and args.enrichment_file != None: - # Store the enrichment data - enrichment_files = [args.enrichment_file] - - # Check for target file data and store if provided, or warn user - if args.target_file != None: - target_files = [args.target_file] - else: - raise FileNotFoundError("Please provide a target file to enrich using the --target_file keyword.") - - # Check for enrichment type and store if provided, or warn user - if args.target_file != None: - enrichment_types = [args.enrichment_type] - else: - raise NameError("Please provide an enrichment type using the --enrichment_type flag. Using --help will show allowed enrichment types.") - - else: - raise FileNotFoundError("Please provide an enrichment file or an enrichment map file.") - - # Loop over all of the target files to enrich - for idx,ifile in enumerate(target_files): - - # Confirm if the target file exists, read in or create as needed - if os.path.exists(ifile): - targets = pickle.load(open(ifile,"rb")) - else: - targets = {} - - # Initialize the data reading class - DR = data_reader(enrichment_files[idx]) - - # Apply the right enrichment logic to get out data - if args.enrichment_type == 'TUEG_TSV_dt': - additional_targets = DR.TUEG_dt() - elif args.enrichment_type == 'keypair': - additional_targets = DR.enrichment_keypair() - new_targets = {**targets,**additional_targets} - - # Save the new target file - pickle.dump(new_targets,open(ifile,"wb")) diff --git a/scripts/codehub/utils/association/misc/README.md b/scripts/codehub/utils/association/misc/README.md deleted file mode 100644 index f5e9bb9a..00000000 --- a/scripts/codehub/utils/association/misc/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Misc - -This folder contains scripts aimed at associating different data sources together. Most scripts in this folder are designed to help with staff tasks in the CNT. diff --git a/scripts/codehub/utils/association/misc/builddcm_python3.py b/scripts/codehub/utils/association/misc/builddcm_python3.py deleted file mode 100644 index 948bd2a2..00000000 --- a/scripts/codehub/utils/association/misc/builddcm_python3.py +++ /dev/null @@ -1,199 +0,0 @@ -#!/usr/bin/env python -import sys,os -import argparse -import numpy as np -from pptx import Presentation - -zeropadjoin=lambda a_b:'%s%02d'%(a_b[0],int(a_b[1])) if all([a_b[0],a_b[1]]) else None -def build_lines(myid,electrodes,gnd,ref): - header_lines=['# %s channel mapping'%myid,''] - fname='%s_channelMapping.txt'%myid - - #build and sort the lines - idx_gnd,idx_ref,last,empties,elec_lines=None,None,None,[],[] - for idx,name in enumerate(electrodes): - if name=='':empties.append(idx) - elif gnd in [name]:idx_gnd=idx - elif ref in [name]:idx_ref=idx - else: - last=idx - elec_lines.append('%d %s'%(idx,name)) - - empties=[idx for idx in empties if idx 0: - if rowcnt < nrows: - channels = output[:ncol] - output = output[ncol:] - leads = output[:ncol] - output = output[ncol:] - rowcnt += 1 - else: - output.pop(ntail) - channels = output[:ntail] - output = output[ntail:] - leads = output[:ntail] - output = [] - for idx,ichannel in enumerate(channels): - chmap[ichannel] = leads[idx] - electrodes.append(f"{ichannel.replace(" ", "")}-{leads[idx]}") - - return electrodes - -def main(myid=None, out=None, gnd=None, ref=None, ppt_file=None, nlead=None): - - # Obtain patient id if calld as a function. - if not myid: - myid = str(input("patient ID: ")).strip() - myid = myid.upper() - - # Call electrode naming functions - if ppt_file == None: - electrodes = manual_entry(myid, out, gnd, ref) - else: - electrodes = ppt_read(ppt_file,nlead) - - # Output logic - if electrodes: - # get ground and ref - if not all([gnd,ref]): - gnd,ref=[input(prompt) for prompt in ['ground? ','reference? ']] - gnd,ref=[zeropadjoin(split_elec(e)) for e in [gnd,ref]] - # write file - fname,lines=build_lines(myid,electrodes,gnd,ref) - if out != None: - with open(os.path.join(os.path.abspath(out),fname),'w') as f: - f.write('\n'.join(lines)) - else: - with open(os.path.join(os.getcwd(),fname),'w') as f: - f.write('\n'.join(lines)) - -if __name__=='__main__': - - """ - Unattributed creation. - - Minor refactorization, comments, and clean up, Brian Prager. 12/15/2023. - """ - - # Argument parser - parser = argparse.ArgumentParser() - parser.add_argument("id", type=str, help="id", default=None) - parser.add_argument("--out", dest="out", type=str, help="output", default=None) - parser.add_argument("--gnd", dest="gnd", type=str, help="ground", default=None) - parser.add_argument("--ref", dest="ref", type=str, help="reference", default=None) - parser.add_argument("--ppt_file", dest="ppt_file", type=str, help="Path to powerpoint file. If provided, automatically assign channels.", default=None) - parser.add_argument("--nlead", dest="nlead", type=int, help="Number of leads in the powerpoint file. Needed for auto assigning.", default=None) - args = parser.parse_args() - - # Call the main function - main(myid=args.id, out=args.out, gnd=args.gnd, ref=args.ref, ppt_file=args.ppt_file, nlead=args.nlead) diff --git a/scripts/codehub/utils/association/misc/ccep_leads_from_pptx.py b/scripts/codehub/utils/association/misc/ccep_leads_from_pptx.py deleted file mode 100644 index d42a7482..00000000 --- a/scripts/codehub/utils/association/misc/ccep_leads_from_pptx.py +++ /dev/null @@ -1,73 +0,0 @@ -import numpy as np -from sys import argv,exit -from pptx import Presentation - -if __name__ == '__main__': - - # Read in the file provided by the user - prs = Presentation(argv[1]) - nlead = int(argv[2]) - - # Get the slides in memory - slides_list = [] - for slide in prs.slides: - slides_list.append(slide) - - # Calculate the number of rows that are complete and the remainder - ncol = 16 - nrows = np.floor(nlead/ncol).astype('int') - ntail = nlead%ncol - rowcnt = 0 - - # Parse the channel pages, save to dict. (Leads are supposed to always be the last four pages) - chmap = {} - for slide_cnt,slide in enumerate(slides_list[-4:]): - - # Read in the table data - output = [] - for shape in slide.shapes: - values = [] - if hasattr(shape, "table"): - cells = shape.table.iter_cells() - for icell in cells: - values.append(icell.text) - - # Clean up the ocassional character array (versus strings) - flag = True - for idx,ivalue in enumerate(values): - if ivalue == '': - if flag: - output.append(ivalue) - flag = False - else: - output.append(ivalue) - flag = True - - # Attempt to clean up the grnd and ref cells - try: - output.pop(0) - if slide_cnt in [0,2]: output.pop(ncol) - output.pop(2*ncol) - output.pop(4*ncol) - output.pop(6*ncol) - if slide_cnt in [0,2]: output.pop(7*ncol) - except IndexError: - pass - - # Iterate over number of rows on this page, checking against maximum number - while len(output) > 0: - if rowcnt < nrows: - channels = output[:ncol] - output = output[ncol:] - leads = output[:ncol] - output = output[ncol:] - rowcnt += 1 - else: - output.pop(ntail) - channels = output[:ntail] - output = output[ntail:] - leads = output[:ntail] - output = [] - for idx,ichannel in enumerate(channels): - chmap[ichannel] = leads[idx] - diff --git a/scripts/codehub/utils/association/visualize_targets/README.md b/scripts/codehub/utils/association/visualize_targets/README.md deleted file mode 100644 index cb858162..00000000 --- a/scripts/codehub/utils/association/visualize_targets/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Visualize Targets - -This folder contains scripts aimed at associating target data with different visualizations. - -## target_to_timeseries.py - -This script creates a plotting environment that shows which data points are associated with target information. (i.e. Marked Sleep and Awake, slowing, etc.) diff --git a/scripts/codehub/utils/association/visualize_targets/target_to_timeseries.py b/scripts/codehub/utils/association/visualize_targets/target_to_timeseries.py deleted file mode 100644 index ec4228f1..00000000 --- a/scripts/codehub/utils/association/visualize_targets/target_to_timeseries.py +++ /dev/null @@ -1,225 +0,0 @@ -import sys -import pickle -import argparse -import itertools -import numpy as np -import pandas as PD -from glob import glob -from tqdm import tqdm -from scipy.signal import find_peaks - -class sleep_state_power: - - def __init__(self): - pass - - def peaks(self, vals, prominence=1, width=3, height=None): - - if height == None: - height = 0.1*max(vals) - - return find_peaks(vals, prominence=prominence, width=width, height=height) - - def histogram_data(self,values): - - # Make a better plotting baseline - self.logbins = np.logspace(7,12,50) - self.log_x = (0.5*(self.logbins[1:]+self.logbins[:-1])) - - # Get the histogram counts - cnts = np.histogram(values,bins=self.logbins)[0] - - # Get the peak information - peaks, properties = self.peaks(cnts) - - # Convert peaks into the right units for plotting - peaks_x = self.log_x[peaks] - peaks_y = cnts[peaks] - - properties["left_ips"] = [self.log_x[int(np.floor(ival))] for ival in properties["left_ips"]] - properties["right_ips"] = [self.log_x[int(np.ceil(ival))] for ival in properties["right_ips"]] - - return cnts,peaks_x,peaks_y,properties - - def get_state(self): - """ - Parse the annotations for sleep state - """ - - # Get list of annotations to parse - annots = self.rawdata.annotation.values - uannot = self.rawdata.annotation.unique() - - # Create sleep awake masks - sleep = np.zeros(annots.size) - awake = sleep.copy() - - # Loop over annotations - for iannot in uannot: - if iannot != None: - ann = iannot.lower() - if 'wake' in ann or 'awake' in ann or 'pdr' in ann: - inds = (annots==iannot) - awake[inds]=1 - if 'sleep' in ann or 'spindle' in ann or 'k complex' in ann or 'sws' in ann: - inds = (annots==iannot) - sleep[inds]=1 - - # Use sleep awake masks to get data splits - try: - self.sleep_list.append(self.rawdata.iloc[sleep.astype('bool')]) - except AttributeError: - self.sleep_list = [self.rawdata.iloc[sleep.astype('bool')]] - - try: - self.awake_list.append(self.rawdata.iloc[awake.astype('bool')]) - except AttributeError: - self.awake_list = [self.rawdata.iloc[awake.astype('bool')]] - - def model_compile(self): - - # Merge the datasets together - if len(self.awake_list) == 1: - self.awake_df = self.awake_list[0] - self.sleep_df = self.sleep_list[0] - else: - self.awake_df = PD.concat(self.awake_list) - self.sleep_df = PD.concat(self.sleep_list) - - # Get the unique patient ids - self.uids = self.rawdata['uid'].unique() - - # Create the output object - id_cols = ['file','t_start','t_end'] - self.output = {} - - # Get the histogram data for awake and asleep in alpha and delta - alpha_delta_tags = ['[8.0,12.0]','[1.0,4.0]'] - awake_sleep_tags = ['awake','sleep'] - tag_combinations = list(itertools.product(awake_sleep_tags,alpha_delta_tags)) - for itag in tag_combinations: - print("Working on %s data in the %s band." %(itag[0],itag[1])) - - # Make the data cuts - if itag[0] == 'awake': - iDF = self.awake_df.loc[(self.awake_df['tag']==itag[1])] - elif itag[1] == 'sleep': - iDF = self.sleep_df.loc[(self.sleep_df['tag']==itag[1])] - - # Create the outputs for this combo - self.output[itag] = iDF[id_cols].copy().reset_index(drop=True) - for ichan in self.channels: - self.output[itag][ichan] = -1 - file_ref = self.output[itag]['file'].values - tstart_ref = self.output[itag]['t_start'].values - range_ref = np.arange(file_ref.size) - - # Create a numpy array to reference for searcing (faster than pandas dataframe lookup) - lookup_array = iDF[id_cols].values - - for ichannel in tqdm(self.channels, desc='Channel searches:', total=len(self.channels), leave=False): - values = iDF[ichannel].values.astype('float') - outvals = self.output[itag][ichannel].values - cnts,peaks_x,peaks_y,properties = self.histogram_data(values) - - # Loop over the peaks to associate it with original dataframe - for idx,ipeak in enumerate(peaks_x): - - # get the boundaries - lo_bound = properties['left_ips'][idx] - hi_bound = properties['right_ips'][idx] - - # Get the indices in bounds - jinds = (values>=lo_bound)&(values<=hi_bound) - jarr = lookup_array[jinds] - - # Loop over the results (yes again ;_;) to populate the reference dict ) - for irow in jarr: - #inds = (file_ref==irow[0])&(tstart_ref==irow[1]) - finds = np.zeros(file_ref.size).astype('bool') - inds = (tstart_ref==irow[1]) - finds_numeric = [i for i in range_ref[inds] if file_ref[i] == irow[0]] - finds[finds_numeric] = True - outvals[inds&finds] = ipeak - self.output[itag][ichannel] = outvals - pickle.dump(self.output,open(self.args.outfile,"wb")) - -class data_manager(sleep_state_power): - - def __init__(self,args): - self.args = args - self.output_list = [] - - def load_data(self,infile): - - # Read in and clean up the data a bit - self.rawdata = PD.read_pickle(infile) - files = self.rawdata['file'] - files = [ifile.split('/')[-1] for ifile in files] - self.rawdata['file'] = files - - # Get the relevant channels - black_list = ['file','t_start','t_end','dt','method','tag','uid','target','annotation'] - self.channels = [] - for icol in self.rawdata.columns: - if icol not in black_list: - self.channels.append(icol) - - def model_handler(self): - - if self.args.sleep_awake_power: - sleep_state_power.get_state(self) - - def model_compile(self): - if self.args.sleep_awake_power: - sleep_state_power.model_compile(self) - - -def parse_list(input_str): - """ - Helper function to allow list inputs to argparse using a space or comma - - Args: - input_str (str): Users inputted string - - Returns: - list: Input argument list as python list - """ - - # Split the input using either spaces or commas as separators - values = input_str.replace(',', ' ').split() - try: - return [int(value) for value in values] - except: - return [str(value) for value in values] - -if __name__ == '__main__': - - # Command line options needed to obtain data. - parser = argparse.ArgumentParser(description="Simplified data merging tool.") - - input_group = parser.add_mutually_exclusive_group() - input_group.add_argument("--file", type=str, help="Input pickle file to read in.") - input_group.add_argument("--wildcard", type=str, help="Wildcard enabled path to pickle files to read in.") - - datachunk_group = parser.add_argument_group('Data Chunking Options') - datachunk_group.add_argument("--group_cols", required=True, type=parse_list, help="List of columns to group by.") - - model_group = parser.add_argument_group('Type of models to associate with timeseries.') - model_group.add_argument("--sleep_awake_power", default=True, action='store_true', help="List of columns to group by.") - model_group.add_argument("--outfile", required=True, type=str, help="Output file path.") - args = parser.parse_args() - - # Create the file list to read in - if args.file != None: - files = [args.file] - else: - files = glob(args.wildcard) - - # Iterate over the data and create the relevant plots - DM = data_manager(args) - for ifile in files: - DM.load_data(ifile) - DM.model_handler() - DM.model_compile() - diff --git a/scripts/codehub/utils/interfaces/EPIPY.py b/scripts/codehub/utils/interfaces/EPIPY_GUI.py similarity index 99% rename from scripts/codehub/utils/interfaces/EPIPY.py rename to scripts/codehub/utils/interfaces/EPIPY_GUI.py index 60f00aed..e80dc099 100644 --- a/scripts/codehub/utils/interfaces/EPIPY.py +++ b/scripts/codehub/utils/interfaces/EPIPY_GUI.py @@ -2,7 +2,7 @@ import dearpygui.dearpygui as dpg # Local imports -import pipeline_manager as PM +import epipy as PM # Interface imports from EPIPY_modules.theme import applyTheme diff --git a/scripts/codehub/utils/interfaces/README.md b/scripts/codehub/utils/interfaces/README.md new file mode 100644 index 00000000..72cfc4eb --- /dev/null +++ b/scripts/codehub/utils/interfaces/README.md @@ -0,0 +1,14 @@ +# EPIPY GUI + +This script helps construct the commands needed to run the pipeline manager for EPIPY. EPIPY is designed to allow rapid deployment of tested research code in any order that the user desires. It also takes new research code, putting it into a tested pipeline. The intention is to remove redundant steps in analyzing data that can introduce errors, and allow for better data versioning. + +## Installation +Please consult the [readme file for the whole repository](https://github.com/penn-cnt/CNT-codehub) for more information. This is a wrapper script to the main pipeline code for the lab and requires a full installation to work properly. + +## Usage + +You can run this code with the following command: + +``` +python EPIPY_GUI.py +``` \ No newline at end of file diff --git a/scripts/codehub/utils/visualization/edf_viewer/README.md b/scripts/codehub/utils/visualization/edf_viewer/README.md index da946c29..6872c2f5 100644 --- a/scripts/codehub/utils/visualization/edf_viewer/README.md +++ b/scripts/codehub/utils/visualization/edf_viewer/README.md @@ -6,6 +6,7 @@ Toolkit for visualizing all of the channel data for an EDF file. This uses the p To use this toolkit, we highly recommend you create a python environment. This protects your base python environment from running into conflicts or versioning issues. We describe how to install the CNT environment below. +### Conda 1. Clone this repository to your local workstation. 2. Install Anaconda - Please visit the [Anaconda Downloads](https://www.anaconda.com/download) page to download the appropriate Anaconda installer for your operating system. @@ -21,7 +22,28 @@ To use this toolkit, we highly recommend you create a python environment. This p - The environment is set to default to the name `cnt_codehub`. (You can change this by modifying the `name` entry in your local copy of the yaml file. If you change this, you would run the above command on the new name 5. Finally, all you need to do is add your new code repository to your anaconda path. Run the following command in the terminal or powershell - conda develop \/scripts/codehub - + +### Venv and pip + +1. To create a virtual environment, you need to create a location for the environment to install to. For this example, we will specify `/demonstration/environment/cnt_codehub` as our environment location. Using the python version of your choice, in this example we will select 3.10, run the following command: + + > python3.10 -m venv /demonstration/environment/cnt_codehub +2. To enter the envrionment, simply run: + + > source /demonstration/environment/cnt_codehub/bin/activate +3. Once in the environment, a requirements.txt file with all the needed packages to run this suite of code can be found at the following location: + + > [CNT Codehub YAML](core_libraries/python/cnt_codehub/envs/requirements.txt) + + This file can be installed using the following call to pip from the envs subdirectory: + + > pip install -r requirements.txt + + which will install everything to your current virual environment. +4. Add the codehub to your virtual environment path. For a virtual environment, an easy way to add `/scripts/codehub/` to your path would be to add a text file with a .pth extention (any filename is fine) to the site-packages subfolder in your virtual environment folder. Within the text file you can just copy and paste the absolute path as the only contents. + + Typically, the path your your site-packages can be found at: `/lib/python/site-packages`. + ## Sample Commands Please consult the `--help` flag for more detailed information on different inputs to the viewer. We provide a few common uses below. @@ -29,33 +51,25 @@ Please consult the `--help` flag for more detailed information on different inpu ### Random start time via seed (default behavior) with for all wildcard data matches ``` -python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" --username bjprager --flagging +python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" ``` ### Random start time via seed (default behavior) with for all wildcard data matches with a common average montage (default=bipolar) ``` -python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" --username bjprager --flagging --montage common_average +python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" --montage common_average ``` ### Random start time via seed (default behavior) with for all wildcard data matches without any montage ``` -python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" --username bjprager --flagging --montage None +python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" --montage None ``` -### Random start time via seed (default behavior) with flagging for all wildcard data matches -``` -python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" --username bjprager --flagging -``` -flagging enables an interactive mode where the user can denote if certain events occur within the observed time window and save the results to a csv file. - -By default the code outputs to `./edf_viewer_flags.csv` but can be changed using the `--outfile` option at runtime. - -### Set start time of t=0 with flagging for wildcard datamatches +### Set start time of t=0 for wildcard datamatches ``` -python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" --username bjprager --flagging --t0 0 +python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" --t0 0 ``` -### Set start time of t=0 and duration 15 with flagging for wildcard datamatches +### Set start time of t=0 and duration 15 for wildcard datamatches ``` python utils/visualization/edf_viewer/edf_viewer.py --wildcard "../../../scalp_deep-learning/user_data/BIDS/BIDS/sub-0008/ses-preimplant01/eeg/sub-0008_ses-preimplant01_task-task_run-*_eeg.edf" --username bjprager --flagging --t0 0 --dur 15 ``` diff --git a/scripts/codehub/utils/visualization/edf_viewer/components/internal/annotation_handler.py b/scripts/codehub/utils/visualization/edf_viewer/components/internal/annotation_handler.py new file mode 100644 index 00000000..050aa69c --- /dev/null +++ b/scripts/codehub/utils/visualization/edf_viewer/components/internal/annotation_handler.py @@ -0,0 +1,97 @@ +import os +import yaml +import tkinter as tk +from tkinter import messagebox + +class annotation_handler: + + def __init__(self,root,width=450,height=300): + + # Create the tkinter object + self.root = root + self.root.title("Annotation Widget") + self.root.geometry(f"{width}x{height}") + + # Read in the selection variables + script_dir = '/'.join(os.path.abspath(__file__).split('/')[:-3]) + self.annot_config = yaml.safe_load(open(f"{script_dir}/configs/annotation_config.yaml","r")) + + # Use the annotation config to make the inputs + self.selection_vars = {ikey:tk.StringVar() for ikey in self.annot_config.keys()} + + # Make an object to break us out of waiting for user input + self.submitted = False + + def workflow(self): + + # Create the row widgets from annotation config + for current_key in self.annot_config.keys(): + self.create_widgets(current_key) + + # Make the user input widget + self.create_user_input() + + # Add the submit and reset buttons + self.create_submit_reset() + + def reset_action(self): + for var in self.selection_vars.values(): + var.set('') + self.user_input_entry.delete(0, tk.END) + + def submit_action(self): + self.selections = {label: var.get() for label, var in self.selection_vars.items()} + self.selections['user'] = self.user_input_entry.get() + self.root.destroy() + self.submitted = True + + def create_widgets(self, current_key): + + row_frame = tk.Frame(self.root) + row_frame.pack(anchor=tk.W, pady=5) + + label = tk.Label(row_frame, text=current_key, width=15) + label.pack(side=tk.LEFT) + + # Create buttons + if self.annot_config[current_key]['type'] == 'radio': + for ivalue in self.annot_config[current_key]['options']: + tk.Radiobutton(row_frame, text=ivalue, variable=self.selection_vars[current_key], value=ivalue,indicatoron=False).pack(side=tk.LEFT, padx=5) + + def create_user_input(self): + + # Row for user provided text input (on the same line) + user_frame = tk.Frame(self.root) + user_frame.pack(anchor=tk.W, pady=5) + + label_user = tk.Label(user_frame, text="User Provided", width=15) + label_user.pack(side=tk.LEFT) + + self.user_input_entry = tk.Entry(user_frame, width=30) + self.user_input_entry.pack(side=tk.LEFT, padx=5) + + def create_submit_reset(self): + + # Row for Submit and Reset buttons + buttons_frame = tk.Frame(self.root) + buttons_frame.pack(pady=20) + + submit_button = tk.Button(buttons_frame, text="Submit", command=self.submit_action) + submit_button.pack(side=tk.LEFT, padx=10) + + reset_button = tk.Button(buttons_frame, text="Reset", command=self.reset_action) + reset_button.pack(side=tk.LEFT, padx=10) + +def annot_main(return_dict): + + #root = tk.Toplevel() + #input_window = annotation_handler(root) + #input_window.workflow() + #root.mainloop() + + #return_dict = input_window.selections + pass + +# Main function just for testing +if __name__ == '__main__': + results = annot_main() \ No newline at end of file diff --git a/scripts/codehub/utils/visualization/edf_viewer/components/internal/data_handler.py b/scripts/codehub/utils/visualization/edf_viewer/components/internal/data_handler.py new file mode 100644 index 00000000..0ceab6e6 --- /dev/null +++ b/scripts/codehub/utils/visualization/edf_viewer/components/internal/data_handler.py @@ -0,0 +1,86 @@ +# Local imports +from components.curation.public.data_loader import * +from components.workflows.public.channel_clean import * +from components.workflows.public.channel_mapping import * +from components.workflows.public.channel_montage import * + +class data_handler: + + def __init__(self,args,infile): + self.args = args + self.infile = infile + + def workflow(self): + self.load_data() + self.clean_channels() + self.map_channels() + self.montage_channels() + self.get_time_info() + return self.DF, self.fs, self.t_max, self.duration, self.t0 + + def load_data(self): + + # Initialize class + DL = data_loader() + + # Get the raw data and pointers + if not self.args.pickle_load: + self.DF,self.fs = DL.direct_inputs(self.infile,'edf') + else: + raw_input = pickle.load(open(self.infile,"rb")) + if type(raw_input) == np.ndarray or type(raw_input) == PD.core.frame.DataFrame: + self.DF = raw_input + if self.args.fs == None: + raise Exception("Must provide sampling frequency if passing a pickle file with only an array or dataframe.") + else: + self.fs = self.args.fs + else: + self.DF = raw_input[0] + self.fs = raw_input[1] + + def clean_channels(self): + + # Initialize class + CHCLN = channel_clean() + + # Get the cleaned channel names + clean_channels = CHCLN.direct_inputs(self.DF.columns,clean_method=self.args.chcln) + channel_dict = dict(zip(self.DF.columns,clean_channels)) + self.DF.rename(columns=channel_dict,inplace=True) + + def map_channels(self): + + # Initialize class + CHMAP = channel_mapping() + + # Get the channel mapping + if self.args.chmap != None: + channel_map = CHMAP.direct_inputs(self.DF.columns,self.args.chmap) + self.DF = self.DF[channel_map] + + def montage_channels(self): + + # Initialize class + CHMON = channel_montage() + + # Get the montage + if self.args.montage != None: + self.DF = CHMON.direct_inputs(self.DF,self.args.montage) + + def get_time_info(self): + + # Get the duration + self.t_max = self.DF.shape[0]/self.fs + if self.args.dur_frac: + self.duration = self.args.dur*self.t_max + else: + self.duration = self.args.dur + + # Get the start time + if self.args.t0_frac and self.args.t0 != None: + self.t0 = self.args.t0*self.t_max + else: + if self.args.t0 != None: + self.t0 = self.args.t0 + else: + self.t0 = np.random.rand()*(self.t_max-self.args.dur) \ No newline at end of file diff --git a/scripts/codehub/utils/visualization/edf_viewer/components/internal/event_handler.py b/scripts/codehub/utils/visualization/edf_viewer/components/internal/event_handler.py new file mode 100644 index 00000000..b28c406a --- /dev/null +++ b/scripts/codehub/utils/visualization/edf_viewer/components/internal/event_handler.py @@ -0,0 +1,222 @@ +import numpy as np +import pylab as PLT +from sys import exit +from prompt_toolkit import prompt +from prompt_toolkit.completion import WordCompleter + +# Local Imports +from components.internal.observer_handler import * +from components.internal.annotation_handler import * + +class event_observer(Observer): + + def listen_event(self,event): + + # Event logic + button_flag = hasattr(event,'button') + if button_flag: + event_handler.button_response(self,event) + else: + # Quit the plot + if event.key == 'Q': + event_handler.quit(self) + # Enlarge subplot options + elif event.key == 'e': + event_handler.enlarge(self,event) + # Annotation options + elif event.key == 'a': + event_handler.annotate(self,event) + elif event.key == 'backspace': + event_handler.delete_annotation(self,event) + # Shift back in time + elif event.key == 'left': + event_handler.shift_time(self,-1*self.duration) + # Shift forward in time + elif event.key == 'right': + event_handler.shift_time(self,self.duration) + # Small shift back in time + elif event.key == '<': + event_handler.shift_time(self,-0.5*self.duration) + # Small shift forward in time + elif event.key == '>': + event_handler.shift_time(self,0.5*self.duration) + # Increase gain + elif event.key == 'up': + event_handler.change_gain(self,0.1) + # Decrease gain + elif event.key == 'down': + event_handler.change_gain(self,-0.1) + # Reset x range + elif event.key == 'r': + event_handler.change_time(self,self.plot_info['xlim_orig']) + # Show entire x range + elif event.key == 'x': + event_handler.change_time(self,[0,self.t_max]) + # Reset the gain + elif event.key == '0': + event_handler.reset_gain(self) + # Zoom in on the user requested portion + elif event.key == 'z': + event_handler.zoom_lines(self) + + ######################################### + #### Send results to other observers #### + ######################################### + elif event.key == 'q': + event_handler.quit_action(self) + +class event_handler: + + def __init__(self): + pass + + def button_response(self,event,line_container=None,line_color='r'): + + # Add zoom lines + if event.button == 1: + + # Get the zoom line index + zcnt = self.plot_info['zoom_cntr'] + zcnt_mod = zcnt%2 + + # Loop around the zoom lines as needed to we only have two at a time + if zcnt >= 2: + for iobj in self.plot_info['zoom_lines'][zcnt_mod]: iobj.remove() + + # Draw the zoom lines + self.plot_info['zoom_lines'][zcnt_mod] = self.draw_zoom(event.xdata) + + # Save the positions + self.plot_info['zlim'][zcnt_mod] = event.xdata + + # iterate the zoom count + self.plot_info['zoom_cntr'] += 1 + + # Draw the lines + PLT.draw() + + + ##################### + #### Key options #### + ##################### + + def quit(self): + PLT.close("all") + exit() + + def enlarge(self,event): + + for ikey in self.plot_info['axes'].keys(): + if event.inaxes == self.plot_info['axes'][ikey]: + self.enlarged_plot(ikey) + + def annotate(self,event): + + # Read in the selection variables + script_dir = '/'.join(os.path.abspath(__file__).split('/')[:-3]) + self.annot_config = yaml.safe_load(open(f"{script_dir}/configs/annotation_config.yaml","r")) + + # Do some cleanup on the input options + for ikey in self.annot_config.keys():self.annot_config[ikey] = dict(self.annot_config[ikey]) + for ikey in list(self.annot_config.keys()): + self.annot_config[ikey.lower()] = self.annot_config.pop(ikey) + + # Print the options for the user. + if len(self.plot_info['annots'].keys()) == 0: + print("Available options:") + for ikey in self.annot_config.keys():print(f" - {ikey}") + print("Inputs not in this list will be taken as your annotation.") + + # Query the user for information + initial_options = list(self.annot_config.keys()) + + # Setup auto-completion for the prompt + options_completer = WordCompleter(initial_options) + + # Ask user to choose one of the options + selected_option = prompt('Please choose an option (Q/q to quit): ', completer=options_completer) + + if selected_option.lower() != 'q': + # Check for annotations in the config to give specific answers + if selected_option.lower() in self.annot_config.keys(): + + # get the current option list and display for user + new_options = list(self.annot_config[selected_option.lower()]['options']) + print("Available options:") + for ival in new_options: + print(f" - {ival}") + + options_completer = WordCompleter(new_options) + selected_value = prompt('Please choose an option: ', completer=options_completer) + annotation = f"{selected_option}_{selected_value}" + else: + annotation = selected_option + + # Get the channel the user clicked on + for ikey in self.plot_info['axes'].keys(): + if event.inaxes == self.plot_info['axes'][ikey]: + ichannel = ikey + + # Draw the annotation line at the event location + self.draw_annotations(event.xdata,annotation,ichannel) + + # Redraw the plot to update the display + PLT.draw() + + def delete_annotation(self,event): + + # Get the position of the click + xpos = event.xdata + for ikey in self.plot_info['axes'].keys(): + if event.inaxes == self.plot_info['axes'][ikey]: + ichannel = ikey + + # Get the current x range so we can try to approximate the annotation + xlim = self.plot_info['axes'][self.first_chan].get_xlim() + xrange = .025*(xlim[1]-xlim[0]) + + # Look across the annotation labels to try and remove the selected entry + lo = xpos-xrange + hi = xpos+xrange + for ikey in self.plot_info['annots']: + if (ikey >= lo) & (ikey <=hi): + annot_obj = self.plot_info['annots'][ikey] + if ichannel == annot_obj[0]: + annot_obj[2].remove() + for iline in annot_obj[3]:iline.remove() + PLT.draw() + self.plot_info['annots'].pop(ikey) + + def shift_time(self,shiftvalue): + current_xlim = self.plot_info['axes'][self.first_chan].get_xlim() + current_xlim = [ival+shiftvalue for ival in current_xlim] + self.change_time(current_xlim) + + def change_time(self,new_range): + for ikey in self.plot_info['axes'].keys(): + self.plot_info['axes'][ikey].set_xlim(new_range) + PLT.draw() + + def zoom_lines(self): + if self.plot_info['zoom_cntr'] >= 2: + new_range = np.sort(self.plot_info['zlim']) + for iobj in self.plot_info['zoom_lines'][0]: iobj.remove() + for iobj in self.plot_info['zoom_lines'][1]: iobj.remove() + self.plot_info['zoom_cntr'] = 0 + + for ikey in self.plot_info['axes'].keys(): + self.plot_info['axes'][ikey].set_xlim(new_range) + PLT.draw() + + def change_gain(self,frac_change): + for ikey in self.plot_info['axes'].keys(): + self.yscaling(ikey,frac_change) + PLT.draw() + + def reset_gain(self): + for ikey in self.plot_info['axes'].keys(): + self.plot_info['axes'][ikey].set_ylim(self.plot_info['ylim'][ikey]) + PLT.draw() + + def quit_action(self): + print("Foo") \ No newline at end of file diff --git a/scripts/codehub/utils/visualization/edf_viewer/components/internal/observer_handler.py b/scripts/codehub/utils/visualization/edf_viewer/components/internal/observer_handler.py new file mode 100644 index 00000000..ff995de1 --- /dev/null +++ b/scripts/codehub/utils/visualization/edf_viewer/components/internal/observer_handler.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod + +class Subject: + """ + Subject class to allow the BIDS handler to listen for new keywords. + """ + def add_event_observer(self, observer): + if observer not in self._event_observers: + self._event_observers.append(observer) + + def notify_event_observers(self,event): + for observer in self._event_observers: + observer.listen_event(self,event) + +class Observer(ABC): + """ + Observer class to allow the BIDS handler to listen for new keywords. + + Args: + ABC (object): Abstract Base Class object. Enforces the use of abstractmethod to prevent accidental access to listen_keyword without matching + class in the observer. + + Raises: + NotImplementedError: Error if the observing class doesn't have the right class object. + """ + + # Listener for BIDS keyword generation to create the correct pathing. + @abstractmethod + def listen_event(self): + raise NotImplementedError("Subclass must implement abstract method") \ No newline at end of file diff --git a/scripts/codehub/utils/visualization/edf_viewer/components/internal/plot_handler.py b/scripts/codehub/utils/visualization/edf_viewer/components/internal/plot_handler.py new file mode 100644 index 00000000..b0dfa038 --- /dev/null +++ b/scripts/codehub/utils/visualization/edf_viewer/components/internal/plot_handler.py @@ -0,0 +1,337 @@ +# User interface imports +import getpass +import pyautogui +import numpy as np + +# Matplotlib import and settings +import matplotlib.pyplot as PLT +from matplotlib.ticker import MultipleLocator + +# Local Imports +from components.internal.data_handler import * +from components.internal.event_handler import * +from components.internal.observer_handler import * + +class data_viewer(Subject,event_handler): + + def __init__(self, infile, args, tight_layout_dict): + """ + Initialize the data viewer class. + + Args: + infile (string): Filepath to an edf or pickle file. + args (object): UI arguments + tight_layout_dict (object): If `None` the code will calculate the best tight layout. + Otherwise, a dictionary with layout arguments will be used for plotting. + Saves time on loading multiple datasets. + """ + + # Save the input info + self.infile = infile + self.fname = infile.split('/')[-1] + self.args = args + self.tight_layout_dict = tight_layout_dict + + # Get the approx screen dimensions and set some plot variables + self.height = args.winfrac*pyautogui.size().height/100 + self.width = args.winfrac*pyautogui.size().width/100 + self.supsize = self.fontsize_scaler(16,14,self.width) + self.supsize = np.min([self.supsize,16]) + + # Prepare the data + DH = data_handler(args,infile) + rawobj = DH.workflow() + self.DF = rawobj[0] + self.fs = rawobj[1] + self.t_max = rawobj[2] + self.duration = rawobj[3] + self.t0 = rawobj[4] + + def workflow(self): + """ + Workflow for plotting data and managing flow of information. + """ + + # Data Enrichment options + if self.args.epilepsy_prob_file!=None: + self.enrich_dataframe_epilepsy() + + # Attach the observers + self.attach_objects() + + # Make the initial plot info + self.create_plot_info() + + # Draw the plot for the first time + self.draw_base_plots() + + # Save the annotations, if any + if self.plot_info['annots'].keys(): + self.save_annotations() + + def enrich_dataframe_epilepsy(self): + + # Make a running time column for the plot data + tvals = np.arange(self.DF.shape[0])/self.fs + + # Read in the probability data + probability_df = PD.read_pickle(self.args.epilepsy_prob_file) + + # get the right dataslice + basefile = self.args.infile.split('/')[-1] + probability_df = probability_df.loc[probability_df.file==basefile] + + # get the time ranges + prob_times = probability_df['t_start'].values + prob_times = np.concatenate((prob_times,[tvals[-1]])) + time_ranges = np.hstack((prob_times[:-1].reshape(-1,1),prob_times[1:].reshape(-1,1))) + + # Populate the probabilities into a vector + probs = probability_df['Epilepsy_Prob'].values + outvals = [] + for ival in tvals: + mask = (ival>=time_ranges[:,0])&(ival<=time_ranges[:,1]) + iprob = probs[mask][0] + outvals.append(iprob) + + # Add the epilepsy info to the dataframe + self.DF['P_epi'] = outvals + + def attach_objects(self): + """ + Attach observers here so we can have each multiprocessor see the pointers correctly. + + event_observer: Manages what happens when a button or key is pressed. This also sends info to other scripts as needed. + """ + + # Create the observer objects + self._event_observers = [] + + # Attach observers + self.add_event_observer(event_observer) + + def create_plot_info(self): + + # Store some valuable information about the plot to reference for events and modifications + self.plot_info = {} + self.plot_info['axes'] = {} + self.plot_info['ylim'] = {} + self.plot_info['shade'] = {} + self.plot_info['xlim_orig'] = [self.t0,self.t0+self.duration] + self.plot_info['xvals'] = np.arange(self.DF.shape[0])/self.fs + self.plot_info['zoom_cntr'] = 0 + self.plot_info['zoom_lines'] = [0,0] + self.plot_info['zlim'] = [0,0] + self.plot_info['annots'] = {} + + def save_annotations(self): + """ + Save any user annotations to an output CSV. + + Looks for the self.plot_info['annots'] object and iterates over the keys. + """ + + # Loop over the annotations and make the output array object + output = [] + for ikey in self.plot_info['annots'].keys(): + ival = self.plot_info['annots'][ikey] + output.append([getpass.getuser(),ival[0],ikey,ival[1]]) + + # Make the output dataframe + outDF = PD.DataFrame(output,columns=['user','channel','time','annotation']) + + # Append as needed to existing records + if os.path.exists(self.args.outfile): + annot_DF = PD.read_csv(self.args.outfile) + outDF = PD.concat((annot_DF,outDF)) + + # Write out the results + outDF.to_csv(self.args.outfile,index=False) + + ############################ + #### Plotting functions #### + ############################ + + def draw_base_plots(self): + + # Set the label shift. 72 points equals ~1 inch in pyplot + width_frac = (0.025*self.width) + npnt = int(72*width_frac) + + # Create the plotting environment + nrows = len(self.DF.columns) + self.fig = PLT.figure(dpi=100,figsize=(self.width,self.height)) + gs = self.fig.add_gridspec(nrows, 1, hspace=0) + for idx,ichan in enumerate(self.DF.columns): + # Define the axes + if idx == 0: + self.plot_info['axes'][ichan] = self.fig.add_subplot(gs[idx, 0]) + self.first_chan = ichan + else: + self.plot_info['axes'][ichan] = self.fig.add_subplot(gs[idx, 0],sharex=self.plot_info['axes'][self.first_chan]) + + # Get the data stats + if ichan not in ['P_epi']: + idata,ymin,ymax = self.get_stats(ichan) + else: + idata = self.DF[ichan].values + + # Plot the data + self.plot_info['axes'][ichan].plot(self.plot_info['xvals'][::self.args.nstride],idata[::self.args.nstride],color='k') + if ichan not in ['P_epi']: + self.plot_info['axes'][ichan].set_ylim([ymin,ymax]) + else: + + # Figure out the best ymin and ymax for this larger dynamic range + ymin, ymax = self.plot_info['axes'][ichan].get_ylim() + + # Some conditionals on allowed ranges + delta = 0.00001 + new_min = 1-delta + new_max = 1+delta + if yminnew_max:new_max = ymax + self.plot_info['axes'][ichan].set_ylim([new_min,new_max]) + self.plot_info['ylim'][ichan] = [ymin,ymax] + + # Add in shading for the original axes limits + if self.args.shade: + self.plot_info['shade'][ichan] = self.plot_info['axes'][ichan].axvspan(self.plot_info['xlim_orig'][0], self.plot_info['xlim_orig'][1], facecolor='orange',alpha=0.2) + + # Clean up the plot + for label in self.plot_info['axes'][ichan].get_xticklabels(): + label.set_alpha(0) + self.plot_info['axes'][ichan].set_yticklabels([]) + self.plot_info['axes'][ichan].set_ylabel(ichan,fontsize=12,rotation=0,labelpad=npnt) + self.plot_info['axes'][ichan].xaxis.grid(True) + + # X-axis cleanup + self.last_chan = ichan + self.plot_info['axes'][ichan].set_xlim(self.plot_info['xlim_orig']) + + # Add an xlabel to the final object + self.plot_info['axes'][self.last_chan].xaxis.set_major_locator(MultipleLocator(1)) + self.plot_info['axes'][self.last_chan].set_xlabel("Time (s)",fontsize=14) + for label in self.plot_info['axes'][self.last_chan].get_xticklabels(): + label.set_alpha(1) + + # Set the axes title object + self.generate_title_str() + self.plot_info['axes'][self.first_chan].set_title(self.title_str,fontsize=10) + + # Set the figure title object + self.generate_suptitle_str() + PLT.suptitle(self.suptitle,fontsize=self.supsize) + + # Layout handling using previous plot layout or find it for the first time + if self.tight_layout_dict == None: + self.fig.tight_layout() + else: + self.fig.subplots_adjust(**self.tight_layout_dict) + + # Event associations + self.fig.canvas.mpl_connect('button_press_event', self.notify_event_observers) + self.fig.canvas.mpl_connect('key_press_event', self.notify_event_observers) + + # Show the results + PLT.show() + + # Store and return tight layout params for faster subsequent plots + if self.tight_layout_dict == None: + self.tight_layout_dict = {par : getattr(self.fig.subplotpars, par) for par in ["left", "right", "bottom", "top", "wspace", "hspace"]} + return self.tight_layout_dict + + def draw_annotations(self,xpos,annotation,ichannel): + + # Add the annotation + ymin,ymax = self.plot_info['axes'][ichannel].get_ylim() + ypos = 0.5*(ymin+ymax) + pltobj = self.plot_info['axes'][ichannel].annotate(text=annotation, xy =(xpos,ypos),bbox=dict(boxstyle="round", facecolor="white", alpha=0.8)) + self.plot_info['annots'][xpos] = (ichannel,annotation,pltobj,[]) + + # Draw the line for the user to see + for ikey in self.plot_info['axes'].keys(): + self.plot_info['annots'][xpos][3].append(self.plot_info['axes'][ikey].axvline(xpos, color='g', linestyle='--')) + + def draw_zoom(self,xpos): + + zoom_lines = [] + for ikey in self.plot_info['axes'].keys(): + zoom_lines.append(self.plot_info['axes'][ikey].axvline(xpos, color='r', linestyle='--')) + return zoom_lines + + def enlarged_plot(self,channel): + + # Get the data view + idata,ymin,ymax = self.get_stats(channel) + xvals = np.arange(idata.size)/self.fs + + # Get the current limits of the main viewer + xlims = self.ax_dict[self.first_chan].get_xlim() + ylims = [ymin,ymax] + + # Plot the enlarged view + fig = PLT.figure(dpi=100,figsize=(self.width,self.height)) + self.ax_enl = fig.add_subplot(111) + self.ax_enl.plot(xvals,idata,color='k') + self.ax_enl.set_xlabel("Time (s)",fontsize=14) + self.ax_enl.set_ylabel(channel,fontsize=14) + self.ax_enl.set_xlim(xlims) + self.ax_enl.set_ylim(ylims) + PLT.title(self.fname,fontsize=14) + fig.tight_layout() + PLT.show() + + ########################## + #### Helper functions #### + ########################## + + def fontsize_scaler(self,font_ref,width_ref,width_val): + return font_ref+2*np.floor((width_val-width_ref)) + + def get_stats(self,ichan): + + idata = self.DF[ichan].values + median = np.median(idata) + stdev = np.std(idata) + idata -= median + ymin = -5*stdev + ymax = 5*stdev + return idata,ymin,ymax + + def yscaling(self,ikey,dy): + + # Get the limits of the current plot for rescaling and recreating + xlim = self.plot_info['axes'][ikey].get_xlim() + ylim = self.plot_info['axes'][ikey].get_ylim() + + # Get the approximate new scale + scale = ylim[1]-ylim[0] + ymin = ylim[0]+dy*scale + ymax = ylim[1]-dy*scale + + # Get the data and limits with a good vertical offset + vals = self.DF[ikey].values + inds = (vals>=ymin)&(vals<=ymax) + vals = vals[inds] + offset = np.median(vals) + vals -= offset + ymin -= offset + ymax -= offset + + # Generate new limits + self.plot_info['axes'][ikey].set_ylim([ymin,ymax]) + + def generate_title_str(self): + upa = u'\u2191' # Up arrow + downa = u'\u2193' # Down arrow + lefta = u'\u2190' # Left arrow + righta = u'\u2192' # Right arrow + self.title_str = r"z=Zoom between mouse clicks; 'r'=reset x-scale; 'x'=Show entire x-axis; '0'=reset y-scale; 't'=Toggle targets; 'q'=quit current plot; 'Q'=quit the program entirely" + self.title_str += '\n' + self.title_str += r"'%s'=Increase Gain; '%s'=Decrease Gain; '%s'=Shift Left; '%s'=Shift Right; '<'=Minor Shift Left; '>'=Minor Shift Right; 'e'=Zoom-in plot of axis the mouse is on;" %(upa, downa, lefta, righta) + + def generate_suptitle_str(self): + + # Base string + self.suptitle = self.fname \ No newline at end of file diff --git a/scripts/codehub/utils/visualization/edf_viewer/configs/annotation_config.yaml b/scripts/codehub/utils/visualization/edf_viewer/configs/annotation_config.yaml new file mode 100644 index 00000000..ce13febf --- /dev/null +++ b/scripts/codehub/utils/visualization/edf_viewer/configs/annotation_config.yaml @@ -0,0 +1,25 @@ +sleep_state: + options: + - 'awake' + - 'sleep' + - unknown +spikes: + options: + - 'yes' + - 'no' + - unknown +seizures: + options: + - 'yes' + - 'no' + - unknown +slowing: + options: + - 'yes' + - 'no' + - unknown +artifact: + options: + - 'yes' + - 'no' + - unknown diff --git a/scripts/codehub/utils/visualization/edf_viewer/edf_viewer.py b/scripts/codehub/utils/visualization/edf_viewer/edf_viewer.py index a46b5919..573fc61e 100644 --- a/scripts/codehub/utils/visualization/edf_viewer/edf_viewer.py +++ b/scripts/codehub/utils/visualization/edf_viewer/edf_viewer.py @@ -3,604 +3,13 @@ rnd.seed(42) # Basic Python Imports -import re -import sys import glob -import pickle import argparse -from os import path +import pandas as PD +import pylab as PLT -# User interface imports -import tkinter as tk - -# Matplotlib import and settings -import matplotlib.pyplot as PLT -from matplotlib.ticker import MultipleLocator - -# Local imports -from components.curation.public.data_loader import * -from components.workflows.public.channel_clean import * -from components.workflows.public.channel_mapping import * -from components.workflows.public.channel_montage import * - -################# -#### Classes #### -################# - -class data_handler: - - def __init__(self): - pass - - def data_prep(self, chcln, chmap, montage): - - # Create pointers to the relevant classes - DL = data_loader() - CHCLN = channel_clean() - CHMAP = channel_mapping() - CHMON = channel_montage() - - # Get the raw data and pointers - if not self.args.pickle_load: - DF,self.fs = DL.direct_inputs(self.infile,filetype,ssh_host=self.args.ssh_host,ssh_username=self.args.ssh_username) - else: - DF,self.fs = pickle.load(open(self.infile,"rb")) - self.fs = self.fs[0] - - # Get the cleaned channel names - clean_channels = CHCLN.direct_inputs(DF.columns,clean_method=chcln) - channel_dict = dict(zip(DF.columns,clean_channels)) - DF.rename(columns=channel_dict,inplace=True) - - # Get the channel mapping - if chmap != None: - channel_map = CHMAP.direct_inputs(DF.columns,chmap) - DF = DF[channel_map] - - # Get the montage - if montage != None: - self.DF = CHMON.direct_inputs(DF,montage) - else: - self.DF = DF - - # Read in optional sleep wake power data if provided - if self.args.sleep_wake_power != None: - self.read_sleep_wake_data() - else: - self.t_flag = False - - def read_sleep_wake_data(self): - - # Read in the pickled data associations - self.assoc_dict = pickle.load(open(self.args.sleep_wake_power,"rb")) - - # Get the relevant keys - self.assoc_keys = (self.assoc_dict.keys()) - - # Look to see if the current file is in the associations - self.color_dict = {} - for ikey in self.assoc_keys: - fvals = self.assoc_dict[ikey]['file'].values - ufiles = np.unique(fvals) - if self.fname in ufiles: - inds = (fvals==self.fname) - self.color_dict[ikey] = self.assoc_dict[ikey].iloc[inds] - - # Create a few variables to iterate through the dictionary as needed - self.color_cnt = 0 - self.color_keys = list(self.assoc_keys) - self.ncolor = len(self.color_keys) - self.t_flag = True - self.t_colors = ['r','b','g','m'] - -class data_viewer(data_handler): - - def __init__(self, infile, args, tight_layout_dict, filetype): - - # Save the input info - self.infile = infile - self.fname = infile.split('/')[-1] - self.args = args - self.tight_layout_dict = tight_layout_dict - self.filetype = filetype - - # Some tracking variables - self.flagged_out = ['','','','','',''] - self.sleep_counter = 0 - self.spike_counter = 0 - self.seizure_counter = 0 - self.focal_slow_counter = 0 - self.general_slow_counter = 0 - self.artifact_counter = 0 - self.sleep_labels = ['','awake','sleep','unknown_sleep_state'] - self.spike_labels = ['','spikes','spike_free','unknown_spike_state'] - self.seizure_labels = ['','seizures','seizure_free','unknown_seizure_state'] - self.focal_slow_labels = ['','focal_slowing','no_focal_slowing','unknown_focal_slowing'] - self.general_slow_labels = ['','general_slowing','no_general_slowing','unknown_general_slowing'] - self.artifact_labels = ['','artifact_heavy'] - - # Get the approx screen dimensions and set some plot variables - root = tk.Tk() - self.height = 0.9*root.winfo_screenheight()/100 - self.width = 0.9*root.winfo_screenwidth()/100 - root.destroy() - self.supsize = self.fontsize_scaler(16,14,self.width) - self.supsize = np.min([self.supsize,16]) - - # Save event driven variables - self.xlim = [] - self.drawn_y = [] - self.drawn_a = [] - - # Prepare the data - data_handler.data_prep(self, args.chcln, args.chmap, args.montage) - - # Get the duration - self.t_max = self.DF.shape[0]/self.fs - if self.args.dur_frac: - self.duration = self.args.dur*self.t_max - else: - self.duration = self.args.dur - - # Get the start time - if self.args.t0_frac and self.args.t0 != None: - self.t0 = self.args.t0*self.t_max - else: - if self.args.t0 != None: - self.t0 = self.args.t0 - else: - self.t0 = np.random.rand()*(self.t_max-self.args.dur) - - # Attempt to get any annotations - pattern = r'(.*?)_(\D+).edf' - match = re.search(pattern, infile) - if match: - base_filename = match.group(1) - events_filename = f"{base_filename}_events.tsv" - if path.exists(events_filename): - self.events_df = PD.read_csv(events_filename,delimiter="\t") - else: - self.events_df = PD.DataFrame() - else: - self.events_df = PD.DataFrame() - - def plot_sleep_wake(self): - - x_list = {} - y_list = {} - c_list = {} - for ikey in self.color_keys: - x_list[ikey] = {} - y_list[ikey] = {} - c_list[ikey] = {} - iDF = self.color_dict[ikey] - - # Loop over the channels and plot results - for ichan in self.DF.columns: - - x_list[ikey][ichan] = [] - y_list[ikey][ichan] = [] - c_list[ikey][ichan] = [] - - # Get the data stats - idata,ymin,ymax = self.get_stats(ichan) - xvals = np.arange(idata.size)/self.fs - - # Get the list of target info to iterate over - values = iDF[ichan].values - uvalues = np.unique(values) - uvalues = uvalues[(uvalues!=-1)] - for ii,ivalue in enumerate(uvalues): - - # Get the different boundaries - inds = (values==ivalue) - t0_vals = iDF['t_start'].values[inds].astype('float') - t1_vals = iDF['t_end'].values[inds].astype('float') - - # Loop over the times and then plot the scatter points - for itr in range(t0_vals.size): - inds_t = (xvals>=t0_vals[itr])&(xvals<=t1_vals[itr]) - x_list[ikey][ichan].append(xvals[inds_t]) - y_list[ikey][ichan].append(idata[inds_t]) - c_list[ikey][ichan].append(self.t_colors[ii]) - return x_list,y_list,c_list - - def montage_plot(self): - - # Get the number of channels to plot - nchan = self.DF.columns.size - - # Set the label shift. 72 points equals ~1 inch in pyplot - width_frac = (0.025*self.width) - npnt = int(72*width_frac) - - # Create the plotting environment - self.fig = PLT.figure(dpi=100,figsize=(self.width,self.height)) - gs = self.fig.add_gridspec(nchan, 1, hspace=0) - self.ax_dict = {} - self.lim_dict = {} - self.shade_dict = {} - self.xlim_orig = [self.t0,self.t0+self.duration] - self.xvals = np.arange(self.DF.shape[0])/self.fs - for idx,ichan in enumerate(self.DF.columns): - # Define the axes - if idx == 0: - self.ax_dict[ichan] = self.fig.add_subplot(gs[idx, 0]) - self.refkey = ichan - else: - self.ax_dict[ichan] = self.fig.add_subplot(gs[idx, 0],sharex=self.ax_dict[self.refkey]) - - # Get the data stats - idata,ymin,ymax = self.get_stats(ichan) - - # Plot the data - self.ax_dict[ichan].plot(self.xvals[::self.args.nstride],idata[::self.args.nstride],color='k') - self.ax_dict[ichan].set_ylim([ymin,ymax]) - self.lim_dict[ichan] = [ymin,ymax] - - # Add in shading for the original axes limits - self.shade_dict[ichan] = self.ax_dict[ichan].axvspan(self.xlim_orig[0], self.xlim_orig[1], facecolor='orange',alpha=0.2) - - # Clean up the plot - for label in self.ax_dict[ichan].get_xticklabels(): - label.set_alpha(0) - self.ax_dict[ichan].set_yticklabels([]) - self.ax_dict[ichan].set_ylabel(ichan,fontsize=12,rotation=0,labelpad=npnt) - self.ax_dict[ichan].xaxis.grid(True) - - # X-axis cleanup - self.refkey2 = ichan - self.ax_dict[ichan].set_xlim(self.xlim_orig) - - # Plot and hide target data as needed - self.t_obj = {} - if self.t_flag: - if self.args.sleep_wake_power != None: - x_list,y_list,c_list = self.plot_sleep_wake() - - for ii,ikey in enumerate(list(x_list.keys())): - self.t_obj[ikey] = [] - for ichan in list(x_list[ikey].keys()): - ix = x_list[ikey][ichan] - iy = y_list[ikey][ichan] - ic = c_list[ikey][ichan] - for jj in range(len(ix)): - self.t_obj[ikey].append(self.ax_dict[ichan].scatter(ix[jj],iy[jj],s=2,c=ic[jj],visible=False)) - - # Add an xlabel to the final object - self.ax_dict[self.refkey2].xaxis.set_major_locator(MultipleLocator(1)) - self.ax_dict[self.refkey2].set_xlabel("Time (s)",fontsize=14) - for label in self.ax_dict[self.refkey2].get_xticklabels(): - label.set_alpha(1) - - # Set the axes title object - self.generate_title_str() - self.ax_dict[self.refkey].set_title(self.title_str,fontsize=10) - - # Set the figure title object - self.generate_suptitle_str() - PLT.suptitle(self.suptitle,fontsize=self.supsize) - - # Layout handling using previous plot layout or find it for the first time - if self.tight_layout_dict == None: - self.fig.tight_layout() - else: - self.fig.subplots_adjust(**self.tight_layout_dict) - - # Even associations - self.fig.canvas.mpl_connect('button_press_event', self.on_click) - self.fig.canvas.mpl_connect('key_press_event', self.update_plot) - - # Show the results - PLT.show() - - # Update predictions if needed - if self.args.flagging: - self.save_flag_state() - - # Store and return tight layout params for faster subsequent plots - if self.tight_layout_dict == None: - self.tight_layout_dict = {par : getattr(self.fig.subplotpars, par) for par in ["left", "right", "bottom", "top", "wspace", "hspace"]} - return self.tight_layout_dict - - def enlarged_plot(self,channel): - - # Get the data view - idata,ymin,ymax = self.get_stats(channel) - xvals = np.arange(idata.size)/self.fs - - # Get the current limits of the main viewer - xlims = self.ax_dict[self.refkey].get_xlim() - ylims = self.ax_dict[self.refkey].get_ylim() - - # Plot the enlarged view - fig = PLT.figure(dpi=100,figsize=(self.width,self.height)) - self.ax_enl = fig.add_subplot(111) - self.ax_enl.plot(xvals,idata,color='k') - self.ax_enl.set_xlabel("Time (s)",fontsize=14) - self.ax_enl.set_ylabel(channel,fontsize=14) - self.ax_enl.set_xlim(xlims) - self.ax_enl.set_ylim(ylims) - PLT.title(self.fname,fontsize=14) - fig.tight_layout() - PLT.show() - - ########################## - #### Helper functions #### - ########################## - - def fontsize_scaler(self,font_ref,width_ref,width_val): - return font_ref+2*np.floor((width_val-width_ref)) - - def get_stats(self,ichan): - - idata = self.DF[ichan].values - median = np.median(idata) - stdev = np.std(idata) - idata -= median - ymin = -5*stdev - ymax = 5*stdev - return idata,ymin,ymax - - def yscaling(self,ikey,dy): - - # Get the limits of the current plot for rescaling and recreating - xlim = self.ax_dict[ikey].get_xlim() - ylim = self.ax_dict[ikey].get_ylim() - - # Get the approximate new scale - scale = ylim[1]-ylim[0] - ymin = ylim[0]+dy*scale - ymax = ylim[1]-dy*scale - - # Get the data and limits with a good vertical offset - vals = self.DF[ikey].values - inds = (vals>=ymin)&(vals<=ymax) - vals = vals[inds] - offset = np.median(vals) - vals -= offset - ymin -= offset - ymax -= offset - - # Generate new limits - self.ax_dict[ikey].set_ylim([ymin,ymax]) - - def generate_title_str(self): - upa = u'\u2191' # Up arrow - downa = u'\u2193' # Down arrow - lefta = u'\u2190' # Left arrow - righta = u'\u2192' # Right arrow - self.title_str = r"z=Zoom between mouse clicks; 'r'=reset x-scale; 'x'=Show entire x-axis; '0'=reset y-scale; 't'=Toggle targets; 'q'=quit current plot; 'Q'=quit the program entirely" - self.title_str += '\n' - self.title_str += r"'%s'=Increase Gain; '%s'=Decrease Gain; '%s'=Shift Left; '%s'=Shift Right; '<'=Minor Shift Left; '>'=Minor Shift Right; 'e'=Zoom-in plot of axis the mouse is on;" %(upa, downa, lefta, righta) - if self.args.flagging: - self.title_str += '\n' - self.title_str += r"1=Sleep State; 2=Spike Presence; 3=Seizure; 4=Focal Slowing; 5=Generalized Slowing; 6=Artifact Heavy" - - def generate_suptitle_str(self): - - # Base string - self.suptitle = self.fname - - # If using flagging, create new string - if self.args.flagging: - self.suptitle += '\n' - for ival in self.flagged_out: - self.suptitle += f" {ival} |" - self.suptitle = self.suptitle[:-1] - - def flag_toggle(self,label_name,counter_name,str_pos): - - # Get the labels - labels = getattr(self,label_name) - counter = getattr(self,counter_name) - - # Handle the counter logic - counter+=1 - if counter == 4 and counter_name != 'artifact_counter': - counter = 0 - elif counter == 2 and counter_name == 'artifact_counter': - counter = 0 - - # Update the substring - newval = labels[counter] - self.flagged_out[str_pos] = newval - - # Generate the new suptilte - self.generate_suptitle_str() - - # Set the new title - PLT.suptitle(f"{self.suptitle}",fontsize=self.supsize) - - # Set the new counter values - setattr(self,counter_name,counter) - - def save_flag_state(self): - - # Create output column list - xlims = self.ax_dict[self.refkey2].get_xlim() - outcols = ['filename','username','assigned_t0','assigned_t1','evaluated_t0','evaluated_t1','sleep_state','spike_state','seizure_state','focal_slowing','general_slowing','artifacts'] - outvals = [self.infile,self.args.username,self.xlim_orig[0],self.xlim_orig[1],xlims[0],xlims[1]] - outvals = outvals+self.flagged_out - - # Make the temporary dataframe to concat to outputs - iDF = PD.DataFrame([outvals],columns=outcols) - - # Check for file - if path.exists(self.args.outfile): - out_DF = PD.read_csv(self.args.outfile) - out_DF = PD.concat((out_DF,iDF),ignore_index=True) - else: - out_DF = iDF - - # Save the results - if not self.args.debug: - out_DF = out_DF.drop_duplicates() - out_DF.to_csv(self.args.outfile,index=False) - - ################################ - #### Event driven functions #### - ################################ - - def on_click(self,event): - """ - Click driven events for the plot object. - - Args: - Matplotlib event. - """ - - # Left click defines the zoom ranges - if event.button == 1: - # Loop over the axes and draw the zoom ranges - for ikey in self.ax_dict.keys(): - self.drawn_y.append(self.ax_dict[ikey].axvline(event.xdata, color='red', linestyle='--')) - - # Redraw the plot to update the display - PLT.draw() - - # Update the event driven zoom object - self.xlim.append(event.xdata) - - def update_plot(self,event): - """ - Key driven events for the plot object. - - Args: - Matplotlib event. - """ - - # Zoom on 'z' press and when there are two bounds - if event.key == 'z' and len(self.xlim) == 2: - - # Set the xlimits - self.ax_dict[self.refkey].set_xlim(self.xlim) - self.xlim = [] - - # Remove the vertical lines on the plot - for iobj in self.drawn_y: - iobj.remove() - self.draw_y = [] - # Reset the -axes of the plot - elif event.key == 'r': - self.ax_dict[self.refkey].set_xlim(self.xlim_orig) - self.xlim = [] - # Increase gain - elif event.key == 'up': - for ikey in self.ax_dict.keys(): - self.yscaling(ikey,0.1) - # Decrease gain - elif event.key == 'down': - for ikey in self.ax_dict.keys(): - self.yscaling(ikey,-0.1) - # Shift back in time - elif event.key == 'left': - current_xlim = self.ax_dict[self.refkey].get_xlim() - current_xlim = [ival-self.duration for ival in current_xlim] - for ikey in self.ax_dict.keys(): - self.ax_dict[ikey].set_xlim(current_xlim) - # Shift forward in time - elif event.key == 'right': - current_xlim = self.ax_dict[self.refkey].get_xlim() - current_xlim = [ival+self.duration for ival in current_xlim] - for ikey in self.ax_dict.keys(): - self.ax_dict[ikey].set_xlim(current_xlim) - # Shift back in time - elif event.key == '<': - current_xlim = self.ax_dict[self.refkey].get_xlim() - current_xlim = [ival-self.duration/2. for ival in current_xlim] - for ikey in self.ax_dict.keys(): - self.ax_dict[ikey].set_xlim(current_xlim) - # Shift forward in time - elif event.key == '>': - current_xlim = self.ax_dict[self.refkey].get_xlim() - current_xlim = [ival+self.duration/2. for ival in current_xlim] - for ikey in self.ax_dict.keys(): - self.ax_dict[ikey].set_xlim(current_xlim) - # Show the entire x-axis - elif event.key == 'x': - for ikey in self.ax_dict.keys(): - self.ax_dict[ikey].set_xlim([0,self.t_max]) - # Reset gain - elif event.key == '0': - for ikey in self.ax_dict.keys(): - self.ax_dict[ikey].set_ylim(self.lim_dict[ikey]) - # Enlarge a singular plot - elif event.key == 'e': - for ikey in self.ax_dict.keys(): - if event.inaxes == self.ax_dict[ikey]: - self.enlarged_plot(ikey) - # Show annotations - elif event.key == 'A': - if 'trial_type' in self.events_df.columns: - annot_xval = self.events_df['onset'][0] - annot_text = self.events_df['trial_type'][0] - - if len(self.drawn_a) == 0: - for ikey in self.ax_dict.keys(): - self.drawn_a.append(self.ax_dict[ikey].axvline(annot_xval, color='blue', linestyle='--',lw=2)) - self.annot_obj = self.ax_dict[self.refkey2].text(annot_xval, 0, annot_text,bbox=dict(boxstyle='round', facecolor='lightgray', - edgecolor='none', alpha=1.0),verticalalignment='center', horizontalalignment='left', fontsize=12) - - else: - # Remove the vertical lines on the plot - for iobj in self.drawn_a: - iobj.remove() - self.annot_obj.remove() - self.draw_a = [] - # Iterate over target dictionary if available to show mapped colors - elif event.key == 't' and hasattr(self,'color_dict'): - for icnt in range(len(self.t_colors)): - ikey = self.color_keys[icnt] - objects = self.t_obj[ikey] - for iobj in objects: - if icnt == self.color_cnt: - # Change visibility - iobj.set_visible(True) - - # Update title - PLT.suptitle(self.fname+" | "+str(ikey),fontsize=self.supsize) - else: - iobj.set_visible(False) - - if self.color_cnt < self.ncolor: - self.color_cnt += 1 - else: - self.color_cnt = 0 - PLT.suptitle(self.fname,fontsize=self.supsize) - # Sleep/awake event mapping - elif event.key == '1' and self.args.flagging: - self.flag_toggle('sleep_labels','sleep_counter',0) - # Spike State Mapping - elif event.key == '2' and self.args.flagging: - self.flag_toggle('spike_labels','spike_counter',1) - # Seizure State Mapping - elif event.key == '3' and self.args.flagging: - self.flag_toggle('seizure_labels','seizure_counter',2) - # Seizure State Mapping - elif event.key == '4' and self.args.flagging: - self.flag_toggle('focal_slow_labels','focal_slow_counter',3) - # Seizure State Mapping - elif event.key == '5' and self.args.flagging: - self.flag_toggle('general_slow_labels','general_slow_counter',4) - # Seizure State Mapping - elif event.key == '6' and self.args.flagging: - self.flag_toggle('artifact_labels','artifact_counter',5) - # Quit functionality - elif event.key == 'Q': - PLT.close("all") - sys.exit() - - # Make sure the axes colorscheme is updated - newlim = self.ax_dict[self.refkey2].get_xlim() - if (newlim[0] == self.xlim_orig[0]) and (newlim[1] == self.xlim_orig[1]): - ialpha = 0.2 - else: - ialpha = 0 - for ichan in self.DF.columns: - self.shade_dict[ichan].set_alpha(ialpha) - - PLT.draw() +# Local Imports +from components.internal.plot_handler import * class CustomFormatter(argparse.HelpFormatter): """ @@ -635,13 +44,16 @@ def make_help_str(idict): parser = argparse.ArgumentParser(description="Simplified data merging tool.", formatter_class=CustomFormatter) input_group = parser.add_mutually_exclusive_group() - input_group.add_argument("--cli", type=str, help="Single input file to plot from cli.") + input_group.add_argument("--infile", type=str, help="Single input file to plot from cli.") input_group.add_argument("--wildcard", type=str, help="Wildcard enabled path to plot multiple datasets.") input_group.add_argument("--file", type=str, help="Filepath to txt or csv of input files.") + dtype_group = parser.add_argument_group('Datatype options') + dtype_group.add_argument("--pickle_load", action='store_true', default=False, help="Load from pickledata. Accepts pickled tuple/list of dataframe/fs or just dataframe. If only dataframe, must provide --fs sampling frequency.") + dtype_group.add_argument("--fs", type=float, help="Sampling frequency.") + output_group = parser.add_argument_group('Output options') - output_group.add_argument("--outfile", default='./edf_viewer_flags.csv', type=str, help="Output filepath if predicting sleep/spikes/etc.") - output_group.add_argument("--username", type=str, help="Username to tag any outputs with.") + output_group.add_argument("--outfile", default='./edf_viewer_annotations.csv', type=str, help="Output filepath if predicting sleep/spikes/etc.") prep_group = parser.add_argument_group('Data preparation options') prep_group.add_argument("--chcln", type=str, default="hup", help="Channel cleaning option") @@ -654,37 +66,25 @@ def make_help_str(idict): duration_group = parser.add_mutually_exclusive_group() duration_group.add_argument("--dur", type=float, default=10, help="Duration to plot in seconds.") - duration_group.add_argument("--dur_frac", action='store_true', default=False, help="Flag. Duration in fraction of total data.") + duration_group.add_argument("--dur_frac", action='store_true', default=False, help="Flag. Duration is interpreted as a fraction of total data.") - ssh_group = parser.add_argument_group('SSH Data Loading Options') - ssh_group.add_argument("--ssh_host", type=str, help="If provided, look for data on this host connection string rather than local.") - ssh_group.add_argument("--ssh_username", type=str, help="When loading data via ssh tunnel, this is the host ssh username to log in as.") + enrichment_group = parser.add_argument_group('Add any data enrichment.') + enrichment_group.add_argument("--epilepsy_prob_file", type=str, default=None, help="Path to probability of epilepsy enrichment vectors.") misc_group = parser.add_argument_group('Misc options') + misc_group.add_argument("--winfrac", type=float, default=0.9, help="Fraction of the window for the plot.") misc_group.add_argument("--nstride", type=int, default=8, help="Stride factor for plotting.") misc_group.add_argument("--debug", action='store_true', default=False, help="Debug mode. Save no outputs.") - misc_group.add_argument("--sleep_wake_power", type=str, help="Optional file with identified groups in alpha/delta for sleep/wake patients") - misc_group.add_argument("--pickle_load", action='store_true', default=False, help="Load from pickled tuple of dataframe,fs.") - misc_group.add_argument("--flagging", action='store_true', default=False, help="Let user flag EEG for important properties.") - misc_group.add_argument("--review_mode", action='store_true', default=False, help="Allows us to prevent reviewers from seeing data they have already reviewed.") + misc_group.add_argument("--shade", action='store_true', default=False, help="Shade the original viewing region.") args = parser.parse_args() # Clean up some argument types args.chmap = None if args.chmap == 'None' else args.chmap args.montage = None if args.montage == 'None' else args.montage - # Get username and output path if needed - if args.flagging: - if args.username == None: - args.username = input("Please enter a username for tagging data: ") - if args.outfile == None: - args.outfile = './edf_annotations.csv' - else: - args.outfile = '' - # Create the file list to read in - if args.cli != None: - files = [args.cli] + if args.infile != None: + files = [args.infile] elif args.wildcard != None: files = glob.glob(args.wildcard) elif args.file != None: @@ -694,31 +94,9 @@ def make_help_str(idict): if len(files) == 0: print("No files found matching your criteria.") - # Set ssh filetype if a connection string is provided - if args.ssh_host != None: - filetype = 'ssh_edf' - else: - filetype = 'edf' - - # Use the output file to skip already reviewed files for state analysis - if path.exists(args.outfile): - ref_DF = PD.read_csv(args.outfile) - else: - ref_DF = PD.DataFrame(columns=['filename','username']) - # Iterate over the data and create the relevant plots tight_layout_dict = None for ifile in files: - - # Check if this user has already reviewed this data - iDF = ref_DF.loc[(ref_DF.username==args.username)&(ref_DF.filename==ifile)] - if iDF.shape[0] == 0 or args.review_mode == False: - - try: - DV = data_viewer(ifile,args,tight_layout_dict,filetype) - tight_layout_dict = DV.montage_plot() - PLT.close("all") - except ValueError as e: - print("Unable to load data. This is likely due to formatting issues in an EDF header.") - print(f"A detail error is as follows: {e}") - PLT.close("all") + DV = data_viewer(ifile,args,tight_layout_dict) + tight_layout_dict = DV.workflow() + PLT.show() diff --git a/scripts/device-name_project-name_optional-subdir-name/README.md b/scripts/device-name_project-name_optional-subdir-name/README.md deleted file mode 100644 index 3d78ab5c..00000000 --- a/scripts/device-name_project-name_optional-subdir-name/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Sample User Directory - -## Naming conventions - -This is an example user folder. Personal project work takes place in folders saved at the same level as the central `codehub' folder. - -We require that the naming of the folder follow the following design pattern: - -> {device name}\_{project name}\_{optional subdirectory} - - -We require this naming structure in order to join different projects within the CNT data ecosystem. Multiple users can work within a single repository, either within the same directory or within their own optional subdirectories. - -## Updating the codehub libraries - -Any changes to scripts within the [modules](../codehub/modules) subdiretory can be submitted to the main lab repository as its own branch, at which point a pull request will be reviewed before changes are accepted or rejected. diff --git a/unit_tests/EEG/unit_tests.py b/unit_tests/EEG/unit_tests.py index 60d82c7d..df809381 100644 --- a/unit_tests/EEG/unit_tests.py +++ b/unit_tests/EEG/unit_tests.py @@ -41,27 +41,28 @@ def __init__(self,args): def run_tests(self): - try: - self.test_header() - self.test_channels() - self.test_sampfreq() - self.load_data_mne() - self.load_data_pyedf() - self.compare_libraries() - self.check_nan() - self.check_running_stats(self.args.sampfreq+1) - except Exception as e: - print(e) + self.test_header() + self.test_channels() + self.test_sampfreq() + self.load_data_mne() + self.load_data_pyedf() + self.compare_libraries() + self.check_nan() + self.check_running_stats(self.args.sampfreq+1) + + def failure(self,istr): + print(istr) + if not args.silent: exit(1) - def failure(self): - raise Exception() - def test_header(self): # Read in the header self.header = read_edf_header(args.infile) + + print("HEADER:\n===============") print(self.header) + print("\n===============") # Ensure casing of the keywords header_keys = list(self.header.keys()) @@ -71,34 +72,34 @@ def test_header(self): # Check the dataset level required header info for ikey in self.required_dataset_headers: - if ikey.lower() not in self.header_keys(): - raise Exception(f"Header missing the {ikey} information.") + if ikey.lower() not in self.header_keys: + self.failure(f"Header missing the {ikey} information.") if self.header[ikey] == None or self.header[ikey] == '': - raise Exception(f"Header missing the {ikey} information.") + self.failure(f"Header missing the {ikey} information.") # Check that the channel headers are all present and contain data channel_header_mask = [] channel_header_entry_mask = [] - for ival in self.header['SignalHeaders']: + for ival in self.header['signalheaders']: ikeys = list(ival.keys()) channel_header_mask.append(all(tmp in ikeys for tmp in self.required_channel_headers)) channel_header_entry_mask.extend([ival[tmp]==None for tmp in self.required_channel_headers]) # Raise exceptions if poorly defined header is found if any(channel_header_mask) == False: - raise Exception("Header contains missing information.") + self.failure("Header contains missing information.") if any(channel_header_entry_mask) == True: - raise Exception("Header contains missing information.") + self.failure("Header contains missing information.") def test_sampfreq(self): # Obtain raw channel names - samp_freqs = np.array([int(ival['sample_rate']) for ival in self.header['SignalHeaders']]) + samp_freqs = np.array([int(ival['sample_rate']) for ival in self.header['signalheaders']]) # Check against the expected frequency freq_mask = (samp_freqs!=self.args.sampfreq) if (freq_mask).any(): - raise Exception(f"Unexpted sampling frequency found in {self.channels[freq_mask]}") + self.failure(f"Unexpted sampling frequency found in {self.channels[freq_mask]}") def test_channels(self): @@ -122,23 +123,25 @@ def test_channels(self): # Make sure all channels are present if not all(channel_check): - raise Exception() + self.failure("Could not find all the expected channels") # Check number of channels if self.channels.size != self.ref_channels.size: - raise Exception("Did not receive expected number of channels. This can arise due to poorly inputted channels.") + self.failure("Did not receive expected number of channels. This can arise due to poorly inputted channels.") def load_data_mne(self): - self.mne_data = read_raw_edf(self.args.infile).get_data() + self.mne_data = read_raw_edf(self.args.infile).get_data()[:,self.args.start_samp:self.args.end_samp] def load_data_pyedf(self): self.pyedf_data, self.pyedf_chan_info,_ = read_edf(self.args.infile) self.pyedf_data *= 1e-6 + self.pyedf_data = self.pyedf_data[:,self.args.start_samp:self.args.end_samp] def compare_libraries(self,tol=1e-8): diffs=self.mne_data-self.pyedf_data if (diffs>tol).any(): + print("Tolerance issue.") exit(1) # Drop the pyedf data to reduce memory usage now that we dont need it @@ -147,7 +150,7 @@ def compare_libraries(self,tol=1e-8): def check_nan(self): if np.isnan(self.mne_data).any(): - raise Exception("NaNs found in the data.") + self.failure("NaNs found in the data.") def check_running_stats(self,window_size): @@ -177,7 +180,7 @@ def check_running_stats(self,window_size): # Get the channel wide variance sum. Zero means all channels had zero variance for the window size mask = variance_array.sum(axis=0)==0 if (mask).any(): - raise Exception(f"All channels have zero variance around second {self.args.sampfreq*np.arange(mask.size)[mask]} seconds.") + self.failure(f"All channels have zero variance around second {self.args.sampfreq*np.arange(mask.size)[mask]} seconds.") if __name__ == '__main__': @@ -187,6 +190,8 @@ def check_running_stats(self,window_size): parser.add_argument("--sampfreq", type=int, default=256, help='Expected sampling frequency') parser.add_argument("--channel_file", type=str, default='configs/hup_standard.csv', help='CSV file containing the expected channels') parser.add_argument("--silent", action='store_true', default=False, help="Silence exceptions.") + parser.add_argument("--start_samp", default=0, help="Start sample to read data in from. Useful if spot checking a large file. (Warning. Still requires initial load of full data into memory.)") + parser.add_argument("--end_samp", default=-1, help="End sample to read data in from. Useful if spot checking a large file. (Warning. Still requires initial load of full data into memory.)") args = parser.parse_args() # Run machine level tests