超分中的GAN总结:常用的判别器类型和GAN loss类型

2024-08-25 08:28

本文主要是介绍超分中的GAN总结:常用的判别器类型和GAN loss类型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 概述

在真实数据超分任务上,从SRGAN开始,Loss函数基本是Pixel loss + GAN loss + Perceptual loss的组合。

与生成任务不同,对于超分这种复原任务,如果只使用Gan loss或者GAN loss的权重比较大的话,效果就比较差。

SRGAN成功的两个关键点:1. 引入了感知损失函数(Perceptual Loss),它是让生成图像产生细节的关键,而不是对抗损失函数。2. 将对抗损失函数的权重调小,让它不能影响训练的方向,只会微调生成图像的清晰度,消除感知损失函数带来的噪声。参见底层视觉之美。

在实践中,一般gan loss的权重设置为Pixel loss的千分之一。

2. 超分中的判别器

判别器一般来说有三种:

  • 分类网络 vgg,resnet等
    最后一层输出输出一个数字,代表整张图的判别结果
  • Patch gan
    最后一层不再输出一个数字,而是输出1xnxn的特征图,其中的每一个数字代表了原图中一个patch的判别结果;最后的loss通过对这nxn个点求均值得到;
  • U-Net discriminator with spectral normalization (SN).
    在Real ESRGAN中提出的,因为unet的输入分辨率和输出分辨率一致,相当于unet判别器对每个像素进行了判别,最后的loss求均值得到;引入spectral normalization 是为了稳定训练,同时可以消除一些artifacts;

3. 超分中的几种 Gan loss

3.1 Vanilla GAN

最原始的gan loss,判别器做的是二分类任务,判别器的最后输出经过sigmoid后计算交叉熵;一般用
self.loss = nn.BCEWithLogitsLoss()实现,其相当于sigmoid + 交叉熵;

3.2 LSGAN (最小平方gan)

不去算sigmoid和交叉熵,而是直接算判别器预测输出与真实标签值的MSE;一般用self.loss = nn.MSELoss()

3.3 WGAN loss

WGAN是对原始的GAN的改进,优化了其会发生梯度消失训练不稳定的问题,原始的GAN最小化生成器loss等价于最小化真实分布P_r与生成分布P_g之间的JS散度 → WGAN最小化真实分布P_r与生成分布P_g之间的Wasserstein距离;
具体来说,WGAN去掉了sigmoid, 同时也不再计算交叉熵,而是直接返回D(x)的均值。因为一般来说,都是最小化loss,对于真实样本直接输出-input.mean();对于生成样本,如果是优化生成器的时候,wgan loss为-input.mean(),如果是优化判别器,则输出input.mean();代码如下;

def wgan_loss(input, target):# target is booleanreturn -1 * input.mean() if target else input.mean()

3.4 RAGAN (相对Gan)

衡量的是真实数据比生成数据真实的概率,也就是说原始的GAN是将判别器的输出直接计算loss,而RAGAN会先计算真实样本的判别器输出和生成样本的判别器输出,做差值后再进行loss计算;比如生成器loss如下:
D_loss = self.D_lossfn_weight * (
self.D_lossfn(pred_d_real - torch.mean(pred_g_fake, 0, True), False) +
self.D_lossfn(pred_g_fake - torch.mean(pred_d_real, 0, True), True)) / 2

3.5 代码(来自KAIR)

  • Loss函数定义代码
    class GANLoss(nn.Module):def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):super(GANLoss, self).__init__()self.gan_type = gan_type.lower()self.real_label_val = real_label_valself.fake_label_val = fake_label_val# 原始gan和ragan都是二分类if self.gan_type == 'gan' or self.gan_type == 'ragan':self.loss = nn.BCEWithLogitsLoss()elif self.gan_type == 'lsgan':self.loss = nn.MSELoss()elif self.gan_type == 'wgan':def wgan_loss(input, target):# target is booleanreturn -1 * input.mean() if target else input.mean()self.loss = wgan_losselif self.gan_type == 'softplusgan':def softplusgan_loss(input, target):# target is booleanreturn F.softplus(-input).mean() if target else F.softplus(input).mean()self.loss = softplusgan_losselse:raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))def get_target_label(self, input, target_is_real):if self.gan_type in ['wgan', 'softplusgan']:return target_is_real# 返回标签,如果target_is_real为true,则返回全1的标签;如果为false则返回全0的标签if target_is_real:return torch.empty_like(input).fill_(self.real_label_val)else:return torch.empty_like(input).fill_(self.fake_label_val)def forward(self, input, target_is_real):target_label = self.get_target_label(input, target_is_real)loss = self.loss(input, target_label)return loss```
    
  • 判别器Loss计算代码
    if self.opt_train['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']:# realpred_d_real = self.netD(self.H)                # 1) real datal_d_real = self.D_lossfn(pred_d_real, True)l_d_real.backward()# fakepred_d_fake = self.netD(self.E.detach().clone()) # 2) fake data, detach to avoid BP to Gl_d_fake = self.D_lossfn(pred_d_fake, False)l_d_fake.backward()
    elif self.opt_train['gan_type'] == 'ragan':# realpred_d_fake = self.netD(self.E).detach()       # 1) fake data, detach to avoid BP to Gpred_d_real = self.netD(self.H)                # 2) real datal_d_real = 0.5 * self.D_lossfn(pred_d_real - torch.mean(pred_d_fake, 0, True), True)l_d_real.backward()# fakepred_d_fake = self.netD(self.E.detach())l_d_fake = 0.5 * self.D_lossfn(pred_d_fake - torch.mean(pred_d_real.detach(), 0, True), False)l_d_fake.backward()```
    
  • 生成器loss计算代码
    if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']:pred_g_fake = self.netD(self.E)D_loss = self.D_lossfn_weight * self.D_lossfn(pred_g_fake, True)
    elif self.opt['train']['gan_type'] == 'ragan':pred_d_real = self.netD(self.H).detach()pred_g_fake = self.netD(self.E)# 相对判别器D_loss = self.D_lossfn_weight * (self.D_lossfn(pred_d_real - torch.mean(pred_g_fake, 0, True), False) +self.D_lossfn(pred_g_fake - torch.mean(pred_d_real, 0, True), True)) / 2
    

