1 Star 3 Fork 0

Aro/脑电新手救星

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
pre.py 9.77 KB
一键复制 编辑 原始数据 按行查看 历史
qq429379744 提交于 2023-05-22 16:20 . 6
# -*- coding: utf-8 -*-
"""
Created on Tue May 10 15:43:29 2022
@author: 42937
"""
import numpy as np
import mne
import cv2
import sys
from mne.preprocessing import ICA
from mne.time_frequency import tfr_morlet
from matplotlib import pyplot as plt
import matplotlib
from mne_icalabel import label_components
import math
from PIL import Image
from matplotlib.figure import Figure
import base64
# 可设置参数
# 3e-4
def param(drop=[], lowf=0.1, highf=30, split_reject=2e-4):
drop = drop
lowf = lowf
highf = highf
split_reject = split_reject
return drop, lowf, highf, split_reject
# 读取文件(该部分没有改动,只删减了电极文件部分,以及提取了文件名用于最后的数据文件命名)
def read_data(data_path, show_plot=False, show_plot_psd=False,
show_plot_sensors=False, show_plot_topo=False, locs_info_path=None):
# 文件后缀
split_names = data_path.rsplit(".")
names_length = len(split_names)
file_suffix_name = split_names[names_length - 1]
file_name = split_names[names_length - 2]
if file_suffix_name == "set":
raw = mne.io.read_raw_eeglab(data_path, preload=True)
elif file_suffix_name == "vhdr":
raw = mne.io.read_raw_brainvision(data_path, preload=True)
elif file_suffix_name == "edf":
raw = mne.io.read_raw_edf(data_path, preload=True)
elif file_suffix_name == "bdf":
raw = mne.io.read_raw_bdf(data_path, preload=True)
elif file_suffix_name == "gdf":
raw = mne.io.read_raw_gdf(data_path, preload=True)
elif file_suffix_name == "cnt":
raw = mne.io.read_raw_cnt(data_path, preload=True)
elif file_suffix_name == "egi" or file_suffix_name == "mff":
raw = mne.io.read_raw_egi(data_path, preload=True)
elif file_suffix_name == "data":
raw = mne.io.read_raw_nicolet(data_path, preload=True)
elif file_suffix_name == "nxe":
raw = mne.io.read_raw_eximia(data_path, preload=True)
elif file_suffix_name == "lay" or file_suffix_name == "dat":
raw = mne.io.read_raw_persyst(data_path, preload=True)
else:
# 如果选择的文件不包含在以上后缀文件中将返回字符串
return "没有该指定的后缀文件"
# 提取文件名
if "\\" in file_name:
split_names = file_name.rsplit("\\")
names_length = len(split_names)
file_name = split_names[names_length - 1]
if "/" in file_name:
split_names = file_name.rsplit("/")
names_length = len(split_names)
file_name = split_names[names_length - 1]
# 展示给后台的数据
# print(raw)
# print(raw.info)
# 返回原始数据
return raw, file_name
# 滤波
# 工频可能为50Hz或60Hz,检测出来;高低通滤波默认为0.1Hz和30Hz
def wave_filter(raw, lowf, highf):
a = raw.plot_psd(average=True)
ax = a.axes[0]
# x坐标为频率,y坐标为能量
freqs = ax.lines[-1].get_xdata()
psds = ax.lines[-1].get_ydata()
cou1 = cou2 = 0
# 比较50Hz和60Hz的能量,较高者则为工频,执行陷波滤波
for f in freqs:
if f < 60:
if f < 50:
cou1 = cou1 + 1
cou2 = cou2 + 1
else:
break
if psds[cou1] > psds[cou2]:
filt = 50
else:
filt = 60
raw = raw.notch_filter(filt) # 陷波滤波
raw.compute_psd().plot(average=True)
raw = raw.filter(lowf, highf) # ⾼低通滤波
return raw
# 导联格式转换及筛选
# 将格式与1020系统不符合的导联转换,1020系统中没有的导联去除,如此以⽣成标准的国际10-20系统
def channels_trans(raw, drop):
# 1020系统包含的导联
channels_1020 = ['LPA', 'RPA', 'Nz', 'Fp1', 'Fpz', 'Fp2', 'AF9', 'AF7', 'AF5', 'AF3', 'AF1', 'AFz', 'AF2', 'AF4',
'AF6', 'AF8', 'AF10', 'F9', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'F10', 'FT9',
'FT7', 'FC5', 'FC3',
'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10', 'T9', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4',
'C6', 'T8', 'T10',
'TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10', 'P9', 'P7', 'P5',
'P3', 'P1', 'Pz',
'P2', 'P4', 'P6', 'P8', 'P10', 'PO9', 'PO7', 'PO5', 'PO3', 'PO1', 'POz', 'PO2', 'PO4', 'PO6',
'PO8',
'PO10', 'O1', 'Oz', 'O2', 'O9', 'Iz', 'O10', 'T3', 'T5', 'T4', 'T6', 'M1', 'M2', 'A1', 'A2']
channels = raw.info["ch_names"]
trans_old = [] # 需要被替换的旧导联名
trans_new = [] # 用以替换的新导联名
low_1020 = [x.lower() for x in channels_1020] # 将1020系统的导联名换成小写
for i in channels:
if i not in channels_1020: # 筛选出输入与1020系统不匹配的导联名
j = i.lower()
# 将输入导联名换成小写,与小写的1020系统比较,若匹配,则说明只是大小写错误,将输入转换成1020系统中的导联名;
# 否则说明1020系统没有该导联,将其去掉
if j in low_1020:
trans_old.append(i)
k = low_1020.index(j)
trans_new.append(channels_1020[k])
else:
drop.append(i)
trans = {trans_old[i]: trans_new[i] for i in range(len(trans_old))}
raw.rename_channels(trans) # 替换不匹配的导联名
raw.drop_channels(drop) # 去掉导联
print("去掉导联:" + str(drop)) # 被去掉的导联
print("剩下的导联" + str(raw.info["ch_names"]))
montage = mne.channels.make_standard_montage("standard_1020") # ⽣成标准的国际10-20系统,不需电极图
raw.set_montage(montage)
return raw
# 自动判断坏导联并插值重建
def artifact_remove(raw):
data = raw.get_data() # 提取数据
dta = len(data[0])
flag = 0
while dta > 1000: # 若数据量太大处理时长太久,所以抽样1000以下时刻对比
dta = int(dta / 10)
flag = flag + 1
dt = np.zeros((len(data), len(data[0][::pow(10, flag)])))
# print(len(dt[0]))
# print(len(data[0][::pow(10, flag)]))
# print(len(data[0]))
for chan1 in range(len(data)): # 抽样操作
dt[chan1] = data[chan1][::pow(10, flag)]
avg = [None] * len(dt[0])
sum = 0
bad = []
for time1 in range(len(dt[0])):
for chan4 in range(len(dt)):
sum = sum + dt[chan4][time1] # 累计,用以算平均数
avg[time1] = sum / len(dt[0]) # 算平均数
for time2 in range(len(dt[0])):
for chan2 in range(len(dt)):
if abs(dt[chan2][time2] - avg[time2]) > 1e-03: # 记录大于阈值的导联
if chan2 not in bad:
bad.append(chan2)
channels = raw.info["ch_names"]
for chan3 in bad:
raw.info['bads'].append(channels[chan3])
raw = raw.interpolate_bads() # 插值重建坏道
return raw
def ica(raw):
#raw.crop(tmax=60.0).pick_types(eeg=True)
raw.load_data()
filt_raw = raw.copy().filter(l_freq=1.0, h_freq=None)
filt_raw = filt_raw.set_eeg_reference("average")
ica = ICA(
n_components=15,
max_iter="auto",
method="infomax",
random_state=97,
fit_params=dict(extended=True),
)
# print(filt_raw)
ica.fit(filt_raw)
raw.load_data()
ica.plot_components()
ic_labels = label_components(filt_raw, ica, method="iclabel")
print("ICA成分依次为:" + str(ic_labels["labels"]))
labels = ic_labels["labels"]
exclude_idx = [idx for idx, label in enumerate(labels) if label not in ["brain", "other"]]
print(f"被剔除的成分序号: {exclude_idx}")
ica.exclude = exclude_idx
ica.apply(raw)
return raw
def split(raw, split_reject):
# 事件信息数据类型转换
events, event_id = mne.events_from_annotations(raw)
print(events.shape, event_id)
# print(raw.annotations)
# 数据分段
# 一开始想让客户选择处理哪个event,但如此就不够自动,所以还是把各个event都处理并生成预处理文件,再让用户选择进行分析或者保存
epochs = []
for key in event_id.values():
epoch = mne.Epochs(raw, events, event_id=int(key), tmin=-1, tmax=2, baseline=(-0.5, 0), \
preload=True, reject=dict(eeg=split_reject))
epochs.append(epoch)
print(epochs)
return epochs, list(event_id.keys())
def average(epochs):
evokeds = []
for epoch in epochs:
evoked = epoch.average()
evokeds.append(evoked)
return evokeds
def get_data(evokeds, file_name, event_id):
a = 1
for evoked in evokeds:
evoked_array = evoked.get_data()
# 提取获取的数据
# 默认将数据添加到可选列表以直接用于多被试分析(临时储存),另外加一个按钮可选择长期保存且可选储存路径
print("第" + str(a) + "个event的evoked:")
print(evoked_array)
# 命名获取的数据的文件名:“读取的文件的文件名”+“/”+“event名”
filename = str(file_name) + "\\" + str(event_id[a - 1])
print("该文件将被命名为" + filename)
# np.savetxt('%s.text' % filename, evoked_array)
a = a + 1
#data_path = "F:\迅雷下载\ICLabel-master\ICLabel-master\\tests\\eeglab_data.set"
data_path = "F:\\work\广大305_CPS\\0.原始数据\\7.vhdr"
drop, lowf, highf, split_reject = param() # 参数
raw, file_name = read_data(data_path) # 读取文件
raw = channels_trans(raw, drop) # 导联格式转换及筛选
raw = wave_filter(raw, lowf, highf) # 滤波
raw = artifact_remove(raw) # 去伪迹
raw = ica(raw) # ICA处理
epochs, event_id = split(raw, split_reject) # 分段
evokeds = average(epochs) # 叠加平均
get_data(evokeds, file_name, event_id) # 提取数据
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/JamesZhu0/eeg-analysis-system.git
git@gitee.com:JamesZhu0/eeg-analysis-system.git
JamesZhu0
eeg-analysis-system
脑电新手救星
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385