TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码

本文主要是介绍TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码
脑电信号采集设备是由NT9200-32D型号脑电图仪和NeuSen W系列无线脑电采集系统组成,采集后的信号用Matlab打开,保存在结构体数据中,采集到的原始信号形式是:16x640000 double,最开始对数据进行手动分段分成[280,16,1000],280指trials,22指channels,1000指 samples,
整个代码可分为:**数据切分,搭建网络,训练数据,测试数据,**四个部分
1.导入包

import numpy as np
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
import scipy.io
from matplotlib import pyplot as plt

2.数据切分

K.set_image_data_format('channels_last')
samplesfile = scipy.io.loadmat('F:/holiday_code/attention/TSA/data/foursecond.mat')
X = samplesfile['eeg']#提取数组,结构体名称是eeg
event_id = dict(l=1, m=2, lm=3, ml=4)#四分类运动想象数据
# Setup for reading the raw data
labels = samplesfile['Mark']#加载标签数据
y = labels[:,-1]#标签数据
kernels, chans, samples = 1, 16, 1000# take 50/25/25 percent of the data to train/validate/test
X_train = X[0:140, ]
Y_train = y[0:140]
X_validate = X[140:210, ]
Y_validate = y[140:210]
X_test = X[210:, ]
Y_test = y[210:]
#把标签数据转换成one-hot编码
Y_train = np_utils.to_categorical(Y_train - 1)
Y_validate = np_utils.to_categorical(Y_validate - 1)
Y_test = np_utils.to_categorical(Y_test - 1)
#根据网络结构设置数据的输入形式(trials, channels, samples, kernels)
X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)

4.搭建网络

#导入需要的库
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
def EEGNet(nb_classes, Chans = 16, Samples = 1000,dropoutRate = 0.5, kernLength = 64, F1 = 8, D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):if dropoutType == 'SpatialDropout2D':dropoutType = SpatialDropout2Delif dropoutType == 'Dropout':dropoutType = Dropoutelse:raise ValueError('dropoutType must be one of SpatialDropout2D ''or Dropout, passed as a string.')input1 = Input(shape = (Chans, Samples, 1))print("input shape", input1.shape, Chans, Samples, kernLength)##################################################################block1 = Conv2D(F1, (1, kernLength), padding = 'same',input_shape = (Chans, Samples, 1),use_bias = False)(input1)block1 = BatchNormalization()(block1)block1 = DepthwiseConv2D((Chans, 1), use_bias = False,depth_multiplier = D,depthwise_constraint = max_norm(1.))(block1)block1 = BatchNormalization()(block1)block1 = Activation('elu')(block1)block1 = AveragePooling2D((1, 4))(block1)block1 = dropoutType(dropoutRate)(block1)block2 = SeparableConv2D(F2, (1, 16),use_bias = False, padding = 'same')(block1)block2 = BatchNormalization()(block2)block2 = Activation('elu')(block2)block2 = AveragePooling2D((1, 8))(block2)block2 = dropoutType(dropoutRate)(block2)flatten = Flatten(name = 'flatten')(block2)dense = Dense(nb_classes, name = 'dense',kernel_constraint = max_norm(norm_rate))(flatten)softmax = Activation('softmax', name = 'softmax')(dense)return Model(inputs=input1, outputs=softmax)

5.训练模型

model = EEGNet(nb_classes = 4, Chans = 16, Samples = 1000,dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16,dropoutType = 'Dropout')
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])
# count number of parameters in the model
numParams = model.count_params()
# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5', verbose=1,save_best_only=True)
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}
fittedModel = model.fit(X_train, Y_train, batch_size=16, epochs=300,verbose=2, validation_data=(X_validate, Y_validate),callbacks=[checkpointer], class_weight=class_weights)

6.测试模型

model.load_weights('F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5')
probs = model.predict(X_test)
preds = probs.argmax(axis=-1)
acc = np.mean(preds == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc))# plot the accuracy and loss graph
plt.plot(fittedModel.history['accuracy'])
plt.plot(fittedModel.history['val_accuracy'])
plt.plot(fittedModel.history['loss'])
plt.plot(fittedModel.history['val_loss'])
plt.title('acc & loss')
plt.xlabel('epoch')
plt.legend(['acc', 'val_acc','loss','val_loss'], loc='upper right')
plt.show()

