使用PyTorch实现手写数字识别功能

2025-03-24 14:50

本文主要是介绍使用PyTorch实现手写数字识别功能,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识...

当计算机学会“看”数字

在人工智能的世界里,计算机视觉是最具魅力的领域之一。通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识别数字的全过程。本文将以通俗易懂的方式,带你走进这个看似神秘实则充满逻辑的美妙世界。

搭建开发环境

在开始训练之前,我们需要准备好三个基础要素:腾讯云HAI,腾讯云HAI,腾讯云HAI。导入必要的工具库:

import torch  # 深度学习框架核心
import torch.nn as nn  # 神经网络模块
from torchvision import datasets, transforms  # 数据处理利器

MNIST数据集解析

1. 认识手写数字数据库

MNIST数据集包含6万张训练图片和1万张测试图片,每张都是28x28像素的灰度图。这些数字由美国高中生和人口普查局员工书写,构成了计算机视觉领域的"Hello World"。

2. 数据预处理的艺术

原始图片需要经过精心处理才能被模型理解:

transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为数值矩阵
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化处理
])

3. 可视化的重要性

通过Matplotlib展示样本图片,我们能直观感受数据的特征:

plt.imshow(images[0].squeeze(), cmap='gray')
plt.title(f'Label: {labels[0]}')

神经网络设计

1. 网络结构蓝图

我们设计一个全连接网络(FCN),其结构如同人类神经系统的简化版:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()  # 将图片展开为向量
        self.fc1 = nn.Linear(28*28, 128)  # 第一隐藏层
        self.fc2 = nn.Linear(128, 64)    # 第二隐藏层
        self.fc3 = nn.Linear(64, 10)     # 输出层
        self.dropout = nn.Dropout(0.5)   # 正则化装置
  • 神经元数量的选择需要平衡学习能力与过拟合风险
  • Dropout层像随机关闭部分神经元,防止模型"死记硬背"

2. 信息传递机制

前向传播模拟人脑的信息处理过程:ReLU激活函数如同神经元的开关,决定是否传递信号。

def forward(self, x):
    x = self.flatten(x)  # 展平操作:将图片变为784维向量
    x = torch.relu(self.fc1(x))  # 通过第一个全连接层
    x = self.dropout(x)         # 随机屏蔽部分神经元
    x = torch.relu(self.fc2(x))  # 第二个全连接层
    return self.fc3(China编程x)          # 最终输出10个数字的概率

让模型学会思考

1. 配置学习参数

  • 损失函数:交叉熵损失(CrossEntropyLoss),衡量预测与真实的差距
  • 优化:Adam优化器,智能调节学习步伐的导航员
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

2. 训练循环解析

每个epoch都是一次完整的学习轮回:

def train(epoch):
    model.train()  # 切换至训练模式
    for BATch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()  # 清空之前的梯度
        output = model(data)    # 前向传播
        loss = criterion(output, target)  # 计算损失值
        loss.backward()         # 反向传播求梯度
        optimizer.step()        # 更新网络参数
  • 梯度清零避免不同批次数据的干扰
  • 反向传播就像纠错老师,沿着计算链修正参数

完整代码示例

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

from torchvision import transforms
import matplotlib.pyplot as plt

# 2. 数据准备
# 定义数据预处理:转换为Tensor并标准化(MNIST的均值和标准差)
transfoChina编程rm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载训练集和测试集
train_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(
    train_dataset,

    batch_size=64,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1000,
    shuffle=False
)

# 查看数据集信息
print(f'Train samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')

# 可视化样本
images, labels = next(iter(train_loader))
plt.imshow(images[0].squeeze(), cmap='gray')
plt.title(f'Label: {labels[0]}')
plt.show()

# 3. 定义神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
print(model)

# 4. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)


# 5. 训练模型
def train(epoch):

    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        

        running_loss += loss.item()
        if batch_idx % 100 == 99:

            print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Loss: {running_loss/100:.3f}')
            running_loss = 0.0

# 6. 测试模型

def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

# 7. 执行训练和测试

epochs = 5
for epoch in range(epochs):
    train(epoch)
    test()

# 8.编程 保存模型
torch.save(model.state_dict(), 'mnist_model.pth')

使用PyTorch实现手写数字识别功能

针对1-9数字的测试

# 扩展测试函数,增加按数字统计的功能

def detailed_test():
    model.eval()
    class_correct = [0] * 10  # 存储每个数字的正确计数
    class_total = [0] * 10    # 存储每个数字的总样本数
    
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            # 遍历每个预测结果
            for label, prediction in zip(labels, predicted):
                class_total[label] += 1
                if label == prediction:
                    class_correct[label] += 1


    # 打印每个数字的准确率
    print("{:^10} | {:^10} | {:^10}".format("数字", "正确数", "准确率"))
    print("-"*33)

    for i in range(10):
        acc = 100 * class_correct[i] / class_total[i]
        print("{:^10} | {:^10} | {:^10.2f}%".format(i, class_correct[i], acc))
    
    # 可视化错误案例
    wrong_examples = []

    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        mask = predicted != labels
        wrong_examples.extend(zip(images[mask], labels[mask], predicted[mask]))
    
    # 随机展示3个错误样本

    fig, axes = plt.subplots(1, 3, figsize=(12,4))

    for ax, (img, true, pred) in zip(axes, wrong_examples[:3]):
        ax.imshow(img.squeeze(), cmap='gray')

        ax.set_title(f'True: {true}\nPred: {pred}')
        ax.axis('off')
    plt.show()

