tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法

2023-10-20 21:10

本文主要是介绍tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

在辛辛苦苦跑了几个小时甚至几天之后,你训练出了几十万个或者更多的参数,那么你肯定不想只使用这些参数仅仅一次,那么就涉及到这些参数的保存以及提取,幸运的是,tensorflow已经帮我们集成好了相关函数,就是接下来要介绍的tf.train.Saver() 类。

tf.train.Saver()

一 . 用于保存权重和偏重(参数)
在使用之前要先实例化一个类,例如以下代码:

saver = tf.train.Saver()

如何保存?

import tensorflow as tf
import numpy as np# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')# init= tf.initialize_all_variables() # tf 马上就要废弃这种写法
# 替换成下面的写法:
init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:sess.run(init)save_path = saver.save(sess, "my_net/save_net.ckpt")#print("Save to path: ", save_path)

这里Saver()类有一个save方法,其参数为(会话名称,要保存的文件路径以及具体文件)保存后的结果如下:
在这里插入图片描述
这里会自动生成一个“checkpoint”文件以及其他几个.ckpt文件,用来存储参数。

保存了以后如何提取,或者说读取参数?
见以下代码:

w = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")# 这里不需要初始化步骤 init= tf.initialize_all_variables()saver = tf.train.Saver()
with tf.Session() as sess:# 提取变量saver.restore(sess, "my_net/save_net.ckpt")print("weights:", sess.run(w))print("biases:", sess.run(b))

这里Saver() 提供了一个restore方法,其参数为(会话,需要提取的文件)

这里有几点需要说明一下:

  1. Saver() 类只能存储和提取神经网络的参数,现在还不能存储整个网络架构,这个比较操蛋(不过我相信以后肯定会出现类似的存储整个训练好的架构的函数),现如今想要使用已经训练好的参数,还是需要重新定义一个一模一样的参数变量,无论是在数据类型上,还是shape上

这篇关于tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/249735

相关文章

Python中logging模块用法示例总结

《Python中logging模块用法示例总结》在Python中logging模块是一个强大的日志记录工具,它允许用户将程序运行期间产生的日志信息输出到控制台或者写入到文件中,:本文主要介绍Pyt... 目录前言一. 基本使用1. 五种日志等级2.  设置报告等级3. 自定义格式4. C语言风格的格式化方法

SpringBoot 获取请求参数的常用注解及用法

《SpringBoot获取请求参数的常用注解及用法》SpringBoot通过@RequestParam、@PathVariable等注解支持从HTTP请求中获取参数,涵盖查询、路径、请求体、头、C... 目录SpringBoot 提供了多种注解来方便地从 HTTP 请求中获取参数以下是主要的注解及其用法:1

HTTP 与 SpringBoot 参数提交与接收协议方式

《HTTP与SpringBoot参数提交与接收协议方式》HTTP参数提交方式包括URL查询、表单、JSON/XML、路径变量、头部、Cookie、GraphQL、WebSocket和SSE,依据... 目录HTTP 协议支持多种参数提交方式,主要取决于请求方法(Method)和内容类型(Content-Ty

Debian 13升级后网络转发等功能异常怎么办? 并非错误而是管理机制变更

《Debian13升级后网络转发等功能异常怎么办?并非错误而是管理机制变更》很多朋友反馈,更新到Debian13后网络转发等功能异常,这并非BUG而是Debian13Trixie调整... 日前 Debian 13 Trixie 发布后已经有众多网友升级到新版本,只不过升级后发现某些功能存在异常,例如网络转

Java中HashMap的用法详细介绍

《Java中HashMap的用法详细介绍》JavaHashMap是一种高效的数据结构,用于存储键值对,它是基于哈希表实现的,提供快速的插入、删除和查找操作,:本文主要介绍Java中HashMap... 目录一.HashMap1.基本概念2.底层数据结构:3.HashCode和equals方法为什么重写Has

Python从Word文档中提取图片并生成PPT的操作代码

《Python从Word文档中提取图片并生成PPT的操作代码》在日常办公场景中,我们经常需要从Word文档中提取图片,并将这些图片整理到PowerPoint幻灯片中,手动完成这一任务既耗时又容易出错,... 目录引言背景与需求解决方案概述代码解析代码核心逻辑说明总结引言在日常办公场景中,我们经常需要从 W

Android协程高级用法大全

《Android协程高级用法大全》这篇文章给大家介绍Android协程高级用法大全,本文结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友跟随小编一起学习吧... 目录1️⃣ 协程作用域(CoroutineScope)与生命周期绑定Activity/Fragment 中手

Java+AI驱动实现PDF文件数据提取与解析

《Java+AI驱动实现PDF文件数据提取与解析》本文将和大家分享一套基于AI的体检报告智能评估方案,详细介绍从PDF上传、内容提取到AI分析、数据存储的全流程自动化实现方法,感兴趣的可以了解下... 目录一、核心流程:从上传到评估的完整链路二、第一步:解析 PDF,提取体检报告内容1. 引入依赖2. 封装

Python异步编程之await与asyncio基本用法详解

《Python异步编程之await与asyncio基本用法详解》在Python中,await和asyncio是异步编程的核心工具,用于高效处理I/O密集型任务(如网络请求、文件读写、数据库操作等),接... 目录一、核心概念二、使用场景三、基本用法1. 定义协程2. 运行协程3. 并发执行多个任务四、关键

python中的显式声明类型参数使用方式

《python中的显式声明类型参数使用方式》文章探讨了Python3.10+版本中类型注解的使用,指出FastAPI官方示例强调显式声明参数类型,通过|操作符替代Union/Optional,可提升代... 目录背景python函数显式声明的类型汇总基本类型集合类型Optional and Union(py