TensorFlow图变量tf.Variable的用法解析

2024-08-29 10:48

本文主要是介绍TensorFlow图变量tf.Variable的用法解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

TensorFlow中的图变量,跟我们平时所接触的一般变量在用法上有很大的差异。尤其对于那些初次接触此类深度学习库的编程人员来说,会显得十分难上手。

本文将按照如下篇幅深入剖析tf.Variable这个核心概念:

图变量的初始化方法
两种定义图变量的方法
scope如何划分命名空间
图变量的复用
图变量的种类


1.图变量的初始化方法
对于一般的Python代码,变量的初始化就是变量的定义,向下面这样:

In [1]: x = 3
In [2]: y = 3 * 5
In [3]: y
Out[3]: 15

如果我们模仿上面的写法来进行TensorFlow编程,就会出现下面的”怪现象”:

In [1]: import tensorflow as tf
In [2]: x = tf.Variable(3, name='x')
In [3]: y = x * 5
In [4]: print(y)
Tensor("mul:0", shape=(), dtype=int32)

y的值并不是我们预想中的15,而是一个莫名其妙的输出——”

In [1]: import tensorflow as tf
In [2]: x = tf.Variable(3, name='x')
In [3]: y = x * 5
In [4]: sess = tf.InteractiveSession()
In [5]: sess.run(tf.global_variables_initializer())
In [6]: sess.run(y)
Out[6]: 15

在TensorFlow的世界里,变量的定义和初始化是分开的,所有关于图变量的赋值和计算都要通过tf.Session的run来进行。想要将所有图变量进行集体初始化时应该使用tf.global_variables_initializer。

2.两种定义图变量的方法
tf.Variable
tf.Variable.init(initial_value, trainable=True, collections=None, validate_shape=True, name=None)

参数名称    参数类型    含义
initial_value    所有可以转换为Tensor的类型    变量的初始值
trainable    bool    如果为True,会把它加入到GraphKeys.TRAINABLE_VARIABLES,才能对它使用Optimizer
collections    list    指定该图变量的类型、默认为[GraphKeys.GLOBAL_VARIABLES]
validate_shape    bool    如果为False,则不进行类型和维度检查
name    string    变量的名称,如果没有指定则系统会自动分配一个唯一的值
虽然有一堆参数,但只有第一个参数initial_value是必需的,用法如下(assign函数用于给图变量赋值):

In [1]: import tensorflow as tf
In [2]: v = tf.Variable(3, name='v')
In [3]: v2 = v.assign(5)
In [4]: sess = tf.InteractiveSession()
In [5]: sess.run(v.initializer)
In [6]: sess.run(v)
Out[6]: 3
In [7]: sess.run(v2)
Out[7]: 5


tf.get_variable
tf.get_variable跟tf.Variable都可以用来定义图变量,但是前者的必需参数(即第一个参数)并不是图变量的初始值,而是图变量的名称。

tf.Variable的用法要更丰富一点,当指定名称的图变量已经存在时表示获取它,当指定名称的图变量不存在时表示定义它,用法如下:

In [1]: import tensorflow as tf
In [2]: init = tf.constant_initializer([5])
In [3]: x = tf.get_variable('x', shape=[1], initializer=init)
In [4]: sess = tf.InteractiveSession()
In [5]: sess.run(x.initializer)
In [6]: sess.run(x)
Out[6]: array([ 5.], dtype=float32)

3.scope如何划分命名空间
一个深度学习模型的参数变量往往是成千上万的,不加上命名空间加以分组整理,将会成为可怕的灾难。TensorFlow的命名空间分为两种,tf.variable_scope和tf.name_scope。

下面示范使用tf.variable_scope把图变量划分为4组:

for i in range(4):
    with tf.variable_scope('scope-{}'.format(i)):
        for j in range(25):
             v = tf.Variable(1, name=str(j))
可视化输出的结果如下:

下面让我们来分析tf.variable_scope和tf.name_scope的区别:

tf.variable_scope
当使用tf.get_variable定义变量时,如果出现同名的情况将会引起报错

In [1]: import tensorflow as tf
In [2]: with tf.variable_scope('scope'):
   ...:     v1 = tf.get_variable('var', [1])
   ...:     v2 = tf.get_variable('var', [1])
ValueError: Variable scope/var already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:

而对于tf.Variable来说,却可以定义“同名”变量

In [1]: import tensorflow as tf
In [2]: with tf.variable_scope('scope'):
   ...:     v1 = tf.Variable(1, name='var')
   ...:     v2 = tf.Variable(2, name='var')
   ...:
In [3]: v1.name, v2.name
Out[3]: ('scope/var:0', 'scope/var_1:0')

但是把这些图变量的name属性打印出来,就可以发现它们的名称并不是一样的。

如果想使用tf.get_variable来定义另一个同名图变量,可以考虑加入新一层scope,比如:

In [1]: import tensorflow as tf
In [2]: with tf.variable_scope('scope1'):
   ...:     v1 = tf.get_variable('var', shape=[1])
   ...:     with tf.variable_scope('scope2'):
   ...:         v2 = tf.get_variable('var', shape=[1])
   ...:
