Tensorflow反卷积(conv2d_transpose)实现原理+手写python代码实现反卷积(DeConv)

本文主要是介绍Tensorflow反卷积(conv2d_transpose)实现原理+手写python代码实现反卷积(DeConv),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1 反卷积原理

反卷积原理不太好用文字描述,这里直接以一个简单例子描述反卷积过程。

假设输入如下:

[[1,0,1],[0,2,1],[1,1,0]]

反卷积卷积核如下:

[[ 1, 0, 1],[-1, 1, 0],[ 0,-1, 0]]

现在通过stride=2来进行反卷积,使得尺寸由原来的3*3变为6*6.那么在Tensorflow框架中,反卷积的过程如下(不同框架在裁剪这步可能不一样):

反卷积实现例子

其实通过我绘制的这张图,就已经把原理讲的很清楚了。大致步奏就是,先填充0,然后进行卷积,卷积过程跟上一篇文章讲述的一致。最后一步还要进行裁剪。好了,原理讲完了,(#.#)....

2 代码实现

上一篇文章我们只针对了输出通道数为1进行代码实现,在这篇文章中,反卷积我们将输出通道设置为多个,这样更符合实际场景。

先定义输入和卷积核:

input_data=[[[1,0,1],[0,2,1],[1,1,0]],[[2,0,2],[0,1,0],[1,0,0]],[[1,1,1],[2,2,0],[1,1,1]],[[1,1,2],[1,0,1],[0,2,2]]]
weights_data=[ [[[ 1, 0, 1],[-1, 1, 0],[ 0,-1, 0]],[[-1, 0, 1],[ 0, 0, 1],[ 1, 1, 1]],[[ 0, 1, 1],[ 2, 0, 1],[ 1, 2, 1]], [[ 1, 1, 1],[ 0, 2, 1],[ 1, 0, 1]]],[[[ 1, 0, 2],[-2, 1, 1],[ 1,-1, 0]],[[-1, 0, 1],[-1, 2, 1],[ 1, 1, 1]],[[ 0, 0, 0],[ 2, 2, 1],[ 1,-1, 1]], [[ 2, 1, 1],[ 0,-1, 1],[ 1, 1, 1]]]  ]

上面定义的输入和卷积核,在接下的运算过程如下图所示:

执行过程

可以看到实际上,反卷积和卷积基本一致,差别在于,反卷积需要填充过程,并在最后一步需要裁剪。具体实现代码如下:

#根据输入map([h,w])和卷积核([k,k]),计算卷积后的feature map
import numpy as np
def compute_conv(fm,kernel):[h,w]=fm.shape [k,_]=kernel.shape r=int(k/2)#定义边界填充0后的mappadding_fm=np.zeros([h+2,w+2],np.float32)#保存计算结果rs=np.zeros([h,w],np.float32) #将输入在指定该区域赋值,即除了4个边界后,剩下的区域padding_fm[1:h+1,1:w+1]=fm #对每个点为中心的区域遍历for i in range(1,h+1):for j in range(1,w+1): #取出当前点为中心的k*k区域roi=padding_fm[i-r:i+r+1,j-r:j+r+1]#计算当前点的卷积,对k*k个点点乘后求和rs[i-1][j-1]=np.sum(roi*kernel)return rs#填充0
def fill_zeros(input):[c,h,w]=input.shapers=np.zeros([c,h*2+1,w*2+1],np.float32)for i in range(c):for j in range(h):for k in range(w): rs[i,2*j+1,2*k+1]=input[i,j,k] return rsdef my_deconv(input,weights):#weights shape=[out_c,in_c,h,w][out_c,in_c,h,w]=weights.shape   out_h=h*2out_w=w*2rs=[]for i in range(out_c):w=weights[i]tmp=np.zeros([out_h,out_w],np.float32)for j in range(in_c):conv=compute_conv(input[j],w[j])#注意裁剪,最后一行和最后一列去掉tmp=tmp+conv[0:out_h,0:out_w]rs.append(tmp)return rs def main():  input=np.asarray(input_data,np.float32)input= fill_zeros(input)weights=np.asarray(weights_data,np.float32)deconv=my_deconv(input,weights)print(np.asarray(deconv))if __name__=='__main__':main()

计算卷积代码,跟上一篇文章一致。代码直接看注释,不再解释。运行结果如下:

[[[  4.   3.   6.   2.   7.   3.][  4.   3.   3.   2.   7.   5.][  8.   6.   8.   5.  11.   2.][  3.   2.   7.   2.   3.   3.][  5.   5.  11.   3.   9.   3.][  2.   1.   4.   5.   4.   4.]][[  4.   1.   7.   0.   7.   2.][  5.   6.   0.   1.   8.   5.][  8.   0.   8.  -2.  14.   2.][  3.   3.   9.   8.   1.   0.][  3.   0.  13.   0.  11.   2.][  3.   5.   3.   1.   3.   0.]]]

为了验证实现的代码的正确性,我们使用tensorflow的conv2d_transpose函数执行相同的输入和卷积核,看看结果是否一致。验证代码如下:

import tensorflow as tf
import numpy as np 
def tf_conv2d_transpose(input,weights):#input_shape=[n,height,width,channel]input_shape = input.get_shape().as_list()#weights shape=[height,width,out_c,in_c]weights_shape=weights.get_shape().as_list() output_shape=[input_shape[0], input_shape[1]*2 , input_shape[2]*2 , weights_shape[2]]print("output_shape:",output_shape)deconv=tf.nn.conv2d_transpose(input,weights,output_shape=output_shape,strides=[1, 2, 2, 1], padding='SAME')return deconvdef main(): weights_np=np.asarray(weights_data,np.float32)#将输入的每个卷积核旋转180°weights_np=np.rot90(weights_np,2,(2,3))const_input = tf.constant(input_data , tf.float32)const_weights = tf.constant(weights_np , tf.float32 )input = tf.Variable(const_input,name="input")#[c,h,w]------>[h,w,c]input=tf.transpose(input,perm=(1,2,0))#[h,w,c]------>[n,h,w,c]input=tf.expand_dims(input,0)#weights shape=[out_c,in_c,h,w]weights = tf.Variable(const_weights,name="weights")#[out_c,in_c,h,w]------>[h,w,out_c,in_c]weights=tf.transpose(weights,perm=(2,3,0,1))#执行tensorflow的反卷积deconv=tf_conv2d_transpose(input,weights) init=tf.global_variables_initializer()sess=tf.Session()sess.run(init)deconv_val  = sess.run(deconv) hwc=deconv_val[0]print(hwc) if __name__=='__main__':main() 

上面代码中,有几点需要注意:

  1. 每个卷积核需要旋转180°后,再传入tf.nn.conv2d_transpose函数中,因为tf.nn.conv2d_transpose内部会旋转180°,所以提前旋转,再经过内部旋转后,能保证卷积核跟我们所使用的卷积核的数据排列一致。
  2. 我们定义的输入的shape为[c,h,w]需要转为tensorflow所使用的[n,h,w,c]。
  3. 我们定义的卷积核shape为[out_c,in_c,h,w],需要转为tensorflow反卷积中所使用的[h,w,out_c,in_c]

执行上面代码后,执行结果如下:

[[  4.   3.   6.   2.   7.   3.][  4.   3.   3.   2.   7.   5.][  8.   6.   8.   5.  11.   2.][  3.   2.   7.   2.   3.   3.][  5.   5.  11.   3.   9.   3.][  2.   1.   4.   5.   4.   4.]]
[[  4.   1.   7.   0.   7.   2.][  5.   6.   0.   1.   8.   5.][  8.   0.   8.  -2.  14.   2.][  3.   3.   9.   8.   1.   0.][  3.   0.  13.   0.  11.   2.][  3.   5.   3.   1.   3.   0.]]

对比结果可以看到,数据是一致的,证明前面手写的python实现的反卷积代码是正确的。



作者:huachao1001
链接:https://www.jianshu.com/p/f0674e48894c
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

这篇关于Tensorflow反卷积(conv2d_transpose)实现原理+手写python代码实现反卷积(DeConv)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/224832

相关文章

golang版本升级如何实现

《golang版本升级如何实现》:本文主要介绍golang版本升级如何实现问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录golanwww.chinasem.cng版本升级linux上golang版本升级删除golang旧版本安装golang最新版本总结gola

SpringBoot中SM2公钥加密、私钥解密的实现示例详解

《SpringBoot中SM2公钥加密、私钥解密的实现示例详解》本文介绍了如何在SpringBoot项目中实现SM2公钥加密和私钥解密的功能,通过使用Hutool库和BouncyCastle依赖,简化... 目录一、前言1、加密信息(示例)2、加密结果(示例)二、实现代码1、yml文件配置2、创建SM2工具

Mysql实现范围分区表(新增、删除、重组、查看)

《Mysql实现范围分区表(新增、删除、重组、查看)》MySQL分区表的四种类型(范围、哈希、列表、键值),主要介绍了范围分区的创建、查询、添加、删除及重组织操作,具有一定的参考价值,感兴趣的可以了解... 目录一、mysql分区表分类二、范围分区(Range Partitioning1、新建分区表:2、分

MySQL 定时新增分区的实现示例

《MySQL定时新增分区的实现示例》本文主要介绍了通过存储过程和定时任务实现MySQL分区的自动创建,解决大数据量下手动维护的繁琐问题,具有一定的参考价值,感兴趣的可以了解一下... mysql创建好分区之后,有时候会需要自动创建分区。比如,一些表数据量非常大,有些数据是热点数据,按照日期分区MululbU

Python中你不知道的gzip高级用法分享

《Python中你不知道的gzip高级用法分享》在当今大数据时代,数据存储和传输成本已成为每个开发者必须考虑的问题,Python内置的gzip模块提供了一种简单高效的解决方案,下面小编就来和大家详细讲... 目录前言:为什么数据压缩如此重要1. gzip 模块基础介绍2. 基本压缩与解压缩操作2.1 压缩文

Python设置Cookie永不超时的详细指南

《Python设置Cookie永不超时的详细指南》Cookie是一种存储在用户浏览器中的小型数据片段,用于记录用户的登录状态、偏好设置等信息,下面小编就来和大家详细讲讲Python如何设置Cookie... 目录一、Cookie的作用与重要性二、Cookie过期的原因三、实现Cookie永不超时的方法(一)

MySQL中查找重复值的实现

《MySQL中查找重复值的实现》查找重复值是一项常见需求,比如在数据清理、数据分析、数据质量检查等场景下,我们常常需要找出表中某列或多列的重复值,具有一定的参考价值,感兴趣的可以了解一下... 目录技术背景实现步骤方法一:使用GROUP BY和HAVING子句方法二:仅返回重复值方法三:返回完整记录方法四:

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

IDEA中新建/切换Git分支的实现步骤

《IDEA中新建/切换Git分支的实现步骤》本文主要介绍了IDEA中新建/切换Git分支的实现步骤,通过菜单创建新分支并选择是否切换,创建后在Git详情或右键Checkout中切换分支,感兴趣的可以了... 前提:项目已被Git托管1、点击上方栏Git->NewBrancjsh...2、输入新的分支的

Python函数作用域示例详解

《Python函数作用域示例详解》本文介绍了Python中的LEGB作用域规则,详细解析了变量查找的四个层级,通过具体代码示例,展示了各层级的变量访问规则和特性,对python函数作用域相关知识感兴趣... 目录一、LEGB 规则二、作用域实例2.1 局部作用域(Local)2.2 闭包作用域(Enclos