1.tensorflow线性回归示例:保存模型,载入模型,打印模型参数,修改模型

本文主要是介绍1.tensorflow线性回归示例:保存模型,载入模型,打印模型参数,修改模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

#coding:utf-8
'''
a liner regression by tenosrflow.
input dimension: 1, output dimension: 1.
显示每个epoch的loss
利用模型预测
保存模型
载入模型
打印模型中的参数
修改模型中的参数
'''
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file# data
x_train = np.linspace(-1, 1, 100)
y_train = 10 * x_train + np.random.randn(x_train.shape[0])
# plt.plot(x_train, y_train, "ro", label="data")
# plt.legend()
# plt.show()epochs = 30
display_step = 2
# input, output
x = tf.placeholder(dtype="float", name="input")
y = tf.placeholder(dtype="float", name="label")
# w, b
w = tf.Variable(initial_value=tf.random_normal([1]), name="weight")
b = tf.Variable(initial_value=tf.zeros([1]), name="bias")
# model
z = tf.multiply(x, w) + b
# loss functon
cost = tf.reduce_mean(tf.square(y - z))
# optimizer
optim = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost)
saver = tf.train.Saver(max_to_keep=4)  # save 4 model
init = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init)for epoch in range(epochs):for x_batch, y_batch in zip(x_train, y_train):  # batch is all data theresess.run(optim, feed_dict={x:x_batch, y:y_batch})if epoch % display_step ==0:loss = sess.run(cost, feed_dict={x:x_train, y:y_train})print("epoch: %d, loss: %d" %(epoch, loss))# 保存训练过程中的模型saver.save(sess, "line_regression_model/regress.cpkt", global_step=epoch)print("train finished...")# 保存最终的模型saver.save(sess, "line_regression_model/regress.cpkt")print("final loss:", sess.run(cost, feed_dict={x:x_train, y:y_train}))print("weight:", sess.run(w))print("bias:", sess.run(b))# show train data and predict dataplt.plot(x_train, y_train, "ro", label="train")predict = sess.run(w) * x_train + sess.run(b)plt.plot(x_train, predict, "b", label="predict")plt.legend()plt.show()# 载入模型
print("*"*50)
saver = tf.train.Saver()
with tf.Session() as sess2:sess2.run(tf.global_variables_initializer())saver.restore(sess2, "line_regression_model/regress.cpkt")print(sess2.run(w))print(sess2.run(b))predict2 = sess2.run(z, feed_dict={x:0.5})print(predict2)# 打印出模型中的变量及参数
print("-"*50)
print("the params in model:")
print_tensors_in_checkpoint_file("line_regression_model/regress.cpkt", None, True)# 修改模型中的参数,并重新保存
print("-"*50)
# 以上得到了模型中参数名字为weight,bias, 下面对他们进行修改
w_change = tf.Variable(10, name="weight")
b_change = tf.Variable(0.001, name="bias")
# 把他们放到一个字典里并写在saver里
saver = tf.train.Saver({"weighs":w_change, "bias":b_change})
with tf.Session() as sess3:sess3.run(tf.global_variables_initializer())# 保存修改后的参数saver.save(sess3, "line_regression_model/regress.cpkt")
# 发现参数已经被修改
print_tensors_in_checkpoint_file("line_regression_model/regress.cpkt", None, True)

输出:

/usr/local/bin/python2.7 /Users/ming/Downloads/zhangming/tf_demo/liner_regression.py
2018-11-17 16:07:32.138907: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
epoch: 0, loss: 21
epoch: 2, loss: 2
epoch: 4, loss: 1
epoch: 6, loss: 1
epoch: 8, loss: 1
epoch: 10, loss: 1
epoch: 12, loss: 1
epoch: 14, loss: 1
epoch: 16, loss: 1
epoch: 18, loss: 1
epoch: 20, loss: 1
epoch: 22, loss: 1
epoch: 24, loss: 1
epoch: 26, loss: 1
epoch: 28, loss: 1
train finished...
('final loss:', 1.0535882)
('weight:', array([10.063329], dtype=float32))
('bias:', array([0.03052005], dtype=float32))
**************************************************
[10.063329]
[0.03052005]
[5.0621843]
--------------------------------------------------
the params in model:
tensor_name:  bias
[0.03052005]
tensor_name:  weight
[10.063329]
--------------------------------------------------
tensor_name:  bias
0.001
tensor_name:  weighs
10
 

