使用estimator结构训练tf模型

2024-08-23 14:18

本文主要是介绍使用estimator结构训练tf模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、使用estimator训练模型的流程

1、构建model_fn

def my_metric_fn(labels, predictions):return {'accuracy': tf.metrics.accuracy(labels, predictions)}def model_fn(features, labels, mode, params):""" TODO: 模型函数必须有这四个参数:param features: # 输入的特征数据:param labels: # 输入的标签数据:param mode: # train、evaluate或predict:param params: #超参数,对应Estimator传来的参数:return: TPUEstimatorSpec类型的对象"""eval_metrics=(my_metric_fn, [labels, predictions])output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, # "train" or "eval" or "predict"loss=total_loss, # double类型eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)  # None or funreturn output_spec

2、定义estimator

run_config = tf.contrib.tpu.RunConfig(cluster=tpu_cluster_resolver,master=FLAGS.master,model_dir=FLAGS.output_dir,save_checkpoints_steps=FLAGS.save_checkpoints_steps,keep_checkpoint_max=FLAGS.keep_checkpoint_max,tf_random_seed=FLAGS.random_seed,tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=FLAGS.save_checkpoints_steps,num_shards=FLAGS.num_tpu_cores,per_host_input_for_training=is_per_host))# 自定义估算器
estimator = tf.contrib.tpu.TPUEstimator(use_tpu=FLAGS.use_tpu,model_fn=model_fn,  # 模型函数config=run_config,  # 设置参数对象train_batch_size=FLAGS.train_batch_size,eval_batch_size=FLAGS.eval_batch_size,predict_batch_size=FLAGS.predict_batch_size)

3、训练模型

def train_input_fn(params):batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return destimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)

4、验证模型

def eval_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)  # type:dict
for key in sorted(result.keys()):log_info = "  %s = %s"%(key, str(result[key]))

5、测试模型

def predict_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.predict(input_fn=predict_input_fn)  # type:dict
for key in sorted(result.keys()):log_info = "  %s = %s"%(key, str(result[key]))

二、使用estimator训练模型的样例

这篇关于使用estimator结构训练tf模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1099613

相关文章

Java使用Thumbnailator库实现图片处理与压缩功能

《Java使用Thumbnailator库实现图片处理与压缩功能》Thumbnailator是高性能Java图像处理库,支持缩放、旋转、水印添加、裁剪及格式转换,提供易用API和性能优化,适合Web应... 目录1. 图片处理库Thumbnailator介绍2. 基本和指定大小图片缩放功能2.1 图片缩放的

Oracle查询表结构建表语句索引等方式

《Oracle查询表结构建表语句索引等方式》使用USER_TAB_COLUMNS查询表结构可避免系统隐藏字段(如LISTUSER的CLOB与VARCHAR2同名字段),这些字段可能为dbms_lob.... 目录oracle查询表结构建表语句索引1.用“USER_TAB_COLUMNS”查询表结构2.用“a

Python使用Tenacity一行代码实现自动重试详解

《Python使用Tenacity一行代码实现自动重试详解》tenacity是一个专为Python设计的通用重试库,它的核心理念就是用简单、清晰的方式,为任何可能失败的操作添加重试能力,下面我们就来看... 目录一切始于一个简单的 API 调用Tenacity 入门:一行代码实现优雅重试精细控制:让重试按我

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

使用Python构建智能BAT文件生成器的完美解决方案

《使用Python构建智能BAT文件生成器的完美解决方案》这篇文章主要为大家详细介绍了如何使用wxPython构建一个智能的BAT文件生成器,它不仅能够为Python脚本生成启动脚本,还提供了完整的文... 目录引言运行效果图项目背景与需求分析核心需求技术选型核心功能实现1. 数据库设计2. 界面布局设计3

使用IDEA部署Docker应用指南分享

《使用IDEA部署Docker应用指南分享》本文介绍了使用IDEA部署Docker应用的四步流程:创建Dockerfile、配置IDEADocker连接、设置运行调试环境、构建运行镜像,并强调需准备本... 目录一、创建 dockerfile 配置文件二、配置 IDEA 的 Docker 连接三、配置 Do

Android Paging 分页加载库使用实践

《AndroidPaging分页加载库使用实践》AndroidPaging库是Jetpack组件的一部分,它提供了一套完整的解决方案来处理大型数据集的分页加载,本文将深入探讨Paging库... 目录前言一、Paging 库概述二、Paging 3 核心组件1. PagingSource2. Pager3.

python使用try函数详解

《python使用try函数详解》Pythontry语句用于异常处理,支持捕获特定/多种异常、else/final子句确保资源释放,结合with语句自动清理,可自定义异常及嵌套结构,灵活应对错误场景... 目录try 函数的基本语法捕获特定异常捕获多个异常使用 else 子句使用 finally 子句捕获所

C++11右值引用与Lambda表达式的使用

《C++11右值引用与Lambda表达式的使用》C++11引入右值引用,实现移动语义提升性能,支持资源转移与完美转发;同时引入Lambda表达式,简化匿名函数定义,通过捕获列表和参数列表灵活处理变量... 目录C++11新特性右值引用和移动语义左值 / 右值常见的左值和右值移动语义移动构造函数移动复制运算符

Python对接支付宝支付之使用AliPay实现的详细操作指南

《Python对接支付宝支付之使用AliPay实现的详细操作指南》支付宝没有提供PythonSDK,但是强大的github就有提供python-alipay-sdk,封装里很多复杂操作,使用这个我们就... 目录一、引言二、准备工作2.1 支付宝开放平台入驻与应用创建2.2 密钥生成与配置2.3 安装ali