Pytorch 入门之-- 逻辑回归 Logistic_Regression

2024-05-06 22:32

本文主要是介绍Pytorch 入门之-- 逻辑回归 Logistic_Regression,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

代码+注释

__author__ = 'SherlockLiao'import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import time
# 定义超参数
batch_size = 32
learning_rate = 1e-3
num_epoches = 100# 下载训练集 MNIST 手写数字训练集,不需要下载的话download=FALSE
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=TRUE)test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 定义 Logistic Regression 模型
# pytorch的基本框架init forward class Logstic_Regression(nn.Module):def __init__(self, in_dim, n_class):super(Logstic_Regression, self).__init__()self.logstic = nn.Linear(in_dim, n_class)def forward(self, x):out = self.logstic(x)return out#模型实例化
model = Logstic_Regression(28 * 28, 10)  # 图片大小是28x28
use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:model = model.cuda()# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 开始训练
for epoch in range(num_epoches):print('*' * 10)print('epoch {}'.format(epoch + 1))since = time.time()running_loss = 0.0running_acc = 0.0for i, data in enumerate(train_loader, 1):#迭代数据img, label = dataimg = img.view(img.size(0), -1)  # 将图片展开成 28x28if use_gpu:img = Variable(img).cuda()label = Variable(label).cuda()else:img = Variable(img)label = Variable(label)# 向前传播out = model(img)loss = criterion(out, label)#解决方法:把        train_loss += loss.data[0]        修改为        train_loss += loss.item()running_loss += loss.item()* label.size(0)_, pred = torch.max(out, 1)num_correct = (pred == label).sum()running_acc += num_correct.item()# 向后传播optimizer.zero_grad()loss.backward()optimizer.step()if i % 300 == 0:print('[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, num_epoches, running_loss / (batch_size * i),running_acc / (batch_size * i)))
print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(train_dataset))))model.eval()
eval_loss = 0.
eval_acc = 0.
for data in test_loader:img, label = dataimg = img.view(img.size(0), -1)if use_gpu:img = Variable(img, volatile=True).cuda()label = Variable(label, volatile=True).cuda()else:img = Variable(img, volatile=True)label = Variable(label, volatile=True)out = model(img)loss = criterion(out, label)eval_loss += loss.item() * label.size(0)_, pred = torch.max(out, 1)num_correct = (pred == label).sum()eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_dataset)), eval_acc / (len(test_dataset))))
print('Time:{:.1f} s'.format(time.time() - since))
print()# 保存模型
torch.save(model.state_dict(), './logstic.pth')

 

这篇关于Pytorch 入门之-- 逻辑回归 Logistic_Regression的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/965558

相关文章

从入门到精通C++11 <chrono> 库特性

《从入门到精通C++11<chrono>库特性》chrono库是C++11中一个非常强大和实用的库,它为时间处理提供了丰富的功能和类型安全的接口,通过本文的介绍,我们了解了chrono库的基本概念... 目录一、引言1.1 为什么需要<chrono>库1.2<chrono>库的基本概念二、时间段(Durat

解析C++11 static_assert及与Boost库的关联从入门到精通

《解析C++11static_assert及与Boost库的关联从入门到精通》static_assert是C++中强大的编译时验证工具,它能够在编译阶段拦截不符合预期的类型或值,增强代码的健壮性,通... 目录一、背景知识:传统断言方法的局限性1.1 assert宏1.2 #error指令1.3 第三方解决

从入门到精通MySQL 数据库索引(实战案例)

《从入门到精通MySQL数据库索引(实战案例)》索引是数据库的目录,提升查询速度,主要类型包括BTree、Hash、全文、空间索引,需根据场景选择,建议用于高频查询、关联字段、排序等,避免重复率高或... 目录一、索引是什么?能干嘛?核心作用:二、索引的 4 种主要类型(附通俗例子)1. BTree 索引(

Redis 配置文件使用建议redis.conf 从入门到实战

《Redis配置文件使用建议redis.conf从入门到实战》Redis配置方式包括配置文件、命令行参数、运行时CONFIG命令,支持动态修改参数及持久化,常用项涉及端口、绑定、内存策略等,版本8... 目录一、Redis.conf 是什么?二、命令行方式传参(适用于测试)三、运行时动态修改配置(不重启服务

MySQL DQL从入门到精通

《MySQLDQL从入门到精通》通过DQL,我们可以从数据库中检索出所需的数据,进行各种复杂的数据分析和处理,本文将深入探讨MySQLDQL的各个方面,帮助你全面掌握这一重要技能,感兴趣的朋友跟随小... 目录一、DQL 基础:SELECT 语句入门二、数据过滤:WHERE 子句的使用三、结果排序:ORDE

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

Python中OpenCV与Matplotlib的图像操作入门指南

《Python中OpenCV与Matplotlib的图像操作入门指南》:本文主要介绍Python中OpenCV与Matplotlib的图像操作指南,本文通过实例代码给大家介绍的非常详细,对大家的学... 目录一、环境准备二、图像的基本操作1. 图像读取、显示与保存 使用OpenCV操作2. 像素级操作3.

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据