终于把tensorflow输入层和输出层搞懂了!fit函数与输入层,输出层,tf.keras.Model输入和输出的关系

2024-06-10 19:44

本文主要是介绍终于把tensorflow输入层和输出层搞懂了!fit函数与输入层,输出层,tf.keras.Model输入和输出的关系,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

结论

fit函数与输入层,输出层,tf.keras.Model输入和输出的关系

  • fit函数使用dataset格式,输入为字典格式,假设tf.keras.Model中输入和输出为字典格式(2.2或2.3),dataset的key必须和2.2或2.3中字典的key一致,否则报错
  • fit函数使用dataset格式,输入为仍然是字典格式,假设tf.keras.Model中输入和输出为list格式(2.1),dataset的key必须和2.1涉及到的**输入层和输出层(1.1和1.2)**的层名一致,否则报错

1. 定义模型输入和输出

1.1 定义模型输入层

continuous_input = {key: tf.keras.layers.Input(shape=(), name=key) for key in continuous_feature}
discrete_input = {key: tf.keras.layers.Input(shape=(), name=key) for key in discrete_feature}  

1.2 定义模型输出层

output_1 = tf.keras.layers.Dense(1, activation='sigmoid', name='is_click')(x)
output_2 = tf.keras.layers.Dense(1, activation='sigmoid', name='is_play')(x)
output_3 = tf.keras.layers.Dense(1, activation='sigmoid', name='is_pay')(x)

2. tf.keras.Model输入和输出

2.1 输入和输出为list格式

model_func = tf.keras.Model(inputs=list(continuous_input.values()) + list(discrete_input.values()),outputs=[output_1, output_2,  output_3])

2.2 输出为dict格式

model_func = tf.keras.Model(inputs=list(continuous_input.values()) + list(discrete_input.values()),outputs={'is_click': output_1, 'is_play': output_2, 'is_pay': output_3})

2.3 输入为dict格式

# 构造输入字典,也可以其他方式构造,此处只是为了说明,continuous_input为字典
continuous_input.update(discrete_input)
model_func = tf.keras.Model(inputs=continuous_input,outputs=[output_1, output_2,  output_3])

3. fit函数中输入和输出-dataset(tfrecord格式)

3.1 dataset定义

def _parse_function(example_proto, feature_description):# Parse the input `tf.Example` proto using the dictionary above.data = tf.io.parse_single_example(example_proto, feature_description)is_click = data.pop('is_click')is_play = data.pop('is_play')is_pay = data.pop('is_pay')return data, {'is_click': is_click, 'is_play': is_play, 'is_pay': is_pay}

3.2 dataset示例-batch_size=1024

dataset为字典格式,请注意!

({'age': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.3448276 , 0.27586207, 0.31034482, ..., 0.37931034, 0.44827586,0.1724138 ], dtype=float32)>, 'first_class_id': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([ 4,  1,  1, ...,  4, 15,  1], dtype=int64)>, 'gender': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([2, 1, 2, ..., 2, 2, 1], dtype=int64)>, 'married': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([2, 2, 2, ..., 1, 1, 1], dtype=int64)>, 'province': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([22, 23, 25, ..., 14, 18, 20], dtype=int64)>, 'second_class_id': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([23, 53, 24, ..., 29, 11, 47], dtype=int64)>, 'tag_id': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([ 5, 58,  6, ..., 17, 76, 49], dtype=int64)>, 'target_item_id': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([ 5, 58,  6, ..., 17, 76, 49], dtype=int64)>, 'type': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([1, 4, 1, ..., 2, 1, 1], dtype=int64)>, 'user_click_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.5       , 0.5833333 , 1.        , ..., 0.41666666, 0.33333334,0.6666667 ], dtype=float32)>, 'user_click_video_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.5       , 0.5833333 , 0.8333333 , ..., 0.41666666, 0.33333334,0.6666667 ], dtype=float32)>, 'user_exp_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.47826087, 0.5797101 , 0.6231884 , ..., 0.3768116 , 0.6086956 ,0.26086956], dtype=float32)>, 'user_exp_video_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.61290324, 0.61290324, 0.58064514, ..., 0.32258064, 0.61290324,0.4516129 ], dtype=float32)>, 'user_id': <tf.Tensor: shape=(1024,), dtype=string, numpy=
array([b'ffb07508-9acc-4253-a1a0-e3e7fc6fad58',b'1ac654df-2b93-47b8-80ba-ca15642b5919',b'69daac99-ad14-4fc8-80f7-8c80cbc221b3', ...,b'97366ccc-b10d-47cb-9ad6-956c535ccf87',b'a7f43278-9e9b-4500-98b7-6d536d680ac1',b'297e2eec-a491-4aab-bc96-24f806751eb1'], dtype=object)>, 'user_name': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([234, 239, 285, ..., 753, 222, 563], dtype=int64)>, 'user_pay_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'user_pay_video_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'user_play_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'user_play_video_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'video_click_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.4722222, 0.5      , 0.6666667, ..., 0.8333333, 0.6666667,0.4722222], dtype=float32)>, 'video_click_user_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.4       , 0.51428574, 0.6857143 , ..., 0.7714286 , 0.71428573,0.4857143 ], dtype=float32)>, 'video_duration': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.8046324 , 0.19939578, 0.52970797, ..., 0.40584087, 0.5800604 ,0.15005036], dtype=float32)>, 'video_exp_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.7175141 , 0.4858757 , 0.7627119 , ..., 0.69491524, 0.4519774 ,0.6384181 ], dtype=float32)>, 'video_exp_user_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.85057473, 0.4827586 , 0.83908045, ..., 0.7011494 , 0.3908046 ,0.54022986], dtype=float32)>, 'video_id': <tf.Tensor: shape=(1024,), dtype=string, numpy=
array([b'HVgLcemGqaFAYgyEemtb', b'YNfPZPQwWggZRBkSsjMG',b'AvTonQbyvahPSCjsLvqN', ..., b'tvcZUdJBXAzJxsOZkXIc',b'HxnekvQEXBAgptCkNpXQ', b'RGumXWzhSqSoikFAZcWH'], dtype=object)>, 'video_name': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([15, 60,  1, ..., 25, 70, 47], dtype=int64)>, 'video_pay_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'video_pay_user_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'video_play_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.6666667 , 0.5555556 , 0.7777778 , ..., 0.6666667 , 0.44444445,0.33333334], dtype=float32)>, 'video_play_user_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.6666667 , 0.5555556 , 0.7777778 , ..., 0.6666667 , 0.44444445,0.33333334], dtype=float32)>, 'work': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([1, 2, 2, ..., 2, 3, 3], dtype=int64)>}, {'is_click': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'is_play': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'is_pay': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>})