# 执行详细测试
detailed_test()

使用PyTorch实现手写数字识别功能

PyTorch vs TensorFlow 深度对比

1. 核心架构差异

特性PyTorchTensorFlow
计算图动态图(即时执行)静态图(需预先定义)
调试便利性支持标准python调试工具需要特殊工具(tfdbg)
API设计更接近Python原生语法自成体系的API风格
移动端部署支持但生态较弱通过TF Lite有成熟解决方案

2. 相同功能的代码对比

以定义全连接层为例:

# PyTorch版
import torch.nn as nn
layer = nn.Linear(in_features=784, out_features=128)

# TensorFlow版
from tensorflow.keras.layers import Dense
layer = Dense(units=128, input_dim=784)

3. 训练流程对比

PyTorch训练循环

for epoandroidch in range(epochs):
    for data, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

TensorFlow训练流程

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(train_dataset, epochs=epochs)  # 自动完成训练循环

4. 性能对比(MNIST示例)

指标PyTorch(CPU)TensorFlow(CPU)
训练时间/epoch~45秒~50秒
内存占用~800MB~1GB
测试准确率97.8-98.2%97.5-98.0%

工具的本质

PyTorch与TensorFlow的差异,本质上是灵活性规范性的不同追求。就像画家选择画笔,PyTorch提供的是自由挥洒的水彩,TensorFlow则是精准可控的钢笔。理解它们的特性差异,根据项目需求选择合适的工具,才是提升开发效率的关键。无论是哪个框架,最终目标都是将数学公式转化为智能的力量。

以上就是使用PyTorch实现手写数字识别功能的详细内容,更多关于PyTorch手写数字识别的资料请关注编程javascript栈(www.chinasem.cn)其它相关文章!

这篇关于使用PyTorch实现手写数字识别功能的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1153919

相关文章

使用jenv工具管理多个JDK版本的方法步骤

《使用jenv工具管理多个JDK版本的方法步骤》jenv是一个开源的Java环境管理工具,旨在帮助开发者在同一台机器上轻松管理和切换多个Java版本,:本文主要介绍使用jenv工具管理多个JD... 目录一、jenv到底是干啥的?二、jenv的核心功能(一)管理多个Java版本(二)支持插件扩展(三)环境隔

Nexus安装和启动的实现教程

《Nexus安装和启动的实现教程》:本文主要介绍Nexus安装和启动的实现教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、Nexus下载二、Nexus安装和启动三、关闭Nexus总结一、Nexus下载官方下载链接:DownloadWindows系统根

SQL中JOIN操作的条件使用总结与实践

《SQL中JOIN操作的条件使用总结与实践》在SQL查询中,JOIN操作是多表关联的核心工具,本文将从原理,场景和最佳实践三个方面总结JOIN条件的使用规则,希望可以帮助开发者精准控制查询逻辑... 目录一、ON与WHERE的本质区别二、场景化条件使用规则三、最佳实践建议1.优先使用ON条件2.WHERE用

SpringBoot集成LiteFlow实现轻量级工作流引擎的详细过程

《SpringBoot集成LiteFlow实现轻量级工作流引擎的详细过程》LiteFlow是一款专注于逻辑驱动流程编排的轻量级框架,它以组件化方式快速构建和执行业务流程,有效解耦复杂业务逻辑,下面给大... 目录一、基础概念1.1 组件(Component)1.2 规则(Rule)1.3 上下文(Conte

Java中Map.Entry()含义及方法使用代码

《Java中Map.Entry()含义及方法使用代码》:本文主要介绍Java中Map.Entry()含义及方法使用的相关资料,Map.Entry是Java中Map的静态内部接口,用于表示键值对,其... 目录前言 Map.Entry作用核心方法常见使用场景1. 遍历 Map 的所有键值对2. 直接修改 Ma

MySQL 衍生表(Derived Tables)的使用

《MySQL衍生表(DerivedTables)的使用》本文主要介绍了MySQL衍生表(DerivedTables)的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学... 目录一、衍生表简介1.1 衍生表基本用法1.2 自定义列名1.3 衍生表的局限在SQL的查询语句select

MySQL 横向衍生表(Lateral Derived Tables)的实现

《MySQL横向衍生表(LateralDerivedTables)的实现》横向衍生表适用于在需要通过子查询获取中间结果集的场景,相对于普通衍生表,横向衍生表可以引用在其之前出现过的表名,本文就来... 目录一、横向衍生表用法示例1.1 用法示例1.2 使用建议前面我们介绍过mysql中的衍生表(From子句

MybatisPlus service接口功能介绍

《MybatisPlusservice接口功能介绍》:本文主要介绍MybatisPlusservice接口功能介绍,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友... 目录Service接口基本用法进阶用法总结:Lambda方法Service接口基本用法MyBATisP

Mybatis的分页实现方式

《Mybatis的分页实现方式》MyBatis的分页实现方式主要有以下几种,每种方式适用于不同的场景,且在性能、灵活性和代码侵入性上有所差异,对Mybatis的分页实现方式感兴趣的朋友一起看看吧... 目录​1. 原生 SQL 分页(物理分页)​​2. RowBounds 分页(逻辑分页)​​3. Page

Mybatis Plus Join使用方法示例详解

《MybatisPlusJoin使用方法示例详解》:本文主要介绍MybatisPlusJoin使用方法示例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,... 目录1、pom文件2、yaml配置文件3、分页插件4、示例代码:5、测试代码6、和PageHelper结合6