tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法

2023-10-20 21:10

本文主要是介绍tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

在辛辛苦苦跑了几个小时甚至几天之后,你训练出了几十万个或者更多的参数,那么你肯定不想只使用这些参数仅仅一次,那么就涉及到这些参数的保存以及提取,幸运的是,tensorflow已经帮我们集成好了相关函数,就是接下来要介绍的tf.train.Saver() 类。

tf.train.Saver()

一 . 用于保存权重和偏重(参数)
在使用之前要先实例化一个类,例如以下代码:

saver = tf.train.Saver()

如何保存?

import tensorflow as tf
import numpy as np# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')# init= tf.initialize_all_variables() # tf 马上就要废弃这种写法
# 替换成下面的写法:
init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:sess.run(init)save_path = saver.save(sess, "my_net/save_net.ckpt")#print("Save to path: ", save_path)

这里Saver()类有一个save方法,其参数为(会话名称,要保存的文件路径以及具体文件)保存后的结果如下:
在这里插入图片描述
这里会自动生成一个“checkpoint”文件以及其他几个.ckpt文件,用来存储参数。

保存了以后如何提取,或者说读取参数?
见以下代码:

w = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")# 这里不需要初始化步骤 init= tf.initialize_all_variables()saver = tf.train.Saver()
with tf.Session() as sess:# 提取变量saver.restore(sess, "my_net/save_net.ckpt")print("weights:", sess.run(w))print("biases:", sess.run(b))

这里Saver() 提供了一个restore方法,其参数为(会话,需要提取的文件)

这里有几点需要说明一下:

  1. Saver() 类只能存储和提取神经网络的参数,现在还不能存储整个网络架构,这个比较操蛋(不过我相信以后肯定会出现类似的存储整个训练好的架构的函数),现如今想要使用已经训练好的参数,还是需要重新定义一个一模一样的参数变量,无论是在数据类型上,还是shape上

这篇关于tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:https://blog.csdn.net/qq_29566629/article/details/90181083
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/249735

相关文章

Python中提取文件名扩展名的多种方法实现

《Python中提取文件名扩展名的多种方法实现》在Python编程中,经常会遇到需要从文件名中提取扩展名的场景,Python提供了多种方法来实现这一功能,不同方法适用于不同的场景和需求,包括os.pa... 目录技术背景实现步骤方法一:使用os.path.splitext方法二:使用pathlib模块方法三

CSS place-items: center解析与用法详解

《CSSplace-items:center解析与用法详解》place-items:center;是一个强大的CSS简写属性,用于同时控制网格(Grid)和弹性盒(Flexbox)... place-items: center; 是一个强大的 css 简写属性,用于同时控制 网格(Grid) 和 弹性盒(F

mysql中insert into的基本用法和一些示例

《mysql中insertinto的基本用法和一些示例》INSERTINTO用于向MySQL表插入新行,支持单行/多行及部分列插入,下面给大家介绍mysql中insertinto的基本用法和一些示例... 目录基本语法插入单行数据插入多行数据插入部分列的数据插入默认值注意事项在mysql中,INSERT I

mapstruct中的@Mapper注解的基本用法

《mapstruct中的@Mapper注解的基本用法》在MapStruct中,@Mapper注解是核心注解之一,用于标记一个接口或抽象类为MapStruct的映射器(Mapper),本文给大家介绍ma... 目录1. 基本用法2. 常用属性3. 高级用法4. 注意事项5. 总结6. 编译异常处理在MapSt

Python实现精准提取 PDF中的文本,表格与图片

《Python实现精准提取PDF中的文本,表格与图片》在实际的系统开发中,处理PDF文件不仅限于读取整页文本,还有提取文档中的表格数据,图片或特定区域的内容,下面我们来看看如何使用Python实... 目录安装 python 库提取 PDF 文本内容:获取整页文本与指定区域内容获取页面上的所有文本内容获取

Java使用HttpClient实现图片下载与本地保存功能

《Java使用HttpClient实现图片下载与本地保存功能》在当今数字化时代,网络资源的获取与处理已成为软件开发中的常见需求,其中,图片作为网络上最常见的资源之一,其下载与保存功能在许多应用场景中都... 目录引言一、Apache HttpClient简介二、技术栈与环境准备三、实现图片下载与保存功能1.

java中long的一些常见用法

《java中long的一些常见用法》在Java中,long是一种基本数据类型,用于表示长整型数值,接下来通过本文给大家介绍java中long的一些常见用法,感兴趣的朋友一起看看吧... 在Java中,long是一种基本数据类型,用于表示长整型数值。它的取值范围比int更大,从-922337203685477

MyBatis ResultMap 的基本用法示例详解

《MyBatisResultMap的基本用法示例详解》在MyBatis中,resultMap用于定义数据库查询结果到Java对象属性的映射关系,本文给大家介绍MyBatisResultMap的基本... 目录MyBATis 中的 resultMap1. resultMap 的基本语法2. 简单的 resul

Python主动抛出异常的各种用法和场景分析

《Python主动抛出异常的各种用法和场景分析》在Python中,我们不仅可以捕获和处理异常,还可以主动抛出异常,也就是以类的方式自定义错误的类型和提示信息,这在编程中非常有用,下面我将详细解释主动抛... 目录一、为什么要主动抛出异常?二、基本语法:raise关键字基本示例三、raise的多种用法1. 抛

java中Optional的核心用法和最佳实践

《java中Optional的核心用法和最佳实践》Java8中Optional用于处理可能为null的值,减少空指针异常,:本文主要介绍java中Optional核心用法和最佳实践的相关资料,文中... 目录前言1. 创建 Optional 对象1.1 常规创建方式2. 访问 Optional 中的值2.1