第P7周:咖啡豆识别(VGG-16复现)

2023-12-16 04:28
文章标签 16 复现 识别 p7 vgg 咖啡豆

本文主要是介绍第P7周:咖啡豆识别(VGG-16复现),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rbOOmire8OocQ90QM78DRA) 中的学习记录博客**
>- **🍖 原作者:[K同学啊 | 接辅导、项目定制](https://mtyjkh.blog.csdn.net/)**

一、前期工作

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")             #忽略警告信息device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device



2. 导入数据

import os,PIL,random,pathlibdata_dir = './7-data/'
data_dir = pathlib.Path(data_dir)data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸# transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder("./7-data/",transform=train_transforms)
total_data

3. 划分数据集





二、手动搭建VGG-16模型

VGG-16结构说明:

●13个卷积层(Convolutional Layer),分别用blockX_convX表示
●3个全连接层(Fully connected Layer),分别用fcX与predictions表示
●5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

这里,我制作了一个视频来展示VGG-16的传播过程

play

0:00/0:22

倍速

volumeUp

fullscreen

fullscreen

VGG-16网络动画展示

FC7

FE8

FC6

7*7*512

1*1*512

1*1*1000

1*1*512

CONV5

CONV4

14*14*512

CONV3

28*28*512

CONV2

56*56*256

112*112*128

CONVOLUTION+RELU

K同学啊制作

MAX POOLING

CONVI

百度/谷歌/微信搜索:K同学啊

224*224*64

FULLY CONNECTED+RELU

image.png


1. 搭建模型



2. 查看模型详情


三、 训练模型
1. 编写训练函数


2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器


3. 正式训练

model.train()、model.eval()训练营往期文章中有详细的介绍。

📌如果将优化器换成 SGD 会发生什么呢?请自行探索接下来发生的诡异事件的原因。



四、 结果可视化
1. Loss与Accuracy图

 

TRAINING AND VALIDATION ACCURACY

TRAINING AND

VALIDATION LOSS

1.4

1.0

TRAINING LOSS

1.2

TEST LOSS

1.0

0.8

0.8

0.6

0.6

0.4

0.4

0.2

TRAINING ACCURACY

TEST ACCURACY

0.0

75

35

15

35

25

20

40

30

40

30

output_29_0.png


2. 指定图片进行预测
 

Python复制代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

from PIL import Image

classes = list(total_data.class_to_idx)

def predict_one_image(image_path, model, transform, classes):

test_img = Image.open(image_path).convert('RGB')

plt.imshow(test_img) # 展示预测的图片

test_img = transform(test_img)

img = test_img.to(device).unsqueeze(0)

model.eval()

output = model(img)

_,pred = torch.max(output,1)

pred_class = classes[pred]

print(f'预测结果是:{pred_class}')

Python复制代码

1

2

3

4

5

# 预测训练集中的某张照片

predict_one_image(image_path='./7-data/Dark/dark (1).png',

model=model,

transform=train_transforms,

classes=classes)

Plain Text复制代码

1

预测结果是:Dark

25

50

75

100

125

150

175

200

50

100

150

200

output_32_1.png


3. 模型评估
 

Python复制代码

1

2

best_model.eval()

epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)

Python复制代码

1

epoch_test_acc, epoch_test_loss

Plain Text复制代码

1

(0.9916666666666667, 0.035762640996836126)

Python复制代码

1

2

# 查看是否与我们记录的最高准确率一致

epoch_test_acc

Plain Text复制代码

1

0.9916666666666667

这篇关于第P7周:咖啡豆识别(VGG-16复现)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/499098

相关文章

Python中图片与PDF识别文本(OCR)的全面指南

《Python中图片与PDF识别文本(OCR)的全面指南》在数据爆炸时代,80%的企业数据以非结构化形式存在,其中PDF和图像是最主要的载体,本文将深入探索Python中OCR技术如何将这些数字纸张转... 目录一、OCR技术核心原理二、python图像识别四大工具库1. Pytesseract - 经典O

Python基于微信OCR引擎实现高效图片文字识别

《Python基于微信OCR引擎实现高效图片文字识别》这篇文章主要为大家详细介绍了一款基于微信OCR引擎的图片文字识别桌面应用开发全过程,可以实现从图片拖拽识别到文字提取,感兴趣的小伙伴可以跟随小编一... 目录一、项目概述1.1 开发背景1.2 技术选型1.3 核心优势二、功能详解2.1 核心功能模块2.

Python验证码识别方式(使用pytesseract库)

《Python验证码识别方式(使用pytesseract库)》:本文主要介绍Python验证码识别方式(使用pytesseract库),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全... 目录1、安装Tesseract-OCR2、在python中使用3、本地图片识别4、结合playwrigh

使用Python和PaddleOCR实现图文识别的代码和步骤

《使用Python和PaddleOCR实现图文识别的代码和步骤》在当今数字化时代,图文识别技术的应用越来越广泛,如文档数字化、信息提取等,PaddleOCR是百度开源的一款强大的OCR工具包,它集成了... 目录一、引言二、环境准备2.1 安装 python2.2 安装 PaddlePaddle2.3 安装

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

讯飞webapi语音识别接口调用示例代码(python)

《讯飞webapi语音识别接口调用示例代码(python)》:本文主要介绍如何使用Python3调用讯飞WebAPI语音识别接口,重点解决了在处理语音识别结果时判断是否为最后一帧的问题,通过运行代... 目录前言一、环境二、引入库三、代码实例四、运行结果五、总结前言基于python3 讯飞webAPI语音

使用Python开发一个图像标注与OCR识别工具

《使用Python开发一个图像标注与OCR识别工具》:本文主要介绍一个使用Python开发的工具,允许用户在图像上进行矩形标注,使用OCR对标注区域进行文本识别,并将结果保存为Excel文件,感兴... 目录项目简介1. 图像加载与显示2. 矩形标注3. OCR识别4. 标注的保存与加载5. 裁剪与重置图像

Python爬虫selenium验证之中文识别点选+图片验证码案例(最新推荐)

《Python爬虫selenium验证之中文识别点选+图片验证码案例(最新推荐)》本文介绍了如何使用Python和Selenium结合ddddocr库实现图片验证码的识别和点击功能,感兴趣的朋友一起看... 目录1.获取图片2.目标识别3.背景坐标识别3.1 ddddocr3.2 打码平台4.坐标点击5.图

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为