4. fit函数与输入层,输出层,tf.keras.Model输入和输出的关系

  • fit函数使用dataset格式,输入为字典格式,假设tf.keras.Model中输入和输出为字典格式(2.2或2.3),dataset的key必须和2.2或2.3中字典的key一致,否则报错
  • fit函数使用dataset格式,输入为仍然是字典格式,假设tf.keras.Model中输入和输出为list格式(2.1),dataset的key必须和2.1涉及到的**输入层和输出层(1.1和1.2)**的层名一致,否则报错

这篇关于终于把tensorflow输入层和输出层搞懂了!fit函数与输入层,输出层,tf.keras.Model输入和输出的关系的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1049096

相关文章

一篇文章彻底搞懂macOS如何决定java环境

《一篇文章彻底搞懂macOS如何决定java环境》MacOS作为一个功能强大的操作系统,为开发者提供了丰富的开发工具和框架,下面:本文主要介绍macOS如何决定java环境的相关资料,文中通过代码... 目录方法一:使用 which命令方法二:使用 Java_home工具(Apple 官方推荐)那问题来了,

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

python中的高阶函数示例详解

《python中的高阶函数示例详解》在Python中,高阶函数是指接受函数作为参数或返回函数作为结果的函数,下面:本文主要介绍python中高阶函数的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录1.定义2.map函数3.filter函数4.reduce函数5.sorted函数6.自定义高阶函数

Python中的sort方法、sorted函数与lambda表达式及用法详解

《Python中的sort方法、sorted函数与lambda表达式及用法详解》文章对比了Python中list.sort()与sorted()函数的区别,指出sort()原地排序返回None,sor... 目录1. sort()方法1.1 sort()方法1.2 基本语法和参数A. reverse参数B.

Java 中的 equals 和 hashCode 方法关系与正确重写实践案例

《Java中的equals和hashCode方法关系与正确重写实践案例》在Java中,equals和hashCode方法是Object类的核心方法,广泛用于对象比较和哈希集合(如HashMa... 目录一、背景与需求分析1.1 equals 和 hashCode 的背景1.2 需求分析1.3 技术挑战1.4

一文详解MySQL索引(六张图彻底搞懂)

《一文详解MySQL索引(六张图彻底搞懂)》MySQL索引的建立对于MySQL的高效运行是很重要的,索引可以大大提高MySQL的检索速度,:本文主要介绍MySQL索引的相关资料,文中通过代码介绍的... 目录一、什么是索引?为什么需要索引?二、索引该用哪种数据结构?1. 哈希表2. 跳表3. 二叉排序树4.

Python函数的基本用法、返回值特性、全局变量修改及异常处理技巧

《Python函数的基本用法、返回值特性、全局变量修改及异常处理技巧》本文将通过实际代码示例,深入讲解Python函数的基本用法、返回值特性、全局变量修改以及异常处理技巧,感兴趣的朋友跟随小编一起看看... 目录一、python函数定义与调用1.1 基本函数定义1.2 函数调用二、函数返回值详解2.1 有返

Python Excel 通用筛选函数的实现

《PythonExcel通用筛选函数的实现》本文主要介绍了PythonExcel通用筛选函数的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录案例目的示例数据假定数据来源是字典优化:通用CSV数据处理函数使用说明使用示例注意事项案例目的第一

C++统计函数执行时间的最佳实践

《C++统计函数执行时间的最佳实践》在软件开发过程中,性能分析是优化程序的重要环节,了解函数的执行时间分布对于识别性能瓶颈至关重要,本文将分享一个C++函数执行时间统计工具,希望对大家有所帮助... 目录前言工具特性核心设计1. 数据结构设计2. 单例模式管理器3. RAII自动计时使用方法基本用法高级用法