Pytorch: 利用预训练的残差网络ResNet50进行图像特征提取,并可视化特征图热图

本文主要是介绍Pytorch: 利用预训练的残差网络ResNet50进行图像特征提取,并可视化特征图热图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 残差网络ResNet的结构

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2.图像特征提取和可视化分析

import cv2
import time
import os
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as npimgname = 'bottle_broken_large.png' 
savepath='vis_resnet50/features_bottle'
if not os.path.isdir(savepath):os.makedirs(savepath)def draw_features(width,height,x,savename):tic = time.time()fig = plt.figure(figsize=(16, 16))fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)for i in range(width*height):plt.subplot(height, width, i + 1)plt.axis('off')img = x[0, i, :, :]pmin = np.min(img)pmax = np.max(img)img = ((img - pmin) / (pmax - pmin + 0.000001))*255  #float在[0,1]之间,转换成0-255img=img.astype(np.uint8)  #转成unit8img=cv2.applyColorMap(img, cv2.COLORMAP_JET) #生成heat mapimg = img[:, :, ::-1]#注意cv2(BGR)和matplotlib(RGB)通道是相反的plt.imshow(img)print("{}/{}".format(i,width*height))fig.savefig(savename, dpi=100)fig.clf()plt.close()print("time:{}".format(time.time()-tic))class ft_net(nn.Module):def __init__(self):super(ft_net, self).__init__()model_ft = models.resnet50(pretrained=True)self.model = model_ftdef forward(self, x):if True: # draw features or notx = self.model.conv1(x)draw_features(8, 8, x.cpu().numpy(),"{}/f1_conv1.png".format(savepath))x = self.model.bn1(x)draw_features(8, 8, x.cpu().numpy(),"{}/f2_bn1.png".format(savepath))x = self.model.relu(x)draw_features(8, 8, x.cpu().numpy(), "{}/f3_relu.png".format(savepath))x = self.model.maxpool(x)draw_features(8, 8, x.cpu().numpy(), "{}/f4_maxpool.png".format(savepath))x = self.model.layer1(x)draw_features(16, 16, x.cpu().numpy(), "{}/f5_layer1.png".format(savepath))x = self.model.layer2(x)draw_features(16, 32, x.cpu().numpy(), "{}/f6_layer2.png".format(savepath))x = self.model.layer3(x)draw_features(32, 32, x.cpu().numpy(), "{}/f7_layer3.png".format(savepath))x = self.model.layer4(x)draw_features(32, 32, x.cpu().numpy()[:, 0:1024, :, :], "{}/f8_layer4_1.png".format(savepath))draw_features(32, 32, x.cpu().numpy()[:, 1024:2048, :, :], "{}/f8_layer4_2.png".format(savepath))x = self.model.avgpool(x)plt.plot(np.linspace(1, 2048, 2048), x.cpu().numpy()[0, :, 0, 0])plt.savefig("{}/f9_avgpool.png".format(savepath))plt.clf()plt.close()x = x.view(x.size(0), -1)x = self.model.fc(x)plt.plot(np.linspace(1, 1000, 1000), x.cpu().numpy()[0, :])plt.savefig("{}/f10_fc.png".format(savepath))plt.clf()plt.close()else :x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)x = self.model.avgpool(x)x = x.view(x.size(0), -1)x = self.model.fc(x)return xmodel = ft_net().cuda()# pretrained_dict = resnet50.state_dict()
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# model_dict.update(pretrained_dict)
# net.load_state_dict(model_dict)
model.eval()
img = cv2.imread(imgname)
img = cv2.resize(img, (288, 288))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = transform(img).cuda()
img = img.unsqueeze(0)with torch.no_grad():start = time.time()out = model(img)print("total time:{}".format(time.time()-start))result = out.cpu().numpy()# ind=np.argmax(out.cpu().numpy())ind = np.argsort(result, axis=1)for i in range(5):print("predict:top {} = cls {} : score {}".format(i+1,ind[0,1000-i-1],result[0,1000-i-1]))print("done")

可视化结果:

在这里插入图片描述

在这里插入图片描述

