1.tensorflow线性回归示例:保存模型,载入模型,打印模型参数,修改模型

本文主要是介绍1.tensorflow线性回归示例:保存模型,载入模型,打印模型参数,修改模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

#coding:utf-8
'''
a liner regression by tenosrflow.
input dimension: 1, output dimension: 1.
显示每个epoch的loss
利用模型预测
保存模型
载入模型
打印模型中的参数
修改模型中的参数
'''
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file# data
x_train = np.linspace(-1, 1, 100)
y_train = 10 * x_train + np.random.randn(x_train.shape[0])
# plt.plot(x_train, y_train, "ro", label="data")
# plt.legend()
# plt.show()epochs = 30
display_step = 2
# input, output
x = tf.placeholder(dtype="float", name="input")
y = tf.placeholder(dtype="float", name="label")
# w, b
w = tf.Variable(initial_value=tf.random_normal([1]), name="weight")
b = tf.Variable(initial_value=tf.zeros([1]), name="bias")
# model
z = tf.multiply(x, w) + b
# loss functon
cost = tf.reduce_mean(tf.square(y - z))
# optimizer
optim = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost)
saver = tf.train.Saver(max_to_keep=4)  # save 4 model
init = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init)for epoch in range(epochs):for x_batch, y_batch in zip(x_train, y_train):  # batch is all data theresess.run(optim, feed_dict={x:x_batch, y:y_batch})if epoch % display_step ==0:loss = sess.run(cost, feed_dict={x:x_train, y:y_train})print("epoch: %d, loss: %d" %(epoch, loss))# 保存训练过程中的模型saver.save(sess, "line_regression_model/regress.cpkt", global_step=epoch)print("train finished...")# 保存最终的模型saver.save(sess, "line_regression_model/regress.cpkt")print("final loss:", sess.run(cost, feed_dict={x:x_train, y:y_train}))print("weight:", sess.run(w))print("bias:", sess.run(b))# show train data and predict dataplt.plot(x_train, y_train, "ro", label="train")predict = sess.run(w) * x_train + sess.run(b)plt.plot(x_train, predict, "b", label="predict")plt.legend()plt.show()# 载入模型
print("*"*50)
saver = tf.train.Saver()
with tf.Session() as sess2:sess2.run(tf.global_variables_initializer())saver.restore(sess2, "line_regression_model/regress.cpkt")print(sess2.run(w))print(sess2.run(b))predict2 = sess2.run(z, feed_dict={x:0.5})print(predict2)# 打印出模型中的变量及参数
print("-"*50)
print("the params in model:")
print_tensors_in_checkpoint_file("line_regression_model/regress.cpkt", None, True)# 修改模型中的参数,并重新保存
print("-"*50)
# 以上得到了模型中参数名字为weight,bias, 下面对他们进行修改
w_change = tf.Variable(10, name="weight")
b_change = tf.Variable(0.001, name="bias")
# 把他们放到一个字典里并写在saver里
saver = tf.train.Saver({"weighs":w_change, "bias":b_change})
with tf.Session() as sess3:sess3.run(tf.global_variables_initializer())# 保存修改后的参数saver.save(sess3, "line_regression_model/regress.cpkt")
# 发现参数已经被修改
print_tensors_in_checkpoint_file("line_regression_model/regress.cpkt", None, True)

输出:

/usr/local/bin/python2.7 /Users/ming/Downloads/zhangming/tf_demo/liner_regression.py
2018-11-17 16:07:32.138907: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
epoch: 0, loss: 21
epoch: 2, loss: 2
epoch: 4, loss: 1
epoch: 6, loss: 1
epoch: 8, loss: 1
epoch: 10, loss: 1
epoch: 12, loss: 1
epoch: 14, loss: 1
epoch: 16, loss: 1
epoch: 18, loss: 1
epoch: 20, loss: 1
epoch: 22, loss: 1
epoch: 24, loss: 1
epoch: 26, loss: 1
epoch: 28, loss: 1
train finished...
('final loss:', 1.0535882)
('weight:', array([10.063329], dtype=float32))
('bias:', array([0.03052005], dtype=float32))
**************************************************
[10.063329]
[0.03052005]
[5.0621843]
--------------------------------------------------
the params in model:
tensor_name:  bias
[0.03052005]
tensor_name:  weight
[10.063329]
--------------------------------------------------
tensor_name:  bias
0.001
tensor_name:  weighs
10
 

