tensorflow2中自定义损失、传递loss函数字典/compile(optimizer=Adam(lr = lr), loss= lambda y_true, y_pred: y_pred)理解

本文主要是介绍tensorflow2中自定义损失、传递loss函数字典/compile(optimizer=Adam(lr = lr), loss= lambda y_true, y_pred: y_pred)理解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在阅读yolov3代码的时候有下面这样一样代码:
model.compile(optimizer=Adam(lr = lr), loss={'yolo_loss': lambda y_true, y_pred: y_pred}),这行代码在网上有人进行解释过,但是都是看的云里雾里,一般使用compile的时候我们都是直接传递的一个函数对象,这里竟然传递的是一个字典,对此很是不解。


经过大量的饿查阅别人写的博客:最后在这篇博客中得到了答案的启发:链接,这篇文章 写的很好,大家可以去看看。


我在上面文章的基础上,会尽量使用简单的语言来描述这个函数的作用,并给出一个例子帮助大家进行理解。


因为这里是在compile模型,因此,要理解其原委,我们还需要到其模型中去看起所以然,进入模型定义中,我们会发现有下面这样一个loss的层定义:

    model_loss  = Lambda(get_yolo_loss(input_shape, len(model_body.output), num_classes), output_shape    = (1, ), name            = 'yolo_loss',)([*model_body.output, *y_true])

而且我们会发现,这里面给该Lambda层起了一个名字:yolo_loss,是的。你没有看错,就是和前面compile里面的loss的键值一样,这是巧合吗?然而当我将这个name进行修改成其他名字的时候,发现无法进行训练,因此,我们可以确定,这个name就是在comple中进行引用的键值。间接性的将,上面的loss引用的是这里的这个Lambda层。但是否是这样呢?我们在上面的那篇博客中可以得到答案,的确是这样

为了进一步的验证该猜想,我们自定义一个简单的层,然后将最后一层当做Loss层进行处理,及最后一层的输出是一个数,这个数既代表预测的结果,也用来表示函数的损失。

在这里我们定义一个简单的LSTM层来进行说明:

from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Embedding,LSTM,Dense
import tensorflow as tfword_size = 128
nb_features = 10
nb_classes = 10
encode_size = 64
margin = 0.1embedding = Embedding(nb_features,word_size) # 对单词进行编码
lstm_encoder = LSTM(encode_size) # LSTM层进行定义def encode(input): # 定义一个函数,进行层的传播return lstm_encoder(embedding(input))q_input = Input(shape=(100,)) # 定义一个输入
q_encoded = Dense(encode_size)(q_encoded)  # 将LSTM层的输出放入全连接层进行整合loss = Lambda(lambda x: K.relu(0.001+x[0][:,1:2]+100),name="test_loss")([q_encoded]) # 随便写了一个算法 让第一个数据*0.001+100作为输出,然后让Dense层的输入通过该Lambda层,这一层也是最后一层,模型的整体组成请看下面model_train = Model(inputs=[q_input], outputs=loss) # 定义模型model_train.compile(optimizer='adam', loss={'test_loss':lambda y_true,y_pred: y_pred})# 对模型进行编译,这里也是本篇文章的重点,loss={'test_loss':lambda #y_true,y_pred: y_pred} 表示loss函数引用的是test_loss这个层,后面的两个#参数是tensorflow2中对loss进行重定义的标准输入,在这里表示直接输出预测#值。这样锁可能不太好理解,我们还可以将上面的compile换成下面这个形式:#model_train.compile(optimizer='adam', loss=lambda y_true,y_pred: y_pred)#这样是不是很好理解了呢?loss和之前的传递自定义函数是不是很向呢?想想在我们传递自定义loss函数的时候是怎么传递的,直接将一个函数对象赋给loss,是的,#这里的Lambda就是一个匿名对象,至于后面的参数这是标准的tensorflow自定义#loss必须要传递的链各个值: y_true,y_pred,不好理解的地方在于,这样不是直#接返回的y_predect嘛,是的,在Lambda函数中,我们要求函数直接返回预测值,#也就是这里的函数输出,这这个输出就是最后一层的输出,因此,通过这样定义,#我们即将最后一层当做输出,也将最后一层当做`loss`损失进行优化。t1 = tf.range(10) # 随便定义一个数据进行预测
y = tf.range(10) #  宿便定义一个输出,因为这里我们后面要进行优化,因此这个值随便定义。这里定义y只是为了瞒住fit的时候需要一个y值而已model_train.fit([t1], y, epochs=10) # 进行训练p = model_train.predict([5]) # 预测5这个数的lossprint(p) # 打印p的值

