Pytorch实例----CAFAR10数据集分类(VGG)

2024-03-02 16:20

本文主要是介绍Pytorch实例----CAFAR10数据集分类(VGG),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

      在上一篇 Pytorch实例----CAFAR10数据集分类(AlexNet)的识别统计,本篇主要调整Net()类,设计VGG网络(+BN)后的识别统计(其他设计注释同上)。

                                                       VGG与AlexNet在CAFAR10数据集的统计结果对比图

可以看到,对于之前cat(19%), bird(33%), truck(47%)有显著提高:cat(50%), bird(42%), truck(80%), 最高识别的类别为:ship(86%), car(81%), frog(80%), turck(80%), 由原来55%的平均识别率提升为71%,各类的识别显著提升。此时的VGG网络仅为VGG11,通过使用VGG16, VGG19有望进一步提升准确率。

VGG网络结构编程实现:

#define the network
cfg = {'VGG11':[64, 'M', 128, 'M', 256, 'M', 512, 'M', 512, 'M'],'VGG13':[64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'VGG16':[64, 64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
}
class Net(nn.Module):def __init__(self, vgg_name):super(Net, self).__init__()self.features = self._make_layer(cfg[vgg_name])self.classifer = nn.Linear(512, 10)def forward(self, x):out = self.features(x)out = out.view(out.size(0), -1)out = self.classifer(out)return outdef _make_layer(self, cfg):layers = []in_channels = 3for x in cfg:if x == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(x),nn.ReLU(True)]in_channels = xreturn nn.Sequential(*layers)net = Net('VGG11')

 整体代码实现:

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from torchvision import modelsimport matplotlib.pyplot as plt
import numpy as npdef imshow(img):img = img / 2 + 0.5np_img = img.numpy()plt.imshow(np.transpose(np_img, (1, 2, 0)))#define transform
#hint: Normalize(mean, var) to normalize RGB
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])
#define trainloader
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
#define testloader
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2)
#define class
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')#define the network
cfg = {'VGG11':[64, 'M', 128, 'M', 256, 'M', 512, 'M', 512, 'M'],'VGG13':[64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'VGG16':[64, 64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
}
class Net(nn.Module):def __init__(self, vgg_name):super(Net, self).__init__()self.features = self._make_layer(cfg[vgg_name])self.classifer = nn.Linear(512, 10)def forward(self, x):out = self.features(x)out = out.view(out.size(0), -1)out = self.classifer(out)return outdef _make_layer(self, cfg):layers = []in_channels = 3for x in cfg:if x == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(x),nn.ReLU(True)]in_channels = xreturn nn.Sequential(*layers)net = Net('VGG11')
if torch.cuda.is_available():net.cuda()
print(net)
#define loss
cost = nn.CrossEntropyLoss()
#define optimizer
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)print('start')
#iteration for training
for epoch in range(2):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())optimizer.zero_grad()outputs = net(inputs)loss = cost(outputs, labels)loss.backward()optimizer.step()#print loss resultrunning_loss += loss.item()if i % 2000 == 1999:print('[%d, %5d]  loss: %.3f'%(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.001
print('done')#get random image and label
dataiter = iter(testloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print('groundTruth: ', ''.join('%6s' %classes[labels[j]] for j in range(4)))#get the predict result
outputs = net(Variable(images.cuda()))
_, pred = torch.max(outputs.data, 1)
print('prediction: ', ''.join('%6s' %classes[labels[j]] for j in range(4)))#test the whole result
correct = 0.0
total = 0
for data in testloader:images, labels = dataoutputs = net(Variable(images.cuda()))_, pred = torch.max(outputs.data, 1)total += labels.size(0)correct += (pred == labels.cuda()).sum()
print('average Accuracy: %d %%' %(100*correct / total))#list each class prediction
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
for data in testloader:images, labels = dataoutputs = net(Variable(images.cuda()))_, pred = torch.max(outputs.data, 1)c = (pred == labels.cuda()).squeeze()for i in range(4):label = labels[i]class_correct[label] += float(c[i])class_total[label] += 1
print('each class accuracy: \n')
for i in range(10):print('Accuracy: %6s %2d %%' %(classes[i], 100 * class_correct[i] / class_total[i]))

实验结果:

practice makes perfect !

github source code:  https://github.com/GinkgoX/CAFAR10_Classification_Task/blob/master/CAFAR10_VGG.ipynb

这篇关于Pytorch实例----CAFAR10数据集分类(VGG)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/766725

相关文章

Linux下利用select实现串口数据读取过程

《Linux下利用select实现串口数据读取过程》文章介绍Linux中使用select、poll或epoll实现串口数据读取,通过I/O多路复用机制在数据到达时触发读取,避免持续轮询,示例代码展示设... 目录示例代码(使用select实现)代码解释总结在 linux 系统里,我们可以借助 select、

PyQt6 键盘事件处理的实现及实例代码

《PyQt6键盘事件处理的实现及实例代码》本文主要介绍了PyQt6键盘事件处理的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起... 目录一、键盘事件处理详解1、核心事件处理器2、事件对象 QKeyEvent3、修饰键处理(1)、修饰键类

C#使用iText获取PDF的trailer数据的代码示例

《C#使用iText获取PDF的trailer数据的代码示例》开发程序debug的时候,看到了PDF有个trailer数据,挺有意思,于是考虑用代码把它读出来,那么就用到我们常用的iText框架了,所... 目录引言iText 核心概念C# 代码示例步骤 1: 确保已安装 iText步骤 2: C# 代码程

Pandas处理缺失数据的方式汇总

《Pandas处理缺失数据的方式汇总》许多教程中的数据与现实世界中的数据有很大不同,现实世界中的数据很少是干净且同质的,本文我们将讨论处理缺失数据的一些常规注意事项,了解Pandas如何表示缺失数据,... 目录缺失数据约定的权衡Pandas 中的缺失数据None 作为哨兵值NaN:缺失的数值数据Panda

C++中处理文本数据char与string的终极对比指南

《C++中处理文本数据char与string的终极对比指南》在C++编程中char和string是两种用于处理字符数据的类型,但它们在使用方式和功能上有显著的不同,:本文主要介绍C++中处理文本数... 目录1. 基本定义与本质2. 内存管理3. 操作与功能4. 性能特点5. 使用场景6. 相互转换核心区别

python库pydantic数据验证和设置管理库的用途

《python库pydantic数据验证和设置管理库的用途》pydantic是一个用于数据验证和设置管理的Python库,它主要利用Python类型注解来定义数据模型的结构和验证规则,本文给大家介绍p... 目录主要特点和用途:Field数值验证参数总结pydantic 是一个让你能够 confidentl

JAVA实现亿级千万级数据顺序导出的示例代码

《JAVA实现亿级千万级数据顺序导出的示例代码》本文主要介绍了JAVA实现亿级千万级数据顺序导出的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面... 前提:主要考虑控制内存占用空间,避免出现同时导出,导致主程序OOM问题。实现思路:A.启用线程池

SpringBoot分段处理List集合多线程批量插入数据方式

《SpringBoot分段处理List集合多线程批量插入数据方式》文章介绍如何处理大数据量List批量插入数据库的优化方案:通过拆分List并分配独立线程处理,结合Spring线程池与异步方法提升效率... 目录项目场景解决方案1.实体类2.Mapper3.spring容器注入线程池bejsan对象4.创建

PHP轻松处理千万行数据的方法详解

《PHP轻松处理千万行数据的方法详解》说到处理大数据集,PHP通常不是第一个想到的语言,但如果你曾经需要处理数百万行数据而不让服务器崩溃或内存耗尽,你就会知道PHP用对了工具有多强大,下面小编就... 目录问题的本质php 中的数据流处理:为什么必不可少生成器:内存高效的迭代方式流量控制:避免系统过载一次性

C#实现千万数据秒级导入的代码

《C#实现千万数据秒级导入的代码》在实际开发中excel导入很常见,现代社会中很容易遇到大数据处理业务,所以本文我就给大家分享一下千万数据秒级导入怎么实现,文中有详细的代码示例供大家参考,需要的朋友可... 目录前言一、数据存储二、处理逻辑优化前代码处理逻辑优化后的代码总结前言在实际开发中excel导入很