PyTorch 的 Pooling 和 UnPooling函数中的 indices 参数:nn.MaxPool2d/nn.MaxUnpool2d、F.max_pool2d/F.max_unpool2d

本文主要是介绍PyTorch 的 Pooling 和 UnPooling函数中的 indices 参数:nn.MaxPool2d/nn.MaxUnpool2d、F.max_pool2d/F.max_unpool2d,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这篇博文主要介绍 PyTorch 的 MaxPooling 和 MAxUnPooling 函数中涉及到的 indices 参数。

indices 是“索引”的意思,对于一些结构对称的网络模型,上采样和下采样的结构往往是对称的,我们可以在下采样做 MaxPooling 的时候记录下来最大值所在的位置,当做上采样的时候把最大值还原到其对应的位置,然后其余的位置补 0 。

indices 参数的作用就是保存 MaxPooling 操作时最大值位置的索引。

如下图所示:

在这里插入图片描述
PyTorch 的 torch.nn 和 torch.nn.functional 模块中均有实现 Pooling 和 UnPooling 的 api ,它们的作用和效果是完全相同的。

1、nn.MaxPool2d 和 nn.MaxUnpool2d

使用 nn.MaxPool2d 和 nn.MaxUnpool2d 时要先实例化,事实上 nn 模块下面的函数都是如此(需要先实例化),比如 nn.Conv

import torch
from torch import nn# 使用 nn.MaxPool2d 和 nn.MaxUnpool2d 时要先实例化
pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)data = torch.tensor([[[[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12],[13, 14, 15, 16]]]], dtype=torch.float32)
pool_out, indice = pool(data)
unpool_out = unpool(input=pool_out, indices=indice)print("pool_out = ", pool_out)
print("indice = ", indice)
print("unpool_out = ", unpool_out)"""
pool_out =  tensor([[[[ 6.,  8.],[14., 16.]]]])
indice =  tensor([[[[ 5,  7],[13, 15]]]])
unpool_out =  tensor([[[[ 0.,  0.,  0.,  0.],[ 0.,  6.,  0.,  8.],[ 0.,  0.,  0.,  0.],[ 0., 14.,  0., 16.]]]])
"""

2、F.max_pool2d 和 F.max_unpool2d

使用 F.max_pool2d 和 F.max_unpool2d 时不需要实例化,可以直接使用。

import torch
import torch.nn.functional as Fdata = torch.tensor([[[[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12],[13, 14, 15, 16]]]], dtype=torch.float32)
pool_out, indice = F.max_pool2d(input=data, kernel_size=2, stride=2, return_indices=True)
unpool_out = F.max_unpool2d(input=pool_out, indices=indice, kernel_size=2, stride=2)print("pool_out = ", pool_out)
print("indice = ", indice)
print("unpool_out = ", unpool_out)"""
pool_out =  tensor([[[[ 6.,  8.],[14., 16.]]]])
indice =  tensor([[[[ 5,  7],[13, 15]]]])
unpool_out =  tensor([[[[ 0.,  0.,  0.,  0.],[ 0.,  6.,  0.,  8.],[ 0.,  0.,  0.,  0.],[ 0., 14.,  0., 16.]]]])
"""

可以看到,nn.MaxPool2d / nn.MaxUnpool2d 和 F.max_pool2d / F.max_unpool2d 的作用和输出结果完全相同。

3、使用 Pooling 和 Conv2d 实现上/下采样的区别和产生的影响

使用 Pooling 和 Conv2d 实现上/下采样的区别主要体现在对奇数大小的特征图的处理中,以特征图大小为 65*65为例。

使用 nn.MaxPool2d 和 F.max_pool2d 实现下采样时,得到的特征图大小是 32*32,上采样得到的特征图大小是 64*64
使用 nn.Conv2d 实现下采样时,得到的特征图大小是 33*33;再使用nn.ConvTranspose2d 上采样得到的特征图大小是 66*66

在很多对称的网络结构中(如 UNet、SegNet),需要对上采样和下采样的对应的特征图进行大小对齐。

若网络中间某个特征图大小是 65 ,不论使用哪种上/下采样策略,得到的特征图大小必然不可能还是64(不可能是奇数)。此时就要考虑64和66的区别了。