In [3]: v1.name, v2.name
Out[3]: ('scope1/var:0', 'scope1/scope2/var:0')

tf.name_scope
当tf.get_variable遇上tf.name_scope,它定义的变量的最终完整名称将不受这个tf.name_scope的影响,如下:

In [1]: import tensorflow as tf
In [2]: with tf.variable_scope('v_scope'):
   ...:     with tf.name_scope('n_scope'):
   ...:         x = tf.Variable([1], name='x')
   ...:         y = tf.get_variable('x', shape=[1], dtype=tf.int32)
   ...:         z = x + y
   ...:
In [3]: x.name, y.name, z.name
Out[3]: ('v_scope/n_scope/x:0', 'v_scope/x:0', 'v_scope/n_scope/add:0')

4.图变量的复用
想象一下,如果我们正在定义一个循环神经网络RNN,想复用上一层的参数以提高模型最终的表现效果,应该怎么做呢?

做法一:

In [1]: import tensorflow as tf
In [2]: with tf.variable_scope('scope'):
   ...:     v1 = tf.get_variable('var', [1])
   ...:     tf.get_variable_scope().reuse_variables()
   ...:     v2 = tf.get_variable('var', [1])
   ...:
In [3]: v1.name, v2.name
Out[3]: ('scope/var:0', 'scope/var:0')


做法二:

In [1]: import tensorflow as tf
In [2]: with tf.variable_scope('scope'):
   ...:     v1 = tf.get_variable('x', [1])
   ...:
In [3]: with tf.variable_scope('scope', reuse=True):
   ...:     v2 = tf.get_variable('x', [1])
   ...:
In [4]: v1.name, v2.name
Out[4]: ('scope/x:0', 'scope/x:0')

5.图变量的种类
TensorFlow的图变量分为两类:local_variables和global_variables。

如果我们想定义一个不需要长期保存的临时图变量,可以向下面这样定义它:

with tf.name_scope("increment"):
    zero64 = tf.constant(0, dtype=tf.int64)
    current = tf.Variable(zero64, name="incr", trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES])


--------------------- 
作者:烧煤的快感 
来源:CSDN 
原文:https://blog.csdn.net/gg_18826075157/article/details/78368924 
版权声明:本文为博主原创文章,转载请附上博文链接!

这篇关于TensorFlow图变量tf.Variable的用法解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1117644

相关文章

SpringBoot3.4配置校验新特性的用法详解

《SpringBoot3.4配置校验新特性的用法详解》SpringBoot3.4对配置校验支持进行了全面升级,这篇文章为大家详细介绍了一下它们的具体使用,文中的示例代码讲解详细,感兴趣的小伙伴可以参考... 目录基本用法示例定义配置类配置 application.yml注入使用嵌套对象与集合元素深度校验开发

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringBoot UserAgentUtils获取用户浏览器的用法

《SpringBootUserAgentUtils获取用户浏览器的用法》UserAgentUtils是于处理用户代理(User-Agent)字符串的工具类,一般用于解析和处理浏览器、操作系统以及设备... 目录介绍效果图依赖封装客户端工具封装IP工具实体类获取设备信息入库介绍UserAgentUtils

Golang HashMap实现原理解析

《GolangHashMap实现原理解析》HashMap是一种基于哈希表实现的键值对存储结构,它通过哈希函数将键映射到数组的索引位置,支持高效的插入、查找和删除操作,:本文主要介绍GolangH... 目录HashMap是一种基于哈希表实现的键值对存储结构,它通过哈希函数将键映射到数组的索引位置,支持

Java中的@SneakyThrows注解用法详解

《Java中的@SneakyThrows注解用法详解》:本文主要介绍Java中的@SneakyThrows注解用法的相关资料,Lombok的@SneakyThrows注解简化了Java方法中的异常... 目录前言一、@SneakyThrows 简介1.1 什么是 Lombok?二、@SneakyThrows

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

Python中的getopt模块用法小结

《Python中的getopt模块用法小结》getopt.getopt()函数是Python中用于解析命令行参数的标准库函数,该函数可以从命令行中提取选项和参数,并对它们进行处理,本文详细介绍了Pyt... 目录getopt模块介绍getopt.getopt函数的介绍getopt模块的常用用法getopt模

Python利用ElementTree实现快速解析XML文件

《Python利用ElementTree实现快速解析XML文件》ElementTree是Python标准库的一部分,而且是Python标准库中用于解析和操作XML数据的模块,下面小编就来和大家详细讲讲... 目录一、XML文件解析到底有多重要二、ElementTree快速入门1. 加载XML的两种方式2.

mysql中的group by高级用法

《mysql中的groupby高级用法》MySQL中的GROUPBY是数据聚合分析的核心功能,主要用于将结果集按指定列分组,并结合聚合函数进行统计计算,下面给大家介绍mysql中的groupby用法... 目录一、基本语法与核心功能二、基础用法示例1. 单列分组统计2. 多列组合分组3. 与WHERE结合使

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组