tensorflow2中自定义损失、传递loss函数字典/compile(optimizer=Adam(lr = lr), loss= lambda y_true, y_pred: y_pred)理解

本文主要是介绍tensorflow2中自定义损失、传递loss函数字典/compile(optimizer=Adam(lr = lr), loss= lambda y_true, y_pred: y_pred)理解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在阅读yolov3代码的时候有下面这样一样代码:
model.compile(optimizer=Adam(lr = lr), loss={'yolo_loss': lambda y_true, y_pred: y_pred}),这行代码在网上有人进行解释过,但是都是看的云里雾里,一般使用compile的时候我们都是直接传递的一个函数对象,这里竟然传递的是一个字典,对此很是不解。


经过大量的饿查阅别人写的博客:最后在这篇博客中得到了答案的启发:链接,这篇文章 写的很好,大家可以去看看。


我在上面文章的基础上,会尽量使用简单的语言来描述这个函数的作用,并给出一个例子帮助大家进行理解。


因为这里是在compile模型,因此,要理解其原委,我们还需要到其模型中去看起所以然,进入模型定义中,我们会发现有下面这样一个loss的层定义:

    model_loss  = Lambda(get_yolo_loss(input_shape, len(model_body.output), num_classes), output_shape    = (1, ), name            = 'yolo_loss',)([*model_body.output, *y_true])

而且我们会发现,这里面给该Lambda层起了一个名字:yolo_loss,是的。你没有看错,就是和前面compile里面的loss的键值一样,这是巧合吗?然而当我将这个name进行修改成其他名字的时候,发现无法进行训练,因此,我们可以确定,这个name就是在comple中进行引用的键值。间接性的将,上面的loss引用的是这里的这个Lambda层。但是否是这样呢?我们在上面的那篇博客中可以得到答案,的确是这样

为了进一步的验证该猜想,我们自定义一个简单的层,然后将最后一层当做Loss层进行处理,及最后一层的输出是一个数,这个数既代表预测的结果,也用来表示函数的损失。

在这里我们定义一个简单的LSTM层来进行说明:

from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Embedding,LSTM,Dense
import tensorflow as tfword_size = 128
nb_features = 10
nb_classes = 10
encode_size = 64
margin = 0.1embedding = Embedding(nb_features,word_size) # 对单词进行编码
lstm_encoder = LSTM(encode_size) # LSTM层进行定义def encode(input): # 定义一个函数,进行层的传播return lstm_encoder(embedding(input))q_input = Input(shape=(100,)) # 定义一个输入
q_encoded = Dense(encode_size)(q_encoded)  # 将LSTM层的输出放入全连接层进行整合loss = Lambda(lambda x: K.relu(0.001+x[0][:,1:2]+100),name="test_loss")([q_encoded]) # 随便写了一个算法 让第一个数据*0.001+100作为输出,然后让Dense层的输入通过该Lambda层,这一层也是最后一层,模型的整体组成请看下面model_train = Model(inputs=[q_input], outputs=loss) # 定义模型model_train.compile(optimizer='adam', loss={'test_loss':lambda y_true,y_pred: y_pred})# 对模型进行编译,这里也是本篇文章的重点,loss={'test_loss':lambda #y_true,y_pred: y_pred} 表示loss函数引用的是test_loss这个层,后面的两个#参数是tensorflow2中对loss进行重定义的标准输入,在这里表示直接输出预测#值。这样锁可能不太好理解,我们还可以将上面的compile换成下面这个形式:#model_train.compile(optimizer='adam', loss=lambda y_true,y_pred: y_pred)#这样是不是很好理解了呢?loss和之前的传递自定义函数是不是很向呢?想想在我们传递自定义loss函数的时候是怎么传递的,直接将一个函数对象赋给loss,是的,#这里的Lambda就是一个匿名对象,至于后面的参数这是标准的tensorflow自定义#loss必须要传递的链各个值: y_true,y_pred,不好理解的地方在于,这样不是直#接返回的y_predect嘛,是的,在Lambda函数中,我们要求函数直接返回预测值,#也就是这里的函数输出,这这个输出就是最后一层的输出,因此,通过这样定义,#我们即将最后一层当做输出,也将最后一层当做`loss`损失进行优化。t1 = tf.range(10) # 随便定义一个数据进行预测
y = tf.range(10) #  宿便定义一个输出,因为这里我们后面要进行优化,因此这个值随便定义。这里定义y只是为了瞒住fit的时候需要一个y值而已model_train.fit([t1], y, epochs=10) # 进行训练p = model_train.predict([5]) # 预测5这个数的lossprint(p) # 打印p的值