这篇关于1.tensorflow线性回归示例:保存模型,载入模型,打印模型参数,修改模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/986052

相关文章

Python中logging模块用法示例总结

《Python中logging模块用法示例总结》在Python中logging模块是一个强大的日志记录工具,它允许用户将程序运行期间产生的日志信息输出到控制台或者写入到文件中,:本文主要介绍Pyt... 目录前言一. 基本使用1. 五种日志等级2.  设置报告等级3. 自定义格式4. C语言风格的格式化方法

Spring 中的切面与事务结合使用完整示例

《Spring中的切面与事务结合使用完整示例》本文给大家介绍Spring中的切面与事务结合使用完整示例,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录 一、前置知识:Spring AOP 与 事务的关系 事务本质上就是一个“切面”二、核心组件三、完

SpringBoot 获取请求参数的常用注解及用法

《SpringBoot获取请求参数的常用注解及用法》SpringBoot通过@RequestParam、@PathVariable等注解支持从HTTP请求中获取参数,涵盖查询、路径、请求体、头、C... 目录SpringBoot 提供了多种注解来方便地从 HTTP 请求中获取参数以下是主要的注解及其用法:1

HTTP 与 SpringBoot 参数提交与接收协议方式

《HTTP与SpringBoot参数提交与接收协议方式》HTTP参数提交方式包括URL查询、表单、JSON/XML、路径变量、头部、Cookie、GraphQL、WebSocket和SSE,依据... 目录HTTP 协议支持多种参数提交方式,主要取决于请求方法(Method)和内容类型(Content-Ty

sky-take-out项目中Redis的使用示例详解

《sky-take-out项目中Redis的使用示例详解》SpringCache是Spring的缓存抽象层,通过注解简化缓存管理,支持Redis等提供者,适用于方法结果缓存、更新和删除操作,但无法实现... 目录Spring Cache主要特性核心注解1.@Cacheable2.@CachePut3.@Ca

QT Creator配置Kit的实现示例

《QTCreator配置Kit的实现示例》本文主要介绍了使用Qt5.12.12与VS2022时,因MSVC编译器版本不匹配及WindowsSDK缺失导致配置错误的问题解决,感兴趣的可以了解一下... 目录0、背景:qt5.12.12+vs2022一、症状:二、原因:(可以跳过,直奔后面的解决方法)三、解决方

MySQL中On duplicate key update的实现示例

《MySQL中Onduplicatekeyupdate的实现示例》ONDUPLICATEKEYUPDATE是一种MySQL的语法,它在插入新数据时,如果遇到唯一键冲突,则会执行更新操作,而不是抛... 目录1/ ON DUPLICATE KEY UPDATE的简介2/ ON DUPLICATE KEY UP

Python中Json和其他类型相互转换的实现示例

《Python中Json和其他类型相互转换的实现示例》本文介绍了在Python中使用json模块实现json数据与dict、object之间的高效转换,包括loads(),load(),dumps()... 项目中经常会用到json格式转为object对象、dict字典格式等。在此做个记录,方便后续用到该方

MySQL分库分表的实践示例

《MySQL分库分表的实践示例》MySQL分库分表适用于数据量大或并发压力高的场景,核心技术包括水平/垂直分片和分库,需应对分布式事务、跨库查询等挑战,通过中间件和解决方案实现,最佳实践为合理策略、备... 目录一、分库分表的触发条件1.1 数据量阈值1.2 并发压力二、分库分表的核心技术模块2.1 水平分

SpringBoot请求参数传递与接收示例详解

《SpringBoot请求参数传递与接收示例详解》本文给大家介绍SpringBoot请求参数传递与接收示例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋... 目录I. 基础参数传递i.查询参数(Query Parameters)ii.路径参数(Path Va