(重要) tensorflow 2.0 functional API

2024-08-22 03:32

本文主要是介绍(重要) tensorflow 2.0 functional API,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

https://tensorflow.google.cn/guide/keras/functional?hl=zh_cn

1. 使用 functional API 流程

1. 创建输入节点

img_inputs = keras.Input(shape=(32, 32, 3))

数据的形状设置为 784 维向量。由于仅指定了每个样本的形状,因此始终忽略批次大小(batch size)。

2. 创建层

x = layers.Dense(64, activation="relu")(inputs)
x = layers.Dense(64, activation="relu")(x)
outputs = layers.Dense(10)(x)

3. 创建model

model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model")

4. 查看模型

model.summary()
keras.utils.plot_model(model, "my_first_model.png") ## 网络图
keras.utils.plot_model(model, "my_first_model_with_shape_info.png", show_shapes=True)

5. 保存模型

model.save会把,网络和weight,参数配置,优化器状态一起保存。

model.save("path_to_my_model")
del model
# Recreate the exact same model purely from the file:
model = keras.models.load_model("path_to_my_model")

2 复用结构和参数

使用keras.Model创建模块即可,这个模块即可用来训练也可以当做子模块。

2.1 结构和参数都共享

encoder_input = keras.Input(shape=(28, 28, 1), name="original_img")
x = layers.Conv2D(16, 3, activation="relu")(encoder_input)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.MaxPooling2D(3)(x)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.Conv2D(16, 3, activation="relu")(x)
encoder_output = layers.GlobalMaxPooling2D()(x)encoder = keras.Model(encoder_input, encoder_output, name="encoder")
encoder.summary()decoder_input = keras.Input(shape=(16,), name="encoded_img")
x = layers.Reshape((4, 4, 1))(decoder_input)
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu")(x)
x = layers.UpSampling2D(3)(x)
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x)decoder = keras.Model(decoder_input, decoder_output, name="decoder")
decoder.summary()autoencoder_input = keras.Input(shape=(28, 28, 1), name="img")
encoded_img = encoder(autoencoder_input)
decoded_img = decoder(encoded_img)
autoencoder = keras.Model(autoencoder_input, decoded_img, name="autoencoder")
autoencoder.summary()

2.2 结构共享,参数不共享

如果像下面这种,那么网络参数就是不共享的,因为都重新创建一个keras.Model对象。

def get_model():inputs = keras.Input(shape=(128,))outputs = layers.Dense(1)(inputs)return keras.Model(inputs, outputs)model1 = get_model()
model2 = get_model()
model3 = get_model()inputs = keras.Input(shape=(128,))
y1 = model1(inputs)
y2 = model2(inputs)
y3 = model3(inputs)
outputs = layers.average([y1, y2, y3])
ensemble_model = keras.Model(inputs=inputs, outputs=outputs)

3. 构建多输入多输出

这样构建即可:

model = keras.Model(inputs=[title_input, body_input, tags_input],outputs=[priority_pred, department_pred],
)

4 共享embedding

# Embedding for 1000 unique words mapped to 128-dimensional vectors
shared_embedding = layers.Embedding(1000, 128)# Variable-length sequence of integers
text_input_a = keras.Input(shape=(None,), dtype="int32")# Variable-length sequence of integers
text_input_b = keras.Input(shape=(None,), dtype="int32")# Reuse the same layer to encode both inputs
encoded_input_a = shared_embedding(text_input_a)
encoded_input_b = shared_embedding(text_input_b)

5 提取和重用层计算图中的节点

## 下面是一个 VGG19 模型,其权重已在 ImageNet 上进行了预训练:
vgg19 = tf.keras.applications.VGG19()
##通过查询计算图数据结构获得的模型的中间激活:
features_list = [layer.output for layer in vgg19.layers]
## 使用以下特征来创建新的特征提取模型,该模型会返回中间层激活的值:
feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)img = np.random.random((1, 224, 224, 3)).astype("float32")
extracted_features = feat_extraction_model(img)

