Keras高级--构建复杂的自定义Losses和Metrics

2023-12-20 10:38

本文主要是介绍Keras高级--构建复杂的自定义Losses和Metrics,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Advanced Keras — Constructing Complex Custom Losses and Metrics

    • 1. 背景
    • 2. 自定义loss函数
    • 3. 多参数loss/metric函数

翻译自:https://towardsdatascience.com/advanced-keras-constructing-complex-custom-losses-and-metrics-c07ca130a618
作者:Eyal Zakkay
转载请注明出处以及本文链接

本文将介绍一个简单技巧来在Keras中构建自定义loss函数,它可以接收除 y_truey_pred 之外的参数。

1. 背景

在Keras中编译模型时,我们为compile函数提供所需的loss与metrics。例如:

model.compile(loss=’mean_squared_error’, optimizer=’sgd’, metrics=‘acc’)

出于可读性目的,从现在开始关注loss函数。但是,大部分内容也适用于metrics。
从Keras的有关loss函数的文档中能得到:

你可以传递一个现有的损失函数名,或者一个 TensorFlow/Theano 符号函数。该符号函数为每个数据点返回一个标量,有以下两个参数:
y_true: 真实标签。TensorFlow/Theano 张量。
y_pred: 预测值。TensorFlow/Theano 张量,其 shape 与 y_true 相同。

所以如果我们想使用常用的loss函数,比如MSE或者 Categorical Cross-entropy,我们可以通过传入合适的参数来很容易实现。
Keras的文档中提供了可用loss和metrics的列表。

2. 自定义loss函数

当我们需要使用实现更复杂的loss函数或者metric时,就需要构造自定义函数并且传给model.compile
例如,构建自定义metric(来自Keras文档):

import keras.backend as Kdef mean_pred(y_true, y_pred):return K.mean(y_pred)model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy', mean_pred])

3. 多参数loss/metric函数

可能你已经注意到,loss函数必须只接受2个参数y_truey_pred,分别是target张量和模型output张量。但是如果我们希望loss/metric函数依赖于除这两个之外的其他张量呢?
为此,我们需要使用 function closure。首先创建一个loss函数(随意取参数)返回一个y_truey_pred的函数。
例如,如果我们想(由于某种原因)创建一个loss函数,其将第一层中所有activations的均方值添加到MSE:


