pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

2024-06-09 06:04

本文主要是介绍pytorch 加权CE_loss实现(语义分割中的类不平衡使用),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

加权CE_loss和BCE_loss稍有不同

1.标签为long类型,BCE标签为float类型
2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。
在这里插入图片描述

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction=‘mean’,label_smoothing=0.0)

增加加权的CE_loss代码实现

# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as npclass CrossEntropyLoss2d(nn.Module):def __init__(self, weight=None):super(CrossEntropyLoss2d, self).__init__()self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')def forward(self, preds, targets):return self.nll_loss(preds, targets)

语义分割类别计算

class CE_w_loss(nn.Module):def __init__(self,ignore_index=255):super(CE_w_loss, self).__init__()self.ignore_index = ignore_index# self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)def forward(self, outputs, targets):class_num = outputs.shape[1]# print("class_num :",class_num )# # 计算每个类别在整个 batch 中的像素数占比class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num)  # 假设有class_num个类别class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)# # 根据类别占比计算权重class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double()  # 使用对数变换平衡权重# # print("class_weights :",class_weights)## 定义交叉熵损失函数,并使用动态计算的类别权重criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)# 计算损失loss = criterion(outputs, targets)print(loss.item())  # 打印损失值return lossnp.random.seed(666)pred = np.ones((2, 5, 256,256))seg = np.ones((2, 5, 256, 256)) # 灰度label = np.ones((2, 256, 256))  # 灰度pred = torch.from_numpy(pred)seg = torch.from_numpy(seg).int()  # 灰度label = torch.from_numpy(label).long()ce = CE_w_loss()loss = ce(pred, label)print("loss:",loss.item())

报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确

参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226
[2] https://blog.csdn.net/qq_40306845/article/details/137651442
[3] https://www.zhihu.com/question/400443029/answer/2477658229

这篇关于pytorch 加权CE_loss实现(语义分割中的类不平衡使用)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1044384

相关文章

Python使用Tenacity一行代码实现自动重试详解

《Python使用Tenacity一行代码实现自动重试详解》tenacity是一个专为Python设计的通用重试库,它的核心理念就是用简单、清晰的方式,为任何可能失败的操作添加重试能力,下面我们就来看... 目录一切始于一个简单的 API 调用Tenacity 入门:一行代码实现优雅重试精细控制:让重试按我

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

Redis客户端连接机制的实现方案

《Redis客户端连接机制的实现方案》本文主要介绍了Redis客户端连接机制的实现方案,包括事件驱动模型、非阻塞I/O处理、连接池应用及配置优化,具有一定的参考价值,感兴趣的可以了解一下... 目录1. Redis连接模型概述2. 连接建立过程详解2.1 连php接初始化流程2.2 关键配置参数3. 最大连

Python实现网格交易策略的过程

《Python实现网格交易策略的过程》本文讲解Python网格交易策略,利用ccxt获取加密货币数据及backtrader回测,通过设定网格节点,低买高卖获利,适合震荡行情,下面跟我一起看看我们的第一... 网格交易是一种经典的量化交易策略,其核心思想是在价格上下预设多个“网格”,当价格触发特定网格时执行买

使用Python构建智能BAT文件生成器的完美解决方案

《使用Python构建智能BAT文件生成器的完美解决方案》这篇文章主要为大家详细介绍了如何使用wxPython构建一个智能的BAT文件生成器,它不仅能够为Python脚本生成启动脚本,还提供了完整的文... 目录引言运行效果图项目背景与需求分析核心需求技术选型核心功能实现1. 数据库设计2. 界面布局设计3

使用IDEA部署Docker应用指南分享

《使用IDEA部署Docker应用指南分享》本文介绍了使用IDEA部署Docker应用的四步流程:创建Dockerfile、配置IDEADocker连接、设置运行调试环境、构建运行镜像,并强调需准备本... 目录一、创建 dockerfile 配置文件二、配置 IDEA 的 Docker 连接三、配置 Do

Android Paging 分页加载库使用实践

《AndroidPaging分页加载库使用实践》AndroidPaging库是Jetpack组件的一部分,它提供了一套完整的解决方案来处理大型数据集的分页加载,本文将深入探讨Paging库... 目录前言一、Paging 库概述二、Paging 3 核心组件1. PagingSource2. Pager3.

python设置环境变量路径实现过程

《python设置环境变量路径实现过程》本文介绍设置Python路径的多种方法:临时设置(Windows用`set`,Linux/macOS用`export`)、永久设置(系统属性或shell配置文件... 目录设置python路径的方法临时设置环境变量(适用于当前会话)永久设置环境变量(Windows系统

python使用try函数详解

《python使用try函数详解》Pythontry语句用于异常处理,支持捕获特定/多种异常、else/final子句确保资源释放,结合with语句自动清理,可自定义异常及嵌套结构,灵活应对错误场景... 目录try 函数的基本语法捕获特定异常捕获多个异常使用 else 子句使用 finally 子句捕获所

C++11右值引用与Lambda表达式的使用

《C++11右值引用与Lambda表达式的使用》C++11引入右值引用,实现移动语义提升性能,支持资源转移与完美转发;同时引入Lambda表达式,简化匿名函数定义,通过捕获列表和参数列表灵活处理变量... 目录C++11新特性右值引用和移动语义左值 / 右值常见的左值和右值移动语义移动构造函数移动复制运算符