TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码

本文主要是介绍TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码
脑电信号采集设备是由NT9200-32D型号脑电图仪和NeuSen W系列无线脑电采集系统组成,采集后的信号用Matlab打开,保存在结构体数据中,采集到的原始信号形式是:16x640000 double,最开始对数据进行手动分段分成[280,16,1000],280指trials,22指channels,1000指 samples,
整个代码可分为:**数据切分,搭建网络,训练数据,测试数据,**四个部分
1.导入包

import numpy as np
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
import scipy.io
from matplotlib import pyplot as plt

2.数据切分

K.set_image_data_format('channels_last')
samplesfile = scipy.io.loadmat('F:/holiday_code/attention/TSA/data/foursecond.mat')
X = samplesfile['eeg']#提取数组,结构体名称是eeg
event_id = dict(l=1, m=2, lm=3, ml=4)#四分类运动想象数据
# Setup for reading the raw data
labels = samplesfile['Mark']#加载标签数据
y = labels[:,-1]#标签数据
kernels, chans, samples = 1, 16, 1000# take 50/25/25 percent of the data to train/validate/test
X_train = X[0:140, ]
Y_train = y[0:140]
X_validate = X[140:210, ]
Y_validate = y[140:210]
X_test = X[210:, ]
Y_test = y[210:]
#把标签数据转换成one-hot编码
Y_train = np_utils.to_categorical(Y_train - 1)
Y_validate = np_utils.to_categorical(Y_validate - 1)
Y_test = np_utils.to_categorical(Y_test - 1)
#根据网络结构设置数据的输入形式(trials, channels, samples, kernels)
X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)

4.搭建网络

#导入需要的库
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
def EEGNet(nb_classes, Chans = 16, Samples = 1000,dropoutRate = 0.5, kernLength = 64, F1 = 8, D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):if dropoutType == 'SpatialDropout2D':dropoutType = SpatialDropout2Delif dropoutType == 'Dropout':dropoutType = Dropoutelse:raise ValueError('dropoutType must be one of SpatialDropout2D ''or Dropout, passed as a string.')input1 = Input(shape = (Chans, Samples, 1))print("input shape", input1.shape, Chans, Samples, kernLength)##################################################################block1 = Conv2D(F1, (1, kernLength), padding = 'same',input_shape = (Chans, Samples, 1),use_bias = False)(input1)block1 = BatchNormalization()(block1)block1 = DepthwiseConv2D((Chans, 1), use_bias = False,depth_multiplier = D,depthwise_constraint = max_norm(1.))(block1)block1 = BatchNormalization()(block1)block1 = Activation('elu')(block1)block1 = AveragePooling2D((1, 4))(block1)block1 = dropoutType(dropoutRate)(block1)block2 = SeparableConv2D(F2, (1, 16),use_bias = False, padding = 'same')(block1)block2 = BatchNormalization()(block2)block2 = Activation('elu')(block2)block2 = AveragePooling2D((1, 8))(block2)block2 = dropoutType(dropoutRate)(block2)flatten = Flatten(name = 'flatten')(block2)dense = Dense(nb_classes, name = 'dense',kernel_constraint = max_norm(norm_rate))(flatten)softmax = Activation('softmax', name = 'softmax')(dense)return Model(inputs=input1, outputs=softmax)

5.训练模型

model = EEGNet(nb_classes = 4, Chans = 16, Samples = 1000,dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16,dropoutType = 'Dropout')
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])
# count number of parameters in the model
numParams = model.count_params()
# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5', verbose=1,save_best_only=True)
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}
fittedModel = model.fit(X_train, Y_train, batch_size=16, epochs=300,verbose=2, validation_data=(X_validate, Y_validate),callbacks=[checkpointer], class_weight=class_weights)

6.测试模型

model.load_weights('F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5')
probs = model.predict(X_test)
preds = probs.argmax(axis=-1)
acc = np.mean(preds == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc))# plot the accuracy and loss graph
plt.plot(fittedModel.history['accuracy'])
plt.plot(fittedModel.history['val_accuracy'])
plt.plot(fittedModel.history['loss'])
plt.plot(fittedModel.history['val_loss'])
plt.title('acc & loss')
plt.xlabel('epoch')
plt.legend(['acc', 'val_acc','loss','val_loss'], loc='upper right')
plt.show()

