MINIST数据集测试不同参数对网络的影响

2024-03-18 11:08

本文主要是介绍MINIST数据集测试不同参数对网络的影响,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • 一.介绍
    • 1.实验环境
    • 2.网络结构
  • 二.网络效果
    • 1.初始状态
    • 2.有BN层
    • 3.激活函数
      • tanh
      • sigmoid
      • relu
    • 4. 正则化
      • L2正则化
      • Dropout
    • 5.优化器
    • 6. 学习率衰减
  • 三.最优测试
  • 附: 完整代码

一.介绍

本实验使用两个不同的神经网络,通过MINIST数据集进行训练,查看不同情况下最后的效果。

1.实验环境

  1. Python 3.8
  2. Pytorch 1.8
  3. Pycharm

2.网络结构

单层卷积:一层卷积+一层池化+两层全连接

class Net_1(nn.Module):def __init__(self):super(Net_1, self).__init__()self.model = nn.Sequential(nn.Conv2d(1,10,kernel_size=3,stride=1),#nn.BatchNorm2d(10),nn.ReLU(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(10*13*13, 120),nn.ReLU(),#nn.Dropout(p=0.5),nn.Linear(120, 10),)def forward(self,x):return self.model(x)

多层网络:三层卷积+三层池化+两层全连接

class Net_2(nn.Module):def __init__(self):super(Net_2, self).__init__()self.model = nn.Sequential(# out-> [40,28,28]nn.Conv2d(1, 40, kernel_size=3, stride=1,padding=1),#nn.BatchNorm2d(40),nn.ReLU(),# out->[40,14,14]nn.MaxPool2d(kernel_size=2, stride=2),# out->[20,12,12]nn.Conv2d(40, 20, kernel_size=3, stride=1),#nn.BatchNorm2d(20),nn.ReLU(),# out->[20,6,6]nn.MaxPool2d(kernel_size=2, stride=2),# out->[20,4,4]nn.Conv2d(20, 20, kernel_size=3, stride=1),#nn.BatchNorm2d(20),nn.ReLU(),# out->[20,2,2]nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(2 * 2 * 20, 100),nn.ReLU(),#nn.Dropout(p=0.5),nn.Linear(100, 10),)def forward(self, x):return self.model(x)

二.网络效果

1.初始状态

Batchsize :200
Learning_rate:0.001
Epochs:15
优化器:torch.optim.SGD
损失函数:nn.CrossEntropyLoss()
激活函数:Relu
以下变更均建立在此基础之上
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1360.1280.9634
多层网络0.0900.0800.9734

2.有BN层

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0700.0690.9784
多层网络0.0400.0480.9839

添加BN层结果有所好转

3.激活函数

tanh

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.2100.2020.9417
多层网络0.1350.1220.9675

单层多层的表现都不如relu,比较接近

sigmoid

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.7400.6930.8436
多层网络2.3012.3010.1135

relu

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1360.1280.9634
多层网络0.0900.0800.9734

4. 正则化

L2正则化

optimizer=optim.SGD(model.parameters(),lr=learning_rate,momentum=0.9,weight_decay=0.001)
weight_decay设置为0.001
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1290.1200.9646
多层网络0.0960.0810.974

Dropout

nn.Dropout(p=0.5)
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.2010.1280.9617
多层网络0.1620.0860.9718

5.优化器

将SGD优化器变更为Adam优化器
单层网络效果:
在这里插入图片描述
多层网络效果:

在这里插入图片描述
最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0070.0530.9836
多层网络0.0200.0340.9889

6. 学习率衰减

scheduler_step = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)此处使用Adam优化器,weight_decay为0.001,添加BN层,Dropout层,使用Relu函数
单层网络效果:

在这里插入图片描述
多层网络效果:
在这里插入图片描述
最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0580.0430.9853
多层网络0.0240.0220.9929

三.最优测试

Batchsize :200
Learning_rate:0.01
Epochs:15
优化器:torch.optim.Adam,weight_decay为0.001
学习率衰减:step_size=2, gamma=0.2
损失函数:nn.CrossEntropyLoss()
激活函数:Relu
添加BN,dropout层
多层网络效果:
在这里插入图片描述

train_losstest_lossaccuracy
多层网络0.0310.0210.9925

附: 完整代码

import matplotlib.pyplot as plt
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn, optimdef main():batch_size = 200learning_rate = 0.01epochs = 50train_loader = DataLoader(datasets.MNIST('MNIST', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])),batch_size=batch_size,shuffle=True)test_loader = DataLoader(datasets.MNIST('MNIST', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])),batch_size=batch_size,shuffle=True)device = torch.device('cuda')model = Net_2().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.001)scheduler_step = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.2)train_loss = []test_loss = []acc = []for epoch in range(epochs):loss_1 = 0.loss_2 = 0.model.train()for i, (x, label) in enumerate(train_loader):x, label = x.to(device), label.to(device)out = model(x)loss = criteon(out, label)optimizer.zero_grad()loss.backward()optimizer.step()loss_1 += loss.item()scheduler_step.step()loss_1 = loss_1 / (i + 1)train_loss.append(loss_1)print('train: 第{}次,loss为{}'.format(epoch, loss_1))model.eval()with torch.no_grad():correct = 0.for i, (x, label) in enumerate(test_loader):x, label = x.to(device), label.to(device)out = model(x)loss = criteon(out, label)pred = out.argmax(dim=1)correct += torch.eq(pred, label).sum().item()loss_2 += loss.item()accuracy = correct / len(test_loader.dataset)loss_2 = loss_2 / (i + 1)test_loss.append(loss_2)acc.append(accuracy)print('test: 第{}次,loss为{},accuracy为{}'.format(epoch, loss_2, accuracy))plt.figure(num=1, figsize=(10, 5.4))plt.subplot(121)# plt.title("loss")plt.plot(train_loss, 'b-', label='train_loss')plt.plot(test_loss, 'g-', label='test_loss')plt.xlabel('epoch')plt.ylabel('loss')plt.legend()plt.subplot(122)plt.plot(acc, 'g-', label='accuracy')plt.xlabel('epoch')plt.ylabel('accuracy')plt.legend()plt.show()if __name__ == '__main__':main()