这篇关于1.tensorflow线性回归示例:保存模型,载入模型,打印模型参数,修改模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/986052

相关文章

详解SpringBoot+Ehcache使用示例

《详解SpringBoot+Ehcache使用示例》本文介绍了SpringBoot中配置Ehcache、自定义get/set方式,并实际使用缓存的过程,文中通过示例代码介绍的非常详细,对大家的学习或者... 目录摘要概念内存与磁盘持久化存储:配置灵活性:编码示例引入依赖:配置ehcache.XML文件:配置

Java高效实现PowerPoint转PDF的示例详解

《Java高效实现PowerPoint转PDF的示例详解》在日常开发或办公场景中,经常需要将PowerPoint演示文稿(PPT/PPTX)转换为PDF,本文将介绍从基础转换到高级设置的多种用法,大家... 目录为什么要将 PowerPoint 转换为 PDF安装 Spire.Presentation fo

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

python中的高阶函数示例详解

《python中的高阶函数示例详解》在Python中,高阶函数是指接受函数作为参数或返回函数作为结果的函数,下面:本文主要介绍python中高阶函数的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录1.定义2.map函数3.filter函数4.reduce函数5.sorted函数6.自定义高阶函数

Vue实现路由守卫的示例代码

《Vue实现路由守卫的示例代码》Vue路由守卫是控制页面导航的钩子函数,主要用于鉴权、数据预加载等场景,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录一、概念二、类型三、实战一、概念路由守卫(Navigation Guards)本质上就是 在路

JAVA实现Token自动续期机制的示例代码

《JAVA实现Token自动续期机制的示例代码》本文主要介绍了JAVA实现Token自动续期机制的示例代码,通过动态调整会话生命周期平衡安全性与用户体验,解决固定有效期Token带来的风险与不便,感兴... 目录1. 固定有效期Token的内在局限性2. 自动续期机制:兼顾安全与体验的解决方案3. 总结PS

C#中通过Response.Headers设置自定义参数的代码示例

《C#中通过Response.Headers设置自定义参数的代码示例》:本文主要介绍C#中通过Response.Headers设置自定义响应头的方法,涵盖基础添加、安全校验、生产实践及调试技巧,强... 目录一、基础设置方法1. 直接添加自定义头2. 批量设置模式二、高级配置技巧1. 安全校验机制2. 类型

Python屏幕抓取和录制的详细代码示例

《Python屏幕抓取和录制的详细代码示例》随着现代计算机性能的提高和网络速度的加快,越来越多的用户需要对他们的屏幕进行录制,:本文主要介绍Python屏幕抓取和录制的相关资料,需要的朋友可以参考... 目录一、常用 python 屏幕抓取库二、pyautogui 截屏示例三、mss 高性能截图四、Pill

Java中的Schema校验技术与实践示例详解

《Java中的Schema校验技术与实践示例详解》本主题详细介绍了在Java环境下进行XMLSchema和JSONSchema校验的方法,包括使用JAXP、JAXB以及专门的JSON校验库等技术,本文... 目录1. XML和jsON的Schema校验概念1.1 XML和JSON校验的必要性1.2 Sche

使用MapStruct实现Java对象映射的示例代码

《使用MapStruct实现Java对象映射的示例代码》本文主要介绍了使用MapStruct实现Java对象映射的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、什么是 MapStruct?二、实战演练:三步集成 MapStruct第一步:添加 Mave