这篇关于超分中的GAN总结:常用的判别器类型和GAN loss类型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1105037

相关文章

MyBatis常用XML语法详解

《MyBatis常用XML语法详解》文章介绍了MyBatis常用XML语法,包括结果映射、查询语句、插入语句、更新语句、删除语句、动态SQL标签以及ehcache.xml文件的使用,感兴趣的朋友跟随小... 目录1、定义结果映射2、查询语句3、插入语句4、更新语句5、删除语句6、动态 SQL 标签7、ehc

Python版本与package版本兼容性检查方法总结

《Python版本与package版本兼容性检查方法总结》:本文主要介绍Python版本与package版本兼容性检查方法的相关资料,文中提供四种检查方法,分别是pip查询、conda管理、PyP... 目录引言为什么会出现兼容性问题方法一:用 pip 官方命令查询可用版本方法二:conda 管理包环境方法

pycharm跑python项目易出错的问题总结

《pycharm跑python项目易出错的问题总结》:本文主要介绍pycharm跑python项目易出错问题的相关资料,当你在PyCharm中运行Python程序时遇到报错,可以按照以下步骤进行排... 1. 一定不要在pycharm终端里面创建环境安装别人的项目子模块等,有可能出现的问题就是你不报错都安装

Python打包成exe常用的四种方法小结

《Python打包成exe常用的四种方法小结》本文主要介绍了Python打包成exe常用的四种方法,包括PyInstaller、cx_Freeze、Py2exe、Nuitka,文中通过示例代码介绍的非... 目录一.PyInstaller11.安装:2. PyInstaller常用参数下面是pyinstal

Python 常用数据类型详解之字符串、列表、字典操作方法

《Python常用数据类型详解之字符串、列表、字典操作方法》在Python中,字符串、列表和字典是最常用的数据类型,它们在数据处理、程序设计和算法实现中扮演着重要角色,接下来通过本文给大家介绍这三种... 目录一、字符串(String)(一)创建字符串(二)字符串操作1. 字符串连接2. 字符串重复3. 字

python语言中的常用容器(集合)示例详解

《python语言中的常用容器(集合)示例详解》Python集合是一种无序且不重复的数据容器,它可以存储任意类型的对象,包括数字、字符串、元组等,下面:本文主要介绍python语言中常用容器(集合... 目录1.核心内置容器1. 列表2. 元组3. 集合4. 冻结集合5. 字典2.collections模块

Python中logging模块用法示例总结

《Python中logging模块用法示例总结》在Python中logging模块是一个强大的日志记录工具,它允许用户将程序运行期间产生的日志信息输出到控制台或者写入到文件中,:本文主要介绍Pyt... 目录前言一. 基本使用1. 五种日志等级2.  设置报告等级3. 自定义格式4. C语言风格的格式化方法

JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法

《JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法》:本文主要介绍JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法,每种方法结合实例代码给大家介绍的非常... 目录引言:为什么"相等"判断如此重要?方法1:使用some()+includes()(适合小数组)方法2

SpringBoot 获取请求参数的常用注解及用法

《SpringBoot获取请求参数的常用注解及用法》SpringBoot通过@RequestParam、@PathVariable等注解支持从HTTP请求中获取参数,涵盖查询、路径、请求体、头、C... 目录SpringBoot 提供了多种注解来方便地从 HTTP 请求中获取参数以下是主要的注解及其用法:1

Spring 依赖注入与循环依赖总结

《Spring依赖注入与循环依赖总结》这篇文章给大家介绍Spring依赖注入与循环依赖总结篇,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. Spring 三级缓存解决循环依赖1. 创建UserService原始对象2. 将原始对象包装成工