这篇关于MINIST数据集测试不同参数对网络的影响的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/822143

相关文章

SpringBoot实现不同接口指定上传文件大小的具体步骤

《SpringBoot实现不同接口指定上传文件大小的具体步骤》:本文主要介绍在SpringBoot中通过自定义注解、AOP拦截和配置文件实现不同接口上传文件大小限制的方法,强调需设置全局阈值远大于... 目录一  springboot实现不同接口指定文件大小1.1 思路说明1.2 工程启动说明二 具体实施2

MyBatis-plus处理存储json数据过程

《MyBatis-plus处理存储json数据过程》文章介绍MyBatis-Plus3.4.21处理对象与集合的差异:对象可用内置Handler配合autoResultMap,集合需自定义处理器继承F... 目录1、如果是对象2、如果需要转换的是List集合总结对象和集合分两种情况处理,目前我用的MP的版本

SpringBoot 获取请求参数的常用注解及用法

《SpringBoot获取请求参数的常用注解及用法》SpringBoot通过@RequestParam、@PathVariable等注解支持从HTTP请求中获取参数,涵盖查询、路径、请求体、头、C... 目录SpringBoot 提供了多种注解来方便地从 HTTP 请求中获取参数以下是主要的注解及其用法:1

HTTP 与 SpringBoot 参数提交与接收协议方式

《HTTP与SpringBoot参数提交与接收协议方式》HTTP参数提交方式包括URL查询、表单、JSON/XML、路径变量、头部、Cookie、GraphQL、WebSocket和SSE,依据... 目录HTTP 协议支持多种参数提交方式,主要取决于请求方法(Method)和内容类型(Content-Ty

Debian 13升级后网络转发等功能异常怎么办? 并非错误而是管理机制变更

《Debian13升级后网络转发等功能异常怎么办?并非错误而是管理机制变更》很多朋友反馈,更新到Debian13后网络转发等功能异常,这并非BUG而是Debian13Trixie调整... 日前 Debian 13 Trixie 发布后已经有众多网友升级到新版本,只不过升级后发现某些功能存在异常,例如网络转

GSON框架下将百度天气JSON数据转JavaBean

《GSON框架下将百度天气JSON数据转JavaBean》这篇文章主要为大家详细介绍了如何在GSON框架下实现将百度天气JSON数据转JavaBean,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录前言一、百度天气jsON1、请求参数2、返回参数3、属性映射二、GSON属性映射实战1、类对象映

sysmain服务可以禁用吗? 电脑sysmain服务关闭后的影响与操作指南

《sysmain服务可以禁用吗?电脑sysmain服务关闭后的影响与操作指南》在Windows系统中,SysMain服务(原名Superfetch)作为一个旨在提升系统性能的关键组件,一直备受用户关... 在使用 Windows 系统时,有时候真有点像在「开盲盒」。全新安装系统后的「默认设置」,往往并不尽编

C# LiteDB处理时间序列数据的高性能解决方案

《C#LiteDB处理时间序列数据的高性能解决方案》LiteDB作为.NET生态下的轻量级嵌入式NoSQL数据库,一直是时间序列处理的优选方案,本文将为大家大家简单介绍一下LiteDB处理时间序列数... 目录为什么选择LiteDB处理时间序列数据第一章:LiteDB时间序列数据模型设计1.1 核心设计原则

Java+AI驱动实现PDF文件数据提取与解析

《Java+AI驱动实现PDF文件数据提取与解析》本文将和大家分享一套基于AI的体检报告智能评估方案,详细介绍从PDF上传、内容提取到AI分析、数据存储的全流程自动化实现方法,感兴趣的可以了解下... 目录一、核心流程:从上传到评估的完整链路二、第一步:解析 PDF,提取体检报告内容1. 引入依赖2. 封装

python中的显式声明类型参数使用方式

《python中的显式声明类型参数使用方式》文章探讨了Python3.10+版本中类型注解的使用,指出FastAPI官方示例强调显式声明参数类型,通过|操作符替代Union/Optional,可提升代... 目录背景python函数显式声明的类型汇总基本类型集合类型Optional and Union(py