模型的摘要:
在这里插入图片描述

训练的输出:
在这里插入图片描述
可以看到这里训练10步之后输出也即loss为99.57左右,那么可以猜想我们的预测下一个值的输出也应该在99.57左右,因为我们的输出即做预测值使用,也做Loss使用,那到底是不是这样呢?
预测输出:
在这里插入图片描述
可以看到,这和我们的猜想是一样的,也验证了我们上面的说法。

这篇关于tensorflow2中自定义损失、传递loss函数字典/compile(optimizer=Adam(lr = lr), loss= lambda y_true, y_pred: y_pred)理解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/533795

相关文章

MySQL常用字符串函数示例和场景介绍

《MySQL常用字符串函数示例和场景介绍》MySQL提供了丰富的字符串函数帮助我们高效地对字符串进行处理、转换和分析,本文我将全面且深入地介绍MySQL常用的字符串函数,并结合具体示例和场景,帮你熟练... 目录一、字符串函数概述1.1 字符串函数的作用1.2 字符串函数分类二、字符串长度与统计函数2.1

python使用try函数详解

《python使用try函数详解》Pythontry语句用于异常处理,支持捕获特定/多种异常、else/final子句确保资源释放,结合with语句自动清理,可自定义异常及嵌套结构,灵活应对错误场景... 目录try 函数的基本语法捕获特定异常捕获多个异常使用 else 子句使用 finally 子句捕获所

C++11右值引用与Lambda表达式的使用

《C++11右值引用与Lambda表达式的使用》C++11引入右值引用,实现移动语义提升性能,支持资源转移与完美转发;同时引入Lambda表达式,简化匿名函数定义,通过捕获列表和参数列表灵活处理变量... 目录C++11新特性右值引用和移动语义左值 / 右值常见的左值和右值移动语义移动构造函数移动复制运算符

springboot自定义注解RateLimiter限流注解技术文档详解

《springboot自定义注解RateLimiter限流注解技术文档详解》文章介绍了限流技术的概念、作用及实现方式,通过SpringAOP拦截方法、缓存存储计数器,结合注解、枚举、异常类等核心组件,... 目录什么是限流系统架构核心组件详解1. 限流注解 (@RateLimiter)2. 限流类型枚举 (

Java Spring的依赖注入理解及@Autowired用法示例详解

《JavaSpring的依赖注入理解及@Autowired用法示例详解》文章介绍了Spring依赖注入(DI)的概念、三种实现方式(构造器、Setter、字段注入),区分了@Autowired(注入... 目录一、什么是依赖注入(DI)?1. 定义2. 举个例子二、依赖注入的几种方式1. 构造器注入(Con

SpringBoot 异常处理/自定义格式校验的问题实例详解

《SpringBoot异常处理/自定义格式校验的问题实例详解》文章探讨SpringBoot中自定义注解校验问题,区分参数级与类级约束触发的异常类型,建议通过@RestControllerAdvice... 目录1. 问题简要描述2. 异常触发1) 参数级别约束2) 类级别约束3. 异常处理1) 字段级别约束

postgresql使用UUID函数的方法

《postgresql使用UUID函数的方法》本文给大家介绍postgresql使用UUID函数的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录PostgreSQL有两种生成uuid的方法。可以先通过sql查看是否已安装扩展函数,和可以安装的扩展函数

MySQL字符串常用函数详解

《MySQL字符串常用函数详解》本文给大家介绍MySQL字符串常用函数,本文结合实例代码给大家介绍的非常详细,对大家学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录mysql字符串常用函数一、获取二、大小写转换三、拼接四、截取五、比较、反转、替换六、去空白、填充MySQL字符串常用函数一、

Python 字典 (Dictionary)使用详解

《Python字典(Dictionary)使用详解》字典是python中最重要,最常用的数据结构之一,它提供了高效的键值对存储和查找能力,:本文主要介绍Python字典(Dictionary)... 目录字典1.基本特性2.创建字典3.访问元素4.修改字典5.删除元素6.字典遍历7.字典的高级特性默认字典

C++中assign函数的使用

《C++中assign函数的使用》在C++标准模板库中,std::list等容器都提供了assign成员函数,它比操作符更灵活,支持多种初始化方式,下面就来介绍一下assign的用法,具有一定的参考价... 目录​1.assign的基本功能​​语法​2. 具体用法示例​​​(1) 填充n个相同值​​(2)