这篇关于Pytorch: 利用预训练的残差网络ResNet50进行图像特征提取,并可视化特征图热图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/899260

相关文章

Python实现简单封装网络请求的示例详解

《Python实现简单封装网络请求的示例详解》这篇文章主要为大家详细介绍了Python实现简单封装网络请求的相关知识,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录安装依赖核心功能说明1. 类与方法概览2.NetHelper类初始化参数3.ApiResponse类属性与方法使用实

Python进行word模板内容替换的实现示例

《Python进行word模板内容替换的实现示例》本文介绍了使用Python自动化处理Word模板文档的常用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友... 目录技术背景与需求场景核心工具库介绍1.获取你的word模板内容2.正常文本内容的替换3.表格内容的

Git进行版本控制的实战指南

《Git进行版本控制的实战指南》Git是一种分布式版本控制系统,广泛应用于软件开发中,它可以记录和管理项目的历史修改,并支持多人协作开发,通过Git,开发者可以轻松地跟踪代码变更、合并分支、回退版本等... 目录一、Git核心概念解析二、环境搭建与配置1. 安装Git(Windows示例)2. 基础配置(必

Debian 13升级后网络转发等功能异常怎么办? 并非错误而是管理机制变更

《Debian13升级后网络转发等功能异常怎么办?并非错误而是管理机制变更》很多朋友反馈,更新到Debian13后网络转发等功能异常,这并非BUG而是Debian13Trixie调整... 日前 Debian 13 Trixie 发布后已经有众多网友升级到新版本,只不过升级后发现某些功能存在异常,例如网络转

Nginx中配置使用非默认80端口进行服务的完整指南

《Nginx中配置使用非默认80端口进行服务的完整指南》在实际生产环境中,我们经常需要将Nginx配置在其他端口上运行,本文将详细介绍如何在Nginx中配置使用非默认端口进行服务,希望对大家有所帮助... 目录一、为什么需要使用非默认端口二、配置Nginx使用非默认端口的基本方法2.1 修改listen指令

MySQL按时间维度对亿级数据表进行平滑分表

《MySQL按时间维度对亿级数据表进行平滑分表》本文将以一个真实的4亿数据表分表案例为基础,详细介绍如何在不影响线上业务的情况下,完成按时间维度分表的完整过程,感兴趣的小伙伴可以了解一下... 目录引言一、为什么我们需要分表1.1 单表数据量过大的问题1.2 分表方案选型二、分表前的准备工作2.1 数据评估

Python开发简易网络服务器的示例详解(新手入门)

《Python开发简易网络服务器的示例详解(新手入门)》网络服务器是互联网基础设施的核心组件,它本质上是一个持续运行的程序,负责监听特定端口,本文将使用Python开发一个简单的网络服务器,感兴趣的小... 目录网络服务器基础概念python内置服务器模块1. HTTP服务器模块2. Socket服务器模块

Python实现数据可视化图表生成(适合新手入门)

《Python实现数据可视化图表生成(适合新手入门)》在数据科学和数据分析的新时代,高效、直观的数据可视化工具显得尤为重要,下面:本文主要介绍Python实现数据可视化图表生成的相关资料,文中通过... 目录前言为什么需要数据可视化准备工作基本图表绘制折线图柱状图散点图使用Seaborn创建高级图表箱线图热

MySQL进行分片合并的实现步骤

《MySQL进行分片合并的实现步骤》分片合并是指在分布式数据库系统中,将不同分片上的查询结果进行整合,以获得完整的查询结果,下面就来具体介绍一下,感兴趣的可以了解一下... 目录环境准备项目依赖数据源配置分片上下文分片查询和合并代码实现1. 查询单条记录2. 跨分片查询和合并测试结论分片合并(Shardin

Go语言网络故障诊断与调试技巧

《Go语言网络故障诊断与调试技巧》在分布式系统和微服务架构的浪潮中,网络编程成为系统性能和可靠性的核心支柱,从高并发的API服务到实时通信应用,网络的稳定性直接影响用户体验,本文面向熟悉Go基本语法和... 目录1. 引言2. Go 语言网络编程的优势与特色2.1 简洁高效的标准库2.2 强大的并发模型2.