空间变换器网络的简介+实现

2024-04-27 02:38

本文主要是介绍空间变换器网络的简介+实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

空间变换器网络

是对任何空间变换的差异化关注的概括。空间变换器网络(简称STN)允许神经网
络学习如何在输入图像上执行空间变换, 以增强模型的几何不变性。

例如,它可以裁剪感兴趣的区域,缩放并校正图像的方向。而这可能是一种有用的机制,因为CNN对于旋转和 缩放以及更一
般的仿射变换并不是不变的。

 

空间变换器网络归结为三个主要组成部分:
本地网络(Localisation Network)是常规CNN,其对变换参数进行回归。不会从该数据集中
明确地学习转换,而是网络自动学习增强 全局准确性的空间变换。
网格生成器( Grid Genator)在输入图像中生成与输出图像中的每个像素相对应的坐标网格。
采样器(Sampler)使用变换的参数并将其应用于输入图像

 

更多有关空间变换器网络的内容 :https://arxiv.org/abs/1506.02025

 

STN的最棒的事情之一是能够简单地将其插入任何现有的CNN,而且只需很少的修改

导包

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
plt.ion() # 交互模式

 

1、加载数据集


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)),
])train_data = datasets.MNIST(root='.',train=True,download=True,transform= transform)train_loader = torch.utils.data.DataLoader(train_data,batch_size = 64,shuffle = True,num_workers = 4)test_data = datasets.MNIST(root='.',train=False,transform= transform)
# 测试数据集
test_loader = torch.utils.data.DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=4)

2、 定义网络结构


# 定义网络结构
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1,10,kernel_size=5)self.conv2 = nn.Conv2d(10,20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)# 空间变换器定位 -- 网络self.localization = nn.Sequential(nn.Conv2d(1, 8, kernel_size = 7),nn.MaxPool2d(2, stride =2),nn.ReLU(True),nn.Conv2d(8,10, kernel_size = 5),nn.MaxPool2d(2, stride=2),nn.ReLU(True))# 3*2 affine 矩阵的回归量self.fc_loc = nn.Sequential(nn.Linear(10*3*3, 32),nn.ReLU(True),nn.Linear(32, 3*2))# 使用身份转换初始化权重 / 偏差self.fc_loc[2].weight.data.zero_()self.fc_loc[2].bias.data.copy_(torch.tensor([1,0,0,0,1,0],dtype=torch.float))# 空间变换器网络转发功能def stn(self, x):xs = self.localization(x)#  1--10xs = xs.view(-1, 10*3*3)theta = self.fc_loc(xs) # 90 -- 6theta = theta.view(-1, 2, 3)grid = F.affine_grid(theta, x.size())x = F.grid_sample(x, grid)return xdef forward(self, x):# transform the inputx = self.stn(x)# 执行一般的前进传递x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training =self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)

3、定义网络和优化器

model = Net().to(device)
# train model
optimizer = optim.SGD(model.parameters(), lr=0.01)

 

4、训练模型

训练模型 现在我们使用 SGD(随机梯度下降)算法来训练模型。网络正在以有监督的方式学习分
类任务。同时,该模型以端到端的方式自动学习STN。


def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx%500 ==0:print('Train Epoch: {} [{}/{} ({:.0f}%)] \tLoss: {:.6f}'.format(epoch, batch_idx*len(data), len(train_loader.dataset),100. *batch_idx / len(train_loader),loss.item()))

5、测试函数


# 测试函数
def test():with torch.no_grad():model.eval()test_loss = 0correct = 0for data, target in test_loader:data , target = data.to(device), target.to(device)output = model(data)# 累加批量损失test_loss += F.nll_loss(output, target, size_average=False).item()# 获取最大对数概率的索引pred = output.max(1, keepdim = True)[1]correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f} %))\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

 

6 可视化


