tensorflow之tf.nn.moments()函数解析

2023-10-14 01:20

本文主要是介绍tensorflow之tf.nn.moments()函数解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这两天在看batch normalization的代码时碰到了这个tf.nn.moments()函数,特此记录。

tf.nn.moments()函数用于计算均值和方差。
# 用于在指定维度计算均值与方差
tf.nn.moments(x,axes,shift=None,	# pylint: disable=unused-argumentname=None,keep_dims=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

参数:

  • x:一个Tensor,可以理解为我们输出的数据,形如 [batchsize, height, width, kernels]。

  • axes:整数数组,用于指定计算均值和方差的轴。如果x是1-D向量且axes=[0] 那么该函数就是计算整个向量的均值与方差。

  • shift:未在当前实现中使用。

  • name:用于计算moment的操作范围的名称。

  • keep_dims:产生与输入具有相同维度的moment,通俗点说就是是否保持维度。

返回:

Two Tensor objects: mean and variance.

两个Tensor对象:mean和variance.

解释如下:

  • mean 就是均值
  • variance 就是方差
例子1:

计算3 * 3维向量的mean和variance,程序如下:

import tensorflow as tf

img = tf.Variable(tf.random_normal([3, 3]))
axis = list(range(len(img.get_shape()) - 1))
mean, variance = tf.nn.moments(img, axis)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(img))
print(axis)
resultMean = sess.run(mean)
print(‘resultMean’,resultMean)
resultVar = sess.run(variance)
print(‘resultVar’,resultVar)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
> [[ 0.9398157   1.1222504  -0.6046098 ][ 1.4187386  -0.298682    1.033441  ][ 0.64805275  0.40496045  1.4371132 ]]
> [0]
> resultMean [1.0022024  0.40950966 0.62198144]
> resultVar [0.10093883 0.33651853 0.77942157]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

根据上面的代码,很容易可以看出axis=[0],那么tf.moments()函数就是在 [0] 维度上求了个均值和方差。

而针对3 * 3的矩阵,其实可以这么理解,当axis=[0]时,那么我们3 * 3的矩阵就可以看成三个长度为3的一维向量,然后就是三个向量的均值和方差计算,也就是对应三个向量的对应第一个数进行一次计算,对应第二个数进行一次计算,对应第三个数进行一次计算,这么说的就非常的通俗了。这是一个非常简单的例子,如果换做形如

例子2:

计算卷积神经网络某层的的mean和variance,程序如下:

import tensorflow as tf

img = tf.Variable(tf.random_normal([128, 32, 32, 64]))
axis = list(range(len(img.get_shape()) - 1))
mean, variance = tf.nn.moments(img, axis)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# print(sess.run(img))
print(axis)
resultMean = sess.run(mean)
print(‘resultMean’,resultMean)
resultVar = sess.run(variance)
print(‘resultVar’,resultVar)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

形如[128, 32, 32, 64]的数据在CNN的中间层非常常见,为了给出一个直观的认识,函数的输出结果如下,可能输出的数字比较多。

> [0, 1, 2]
> resultMean [ 2.6853075e-03 -3.0257576e-03  2.3035323e-03 -2.6062224e-03-3.2305701e-03  2.1037455e-03  6.7329779e-03 -2.2388114e-042.6253066e-03 -6.7248638e-03  2.3191441e-04 -5.8187090e-04-9.0473756e-04 -1.6551851e-03 -1.1362392e-03 -1.9381186e-03-1.2468656e-03 -3.5813404e-03  7.0505054e-04  5.0926261e-04-3.4001360e-03 -2.2198933e-03 -1.8291552e-04 -2.9487342e-032.8003550e-03  2.0361040e-03  7.5038348e-04  1.1216532e-031.1721103e-03  4.0136781e-03 -1.3581098e-03 -1.9081675e-03-5.7506924e-03  1.4085017e-04  9.2261989e-04  3.6248637e-03-3.4064866e-04 -1.7123687e-03  2.8599303e-03  3.3247408e-03-3.0919732e-04 -2.5428729e-03 -1.8558424e-03  6.8022363e-04-2.3567205e-04  2.0230825e-03 -5.6563923e-03 -4.9449857e-03-1.5591505e-03  5.4281385e-04  3.4175792e-03  3.4342592e-03-2.2981209e-03 -1.1064336e-03 -2.4347606e-03 -8.7688277e-034.2153443e-03  1.8990067e-03 -1.7339690e-03 -4.1099632e-042.9905797e-05 -2.2589187e-03  1.3317640e-03 -1.0637580e-03]
> resultVar [0.99827653 0.99892205 1.0023996  1.0008711  1.0027382  1.00621831.0062574  0.9907291  1.0007423  1.0074934  0.9987777  0.997345860.99948376 0.9996146  0.9981512  0.9992911  1.0065222  0.99599120.99847895 0.9947184  1.0043     1.004565   0.9955365  1.00639280.9991787  0.99631685 1.0008278  1.0084031  1.0019135  1.00098971.0022242  1.0076597  1.0040829  0.9944737  1.0008909  0.99621671.002177   1.0043476  1.0003107  1.0018493  1.0021918  1.00386640.9958006  0.99403363 1.0066489  1.001033   0.9994988  0.99438080.9973529  0.9969688  1.0023019  1.004277   1.0000937  1.00093651.0067816  1.0005956  0.9942864  1.0030564  0.99745005 0.99089261.0037254  0.9974016  0.99849343 1.0066065 ]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