(1)如果使用 64 * 64 大小的特征图,则需要进行 padding 得到 65 * 65 的特征图。但padding操作可能会导致在训练和推理过程中的不确定性问题。

(2)如果使用 66 * 66 大小的特征图,则只需要进行切片得到 65 * 65 的特征图。该操作是更靠谱的。

因此,个人建议使用 nn.Conv2d 和其他函数实现采样操作。

这篇关于PyTorch 的 Pooling 和 UnPooling函数中的 indices 参数:nn.MaxPool2d/nn.MaxUnpool2d、F.max_pool2d/F.max_unpool2d的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/342448

相关文章

MySQL常用字符串函数示例和场景介绍

《MySQL常用字符串函数示例和场景介绍》MySQL提供了丰富的字符串函数帮助我们高效地对字符串进行处理、转换和分析,本文我将全面且深入地介绍MySQL常用的字符串函数,并结合具体示例和场景,帮你熟练... 目录一、字符串函数概述1.1 字符串函数的作用1.2 字符串函数分类二、字符串长度与统计函数2.1

python使用try函数详解

《python使用try函数详解》Pythontry语句用于异常处理,支持捕获特定/多种异常、else/final子句确保资源释放,结合with语句自动清理,可自定义异常及嵌套结构,灵活应对错误场景... 目录try 函数的基本语法捕获特定异常捕获多个异常使用 else 子句使用 finally 子句捕获所

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

postgresql使用UUID函数的方法

《postgresql使用UUID函数的方法》本文给大家介绍postgresql使用UUID函数的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录PostgreSQL有两种生成uuid的方法。可以先通过sql查看是否已安装扩展函数,和可以安装的扩展函数

MySQL字符串常用函数详解

《MySQL字符串常用函数详解》本文给大家介绍MySQL字符串常用函数,本文结合实例代码给大家介绍的非常详细,对大家学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录mysql字符串常用函数一、获取二、大小写转换三、拼接四、截取五、比较、反转、替换六、去空白、填充MySQL字符串常用函数一、

C++中assign函数的使用

《C++中assign函数的使用》在C++标准模板库中,std::list等容器都提供了assign成员函数,它比操作符更灵活,支持多种初始化方式,下面就来介绍一下assign的用法,具有一定的参考价... 目录​1.assign的基本功能​​语法​2. 具体用法示例​​​(1) 填充n个相同值​​(2)

MySql基本查询之表的增删查改+聚合函数案例详解

《MySql基本查询之表的增删查改+聚合函数案例详解》本文详解SQL的CURD操作INSERT用于数据插入(单行/多行及冲突处理),SELECT实现数据检索(列选择、条件过滤、排序分页),UPDATE... 目录一、Create1.1 单行数据 + 全列插入1.2 多行数据 + 指定列插入1.3 插入否则更

PostgreSQL中rank()窗口函数实用指南与示例

《PostgreSQL中rank()窗口函数实用指南与示例》在数据分析和数据库管理中,经常需要对数据进行排名操作,PostgreSQL提供了强大的窗口函数rank(),可以方便地对结果集中的行进行排名... 目录一、rank()函数简介二、基础示例:部门内员工薪资排名示例数据排名查询三、高级应用示例1. 每

全面掌握 SQL 中的 DATEDIFF函数及用法最佳实践

《全面掌握SQL中的DATEDIFF函数及用法最佳实践》本文解析DATEDIFF在不同数据库中的差异,强调其边界计算原理,探讨应用场景及陷阱,推荐根据需求选择TIMESTAMPDIFF或inte... 目录1. 核心概念:DATEDIFF 究竟在计算什么?2. 主流数据库中的 DATEDIFF 实现2.1

MySQL中的LENGTH()函数用法详解与实例分析

《MySQL中的LENGTH()函数用法详解与实例分析》MySQLLENGTH()函数用于计算字符串的字节长度,区别于CHAR_LENGTH()的字符长度,适用于多字节字符集(如UTF-8)的数据验证... 目录1. LENGTH()函数的基本语法2. LENGTH()函数的返回值2.1 示例1:计算字符串