7.分类结果
在这里插入图片描述
整个网络框架大概就是这样,这是其中一个被试的分类结果,属于分类效果比较好的,其他被试可能由于数据质量,网络结构等原因分类效果不是很理想,考虑数据增强以及网络结构优化去提高分类准确率。

这篇关于TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/224835

相关文章

Python调用LibreOffice处理自动化文档的完整指南

《Python调用LibreOffice处理自动化文档的完整指南》在数字化转型的浪潮中,文档处理自动化已成为提升效率的关键,LibreOffice作为开源办公软件的佼佼者,其命令行功能结合Python... 目录引言一、环境搭建:三步构建自动化基石1. 安装LibreOffice与python2. 验证安装

Django HTTPResponse响应体中返回openpyxl生成的文件过程

《DjangoHTTPResponse响应体中返回openpyxl生成的文件过程》Django返回文件流时需通过Content-Disposition头指定编码后的文件名,使用openpyxl的sa... 目录Django返回文件流时使用指定文件名Django HTTPResponse响应体中返回openp

使用Python开发一个Ditto剪贴板数据导出工具

《使用Python开发一个Ditto剪贴板数据导出工具》在日常工作中,我们经常需要处理大量的剪贴板数据,下面将介绍如何使用Python的wxPython库开发一个图形化工具,实现从Ditto数据库中读... 目录前言运行结果项目需求分析技术选型核心功能实现1. Ditto数据库结构分析2. 数据库自动定位3

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

pandas数据的合并concat()和merge()方式

《pandas数据的合并concat()和merge()方式》Pandas中concat沿轴合并数据框(行或列),merge基于键连接(内/外/左/右),concat用于纵向或横向拼接,merge用于... 目录concat() 轴向连接合并(1) join='outer',axis=0(2)join='o

Linux线程同步/互斥过程详解

《Linux线程同步/互斥过程详解》文章讲解多线程并发访问导致竞态条件,需通过互斥锁、原子操作和条件变量实现线程安全与同步,分析死锁条件及避免方法,并介绍RAII封装技术提升资源管理效率... 目录01. 资源共享问题1.1 多线程并发访问1.2 临界区与临界资源1.3 锁的引入02. 多线程案例2.1 为

Django开发时如何避免频繁发送短信验证码(python图文代码)

《Django开发时如何避免频繁发送短信验证码(python图文代码)》Django开发时,为防止频繁发送验证码,后端需用Redis限制请求频率,结合管道技术提升效率,通过生产者消费者模式解耦业务逻辑... 目录避免频繁发送 验证码1. www.chinasem.cn避免频繁发送 验证码逻辑分析2. 避免频繁

批量导入txt数据到的redis过程

《批量导入txt数据到的redis过程》用户通过将Redis命令逐行写入txt文件,利用管道模式运行客户端,成功执行批量删除以Product*匹配的Key操作,提高了数据清理效率... 目录批量导入txt数据到Redisjs把redis命令按一条 一行写到txt中管道命令运行redis客户端成功了批量删除k

分布式锁在Spring Boot应用中的实现过程

《分布式锁在SpringBoot应用中的实现过程》文章介绍在SpringBoot中通过自定义Lock注解、LockAspect切面和RedisLockUtils工具类实现分布式锁,确保多实例并发操作... 目录Lock注解LockASPect切面RedisLockUtils工具类总结在现代微服务架构中,分布

Java使用Thumbnailator库实现图片处理与压缩功能

《Java使用Thumbnailator库实现图片处理与压缩功能》Thumbnailator是高性能Java图像处理库,支持缩放、旋转、水印添加、裁剪及格式转换,提供易用API和性能优化,适合Web应... 目录1. 图片处理库Thumbnailator介绍2. 基本和指定大小图片缩放功能2.1 图片缩放的