# Build a model
inputs = Input(shape=(128,))
layer1 = Dense(64, activation='relu')(inputs)
layer2 = Dense(64, activation='relu')(layer1)
predictions = Dense(10, activation='softmax')(layer2)
model = Model(inputs=inputs, outputs=predictions)# Define custom loss
def custom_loss(layer):# Create a loss function that adds the MSE loss to the mean of all squared activations of a specific layerdef loss(y_true,y_pred):return K.mean(K.square(y_pred - y_true) + K.square(layer), axis=-1)# Return a functionreturn loss# Compile the model
model.compile(optimizer='adam',loss=custom_loss(layer), # Call the loss function with the selected layermetrics=['accuracy'])# train
model.fit(data, labels)  

注意,我们创建了一个函数(不限制参数的数量),它返回了一个合法的loss函数,该函数可以访问其封闭函数的参数。
一个更具体的例子:

假设设计一个变分自动编码器。希望模型能够从编码的潜在空间重建其输入。但还希望潜在空间中的编码(近似)正态分布。
前者的目标可以通过设计一个仅依赖于输入和期望输出y_truey_pred的重建损失来实现。对于后者,需要设计一个对潜在张量进行操作的损失项(例如,Kullback Leibler损失)。为了让你的loss函数能够访问这个中间张量,我们刚学到的技巧可以派上用场。
使用示例:

    def model_loss(self):"""" Wrapper function which calculates auxiliary values for the complete loss function.Returns a *function* which calculates the complete loss given only the input and target output """# KL losskl_loss = self.calculate_kl_loss# Reconstruction lossmd_loss_func = self.calculate_md_loss# KL weight (to be used by total loss and by annealing scheduler)self.kl_weight = K.variable(self.hps['kl_weight_start'], name='kl_weight')kl_weight = self.kl_weightdef seq2seq_loss(y_true, y_pred):""" Final loss calculation function to be passed to optimizer"""# Reconstruction lossmd_loss = md_loss_func(y_true, y_pred)# Full lossmodel_loss = kl_weight*kl_loss() + md_lossreturn model_lossreturn seq2seq_loss

此示例是Sequence to Sequence Variational Autoencoder模型的一部分,更多上下文和完整代码访问此

使用的Keras版本=2.2.4

Reference

[1]. Keras — Losses
[2]. Keras — Metrics
[3]. Github Issue — Passing additional arguments to objective function

Thanks to Ludovic Benistant.

这篇关于Keras高级--构建复杂的自定义Losses和Metrics的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/515734

相关文章

python panda库从基础到高级操作分析

《pythonpanda库从基础到高级操作分析》本文介绍了Pandas库的核心功能,包括处理结构化数据的Series和DataFrame数据结构,数据读取、清洗、分组聚合、合并、时间序列分析及大数据... 目录1. Pandas 概述2. 基本操作:数据读取与查看3. 索引操作:精准定位数据4. Group

使用Python构建智能BAT文件生成器的完美解决方案

《使用Python构建智能BAT文件生成器的完美解决方案》这篇文章主要为大家详细介绍了如何使用wxPython构建一个智能的BAT文件生成器,它不仅能够为Python脚本生成启动脚本,还提供了完整的文... 目录引言运行效果图项目背景与需求分析核心需求技术选型核心功能实现1. 数据库设计2. 界面布局设计3

深入浅出SpringBoot WebSocket构建实时应用全面指南

《深入浅出SpringBootWebSocket构建实时应用全面指南》WebSocket是一种在单个TCP连接上进行全双工通信的协议,这篇文章主要为大家详细介绍了SpringBoot如何集成WebS... 目录前言为什么需要 WebSocketWebSocket 是什么Spring Boot 如何简化 We

springboot自定义注解RateLimiter限流注解技术文档详解

《springboot自定义注解RateLimiter限流注解技术文档详解》文章介绍了限流技术的概念、作用及实现方式,通过SpringAOP拦截方法、缓存存储计数器,结合注解、枚举、异常类等核心组件,... 目录什么是限流系统架构核心组件详解1. 限流注解 (@RateLimiter)2. 限流类型枚举 (

SpringBoot 异常处理/自定义格式校验的问题实例详解

《SpringBoot异常处理/自定义格式校验的问题实例详解》文章探讨SpringBoot中自定义注解校验问题,区分参数级与类级约束触发的异常类型,建议通过@RestControllerAdvice... 目录1. 问题简要描述2. 异常触发1) 参数级别约束2) 类级别约束3. 异常处理1) 字段级别约束

Spring Boot Maven 插件如何构建可执行 JAR 的核心配置

《SpringBootMaven插件如何构建可执行JAR的核心配置》SpringBoot核心Maven插件,用于生成可执行JAR/WAR,内置服务器简化部署,支持热部署、多环境配置及依赖管理... 目录前言一、插件的核心功能与目标1.1 插件的定位1.2 插件的 Goals(目标)1.3 插件定位1.4 核

使用Python构建一个高效的日志处理系统

《使用Python构建一个高效的日志处理系统》这篇文章主要为大家详细讲解了如何使用Python开发一个专业的日志分析工具,能够自动化处理、分析和可视化各类日志文件,大幅提升运维效率,需要的可以了解下... 目录环境准备工具功能概述完整代码实现代码深度解析1. 类设计与初始化2. 日志解析核心逻辑3. 文件处

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并

使用Docker构建Python Flask程序的详细教程

《使用Docker构建PythonFlask程序的详细教程》在当今的软件开发领域,容器化技术正变得越来越流行,而Docker无疑是其中的佼佼者,本文我们就来聊聊如何使用Docker构建一个简单的Py... 目录引言一、准备工作二、创建 Flask 应用程序三、创建 dockerfile四、构建 Docker

Python中你不知道的gzip高级用法分享

《Python中你不知道的gzip高级用法分享》在当今大数据时代,数据存储和传输成本已成为每个开发者必须考虑的问题,Python内置的gzip模块提供了一种简单高效的解决方案,下面小编就来和大家详细讲... 目录前言:为什么数据压缩如此重要1. gzip 模块基础介绍2. 基本压缩与解压缩操作2.1 压缩文