# 可视化 STN 结果
def convert_image_np(inp):inp = inp.numpy().transpose((1,2,0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std*inp +meaninp = np.clip(inp, 0, 1)return inp# STN 可视化一批输入图像和相应变换批次def visualize_stn():with torch.no_grad():data = next(iter(test_loader))[0].to(device)input_tensor = data.cpu()transformed_input_tensor = model.stn(data).cpu()in_grid = convert_image_np(torchvision.utils.make_grid(input_tensor))out_grid = convert_image_np(torchvision.utils.make_grid(transformed_input_tensor))# Plot the results side-by_sidef, axarr= plt.subplots(1,2)axarr[0].imshow(in_grid)axarr[0].set_title("Dataset Images")axarr[1].imshow(out_grid)axarr[1].set_title("Transformed Images")

7、训练并显示结果


for epoch in range(1, 20+1):train(epoch)test()plt.ioff()
plt.imshow()

 

完整项目链接:https://github.com/Whq123/Space-transformer-network

这篇关于空间变换器网络的简介+实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/939316

相关文章

enable_shared_from_this 实现原理

前言 enable_shared_from_this 可以帮助我们用 this 指针安全地创建 shared_ptr。 enable_shared_from_this 假设我们的程序使用 shared_ptr 管理 Widget 对象,我们用一个 vector 来记录已经处理过的 Widget 对象: std::vector<std::shared_ptr<Widget>> process

【PyTorch与深度学习】6、PyTorch中搭建分类网络实例

课程地址 最近做实验发现自己还是基础框架上掌握得不好,于是开始重学一遍PyTorch框架,这个是课程笔记,此节课很详细,笔记记的比较粗,这个视频课是需要有点深度学习数学基础的,如果没有数学基础,可以一边学一边查一查 1. Transforms 我们导入到数据集中的图片可能大小不一样,数据并不总是以训模型所需的最终处理形式出现。我们使用Transforms对数据进行一些操作,使其适合训练(比如统

Java微信小程序订阅消息提醒的实现与对接

文章目录 一、准备工作1. 注册微信小程序,并开通订阅消息功能。2. 获取小程序的AppID和AppSecret。3. 在微信小程序管理后台,设置提醒模板,并获取模板ID。4. 小程序端需要获取用户订阅允许提醒的权限(1)引导用户触发订阅(2)编写订阅消息的JavaScript代码(3)用户授权订阅(4)后台发送订阅消息(5)注意事项 二、实现步骤1. 获取Access Token2. 用

网络审计:为什么定期检查您的网络很重要

在数字化时代,网络安全成为组织和个人必须面对的重要挑战。网络审计是一种关键的安全措施,通过定期检查和评估网络系统的安全性,帮助发现潜在的安全漏洞和弱点,从而防止数据泄露和其他安全威胁。本文将介绍网络审计的重要性,并提供一系列具体的操作步骤,帮助您有效地执行网络审计。   为什么网络审计很重要? 发现潜在的安全漏洞:网络审计可以帮助识别网络中存在的安全漏洞,防止未经授权的访问和数据泄露。 遵守

实现腾讯地图的接口调用以及微信小程序的地图标注

目录 微信小程序端1. 引入腾讯地图SDK2. 使用地图组件3. 地图页面编写4. 地图标注 Java后端业务逻辑1. 引入腾讯地图Java SDK2. 配置API密钥3. 调用腾讯地图API4. 提供小程序调用的接口 总结 要实现腾讯地图的接口调用以及微信小程序的地图标注,需要分为两个部分:微信小程序端和Java后端。下面将分别介绍这两部分的实现过程和代码。 微信小程序端

二叉树基本操作--java实现

public class Node { //node class, the base of treeint data;int index;Node leftChild;Node rightChild;} import java.util.Scanner;/*** @author NEU 灏忓畤* @version 1.1*/public class TreeD

红黑树模拟实现map与set

目录 1.红黑树的泛型 1.1 库里的做法 1.2 我们的模拟实现  map部分实现 红黑树部分区域的改造 insert与find的改造  节点的定义  2. 红黑树的迭代器  库里的红黑树。 ​编辑 迭代器中的方法 迭代器的实现  迭代器的定义 迭代器的方法实现  ​编辑 代码展示  红黑树迭代器实现  树内实现begin,end  map与set封装  测试

Spring Security实现用户认证一:简单示例

Spring Security实现用户认证一:简单示例 1 原理1.1 用户认证怎么进行和保存的?认证流程SecurityContext保存 2 创建简单的登录认证示例2.1 pom.xml依赖添加2.2 application.yaml配置2.3 创建WebSecurityConfig配置类2.4 测试 1 原理 Spring Security是一个Java框架,用于保护应

公共命名空间和RHP

概述 RHP的全称是:the little Robot that Helped me Program,帮我编程序的小机器人。 RHP必然存在,C语言的宏、C++的模板,都是RHP;更复杂的例子,是lex和yacc,它们是制作程序的程序,也可以认为是RHP。可能从某个角度看,它和游戏中的NPC有点像。 公共命名空间是一张巨大的表格,里面收录了所有方言的所有句子。 公共命名空间可以把一行行源代码,翻

InfiniGate自研网关实现五

17.核心通信组件管理和处理服务映射 引入模块api-gateway-core 到 api-gateway-assist 中进行创建和使用,并拉取自注册中心的映射信息注册到本地的网关通信组件中。 第17节是在第15节的基础上继续完善服务发现的相关功能,把从注册中心拉取的网关映射信息【系统、接口、方法】映射到本地通信组件中。这样就算完成了注册中心到本地服务的一个打通处理,映射完成后就可以通过HT