From 48e65113260656f97e11cbf0a495a49b5cb08e09 Mon Sep 17 00:00:00 2001 From: Lei Haixia <1364747481@qq.com> Date: Wed, 9 Apr 2025 20:59:01 +0800 Subject: [PATCH 1/2] add API of "FES" --- demos/brainflow_demos/FES.py | 22 ++ demos/brainflow_demos/Online_mi_FES.py | 340 +++++++++++++++++++++++++ metabci/brainflow/ElectroStimulator.py | 223 ++++++++++++++++ 3 files changed, 585 insertions(+) create mode 100644 demos/brainflow_demos/FES.py create mode 100644 demos/brainflow_demos/Online_mi_FES.py create mode 100644 metabci/brainflow/ElectroStimulator.py diff --git a/demos/brainflow_demos/FES.py b/demos/brainflow_demos/FES.py new file mode 100644 index 00000000..c81bfa31 --- /dev/null +++ b/demos/brainflow_demos/FES.py @@ -0,0 +1,22 @@ +from metabci.brainflow.ElectroStimulator import ElectroStimulator +stim = ElectroStimulator('COM1') # 串口号 + +# 启用通道1并设置参数 +stim.select_channel(1) # 启用通道1 +stim.set_channel_parameters(1, { + ElectroStimulator._Param.current_positive: 2, + ElectroStimulator._Param.current_negative: 2, + ElectroStimulator._Param.pulse_positive: 250, + ElectroStimulator._Param.pulse_negative: 250, + ElectroStimulator._Param.frequency: 50, + ElectroStimulator._Param.rise_time: 500, + ElectroStimulator._Param.stable_time: 3000, + ElectroStimulator._Param.descent_time: 500 + }) + +# 锁定参数并启动 +stim.lock_parameters() +stim.run_stimulation(duration=10) + +if 'stim' in locals(): + stim.close() diff --git a/demos/brainflow_demos/Online_mi_FES.py b/demos/brainflow_demos/Online_mi_FES.py new file mode 100644 index 00000000..75ef5580 --- /dev/null +++ b/demos/brainflow_demos/Online_mi_FES.py @@ -0,0 +1,340 @@ +# -*- coding: utf-8 -*- +# License: MIT License +""" +MI Feedback on NeuroScan and FES. + +""" +import time +import numpy as np + +import mne +from mne.filter import resample +from pylsl import StreamInfo, StreamOutlet +from metabci.brainflow.amplifiers import NeuroScan, Marker +from metabci.brainflow.workers import ProcessWorker +from metabci.brainda.algorithms.decomposition.base import generate_filterbank +from metabci.brainda.algorithms.utils.model_selection \ + import EnhancedLeaveOneGroupOut +from metabci.brainda.algorithms.decomposition.csp import FBCSP +from metabci.brainda.utils import upper_ch_names +from mne.io import read_raw_cnt +from sklearn.svm import SVC +from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.pipeline import make_pipeline +from scipy import signal +import threading +from metabci.brainflow.ElectroStimulator import ElectroStimulator + + +def label_encoder(y, labels): + new_y = y.copy() + for i, label in enumerate(labels): + ix = (y == label) + new_y[ix] = i + return new_y + + +class MaxClassifier(BaseEstimator, ClassifierMixin): + def __init__(self): + pass + + def fit(self, X, y): + pass + + def predict(self, X): + X = X.reshape((-1, X.shape[-1])) + y = np.argmax(X, axis=-1) + return y + + +def read_data(run_files, chs, interval, labels): + Xs, ys = [], [] + for run_file in run_files: + raw = read_raw_cnt(run_file, preload=True, verbose=False) + raw = upper_ch_names(raw) + raw.filter(6, 30, l_trans_bandwidth=2, h_trans_bandwidth=5, + phase='zero-double') + events = mne.events_from_annotations( + raw, event_id=lambda x: int(x), verbose=False)[0] + ch_picks = mne.pick_channels(raw.ch_names, chs, ordered=True) + epochs = mne.Epochs(raw, events, + event_id=labels, + tmin=interval[0], + tmax=interval[1], + baseline=None, + picks=ch_picks, + verbose=False) + + for label in labels: + X = epochs[str(label)].get_data()[..., 1:] + Xs.append(X) + ys.append(np.ones((len(X))) * label) + Xs = np.concatenate(Xs, axis=0) + ys = np.concatenate(ys, axis=0) + ys = label_encoder(ys, labels) + + return Xs, ys, ch_picks + + +def bandpass(sig, freq0, freq1, srate, axis=-1): + wn1 = 2 * freq0 / srate + wn2 = 2 * freq1 / srate + b, a = signal.butter(4, [wn1, wn2], 'bandpass') + sig_new = signal.filtfilt(b, a, sig, axis=axis) + return sig_new + + +# 训练模型 + + +def train_model(X, y, srate=1000): + y = np.reshape(y, (-1)) + # 降采样 + X = resample(X, up=256, down=srate) + # 滤波 + # X = bandpass(X, 6, 30, 256) + # 零均值单位方差 归一化 + X = X - np.mean(X, axis=-1, keepdims=True) + X = X / np.std(X, axis=(-1, -2), keepdims=True) + # brainda.algorithms.decomposition.csp.MultiCSP + wp = [(4, 8), (8, 12), (12, 30)] + ws = [(2, 10), (6, 14), (10, 32)] + filterbank = generate_filterbank(wp, ws, srate=256, order=4, rp=0.5) + # model = make_pipeline( + # MultiCSP(n_components = 2), + # LinearDiscriminantAnalysis()) + model = make_pipeline(*[ + FBCSP(n_components=5, + n_mutualinfo_components=4, + filterbank=filterbank), + SVC() + ]) + # fit()训练模型 + model = model.fit(X, y) + + return model + + +# 预测标签 + + +def model_predict(X, srate=1000, model=None): + X = np.reshape(X, (-1, X.shape[-2], X.shape[-1])) + # 降采样 + X = resample(X, up=256, down=srate) + # 滤波 + X = bandpass(X, 8, 30, 256) + # 零均值单位方差 归一化 + X = X - np.mean(X, axis=-1, keepdims=True) + X = X / np.std(X, axis=(-1, -2), keepdims=True) + # predict()预测标签 + p_labels = model.predict(X) + return p_labels + + +# 计算离线正确率 + + +def offline_validation(X, y, srate=1000): + y = np.reshape(y, (-1)) + spliter = EnhancedLeaveOneGroupOut(return_validate=False) + + kfold_accs = [] + for train_ind, test_ind in spliter.split(X, y=y): + X_train, y_train = np.copy(X[train_ind]), np.copy(y[train_ind]) + X_test, y_test = np.copy(X[test_ind]), np.copy(y[test_ind]) + + model = train_model(X_train, y_train, srate=srate) + p_labels = model_predict(X_test, srate=srate, model=model) + kfold_accs.append(np.mean(p_labels == y_test)) + + return np.mean(kfold_accs) + + +class FeedbackWorker(ProcessWorker): + def __init__(self, + run_files, + pick_chs, + stim_interval, + stim_labels, + srate, + lsl_source_id, + timeout, + worker_name): + self.run_files = run_files + self.pick_chs = pick_chs + self.stim_interval = stim_interval + self.stim_labels = stim_labels + self.srate = srate + self.lsl_source_id = lsl_source_id + super().__init__(timeout=timeout, name=worker_name) + self.stimulator = None # 电刺激器 + self.stim_lock = None # 线程锁 + + def pre(self): + X, y, ch_ind = read_data(run_files=self.run_files, + chs=self.pick_chs, + interval=self.stim_interval, + labels=self.stim_labels) + print("Loding data successfully") + acc = offline_validation(X, y, srate=self.srate) # 计算离线准确率 + print("Current Model accuracy:", acc) + self.estimator = train_model(X, y, srate=self.srate) + self.stimulator = ElectroStimulator('COM3') + self.stim_lock = threading.Lock() # 在子进程中初始化锁 + print("电刺激器初始化成功") + self.ch_ind = ch_ind + info = StreamInfo( + name='meta_feedback', + type='Markers', + channel_count=1, + nominal_srate=0, + channel_format='int32', + source_id=self.lsl_source_id) + self.outlet = StreamOutlet(info) + print('Waiting connection...') + while not self._exit: + if self.outlet.wait_for_consumers(1e-3): + break + print('Connected') + + def _stimulate(self, channels, params_list, duration=4): + """电刺激线程函数""" + with self.stim_lock: + try: + # 清除所有已选通道 + for ch in list(self.stimulator._selected_channels): + self.stimulator.disable_channel(ch) + + # 设置多个通道参数 + for channel, params in zip(channels, params_list): + self.stimulator.select_channel(channel) + self.stimulator.set_channel_parameters(channel, params) + self.stimulator.lock_parameters() + self.stimulator.run_stimulation(duration) + + except Exception as e: + print(f"电刺激控制出错: {e}") + + def consume(self, data): + # 电刺激参数配置 + params_ch1 = { + ElectroStimulator._Param.current_positive: 2, + ElectroStimulator._Param.current_negative: 2, + ElectroStimulator._Param.pulse_positive: 250, + ElectroStimulator._Param.pulse_negative: 250, + ElectroStimulator._Param.frequency: 50, + ElectroStimulator._Param.rise_time: 500, + ElectroStimulator._Param.stable_time: 3000, + ElectroStimulator._Param.descent_time: 500 + } + params_ch2 = { + ElectroStimulator._Param.current_positive: 2, + ElectroStimulator._Param.current_negative: 2, + ElectroStimulator._Param.pulse_positive: 250, + ElectroStimulator._Param.pulse_negative: 250, + ElectroStimulator._Param.frequency: 50, + ElectroStimulator._Param.rise_time: 500, + ElectroStimulator._Param.stable_time: 3000, + ElectroStimulator._Param.descent_time: 500 + } + data = np.array(data, dtype=np.float64).T + data = data[self.ch_ind] + p_labels = model_predict(data, srate=self.srate, model=self.estimator) + p_labels = int(p_labels) + p_labels = p_labels + 1 + # 根据标签选择通道 + if p_labels == 1: + print("激活通道1") + stim_thread = threading.Thread( + target=self._stimulate, + args=([1], [params_ch1])) + elif p_labels == 2: + print("激活通道2") + stim_thread = threading.Thread( + target=self._stimulate, + args=([2], [params_ch2])) + else: + return + # 启动电刺激线程 + stim_thread.start() + p_labels = [p_labels] + # p_labels = p_labels.tolist() + print(p_labels) + + if self.outlet.have_consumers(): + self.outlet.push_sample(p_labels) + + def post(self): + # 关闭电刺激器连接 + if self.stimulator: + self.stimulator.close() + + +if __name__ == '__main__': + # 放大器的采样率 + srate = 1000 + # 截取数据的时间段,考虑进视觉刺激延迟140ms + stim_interval = [0, 4] + # 事件标签 + stim_labels = list(range(1, 3)) + cnts = 1 # .cnt数目 + # 数据路径 + filepath = "E:\\lhx\\" + runs = list(range(1, cnts + 1)) + run_files = ['{:s}\\{:d}.cnt'.format( + filepath, run) for run in runs] # 具体数据路径 + pick_chs = ['FC5', 'FC3', 'FC1', 'FCZ', 'FC2', + 'FC4', 'FC6', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', + 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'P5', + 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6'] + + lsl_source_id = 'meta_online_worker' + feedback_worker_name = 'feedback_worker' + + worker = FeedbackWorker(run_files=run_files, + pick_chs=pick_chs, + stim_interval=stim_interval, + stim_labels=stim_labels, + srate=srate, + lsl_source_id=lsl_source_id, + timeout=5e-2, + worker_name=feedback_worker_name) # 在线处理 + marker = Marker(interval=stim_interval, srate=srate, + events=stim_labels) # 打标签全为1 + # worker.pre() + + ns = NeuroScan( + device_address=('169.254.80.232', 4000), + srate=srate, + num_chans=64) # NeuroScan parameter + + # 与ns建立tcp连接 + ns.connect_tcp() + # ns开始采集波形数据 + ns.start_acq() + + # register worker来实现在线处理 + ns.register_worker(feedback_worker_name, worker, marker) + # 开启在线处理进程 + ns.up_worker(feedback_worker_name) + # 等待 0.5s + time.sleep(0.5) + + # ns开始截取数据线程,并把数据传递数据给处理进程 + ns.start_trans() + + # 任意键关闭处理进程 + input('press any key to close\n') + # 关闭处理进程 + ns.down_worker('feedback_worker') + # 等待 1s + time.sleep(1) + + # ns停止在线截取线程 + ns.stop_trans() + # ns停止采集波形数据 + ns.stop_acq() + ns.close_connection() # 与ns断开连接 + ns.clear() + print('bye') diff --git a/metabci/brainflow/ElectroStimulator.py b/metabci/brainflow/ElectroStimulator.py new file mode 100644 index 00000000..2a2a0912 --- /dev/null +++ b/metabci/brainflow/ElectroStimulator.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +""" +Control the electrical stimulator.Implements control of electrical stimulation parameters via serial communication, including channel selection, waveform parameter configuration, parameter locking, and therapy start/stop operations. + +""" +import serial +import struct +import time +from enum import IntEnum +from serial.serialutil import SerialException +from typing import Set, Dict + + +class ElectroStimulator: + """ + Electrical stimulator controller class for multichannel parameter configuration and pulse therapy control. + + author: Haixia Lei + + Created on: 2024-04-08 + + update log: + None + + Parameters + ---------- + port : str + Serial port device path. + baudrate : int + Communication baud rate (default 115200). + + Attributes + ---------- + _is_locked : bool + Parameter locking status, prohibits parameter modification when locked. + _selected_channels : Set[int] + Currently enabled therapy channels (0-12). + + Raises + ---------- + RuntimeError + Raised when serial connection fails or state machine rules are violated. + ValueError + Raised when channel number or parameter values exceed valid ranges. + + Note + ---------- + 1. Parameters must be modified before locking. Only start/stop operations are allowed after locking. + 2. The device connection must be reinitialized after calling close(). + 3. For channel-specific parameters, ensure the channel is selected first. + + """ + class _Param(IntEnum): + """Parameter address""" + channel_select = 0x10 # 通道选择 + rise_time = 0x11 # 斜升时间 (ms) + stable_time = 0x12 # 稳定时间 (ms) + descent_time = 0x13 # 斜降时间 (ms) + current_positive = 0x18 # 正脉冲峰值 (mA) + pulse_positive = 0x19 # 正脉冲宽度 (us) + current_negative = 0x1A # 负脉冲峰值 (mA) + pulse_negative = 0x1B # 负脉冲宽度 (us) + frequency = 0x1D # 频率 (Hz) + small_cycles = 0x20 # 小周期次数 + big_cycles = 0x21 # 大周期次数 + small_interval = 0x22 # 小周期间隔 + big_interval = 0x23 # 大周期间隔 + lock = 0xF1 # 参数锁定 + start = 0xFA # 开始治疗 + stop = 0xFC # 停止治疗 + + def __init__(self, port, baudrate=115200): + self.ser = None + self._is_locked = False # 参数锁定状态 + self._selected_channels: Set[int] = set() # 存储已选通道 + + # 初始化串口连接 + try: + self.ser = serial.Serial( + port=port, + baudrate=baudrate, + bytesize=8, + parity='N', + stopbits=1, + timeout=1 + ) + print(f"Connected to {port}") + except SerialException as e: + raise RuntimeError(f"Failed to open serial port: {e}") from None + + def _validate_channel(self, channel): + """Verify channel number validity (0-12)""" + if not 0 <= channel <= 12: + raise ValueError(f"Invalid channel {channel}, must be 0-12") + + def select_channel(self, channel: int, enable: bool = True): + """Select or deselect a therapy channel (must be called before locking). + + Parameters + ---------- + channel : int + Target channel number (0-12). + enable : bool + Enable/disable the channel (default True). + """ + if self._is_locked: + raise RuntimeError("Channel selection cannot be modified after the parameters are locked") + + self._validate_channel(channel) + + # 设置通道使能位 + # 数据格式:0x0001 表示启用,0x0000 表示禁用 + value = 0x0001 if enable else 0x0000 + self.set_parameter(channel, self._Param.channel_select, value) + + # 更新已选通道集合 + if enable: + self._selected_channels.add(channel) + elif channel in self._selected_channels: + self._selected_channels.remove(channel) + + print(f"通道 {channel} {'已启用' if enable else '已禁用'}") + + def disable_channel(self, channel: int): + """Disable channel.""" + self.select_channel(channel, enable=False) + + def set_channel_parameters(self, channel: int, params: Dict[_Param, int]): + """Set channel parameters in batches.""" + for param, value in params.items(): + self.set_parameter(channel, param, value) + + def _build_frame(self, channel, param_addr, data_value): + """Build protocol data frames.""" + # 验证通道号 + self._validate_channel(channel) + + # 数据区转换(16位高位在前) + try: + data_bytes = struct.pack('>H', data_value) + except struct.error: + raise ValueError(f"Invalid data value: {data_value} (0-65535)") from None + + # 计算总长度:n*2 + 4(n=1) + total_length = struct.pack('B', 1*2 + 4) + + # 组合数据帧 + return (b'\x5A\xA5' + # 帧头 + total_length + # 总长度 + b'\x93' + # 写命令 + struct.pack('B', channel) + + struct.pack('B', param_addr) + + b'\x01' + # 数据长度 + data_bytes) + + def set_parameter(self, channel, param_addr, value): + """Set stimulation parameters for a channel. + + Parameters + ---------- + channel : int + Target channel number (0-12). Use 0 for global commands. + param_addr : ElectroStimulator._Param + Register address (e.g., _Param.frequency). + value : int + Parameter value (0-65535). Specific ranges depend on the parameter. + """ + try: + # 参数锁定后禁止修改参数(全局命令除外) + if self._is_locked and channel != 0: + raise RuntimeError("Cannot modify parameters after locking") + + frame = self._build_frame(channel, param_addr, value) + self.ser.write(frame) + print(f"Set Channel {channel}: Addr 0x{param_addr:02X} = {value}") + + # 添加操作间隔防止设备过载 + time.sleep(0.1) + + except SerialException as e: + raise RuntimeError(f"Serial communication failed: {e}") from None + + def lock_parameters(self): + """Lock all parameters to prevent accidental modifications.At least one channel must be selected.""" + if not self._selected_channels: + raise RuntimeError("At least one channel must be selected before locking") + + if self._is_locked: + print("Parameters already locked") + return + + self.set_parameter(0, self._Param.lock, 0x0001) + self._is_locked = True + print("Parameters Locked") + + # 等待设备确认锁定 + time.sleep(0.5) + + def run_stimulation(self, duration: int): + """Start therapy and automatically stop after a specified duration. + + Parameters + ---------- + duration : int + Therapy duration. + """ + if not self._is_locked: + raise RuntimeError("Must lock parameters before starting stimulation") + if not self._selected_channels: + raise RuntimeError("There is no effective treatment channel") + self.set_parameter(0, self._Param.start, 0x0001) + print(f"治疗已启动,激活通道: {sorted(self._selected_channels)}") + + time.sleep(duration) + self.set_parameter(0, self._Param.stop, 0x0001) + print("治疗结束") + self._is_locked = False # 停止后自动解锁 + + def close(self): + """Safely terminate the serial connection.""" + if self.ser and self.ser.is_open: + self.ser.close() + print("Serial port closed") From 6ec78c588a160bd14effd452ea407d3233efeaad Mon Sep 17 00:00:00 2001 From: Lei Haixia <1364747481@qq.com> Date: Thu, 5 Jun 2025 14:51:03 +0800 Subject: [PATCH 2/2] add API of "FES" --- metabci/brainflow/ElectroStimulator.py | 32 +++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/metabci/brainflow/ElectroStimulator.py b/metabci/brainflow/ElectroStimulator.py index 2a2a0912..b6364e4d 100644 --- a/metabci/brainflow/ElectroStimulator.py +++ b/metabci/brainflow/ElectroStimulator.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- """ -Control the electrical stimulator.Implements control of electrical stimulation parameters via serial communication, including channel selection, waveform parameter configuration, parameter locking, and therapy start/stop operations. - +Control the electrical stimulator.Implements control of electrical stimulation +parameters via serial communication, including channel selection, waveform +parameter configuration, parameter locking, and therapy start/stop operations. """ import serial import struct @@ -13,7 +14,8 @@ class ElectroStimulator: """ - Electrical stimulator controller class for multichannel parameter configuration and pulse therapy control. + Electrical stimulator controller class for + multichannel parameter configuration and pulse therapy control. author: Haixia Lei @@ -39,16 +41,18 @@ class ElectroStimulator: Raises ---------- RuntimeError - Raised when serial connection fails or state machine rules are violated. + Raised when serial connection fails + or state machine rules are violated. ValueError Raised when channel number or parameter values exceed valid ranges. - + Note ---------- - 1. Parameters must be modified before locking. Only start/stop operations are allowed after locking. + 1. Parameters must be modified before locking. + Only start/stop operations are allowed after locking. 2. The device connection must be reinitialized after calling close(). 3. For channel-specific parameters, ensure the channel is selected first. - + """ class _Param(IntEnum): """Parameter address""" @@ -94,7 +98,8 @@ def _validate_channel(self, channel): raise ValueError(f"Invalid channel {channel}, must be 0-12") def select_channel(self, channel: int, enable: bool = True): - """Select or deselect a therapy channel (must be called before locking). + """Select or deselect a therapy channel + (must be called before locking). Parameters ---------- @@ -104,7 +109,7 @@ def select_channel(self, channel: int, enable: bool = True): Enable/disable the channel (default True). """ if self._is_locked: - raise RuntimeError("Channel selection cannot be modified after the parameters are locked") + raise RuntimeError("Channel selection cannot be modified") self._validate_channel(channel) @@ -139,7 +144,7 @@ def _build_frame(self, channel, param_addr, data_value): try: data_bytes = struct.pack('>H', data_value) except struct.error: - raise ValueError(f"Invalid data value: {data_value} (0-65535)") from None + raise ValueError(f"Invalid data value:{data_value}") from None # 计算总长度:n*2 + 4(n=1) total_length = struct.pack('B', 1*2 + 4) @@ -181,9 +186,10 @@ def set_parameter(self, channel, param_addr, value): raise RuntimeError(f"Serial communication failed: {e}") from None def lock_parameters(self): - """Lock all parameters to prevent accidental modifications.At least one channel must be selected.""" + """Lock all parameters to prevent accidental modifications. + At least one channel must be selected.""" if not self._selected_channels: - raise RuntimeError("At least one channel must be selected before locking") + raise RuntimeError("At least one channel must be selected") if self._is_locked: print("Parameters already locked") @@ -205,7 +211,7 @@ def run_stimulation(self, duration: int): Therapy duration. """ if not self._is_locked: - raise RuntimeError("Must lock parameters before starting stimulation") + raise RuntimeError("Must lock parameters before starting") if not self._selected_channels: raise RuntimeError("There is no effective treatment channel") self.set_parameter(0, self._Param.start, 0x0001)