Skip to content

Commit c4bb806

Browse files
committed
Add get_raw method for easier data access
1 parent 4e0e039 commit c4bb806

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

pyprep/prep_pipeline.py

+51
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,57 @@ def current_reference_signal(self):
206206
post_interp = self._interpolated_reference_signal
207207
return post_interp if post_interp else post_ref
208208

209+
def get_raw(self, stage=None):
210+
"""Retrieve the full recording data at a given stage of the pipeline.
211+
212+
Valid pipeline stages include 'unprocessed' (the raw data prior to running
213+
the pipeline), 'filtered' (the data following adaptive line noise
214+
removal), 'post-reference' (the data after robust referencing, prior to any
215+
bad channel interpolation), and 'post-interpolation' (the data after robust
216+
referencing and bad channel interpolation).
217+
218+
Parameters
219+
----------
220+
stage : str, optional
221+
The stage of the pipeline for which the full data will be retrieved. If
222+
not specified, the current state of the data will be retrieved.
223+
224+
Returns
225+
-------
226+
full_raw: mne.io.Raw
227+
An MNE Raw object containing the EEG data for the given stage of the
228+
pipeline, along with any non-EEG channels that were present in the
229+
original input data.
230+
231+
"""
232+
interpolated = self.interpolated_channels is not None
233+
stages = {
234+
"unprocessed": self.EEG_raw,
235+
"filtered": self.EEG_filtered,
236+
"post-reference": self.EEG_post_reference,
237+
"post-interpolation": self.raw_eeg._data if interpolated else None,
238+
}
239+
if stage is not None and stage.lower() not in stages.keys():
240+
raise ValueError(
241+
"'{stage}' is not a valid pipeline stage. Valid stages are "
242+
"'unprocessed', 'filtered', 'post-reference', and 'post-interpolation'."
243+
)
244+
245+
eeg_data = self.raw_eeg._data # Default to most recent stage of pipeline
246+
if stage:
247+
eeg_data = stages[stage.lower()]
248+
if not eeg_data:
249+
raise ValueError(
250+
"Could not retrieve {stage} data, as that stage of the pipeline "
251+
"has not yet been performed."
252+
)
253+
full_raw = self.raw_eeg.copy()
254+
full_raw._data = eeg_data
255+
if self.raw_non_eeg is not None:
256+
full_raw.add_channels([self.raw_non_eeg])
257+
258+
return full_raw
259+
209260
def remove_line_noise(self, line_freqs):
210261
"""Remove line noise from all EEG channels using multi-taper decomposition.
211262

0 commit comments

Comments
 (0)