使用estimator结构训练tf模型

2024-08-23 14:18

本文主要是介绍使用estimator结构训练tf模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、使用estimator训练模型的流程

1、构建model_fn

def my_metric_fn(labels, predictions):return {'accuracy': tf.metrics.accuracy(labels, predictions)}def model_fn(features, labels, mode, params):""" TODO: 模型函数必须有这四个参数:param features: # 输入的特征数据:param labels: # 输入的标签数据:param mode: # train、evaluate或predict:param params: #超参数,对应Estimator传来的参数:return: TPUEstimatorSpec类型的对象"""eval_metrics=(my_metric_fn, [labels, predictions])output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, # "train" or "eval" or "predict"loss=total_loss, # double类型eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)  # None or funreturn output_spec

2、定义estimator

run_config = tf.contrib.tpu.RunConfig(cluster=tpu_cluster_resolver,master=FLAGS.master,model_dir=FLAGS.output_dir,save_checkpoints_steps=FLAGS.save_checkpoints_steps,keep_checkpoint_max=FLAGS.keep_checkpoint_max,tf_random_seed=FLAGS.random_seed,tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=FLAGS.save_checkpoints_steps,num_shards=FLAGS.num_tpu_cores,per_host_input_for_training=is_per_host))# 自定义估算器
estimator = tf.contrib.tpu.TPUEstimator(use_tpu=FLAGS.use_tpu,model_fn=model_fn,  # 模型函数config=run_config,  # 设置参数对象train_batch_size=FLAGS.train_batch_size,eval_batch_size=FLAGS.eval_batch_size,predict_batch_size=FLAGS.predict_batch_size)

3、训练模型

def train_input_fn(params):batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return destimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)

4、验证模型

def eval_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)  # type:dict
for key in sorted(result.keys()):log_info = "  %s = %s"%(key, str(result[key]))

5、测试模型

def predict_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.predict(input_fn=predict_input_fn)  # type:dict
for key in sorted(result.keys()):log_info = "  %s = %s"%(key, str(result[key]))

二、使用estimator训练模型的样例

这篇关于使用estimator结构训练tf模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1099613

相关文章

使用jenv工具管理多个JDK版本的方法步骤

《使用jenv工具管理多个JDK版本的方法步骤》jenv是一个开源的Java环境管理工具,旨在帮助开发者在同一台机器上轻松管理和切换多个Java版本,:本文主要介绍使用jenv工具管理多个JD... 目录一、jenv到底是干啥的?二、jenv的核心功能(一)管理多个Java版本(二)支持插件扩展(三)环境隔

SQL中JOIN操作的条件使用总结与实践

《SQL中JOIN操作的条件使用总结与实践》在SQL查询中,JOIN操作是多表关联的核心工具,本文将从原理,场景和最佳实践三个方面总结JOIN条件的使用规则,希望可以帮助开发者精准控制查询逻辑... 目录一、ON与WHERE的本质区别二、场景化条件使用规则三、最佳实践建议1.优先使用ON条件2.WHERE用

Java中Map.Entry()含义及方法使用代码

《Java中Map.Entry()含义及方法使用代码》:本文主要介绍Java中Map.Entry()含义及方法使用的相关资料,Map.Entry是Java中Map的静态内部接口,用于表示键值对,其... 目录前言 Map.Entry作用核心方法常见使用场景1. 遍历 Map 的所有键值对2. 直接修改 Ma

MySQL 衍生表(Derived Tables)的使用

《MySQL衍生表(DerivedTables)的使用》本文主要介绍了MySQL衍生表(DerivedTables)的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学... 目录一、衍生表简介1.1 衍生表基本用法1.2 自定义列名1.3 衍生表的局限在SQL的查询语句select

Mybatis Plus Join使用方法示例详解

《MybatisPlusJoin使用方法示例详解》:本文主要介绍MybatisPlusJoin使用方法示例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,... 目录1、pom文件2、yaml配置文件3、分页插件4、示例代码:5、测试代码6、和PageHelper结合6

MySQL分区表的具体使用

《MySQL分区表的具体使用》MySQL分区表通过规则将数据分至不同物理存储,提升管理与查询效率,本文主要介绍了MySQL分区表的具体使用,具有一定的参考价值,感兴趣的可以了解一下... 目录一、分区的类型1. Range partition(范围分区)2. List partition(列表分区)3. H

使用SpringBoot整合Sharding Sphere实现数据脱敏的示例

《使用SpringBoot整合ShardingSphere实现数据脱敏的示例》ApacheShardingSphere数据脱敏模块,通过SQL拦截与改写实现敏感信息加密存储,解决手动处理繁琐及系统改... 目录痛点一:痛点二:脱敏配置Quick Start——Spring 显示配置:1.引入依赖2.创建脱敏

Python使用smtplib库开发一个邮件自动发送工具

《Python使用smtplib库开发一个邮件自动发送工具》在现代软件开发中,自动化邮件发送是一个非常实用的功能,无论是系统通知、营销邮件、还是日常工作报告,Python的smtplib库都能帮助我们... 目录代码实现与知识点解析1. 导入必要的库2. 配置邮件服务器参数3. 创建邮件发送类4. 实现邮件

Go语言中Recover机制的使用

《Go语言中Recover机制的使用》Go语言的recover机制通过defer函数捕获panic,实现异常恢复与程序稳定性,具有一定的参考价值,感兴趣的可以了解一下... 目录引言Recover 的基本概念基本代码示例简单的 Recover 示例嵌套函数中的 Recover项目场景中的应用Web 服务器中

CnPlugin是PL/SQL Developer工具插件使用教程

《CnPlugin是PL/SQLDeveloper工具插件使用教程》:本文主要介绍CnPlugin是PL/SQLDeveloper工具插件使用教程,具有很好的参考价值,希望对大家有所帮助,如有错... 目录PL/SQL Developer工具插件使用安装拷贝文件配置总结PL/SQL Developer工具插