模型的摘要:
在这里插入图片描述

训练的输出:
在这里插入图片描述
可以看到这里训练10步之后输出也即loss为99.57左右,那么可以猜想我们的预测下一个值的输出也应该在99.57左右,因为我们的输出即做预测值使用,也做Loss使用,那到底是不是这样呢?
预测输出:
在这里插入图片描述
可以看到,这和我们的猜想是一样的,也验证了我们上面的说法。

这篇关于tensorflow2中自定义损失、传递loss函数字典/compile(optimizer=Adam(lr = lr), loss= lambda y_true, y_pred: y_pred)理解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/533795

相关文章

Spring Security自定义身份认证的实现方法

《SpringSecurity自定义身份认证的实现方法》:本文主要介绍SpringSecurity自定义身份认证的实现方法,下面对SpringSecurity的这三种自定义身份认证进行详细讲解,... 目录1.内存身份认证(1)创建配置类(2)验证内存身份认证2.JDBC身份认证(1)数据准备 (2)配置依

Python的time模块一些常用功能(各种与时间相关的函数)

《Python的time模块一些常用功能(各种与时间相关的函数)》Python的time模块提供了各种与时间相关的函数,包括获取当前时间、处理时间间隔、执行时间测量等,:本文主要介绍Python的... 目录1. 获取当前时间2. 时间格式化3. 延时执行4. 时间戳运算5. 计算代码执行时间6. 转换为指

Spring 请求之传递 JSON 数据的操作方法

《Spring请求之传递JSON数据的操作方法》JSON就是一种数据格式,有自己的格式和语法,使用文本表示一个对象或数组的信息,因此JSON本质是字符串,主要负责在不同的语言中数据传递和交换,这... 目录jsON 概念JSON 语法JSON 的语法JSON 的两种结构JSON 字符串和 Java 对象互转

Python正则表达式语法及re模块中的常用函数详解

《Python正则表达式语法及re模块中的常用函数详解》这篇文章主要给大家介绍了关于Python正则表达式语法及re模块中常用函数的相关资料,正则表达式是一种强大的字符串处理工具,可以用于匹配、切分、... 目录概念、作用和步骤语法re模块中的常用函数总结 概念、作用和步骤概念: 本身也是一个字符串,其中

Java中的Lambda表达式及其应用小结

《Java中的Lambda表达式及其应用小结》Java中的Lambda表达式是一项极具创新性的特性,它使得Java代码更加简洁和高效,尤其是在集合操作和并行处理方面,:本文主要介绍Java中的La... 目录前言1. 什么是Lambda表达式?2. Lambda表达式的基本语法例子1:最简单的Lambda表

shell编程之函数与数组的使用详解

《shell编程之函数与数组的使用详解》:本文主要介绍shell编程之函数与数组的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录shell函数函数的用法俩个数求和系统资源监控并报警函数函数变量的作用范围函数的参数递归函数shell数组获取数组的长度读取某下的

深入理解Apache Kafka(分布式流处理平台)

《深入理解ApacheKafka(分布式流处理平台)》ApacheKafka作为现代分布式系统中的核心中间件,为构建高吞吐量、低延迟的数据管道提供了强大支持,本文将深入探讨Kafka的核心概念、架构... 目录引言一、Apache Kafka概述1.1 什么是Kafka?1.2 Kafka的核心概念二、Ka

MySQL高级查询之JOIN、子查询、窗口函数实际案例

《MySQL高级查询之JOIN、子查询、窗口函数实际案例》:本文主要介绍MySQL高级查询之JOIN、子查询、窗口函数实际案例的相关资料,JOIN用于多表关联查询,子查询用于数据筛选和过滤,窗口函... 目录前言1. JOIN(连接查询)1.1 内连接(INNER JOIN)1.2 左连接(LEFT JOI

MySQL中FIND_IN_SET函数与INSTR函数用法解析

《MySQL中FIND_IN_SET函数与INSTR函数用法解析》:本文主要介绍MySQL中FIND_IN_SET函数与INSTR函数用法解析,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友一... 目录一、功能定义与语法1、FIND_IN_SET函数2、INSTR函数二、本质区别对比三、实际场景案例分

C++ Sort函数使用场景分析

《C++Sort函数使用场景分析》sort函数是algorithm库下的一个函数,sort函数是不稳定的,即大小相同的元素在排序后相对顺序可能发生改变,如果某些场景需要保持相同元素间的相对顺序,可使... 目录C++ Sort函数详解一、sort函数调用的两种方式二、sort函数使用场景三、sort函数排序