Keras高级--构建复杂的自定义Losses和Metrics

2023-12-20 10:38

本文主要是介绍Keras高级--构建复杂的自定义Losses和Metrics,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Advanced Keras — Constructing Complex Custom Losses and Metrics

    • 1. 背景
    • 2. 自定义loss函数
    • 3. 多参数loss/metric函数

翻译自:https://towardsdatascience.com/advanced-keras-constructing-complex-custom-losses-and-metrics-c07ca130a618
作者:Eyal Zakkay
转载请注明出处以及本文链接

本文将介绍一个简单技巧来在Keras中构建自定义loss函数,它可以接收除 y_truey_pred 之外的参数。

1. 背景

在Keras中编译模型时,我们为compile函数提供所需的loss与metrics。例如:

model.compile(loss=’mean_squared_error’, optimizer=’sgd’, metrics=‘acc’)

出于可读性目的,从现在开始关注loss函数。但是,大部分内容也适用于metrics。
从Keras的有关loss函数的文档中能得到:

你可以传递一个现有的损失函数名,或者一个 TensorFlow/Theano 符号函数。该符号函数为每个数据点返回一个标量,有以下两个参数:
y_true: 真实标签。TensorFlow/Theano 张量。
y_pred: 预测值。TensorFlow/Theano 张量,其 shape 与 y_true 相同。

所以如果我们想使用常用的loss函数,比如MSE或者 Categorical Cross-entropy,我们可以通过传入合适的参数来很容易实现。
Keras的文档中提供了可用loss和metrics的列表。

2. 自定义loss函数

当我们需要使用实现更复杂的loss函数或者metric时,就需要构造自定义函数并且传给model.compile
例如,构建自定义metric(来自Keras文档):

import keras.backend as Kdef mean_pred(y_true, y_pred):return K.mean(y_pred)model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy', mean_pred])

3. 多参数loss/metric函数

可能你已经注意到,loss函数必须只接受2个参数y_truey_pred,分别是target张量和模型output张量。但是如果我们希望loss/metric函数依赖于除这两个之外的其他张量呢?
为此,我们需要使用 function closure。首先创建一个loss函数(随意取参数)返回一个y_truey_pred的函数。
例如,如果我们想(由于某种原因)创建一个loss函数,其将第一层中所有activations的均方值添加到MSE:


# Build a model
inputs = Input(shape=(128,))
layer1 = Dense(64, activation='relu')(inputs)
layer2 = Dense(64, activation='relu')(layer1)
predictions = Dense(10, activation='softmax')(layer2)
model = Model(inputs=inputs, outputs=predictions)# Define custom loss
def custom_loss(layer):# Create a loss function that adds the MSE loss to the mean of all squared activations of a specific layerdef loss(y_true,y_pred):return K.mean(K.square(y_pred - y_true) + K.square(layer), axis=-1)# Return a functionreturn loss# Compile the model
model.compile(optimizer='adam',loss=custom_loss(layer), # Call the loss function with the selected layermetrics=['accuracy'])# train
model.fit(data, labels)  

注意,我们创建了一个函数(不限制参数的数量),它返回了一个合法的loss函数,该函数可以访问其封闭函数的参数。
一个更具体的例子:

假设设计一个变分自动编码器。希望模型能够从编码的潜在空间重建其输入。但还希望潜在空间中的编码(近似)正态分布。
前者的目标可以通过设计一个仅依赖于输入和期望输出y_truey_pred的重建损失来实现。对于后者,需要设计一个对潜在张量进行操作的损失项(例如,Kullback Leibler损失)。为了让你的loss函数能够访问这个中间张量,我们刚学到的技巧可以派上用场。
使用示例:

    def model_loss(self):"""" Wrapper function which calculates auxiliary values for the complete loss function.Returns a *function* which calculates the complete loss given only the input and target output """# KL losskl_loss = self.calculate_kl_loss# Reconstruction lossmd_loss_func = self.calculate_md_loss# KL weight (to be used by total loss and by annealing scheduler)self.kl_weight = K.variable(self.hps['kl_weight_start'], name='kl_weight')kl_weight = self.kl_weightdef seq2seq_loss(y_true, y_pred):""" Final loss calculation function to be passed to optimizer"""# Reconstruction lossmd_loss = md_loss_func(y_true, y_pred)# Full lossmodel_loss = kl_weight*kl_loss() + md_lossreturn model_lossreturn seq2seq_loss

此示例是Sequence to Sequence Variational Autoencoder模型的一部分,更多上下文和完整代码访问此

使用的Keras版本=2.2.4

Reference

[1]. Keras — Losses
[2]. Keras — Metrics
[3]. Github Issue — Passing additional arguments to objective function

Thanks to Ludovic Benistant.

这篇关于Keras高级--构建复杂的自定义Losses和Metrics的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/515734

相关文章

如何自定义一个log适配器starter

《如何自定义一个log适配器starter》:本文主要介绍如何自定义一个log适配器starter的问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求Starter 项目目录结构pom.XML 配置LogInitializer实现MDCInterceptor

基于Python构建一个高效词汇表

《基于Python构建一个高效词汇表》在自然语言处理(NLP)领域,构建高效的词汇表是文本预处理的关键步骤,本文将解析一个使用Python实现的n-gram词频统计工具,感兴趣的可以了解下... 目录一、项目背景与目标1.1 技术需求1.2 核心技术栈二、核心代码解析2.1 数据处理函数2.2 数据处理流程

MySQL复杂SQL之多表联查/子查询详细介绍(最新整理)

《MySQL复杂SQL之多表联查/子查询详细介绍(最新整理)》掌握多表联查(INNERJOIN,LEFTJOIN,RIGHTJOIN,FULLJOIN)和子查询(标量、列、行、表子查询、相关/非相关、... 目录第一部分:多表联查 (JOIN Operations)1. 连接的类型 (JOIN Types)

Python FastMCP构建MCP服务端与客户端的详细步骤

《PythonFastMCP构建MCP服务端与客户端的详细步骤》MCP(Multi-ClientProtocol)是一种用于构建可扩展服务的通信协议框架,本文将使用FastMCP搭建一个支持St... 目录简介环境准备服务端实现(server.py)客户端实现(client.py)运行效果扩展方向常见问题结

详解如何使用Python构建从数据到文档的自动化工作流

《详解如何使用Python构建从数据到文档的自动化工作流》这篇文章将通过真实工作场景拆解,为大家展示如何用Python构建自动化工作流,让工具代替人力完成这些数字苦力活,感兴趣的小伙伴可以跟随小编一起... 目录一、Excel处理:从数据搬运工到智能分析师二、PDF处理:文档工厂的智能生产线三、邮件自动化:

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

Apache 高级配置实战之从连接保持到日志分析的完整指南

《Apache高级配置实战之从连接保持到日志分析的完整指南》本文带你从连接保持优化开始,一路走到访问控制和日志管理,最后用AWStats来分析网站数据,对Apache配置日志分析相关知识感兴趣的朋友... 目录Apache 高级配置实战:从连接保持到日志分析的完整指南前言 一、Apache 连接保持 - 性

Druid连接池实现自定义数据库密码加解密功能

《Druid连接池实现自定义数据库密码加解密功能》在现代应用开发中,数据安全是至关重要的,本文将介绍如何在​​Druid​​连接池中实现自定义的数据库密码加解密功能,有需要的小伙伴可以参考一下... 目录1. 环境准备2. 密码加密算法的选择3. 自定义 ​​DruidDataSource​​ 的密码解密3

spring-gateway filters添加自定义过滤器实现流程分析(可插拔)

《spring-gatewayfilters添加自定义过滤器实现流程分析(可插拔)》:本文主要介绍spring-gatewayfilters添加自定义过滤器实现流程分析(可插拔),本文通过实例图... 目录需求背景需求拆解设计流程及作用域逻辑处理代码逻辑需求背景公司要求,通过公司网络代理访问的请求需要做请

mysql中的group by高级用法详解

《mysql中的groupby高级用法详解》MySQL中的GROUPBY是数据聚合分析的核心功能,主要用于将结果集按指定列分组,并结合聚合函数进行统计计算,本文给大家介绍mysql中的groupby... 目录一、基本语法与核心功能二、基础用法示例1. 单列分组统计2. 多列组合分组3. 与WHERE结合使