27、ResNet50处理STEW数据集,用于情感三分类+全备的代码

2023-12-22 15:44

本文主要是介绍27、ResNet50处理STEW数据集,用于情感三分类+全备的代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、数据介绍

IEEE-Datasets-STEW:SIMULTANEOUS TASK EEG WORKLOAD DATASET :

该数据集由48名受试者的原始EEG数据组成,他们参加了利用SIMKAP多任务测试进行的多任务工作负荷实验。受试者在休息时的大脑活动也在测试前被记录下来,也包括在其中。Emotiv EPOC设备,采样频率为128Hz,有14个通道,用于获取数据,每个案例都有2.5分钟的EEG记录。受试者还被要求在每个阶段后以1到9的评分标准对其感知的心理工作量进行评分,评分结果在单独的文件中提供。

说明:每个受试者的数据遵循命名惯例:subno_task.txt。例如,sub01_lo.txt将是受试者1在休息时的原始脑电数据,而sub23_hi.txt将是受试者23在多任务测试中的原始脑电数据。每个数据文件的行对应于记录中的样本,列对应于EEG设备的14个通道: AF3, F7, F3, FC5, T7, P7, O1, O2, P8, T8, FC6, F4, F8, AF4。

数据说明、下载地址:

STEW: Simultaneous Task EEG Workload Data Set | IEEE Journals & Magazine | IEEE Xplore

2、代码

本次使用ResNet50,去做此情感数据的分类工作,数据导入+模型训练+测试代码如下:

import torch
import torchvision.datasets
from torch.utils.data import Dataset        # 继承Dataset类
import os
from PIL import Image
import numpy as np
from torchvision import transforms# 预处理
data_transform = transforms.Compose([transforms.Resize((224,224)),           # 缩放图像transforms.ToTensor(),                  # 转为Tensotransforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))       # 标准化
])path =  r'C:\STEW\test'for root,dirs,files in os.walk(path):print('root',root) #遍历到该目录地址print('dirs',dirs) #遍历到该目录下的子目录名 []print('files',files)  #遍历到该目录下的文件  []def read_txt_files(path):# 创建文件名列表file_names = []# 遍历给定目录及其子目录下的所有文件for root, dirs, files in os.walk(path):# 遍历所有文件for file in files:# 如果是 .txt 文件,则加入文件名列表if file.endswith('.txt'): # endswith () 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。file_names.append(os.path.join(root, file))# 返回文件名列表return file_namesclass DogCat(Dataset):      # 数据处理def __init__(self,root,transforms = None):                  # 初始化,指定路径,是否预处理等等#['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']imgs = os.listdir(root)self.imgs = [os.path.join(root,img) for img in imgs]    # 取出root下所有的文件self.transforms = data_transform                        # 图像预处理def __getitem__(self, index):       # 读取图片img_path = self.imgs[index]label = 1 if 'dog' in img_path.split('/')[-1] else 0 #然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0.data = Image.open(img_path)if self.transforms:     # 图像预处理data = self.transforms(data)return data,labeldef __len__(self):return len(self.imgs)dataset = DogCat('./data/',transforms=True)for img,label in dataset:print('img:',img.size(),'label:',label)
'''
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
'''import os# 获取file_path路径下的所有TXT文本内容和文件名
def get_text_list(file_path):files = os.listdir(file_path)text_list = []for file in files:with open(os.path.join(file_path, file), "r", encoding="UTF-8") as f:text_list.append(f.read())return text_list, filesclass ImageFolderCustom(Dataset):# 2. Initialize with a targ_dir and transform (optional) parameterdef __init__(self, targ_dir: str, transform=None) -> None:# 3. Create class attributes# Get all image pathsself.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # note: you'd have to update this if you've got .png's or .jpeg's# Setup transformsself.transform = transform# Create classes and class_to_idx attributesself.classes, self.class_to_idx = find_classes(targ_dir)# 4. Make function to load imagesdef load_image(self, index: int) -> Image.Image:"Opens an image via a path and returns it."image_path = self.paths[index]return Image.open(image_path) # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)def __len__(self) -> int:"Returns the total number of samples."return len(self.paths)# 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:"Returns one sample of data, data and label (X, y)."img = self.load_image(index)class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpegclass_idx = self.class_to_idx[class_name]# Transform if necessaryif self.transform:return self.transform(img), class_idx # return data, label (X, y)else:return img, class_idx # return data, label (X, y)import torchvision as tv
import numpy as np
import torch
import time
import os
from torch import nn, optim
from torchvision.models import resnet50
from torchvision.transforms import transformsos.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2"# cifar-10进行测验class Cutout(object):"""Randomly mask out one or more patches from an image.Args:n_holes (int): Number of patches to cut out of each image.length (int): The length (in pixels) of each square patch."""def __init__(self, n_holes, length):self.n_holes = n_holesself.length = lengthdef __call__(self, img):"""Args:img (Tensor): Tensor image of size (C, H, W).Returns:Tensor: Image with n_holes of dimension length x length cut out of it."""h = img.size(1)w = img.size(2)mask = np.ones((h, w), np.float32)for n in range(self.n_holes):y = np.random.randint(h)x = np.random.randint(w)y1 = np.clip(y - self.length // 2, 0, h)y2 = np.clip(y + self.length // 2, 0, h)x1 = np.clip(x - self.length // 2, 0, w)x2 = np.clip(x + self.length // 2, 0, w)mask[y1: y2, x1: x2] = 0.mask = torch.from_numpy(mask)mask = mask.expand_as(img)img = img * maskreturn imgdef load_data_cifar10(batch_size=128,num_workers=2):# 操作合集# Data augmentationtrain_transform_1 = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(degrees=(-80,80)),  # 随机角度翻转transforms.ToTensor(),transforms.Normalize((0.491339968,0.48215827,0.44653124), (0.24703233,0.24348505,0.26158768)  # 两者分别为(mean,std)),Cutout(1, 16),  # 务必放在ToTensor的后面])train_transform_2 = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])test_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])# 训练集1trainset1 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_1,)# 训练集2trainset2 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_2,)# 测试集testset = tv.datasets.CIFAR10(root='data',train=False,download=False,transform=test_transform,)# 训练数据加载器1trainloader1 = torch.utils.data.DataLoader(trainset1,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 训练数据加载器2trainloader2 = torch.utils.data.DataLoader(trainset2,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 测试数据加载器testloader = torch.utils.data.DataLoader(testset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))return trainloader1,trainloader2,testloaderdef main():start = time.time()batch_size = 128cifar_train1,cifar_train2,cifar_test = load_data_cifar10(batch_size=batch_size)model = resnet50().cuda()# model.load_state_dict(torch.load('_ResNet50.pth'))# 存在已保存的参数文件# model = nn.DataParallel(model,device_ids=[0,])  # 又套一层model = nn.DataParallel(model,device_ids=[0,1,2])loss = nn.CrossEntropyLoss().cuda()optimizer = optim.Adam(model.parameters(),lr=0.001)for epoch in range(50):model.train()  # 训练时务必写loss_=0.0num=0.0# train on trainloader1(data augmentation) and trainloader2for i,data in enumerate(cifar_train1,0):x, label = datax, label = x.cuda(),label.cuda()# xp = model(x) #outputl = loss(p,label) #lossoptimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num+=1for i, data in enumerate(cifar_train2, 0):x, label = datax, label = x.cuda(), label.cuda()# xp = model(x)l = loss(p, label)optimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num += 1model.eval()  # 评估时务必写print("loss:",float(loss_)/num)# test on trainloader2,testloaderwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_train2:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x)# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_1 = total_correct / total_num# Testwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x) #output# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_2 = total_correct / total_numprint(epoch+1,'train acc',acc_1,'|','test acc:', acc_2)# 保存时只保存model.moduletorch.save(model.module.state_dict(),'resnet50.pth')print("The interval is :",time.time() - start)if __name__ == '__main__':main()

3、对你有帮助的话,给个关注吧~

这篇关于27、ResNet50处理STEW数据集,用于情感三分类+全备的代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/524525

相关文章

Spring Boot @RestControllerAdvice全局异常处理最佳实践

《SpringBoot@RestControllerAdvice全局异常处理最佳实践》本文详解SpringBoot中通过@RestControllerAdvice实现全局异常处理,强调代码复用、统... 目录前言一、为什么要使用全局异常处理?二、核心注解解析1. @RestControllerAdvice2

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

MyBatisPlus如何优化千万级数据的CRUD

《MyBatisPlus如何优化千万级数据的CRUD》最近负责的一个项目,数据库表量级破千万,每次执行CRUD都像走钢丝,稍有不慎就引起数据库报警,本文就结合这个项目的实战经验,聊聊MyBatisPl... 目录背景一、MyBATis Plus 简介二、千万级数据的挑战三、优化 CRUD 的关键策略1. 查

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

mysql中的数据目录用法及说明

《mysql中的数据目录用法及说明》:本文主要介绍mysql中的数据目录用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、版本3、数据目录4、总结1、背景安装mysql之后,在安装目录下会有一个data目录,我们创建的数据库、创建的表、插入的

MySQL数据库的内嵌函数和联合查询实例代码

《MySQL数据库的内嵌函数和联合查询实例代码》联合查询是一种将多个查询结果组合在一起的方法,通常使用UNION、UNIONALL、INTERSECT和EXCEPT关键字,下面:本文主要介绍MyS... 目录一.数据库的内嵌函数1.1聚合函数COUNT([DISTINCT] expr)SUM([DISTIN

Navicat数据表的数据添加,删除及使用sql完成数据的添加过程

《Navicat数据表的数据添加,删除及使用sql完成数据的添加过程》:本文主要介绍Navicat数据表的数据添加,删除及使用sql完成数据的添加过程,具有很好的参考价值,希望对大家有所帮助,如有... 目录Navicat数据表数据添加,删除及使用sql完成数据添加选中操作的表则出现如下界面,查看左下角从左

MySQL中的索引结构和分类实战案例详解

《MySQL中的索引结构和分类实战案例详解》本文详解MySQL索引结构与分类,涵盖B树、B+树、哈希及全文索引,分析其原理与优劣势,并结合实战案例探讨创建、管理及优化技巧,助力提升查询性能,感兴趣的朋... 目录一、索引概述1.1 索引的定义与作用1.2 索引的基本原理二、索引结构详解2.1 B树索引2.2