对于 [128, 32, 32, 64] 这样的4维矩阵来说,一个batch里的128个图,经过一个64kernels卷积层处理,得到了128 * 64个图,再针对每一个kernel所对应的128个图,求它们所有像素的mean和variance,因为总共有64个kernels,输出的结果就是一个一维长度64的数组。
在这里插入图片描述

参考文章:

谈谈Tensorflow的Batch Normalization

                                </div><link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-b6c3c6d139.css" rel="stylesheet"><div class="more-toolbox"><div class="left-toolbox"><ul class="toolbox-list"><li class="tool-item tool-active is-like "><a href="javascript:;"><svg class="icon" aria-hidden="true"><use xlink:href="#csdnc-thumbsup"></use></svg><span class="name">点赞</span><span class="count">14</span></a></li><li class="tool-item tool-active is-collection "><a href="javascript:;" data-report-click="{&quot;mod&quot;:&quot;popu_824&quot;}"><svg class="icon" aria-hidden="true"><use xlink:href="#icon-csdnc-Collection-G"></use></svg><span class="name">收藏</span></a></li><li class="tool-item tool-active is-share"><a href="javascript:;"><svg class="icon" aria-hidden="true"><use xlink:href="#icon-csdnc-fenxiang"></use></svg>分享</a></li><!--打赏开始--><!--打赏结束--><li class="tool-item tool-more"><a><svg t="1575545411852" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="5717" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><defs><style type="text/css"></style></defs><path d="M179.176 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5718"></path><path d="M509.684 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5719"></path><path d="M846.175 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5720"></path></svg></a><ul class="more-box"><li class="item"><a class="article-report">文章举报</a></li></ul></li></ul></div></div><div class="person-messagebox"><div class="left-message"><a href="https://blog.csdn.net/TeFuirnever"><img src="https://profile.csdnimg.cn/6/1/2/3_tefuirnever" class="avatar_pic" username="TeFuirnever"><img src="https://g.csdnimg.cn/static/user-reg-year/1x/1.png" class="user-years"></a></div><div class="middle-message"><div class="title"><span class="tit"><a href="https://blog.csdn.net/TeFuirnever" data-report-click="{&quot;mod&quot;:&quot;popu_379&quot;}" target="_blank">我是管小亮 :)</a></span><span class="flag expert"><a href="https://blog.csdn.net/home/help.html#classicfication" target="_blank"><svg class="icon" aria-hidden="true"><use xlink:href="#csdnc-blogexpert"></use></svg>博客专家</a></span></div><div class="text"><span>发布了192 篇原创文章</span> · <span>获赞 4116</span> · <span>访问量 44万+</span></div></div><div class="right-message"><a href="https://im.csdn.net/im/main.html?userName=TeFuirnever" target="_blank" class="btn btn-sm btn-red-hollow bt-button personal-letter">私信</a><a class="btn btn-sm  bt-button personal-watch" data-report-click="{&quot;mod&quot;:&quot;popu_379&quot;}">关注</a></div></div></div>

这篇关于tensorflow之tf.nn.moments()函数解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/207142

相关文章

深度解析Spring Security 中的 SecurityFilterChain核心功能

《深度解析SpringSecurity中的SecurityFilterChain核心功能》SecurityFilterChain通过组件化配置、类型安全路径匹配、多链协同三大特性,重构了Spri... 目录Spring Security 中的SecurityFilterChain深度解析一、Security

MySQL常用字符串函数示例和场景介绍

《MySQL常用字符串函数示例和场景介绍》MySQL提供了丰富的字符串函数帮助我们高效地对字符串进行处理、转换和分析,本文我将全面且深入地介绍MySQL常用的字符串函数,并结合具体示例和场景,帮你熟练... 目录一、字符串函数概述1.1 字符串函数的作用1.2 字符串函数分类二、字符串长度与统计函数2.1

全面解析Golang 中的 Gorilla CORS 中间件正确用法

《全面解析Golang中的GorillaCORS中间件正确用法》Golang中使用gorilla/mux路由器配合rs/cors中间件库可以优雅地解决这个问题,然而,很多人刚开始使用时会遇到配... 目录如何让 golang 中的 Gorilla CORS 中间件正确工作一、基础依赖二、错误用法(很多人一开

python使用try函数详解

《python使用try函数详解》Pythontry语句用于异常处理,支持捕获特定/多种异常、else/final子句确保资源释放,结合with语句自动清理,可自定义异常及嵌套结构,灵活应对错误场景... 目录try 函数的基本语法捕获特定异常捕获多个异常使用 else 子句使用 finally 子句捕获所

Mysql中设计数据表的过程解析

《Mysql中设计数据表的过程解析》数据库约束通过NOTNULL、UNIQUE、DEFAULT、主键和外键等规则保障数据完整性,自动校验数据,减少人工错误,提升数据一致性和业务逻辑严谨性,本文介绍My... 目录1.引言2.NOT NULL——制定某列不可以存储NULL值2.UNIQUE——保证某一列的每一

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

MySQL CTE (Common Table Expressions)示例全解析

《MySQLCTE(CommonTableExpressions)示例全解析》MySQL8.0引入CTE,支持递归查询,可创建临时命名结果集,提升复杂查询的可读性与维护性,适用于层次结构数据处... 目录基本语法CTE 主要特点非递归 CTE简单 CTE 示例多 CTE 示例递归 CTE基本递归 CTE 结

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

Spring Boot 3.x 中 WebClient 示例详解析

《SpringBoot3.x中WebClient示例详解析》SpringBoot3.x中WebClient是响应式HTTP客户端,替代RestTemplate,支持异步非阻塞请求,涵盖GET... 目录Spring Boot 3.x 中 WebClient 全面详解及示例1. WebClient 简介2.

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分