tensorflow模型的恢复和加载ckpt

2024-01-10 20:32

本文主要是介绍tensorflow模型的恢复和加载ckpt,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

链接:https://www.jianshu.com/p/c9fd5c01715e

   总的来说,模型在保存和恢复时最重要的是留下数据接口,方便使用时传入数据和获取结果。TensorFlow 中常用的模型保存格式为 .ckpt 和 .pb,下面分别进行详细说明。

一、ckpt 格式模型保存与恢复

        .ckpt 格式保存与恢复都很简单,具体可参考 TensorFlow 训练 CNN 分类器。

1. ckpt 格式模型保存

inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs')  <-- 入口
···
prediction = tf.nn.softmax(logits, name='prediction')  <-- 出口(仅作为例子,下同)
···
saver = tf.train.Saver()
···with tf.Session() as sess:···    <-- 训练过程saver.save(sess, './xxx/xxx.ckpt')  <-- 模型保存

如上述代码所示,假设你定义了一个 TensorFlow 模型,数据入口由占位符 inputs 给定,结果出口由张量 prediction 给定。通过语句 saver = tf.train.Saver() 定义了模型保存的一个实例对象 saver,当模型训练结束之后只需要简单的一条语句:

saver.save(sess, path_to_model.ckpt)

就把训练结果保存到了指定的路径。

        以上代码之所以把变量 inputsprediction 单独列出,一方面是因为它们是模型 Graph 的起点和终点(戏称为数据入口、出口),另一方面的原因是它们被特别的指定了名称,因而在模型恢复时可以通过它们的名称而得到 Graph 中对应的节点。

2. ckpt 格式模型恢复

        当你需要导入模型进行推断时,只需要通过张量名获取数据入口和出口,然后传入数据即可:

with tf.Session() as sess:saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')saver.restore(sess, './xxx/xxx.ckpt')inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')prediction = tf.get_default_graph().get_tensor_by_name('prediction:0')pred = sess.run(prediction, feed_dict={inputs: xxx}

  保存为 .ckpt 模型的一个好处是,当需要继续训练时,只需要将训练过的模型结果导入,然后在这个基础上再继续训练。而下面的 .pb 格式则不能继续训练,因为这种格式保存的模型参数都已经转化为了常量(而不再是变量)。

二、pb 格式模型保存与恢复

        .pb 格式模型保存与恢复相比于前面的 .ckpt 格式而言要稍微麻烦一点,但使用更灵活,特别是模型恢复,因为它可以脱离会话(Session)而存在,便于部署。

1. pb 格式模型保存

        与 .ckpt 格式模型保存类似,首先定义数据入口、出口:

from tensorflow.python.framework import graph_util···
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs') 
···
prediction = tf.nn.softmax(logits, name='prediction') 
···with tf.Session() as sess:···    <-- 训练过程graph_def = tf.get_default_graph().as_graph_def()output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['prediction']  <-- 参数:output_node_names,输出节点名)with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:serialized_graph = output_graph_def.SerializeToString()fid.write(serialized_graph)

然后通过函数 graph_util.convert_variables_to_constants 将模型固话,使得所有变量转化为常量,之后写入到指定的路径完成模型保存过程。

2. pb 格式模型恢复

        .pb 格式模型恢复自由度较大,不需要在会话里进行操作,可以独立存在:

import osdef load_model(path_to_model.pb):if not os.path.exists(path_to_model.pb):raise ValueError("'path_to_model.pb' is not exist.")model_graph = tf.Graph()with model_graph.as_default():od_graph_def = tf.GraphDef()with tf.gfile.GFile(path_to_model.pb, 'rb') as fid:serialized_graph = fid.read()od_graph_def.ParseFromString(serialized_graph)tf.import_graph_def(od_graph_def, name='')return model_graph

模型导入之后,便可以获取数据入口和出口,然后进行推断:

model_graph = load_model('./xxx/xxx.pb')inputs = model_graph.get_tensor_by_name('inputs:0')
prediction = model_graph.get_tensor_by_name('prediction:0')with model_graph.as_default():with tf.Session(graph=model_graph) as sess:···pred = sess.run(prediction, feed_dict={inputs: xxx}

三、ckpt 格式转 pb 格式

        一般情况下,为了便于从断点之处继续训练,模型通常保存为 .ckpt 格式,而一旦对训练结果很满意之后则可能需要将 .ckpt 格式转化为 .pb 格式。转化方法很简单,只需要综合前面的一、二两步即可:

from tensorflow.python.framework import graph_utilwith tf.Session() as sess:# Load .ckpt filesaver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')saver.restore(sess, './xxx/xxx.ckpt')# Save as .pb filegraph_def = tf.get_default_graph().as_graph_def()output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['prediction']  <-- 输出节点名,以实际情况为准)with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:serialized_graph = output_graph_def.SerializeToString()fid.write(serialized_graph)


 

 

这篇关于tensorflow模型的恢复和加载ckpt的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/591983

相关文章

springboot加载不到nacos配置中心的配置问题处理

《springboot加载不到nacos配置中心的配置问题处理》:本文主要介绍springboot加载不到nacos配置中心的配置问题处理,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑... 目录springboot加载不到nacos配置中心的配置两种可能Spring Boot 版本Nacos

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

使用Python获取JS加载的数据的多种实现方法

《使用Python获取JS加载的数据的多种实现方法》在当今的互联网时代,网页数据的动态加载已经成为一种常见的技术手段,许多现代网站通过JavaScript(JS)动态加载内容,这使得传统的静态网页爬取... 目录引言一、动态 网页与js加载数据的原理二、python爬取JS加载数据的方法(一)分析网络请求1

IDEA下"File is read-only"可能原因分析及"找不到或无法加载主类"的问题

《IDEA下Fileisread-only可能原因分析及找不到或无法加载主类的问题》:本文主要介绍IDEA下Fileisread-only可能原因分析及找不到或无法加载主类的问题,具有很好的参... 目录1.File is read-only”可能原因2.“找不到或无法加载主类”问题的解决总结1.File

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

在 PyQt 加载 UI 三种常见方法

《在PyQt加载UI三种常见方法》在PyQt中,加载UI文件通常指的是使用QtDesigner设计的.ui文件,并将其转换为Python代码,以便在PyQt应用程序中使用,这篇文章给大家介绍在... 目录方法一:使用 uic 模块动态加载 (不推荐用于大型项目)方法二:将 UI 文件编译为 python 模

Spring框架中@Lazy延迟加载原理和使用详解

《Spring框架中@Lazy延迟加载原理和使用详解》:本文主要介绍Spring框架中@Lazy延迟加载原理和使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录一、@Lazy延迟加载原理1.延迟加载原理1.1 @Lazy三种配置方法1.2 @Component

SpringBoot中配置文件的加载顺序解读

《SpringBoot中配置文件的加载顺序解读》:本文主要介绍SpringBoot中配置文件的加载顺序,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录SpringBoot配置文件的加载顺序1、命令⾏参数2、Java系统属性3、操作系统环境变量5、项目【外部】的ap

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA