Pytorch: 利用预训练的残差网络ResNet50进行图像特征提取,并可视化特征图热图

本文主要是介绍Pytorch: 利用预训练的残差网络ResNet50进行图像特征提取,并可视化特征图热图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 残差网络ResNet的结构

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2.图像特征提取和可视化分析

import cv2
import time
import os
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as npimgname = 'bottle_broken_large.png' 
savepath='vis_resnet50/features_bottle'
if not os.path.isdir(savepath):os.makedirs(savepath)def draw_features(width,height,x,savename):tic = time.time()fig = plt.figure(figsize=(16, 16))fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)for i in range(width*height):plt.subplot(height, width, i + 1)plt.axis('off')img = x[0, i, :, :]pmin = np.min(img)pmax = np.max(img)img = ((img - pmin) / (pmax - pmin + 0.000001))*255  #float在[0,1]之间,转换成0-255img=img.astype(np.uint8)  #转成unit8img=cv2.applyColorMap(img, cv2.COLORMAP_JET) #生成heat mapimg = img[:, :, ::-1]#注意cv2(BGR)和matplotlib(RGB)通道是相反的plt.imshow(img)print("{}/{}".format(i,width*height))fig.savefig(savename, dpi=100)fig.clf()plt.close()print("time:{}".format(time.time()-tic))class ft_net(nn.Module):def __init__(self):super(ft_net, self).__init__()model_ft = models.resnet50(pretrained=True)self.model = model_ftdef forward(self, x):if True: # draw features or notx = self.model.conv1(x)draw_features(8, 8, x.cpu().numpy(),"{}/f1_conv1.png".format(savepath))x = self.model.bn1(x)draw_features(8, 8, x.cpu().numpy(),"{}/f2_bn1.png".format(savepath))x = self.model.relu(x)draw_features(8, 8, x.cpu().numpy(), "{}/f3_relu.png".format(savepath))x = self.model.maxpool(x)draw_features(8, 8, x.cpu().numpy(), "{}/f4_maxpool.png".format(savepath))x = self.model.layer1(x)draw_features(16, 16, x.cpu().numpy(), "{}/f5_layer1.png".format(savepath))x = self.model.layer2(x)draw_features(16, 32, x.cpu().numpy(), "{}/f6_layer2.png".format(savepath))x = self.model.layer3(x)draw_features(32, 32, x.cpu().numpy(), "{}/f7_layer3.png".format(savepath))x = self.model.layer4(x)draw_features(32, 32, x.cpu().numpy()[:, 0:1024, :, :], "{}/f8_layer4_1.png".format(savepath))draw_features(32, 32, x.cpu().numpy()[:, 1024:2048, :, :], "{}/f8_layer4_2.png".format(savepath))x = self.model.avgpool(x)plt.plot(np.linspace(1, 2048, 2048), x.cpu().numpy()[0, :, 0, 0])plt.savefig("{}/f9_avgpool.png".format(savepath))plt.clf()plt.close()x = x.view(x.size(0), -1)x = self.model.fc(x)plt.plot(np.linspace(1, 1000, 1000), x.cpu().numpy()[0, :])plt.savefig("{}/f10_fc.png".format(savepath))plt.clf()plt.close()else :x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)x = self.model.avgpool(x)x = x.view(x.size(0), -1)x = self.model.fc(x)return xmodel = ft_net().cuda()# pretrained_dict = resnet50.state_dict()
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# model_dict.update(pretrained_dict)
# net.load_state_dict(model_dict)
model.eval()
img = cv2.imread(imgname)
img = cv2.resize(img, (288, 288))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = transform(img).cuda()
img = img.unsqueeze(0)with torch.no_grad():start = time.time()out = model(img)print("total time:{}".format(time.time()-start))result = out.cpu().numpy()# ind=np.argmax(out.cpu().numpy())ind = np.argsort(result, axis=1)for i in range(5):print("predict:top {} = cls {} : score {}".format(i+1,ind[0,1000-i-1],result[0,1000-i-1]))print("done")

可视化结果:

在这里插入图片描述

在这里插入图片描述

这篇关于Pytorch: 利用预训练的残差网络ResNet50进行图像特征提取,并可视化特征图热图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/899260

相关文章

SpringBoot结合Docker进行容器化处理指南

《SpringBoot结合Docker进行容器化处理指南》在当今快速发展的软件工程领域,SpringBoot和Docker已经成为现代Java开发者的必备工具,本文将深入讲解如何将一个SpringBo... 目录前言一、为什么选择 Spring Bootjavascript + docker1. 快速部署与

linux解压缩 xxx.jar文件进行内部操作过程

《linux解压缩xxx.jar文件进行内部操作过程》:本文主要介绍linux解压缩xxx.jar文件进行内部操作,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、解压文件二、压缩文件总结一、解压文件1、把 xxx.jar 文件放在服务器上,并进入当前目录#

SpringBoot中如何使用Assert进行断言校验

《SpringBoot中如何使用Assert进行断言校验》Java提供了内置的assert机制,而Spring框架也提供了更强大的Assert工具类来帮助开发者进行参数校验和状态检查,下... 目录前言一、Java 原生assert简介1.1 使用方式1.2 示例代码1.3 优缺点分析二、Spring Fr

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

Golang如何对cron进行二次封装实现指定时间执行定时任务

《Golang如何对cron进行二次封装实现指定时间执行定时任务》:本文主要介绍Golang如何对cron进行二次封装实现指定时间执行定时任务问题,具有很好的参考价值,希望对大家有所帮助,如有错误... 目录背景cron库下载代码示例【1】结构体定义【2】定时任务开启【3】使用示例【4】控制台输出总结背景

使用Python进行GRPC和Dubbo协议的高级测试

《使用Python进行GRPC和Dubbo协议的高级测试》GRPC(GoogleRemoteProcedureCall)是一种高性能、开源的远程过程调用(RPC)框架,Dubbo是一种高性能的分布式服... 目录01 GRPC测试安装gRPC编写.proto文件实现服务02 Dubbo测试1. 安装Dubb

HTML5 中的<button>标签用法和特征

《HTML5中的<button>标签用法和特征》在HTML5中,button标签用于定义一个可点击的按钮,它是创建交互式网页的重要元素之一,本文将深入解析HTML5中的button标签,详细介绍其属... 目录引言<button> 标签的基本用法<button> 标签的属性typevaluedisabled

Linux使用scp进行远程目录文件复制的详细步骤和示例

《Linux使用scp进行远程目录文件复制的详细步骤和示例》在Linux系统中,scp(安全复制协议)是一个使用SSH(安全外壳协议)进行文件和目录安全传输的命令,它允许在远程主机之间复制文件和目录,... 目录1. 什么是scp?2. 语法3. 示例示例 1: 复制本地目录到远程主机示例 2: 复制远程主

Python数据分析与可视化的全面指南(从数据清洗到图表呈现)

《Python数据分析与可视化的全面指南(从数据清洗到图表呈现)》Python是数据分析与可视化领域中最受欢迎的编程语言之一,凭借其丰富的库和工具,Python能够帮助我们快速处理、分析数据并生成高质... 目录一、数据采集与初步探索二、数据清洗的七种武器1. 缺失值处理策略2. 异常值检测与修正3. 数据

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优