6 自定义layer和重用layer的区别

  • 自定义layer通常是没有算子的情况下需要做

  • 而重用layer只需要定义keras.Model的对象即可。

这篇关于(重要) tensorflow 2.0 functional API的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1095119

相关文章

使用Go调用第三方API的方法详解

《使用Go调用第三方API的方法详解》在现代应用开发中,调用第三方API是非常常见的场景,比如获取天气预报、翻译文本、发送短信等,Go作为一门高效并发的编程语言,拥有强大的标准库和丰富的第三方库,可以... 目录引言一、准备工作二、案例1:调用天气查询 API1. 注册并获取 API Key2. 代码实现3

PHP应用中处理限流和API节流的最佳实践

《PHP应用中处理限流和API节流的最佳实践》限流和API节流对于确保Web应用程序的可靠性、安全性和可扩展性至关重要,本文将详细介绍PHP应用中处理限流和API节流的最佳实践,下面就来和小编一起学习... 目录限流的重要性在 php 中实施限流的最佳实践使用集中式存储进行状态管理(如 Redis)采用滑动

Go语言使用net/http构建一个RESTful API的示例代码

《Go语言使用net/http构建一个RESTfulAPI的示例代码》Go的标准库net/http提供了构建Web服务所需的强大功能,虽然众多第三方框架(如Gin、Echo)已经封装了很多功能,但... 目录引言一、什么是 RESTful API?二、实战目标:用户信息管理 API三、代码实现1. 用户数据

Python用Flask封装API及调用详解

《Python用Flask封装API及调用详解》本文介绍Flask的优势(轻量、灵活、易扩展),对比GET/POST表单/JSON请求方式,涵盖错误处理、开发建议及生产环境部署注意事项... 目录一、Flask的优势一、基础设置二、GET请求方式服务端代码客户端调用三、POST表单方式服务端代码客户端调用四

SpringBoot结合Knife4j进行API分组授权管理配置详解

《SpringBoot结合Knife4j进行API分组授权管理配置详解》在现代的微服务架构中,API文档和授权管理是不可或缺的一部分,本文将介绍如何在SpringBoot应用中集成Knife4j,并进... 目录环境准备配置 Swagger配置 Swagger OpenAPI自定义 Swagger UI 底

使用Python的requests库调用API接口的详细步骤

《使用Python的requests库调用API接口的详细步骤》使用Python的requests库调用API接口是开发中最常用的方式之一,它简化了HTTP请求的处理流程,以下是详细步骤和实战示例,涵... 目录一、准备工作:安装 requests 库二、基本调用流程(以 RESTful API 为例)1.

SpringBoot监控API请求耗时的6中解决解决方案

《SpringBoot监控API请求耗时的6中解决解决方案》本文介绍SpringBoot中记录API请求耗时的6种方案,包括手动埋点、AOP切面、拦截器、Filter、事件监听、Micrometer+... 目录1. 简介2.实战案例2.1 手动记录2.2 自定义AOP记录2.3 拦截器技术2.4 使用Fi

Olingo分析和实践之ODataImpl详细分析(重要方法详解)

《Olingo分析和实践之ODataImpl详细分析(重要方法详解)》ODataImpl.java是ApacheOlingoOData框架的核心工厂类,负责创建序列化器、反序列化器和处理器等组件,... 目录概述主要职责类结构与继承关系核心功能分析1. 序列化器管理2. 反序列化器管理3. 处理器管理重要方

Knife4j+Axios+Redis前后端分离架构下的 API 管理与会话方案(最新推荐)

《Knife4j+Axios+Redis前后端分离架构下的API管理与会话方案(最新推荐)》本文主要介绍了Swagger与Knife4j的配置要点、前后端对接方法以及分布式Session实现原理,... 目录一、Swagger 与 Knife4j 的深度理解及配置要点Knife4j 配置关键要点1.Spri

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到