Pytorch 入门之-- 逻辑回归 Logistic_Regression

2024-05-06 22:32

本文主要是介绍Pytorch 入门之-- 逻辑回归 Logistic_Regression,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

代码+注释

__author__ = 'SherlockLiao'import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import time
# 定义超参数
batch_size = 32
learning_rate = 1e-3
num_epoches = 100# 下载训练集 MNIST 手写数字训练集,不需要下载的话download=FALSE
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=TRUE)test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 定义 Logistic Regression 模型
# pytorch的基本框架init forward class Logstic_Regression(nn.Module):def __init__(self, in_dim, n_class):super(Logstic_Regression, self).__init__()self.logstic = nn.Linear(in_dim, n_class)def forward(self, x):out = self.logstic(x)return out#模型实例化
model = Logstic_Regression(28 * 28, 10)  # 图片大小是28x28
use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:model = model.cuda()# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 开始训练
for epoch in range(num_epoches):print('*' * 10)print('epoch {}'.format(epoch + 1))since = time.time()running_loss = 0.0running_acc = 0.0for i, data in enumerate(train_loader, 1):#迭代数据img, label = dataimg = img.view(img.size(0), -1)  # 将图片展开成 28x28if use_gpu:img = Variable(img).cuda()label = Variable(label).cuda()else:img = Variable(img)label = Variable(label)# 向前传播out = model(img)loss = criterion(out, label)#解决方法:把        train_loss += loss.data[0]        修改为        train_loss += loss.item()running_loss += loss.item()* label.size(0)_, pred = torch.max(out, 1)num_correct = (pred == label).sum()running_acc += num_correct.item()# 向后传播optimizer.zero_grad()loss.backward()optimizer.step()if i % 300 == 0:print('[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, num_epoches, running_loss / (batch_size * i),running_acc / (batch_size * i)))
print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(train_dataset))))model.eval()
eval_loss = 0.
eval_acc = 0.
for data in test_loader:img, label = dataimg = img.view(img.size(0), -1)if use_gpu:img = Variable(img, volatile=True).cuda()label = Variable(label, volatile=True).cuda()else:img = Variable(img, volatile=True)label = Variable(label, volatile=True)out = model(img)loss = criterion(out, label)eval_loss += loss.item() * label.size(0)_, pred = torch.max(out, 1)num_correct = (pred == label).sum()eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_dataset)), eval_acc / (len(test_dataset))))
print('Time:{:.1f} s'.format(time.time() - since))
print()# 保存模型
torch.save(model.state_dict(), './logstic.pth')

 

这篇关于Pytorch 入门之-- 逻辑回归 Logistic_Regression的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/965558

相关文章

Java List 使用举例(从入门到精通)

《JavaList使用举例(从入门到精通)》本文系统讲解JavaList,涵盖基础概念、核心特性、常用实现(如ArrayList、LinkedList)及性能对比,介绍创建、操作、遍历方法,结合实... 目录一、List 基础概念1.1 什么是 List?1.2 List 的核心特性1.3 List 家族成

mybatisplus的逻辑删除过程

《mybatisplus的逻辑删除过程》:本文主要介绍mybatisplus的逻辑删除过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录myBATisplus的逻辑删除1、在配置文件中添加逻辑删除的字段2、在实体类上加上@TableLogic3、业务层正常删除即

c++日志库log4cplus快速入门小结

《c++日志库log4cplus快速入门小结》文章浏览阅读1.1w次,点赞9次,收藏44次。本文介绍Log4cplus,一种适用于C++的线程安全日志记录API,提供灵活的日志管理和配置控制。文章涵盖... 目录简介日志等级配置文件使用关于初始化使用示例总结参考资料简介log4j 用于Java,log4c

史上最全MybatisPlus从入门到精通

《史上最全MybatisPlus从入门到精通》MyBatis-Plus是MyBatis增强工具,简化开发并提升效率,支持自动映射表名/字段与实体类,提供条件构造器、多种查询方式(等值/范围/模糊/分页... 目录1.简介2.基础篇2.1.通用mapper接口操作2.2.通用service接口操作3.进阶篇3

Python自定义异常的全面指南(入门到实践)

《Python自定义异常的全面指南(入门到实践)》想象你正在开发一个银行系统,用户转账时余额不足,如果直接抛出ValueError,调用方很难区分是金额格式错误还是余额不足,这正是Python自定义异... 目录引言:为什么需要自定义异常一、异常基础:先搞懂python的异常体系1.1 异常是什么?1.2

Python实现Word转PDF全攻略(从入门到实战)

《Python实现Word转PDF全攻略(从入门到实战)》在数字化办公场景中,Word文档的跨平台兼容性始终是个难题,而PDF格式凭借所见即所得的特性,已成为文档分发和归档的标准格式,下面小编就来和大... 目录一、为什么需要python处理Word转PDF?二、主流转换方案对比三、五套实战方案详解方案1:

Spring WebClient从入门到精通

《SpringWebClient从入门到精通》本文详解SpringWebClient非阻塞响应式特性及优势,涵盖核心API、实战应用与性能优化,对比RestTemplate,为微服务通信提供高效解决... 目录一、WebClient 概述1.1 为什么选择 WebClient?1.2 WebClient 与

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

Spring Boot 与微服务入门实战详细总结

《SpringBoot与微服务入门实战详细总结》本文讲解SpringBoot框架的核心特性如快速构建、自动配置、零XML与微服务架构的定义、演进及优缺点,涵盖开发环境准备和HelloWorld实战... 目录一、Spring Boot 核心概述二、微服务架构详解1. 微服务的定义与演进2. 微服务的优缺点三

从入门到精通详解LangChain加载HTML内容的全攻略

《从入门到精通详解LangChain加载HTML内容的全攻略》这篇文章主要为大家详细介绍了如何用LangChain优雅地处理HTML内容,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录引言:当大语言模型遇见html一、HTML加载器为什么需要专门的HTML加载器核心加载器对比表二