7.分类结果
在这里插入图片描述
整个网络框架大概就是这样,这是其中一个被试的分类结果,属于分类效果比较好的,其他被试可能由于数据质量,网络结构等原因分类效果不是很理想,考虑数据增强以及网络结构优化去提高分类准确率。

这篇关于TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/224835

相关文章

HTML5实现的移动端购物车自动结算功能示例代码

《HTML5实现的移动端购物车自动结算功能示例代码》本文介绍HTML5实现移动端购物车自动结算,通过WebStorage、事件监听、DOM操作等技术,确保实时更新与数据同步,优化性能及无障碍性,提升用... 目录1. 移动端购物车自动结算概述2. 数据存储与状态保存机制2.1 浏览器端的数据存储方式2.1.

基于 HTML5 Canvas 实现图片旋转与下载功能(完整代码展示)

《基于HTML5Canvas实现图片旋转与下载功能(完整代码展示)》本文将深入剖析一段基于HTML5Canvas的代码,该代码实现了图片的旋转(90度和180度)以及旋转后图片的下载... 目录一、引言二、html 结构分析三、css 样式分析四、JavaScript 功能实现一、引言在 Web 开发中,

Python如何去除图片干扰代码示例

《Python如何去除图片干扰代码示例》图片降噪是一个广泛应用于图像处理的技术,可以提高图像质量和相关应用的效果,:本文主要介绍Python如何去除图片干扰的相关资料,文中通过代码介绍的非常详细,... 目录一、噪声去除1. 高斯噪声(像素值正态分布扰动)2. 椒盐噪声(随机黑白像素点)3. 复杂噪声(如伪

Java Spring ApplicationEvent 代码示例解析

《JavaSpringApplicationEvent代码示例解析》本文解析了Spring事件机制,涵盖核心概念(发布-订阅/观察者模式)、代码实现(事件定义、发布、监听)及高级应用(异步处理、... 目录一、Spring 事件机制核心概念1. 事件驱动架构模型2. 核心组件二、代码示例解析1. 事件定义

电脑提示xlstat4.dll丢失怎么修复? xlstat4.dll文件丢失处理办法

《电脑提示xlstat4.dll丢失怎么修复?xlstat4.dll文件丢失处理办法》长时间使用电脑,大家多少都会遇到类似dll文件丢失的情况,不过,解决这一问题其实并不复杂,下面我们就来看看xls... 在Windows操作系统中,xlstat4.dll是一个重要的动态链接库文件,通常用于支持各种应用程序

一文详解如何在idea中快速搭建一个Spring Boot项目

《一文详解如何在idea中快速搭建一个SpringBoot项目》IntelliJIDEA作为Java开发者的‌首选IDE‌,深度集成SpringBoot支持,可一键生成项目骨架、智能配置依赖,这篇文... 目录前言1、创建项目名称2、勾选需要的依赖3、在setting中检查maven4、编写数据源5、开启热

SQL Server修改数据库名及物理数据文件名操作步骤

《SQLServer修改数据库名及物理数据文件名操作步骤》在SQLServer中重命名数据库是一个常见的操作,但需要确保用户具有足够的权限来执行此操作,:本文主要介绍SQLServer修改数据... 目录一、背景介绍二、操作步骤2.1 设置为单用户模式(断开连接)2.2 修改数据库名称2.3 查找逻辑文件名

SQL Server数据库死锁处理超详细攻略

《SQLServer数据库死锁处理超详细攻略》SQLServer作为主流数据库管理系统,在高并发场景下可能面临死锁问题,影响系统性能和稳定性,这篇文章主要给大家介绍了关于SQLServer数据库死... 目录一、引言二、查询 Sqlserver 中造成死锁的 SPID三、用内置函数查询执行信息1. sp_w

Java对异常的认识与异常的处理小结

《Java对异常的认识与异常的处理小结》Java程序在运行时可能出现的错误或非正常情况称为异常,下面给大家介绍Java对异常的认识与异常的处理,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参... 目录一、认识异常与异常类型。二、异常的处理三、总结 一、认识异常与异常类型。(1)简单定义-什么是

Python实例题之pygame开发打飞机游戏实例代码

《Python实例题之pygame开发打飞机游戏实例代码》对于python的学习者,能够写出一个飞机大战的程序代码,是不是感觉到非常的开心,:本文主要介绍Python实例题之pygame开发打飞机... 目录题目pygame-aircraft-game使用 Pygame 开发的打飞机游戏脚本代码解释初始化部