深度学习Keras保存模型(当包含自定义层时)

2024-06-15 16:38

本文主要是介绍深度学习Keras保存模型(当包含自定义层时),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

20210409 -

0.引言

一般来说,如果没有什么特殊情况,那么在进行保存模型的时候,通过调用一些api进行保存即可,在文章《深度学习的基础知识与问题汇总》简单介绍了一种方式,直接保存模型,并重新载入。

但是在今天的实验中,出现的需求就是,如果定义的模型中,包含了用户自定义的层就会报错,具体情况见[1],在保存模型的时候没有问题,但是载入时就会报错。
简单说明一下实验环境

python 3.6.8
tensorflow-gpu 2.3.1
Keras 2.4.3

1. 加载模型报错:未定义层

在[1]中,如果没有对自定义的层进行一系列的规定,那么在加载模型的时候, 就会报错为定义层。

ValueError: Unknown layer: CustomLayer

上面这种问题使用的api是load_model,通过这种方式加载整个模型以及各种权值,针对这种错误,可以通过两种方法来解决。在问答[2]中都提到了,一种是在加载模型的时候,在api指定自定义类,如下:

new_model = tf.keras.models.load_model('model.h5', custom_objects={'CustomLayer': CustomLayer})

可能如果你使用了自定义的损失函数,也需要将这部分内容传输进去。
而另外一种方法更方便,针对如果定义了多个自定义内容:

import tensorflow as tf@tf.keras.utils.register_keras_serializable()
class CustomLayer(tf.keras.layers.Layer):def __init__(self, k, **kwargs):self.k = ksuper(CustomLayer, self).__init__(**kwargs)def get_config(self):config = super().get_config()config["k"] = self.kreturn configdef call(self, input):return tf.multiply(input, 2)

在自定义层上加上修饰器。

注意看,这里他自定义类的时候,上面这些要实现的函数。最重要的是,将自己的参数在函数get_config中进行保存。具体可以看[2]的方式,可以以他的编程方式作为模板,将模型保存起来,其中还包含了参数初始化的内容,而且可以看到在build过程中引入的add_weight不用在get_config中声明。

2. 保存权值

本次实验中,最后使用的方式是仅仅保存权值,反正逻辑上都得先定义这个模型。使用的方式是在训练的时候加入了保存模型的回调函数。

参考

[1]Saving Keras models with Custom Layers
[2]Not able to load a saved model with custom layer

这篇关于深度学习Keras保存模型(当包含自定义层时)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1064004

相关文章

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置

Maven中生命周期深度解析与实战指南

《Maven中生命周期深度解析与实战指南》这篇文章主要为大家详细介绍了Maven生命周期实战指南,包含核心概念、阶段详解、SpringBoot特化场景及企业级实践建议,希望对大家有一定的帮助... 目录一、Maven 生命周期哲学二、default生命周期核心阶段详解(高频使用)三、clean生命周期核心阶

深度剖析SpringBoot日志性能提升的原因与解决

《深度剖析SpringBoot日志性能提升的原因与解决》日志记录本该是辅助工具,却为何成了性能瓶颈,SpringBoot如何用代码彻底破解日志导致的高延迟问题,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言第一章:日志性能陷阱的底层原理1.1 日志级别的“双刃剑”效应1.2 同步日志的“吞吐量杀手”

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

聊聊springboot中如何自定义消息转换器

《聊聊springboot中如何自定义消息转换器》SpringBoot通过HttpMessageConverter处理HTTP数据转换,支持多种媒体类型,接下来通过本文给大家介绍springboot中... 目录核心接口springboot默认提供的转换器如何自定义消息转换器Spring Boot 中的消息

深度解析Python yfinance的核心功能和高级用法

《深度解析Pythonyfinance的核心功能和高级用法》yfinance是一个功能强大且易于使用的Python库,用于从YahooFinance获取金融数据,本教程将深入探讨yfinance的核... 目录yfinance 深度解析教程 (python)1. 简介与安装1.1 什么是 yfinance?

Python学习笔记之getattr和hasattr用法示例详解

《Python学习笔记之getattr和hasattr用法示例详解》在Python中,hasattr()、getattr()和setattr()是一组内置函数,用于对对象的属性进行操作和查询,这篇文章... 目录1.getattr用法详解1.1 基本作用1.2 示例1.3 原理2.hasattr用法详解2.

Python自定义异常的全面指南(入门到实践)

《Python自定义异常的全面指南(入门到实践)》想象你正在开发一个银行系统,用户转账时余额不足,如果直接抛出ValueError,调用方很难区分是金额格式错误还是余额不足,这正是Python自定义异... 目录引言:为什么需要自定义异常一、异常基础:先搞懂python的异常体系1.1 异常是什么?1.2

Linux中的自定义协议+序列反序列化用法

《Linux中的自定义协议+序列反序列化用法》文章探讨网络程序在应用层的实现,涉及TCP协议的数据传输机制、结构化数据的序列化与反序列化方法,以及通过JSON和自定义协议构建网络计算器的思路,强调分层... 目录一,再次理解协议二,序列化和反序列化三,实现网络计算器3.1 日志文件3.2Socket.hpp