Self-Attention Generative Adversarial Networks解读+部分代码

本文主要是介绍Self-Attention Generative Adversarial Networks解读+部分代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

     Self-Attention Generative Adversarial Networks解读+部分代码
 

引言

这篇是文章是Ian goodfellow他们的新工作,在GAN中引入Attention。
在文章的摘要中作者主要突出了三点。
Self-Attention Generative Adversarial Network(SAGAN)是一个注意力驱动,长范围 关联模型(attention-driven, long-range dependency modeling )。
传统的GAN在生成高分辨率的细节时,是基于低分辨率的feature map中的某一个小部分的。而SAGAN是基于所有的特征点(all feature locations).
在训练时使用了光谱归一化(spectral normalization )来提升训练强度(training dynamics)。

SAGAN的优势

  • 可以很好的处理长范围、多层次的依赖(可以很好的发现图像中的依赖关系)
  • 生成图像时每一个位置的细节和远端的细节协调好
  • 判别器还可以更准确地对全局图像结构实施复杂的几何约束

因为文章提到了long range 所以这里的远端,个人的理解是前几层卷积的output。

SAGAN

作者提到,大多数的GAN都使用了卷积,但是在处理long range依赖时,卷积的效率很低,所以他们采用了non-local model

x 被送入两个特征空间f,g去计算attention。

Bij 表示在生成第j个区域时,是否关注第i个位置。



上面是每个可学习矩阵的纬度,都是用1X1卷积实现的。



在文章的所有实验中都用到了上面这个超参。

之后再带权相加,得到融合了attention的feature map


γ的值初始化为0,这是因为在最开始,只需要依赖于局部信息,之后在慢慢增大权重加入non-local evidence.
在训练过程中还使用了光谱归一化(spectral normalization)和two-timescale update rule(TTUR)来稳定训练。

部分代码

attention 具体实现

    def attention(self, x, ch, sn=False, scope='attention', reuse=False):with tf.variable_scope(scope, reuse=reuse):f = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c']g = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c']h = conv(x, ch, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c]# N = h * ws = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]beta = tf.nn.softmax(s, axis=-1)  # attention mapo = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]x = gamma * o + xreturn x

生成器

    def generator(self, z, is_training=True, reuse=False):with tf.variable_scope("generator", reuse=reuse):ch = 1024x = deconv(z, channels=ch, kernel=4, stride=1, padding='VALID', use_bias=False, sn=self.sn, scope='deconv')x = batch_norm(x, is_training, scope='batch_norm')x = relu(x)for i in range(self.layer_num // 2):if self.up_sample:x = up_sample(x, scale_factor=2)x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm_' + str(i))x = relu(x)else:x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm_' + str(i))x = relu(x)ch = ch // 2# Self Attentionx = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)for i in range(self.layer_num // 2, self.layer_num):if self.up_sample:x = up_sample(x, scale_factor=2)x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm_' + str(i))x = relu(x)else:x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm_' + str(i))x = relu(x)ch = ch // 2if self.up_sample:x = up_sample(x, scale_factor=2)x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, sn=self.sn, scope='G_conv_logit')x = tanh(x)else:x = deconv(x, channels=self.c_dim, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='G_deconv_logit')x = tanh(x)return x

判别器

    def discriminator(self, x, is_training=True, reuse=False):with tf.variable_scope("discriminator", reuse=reuse):ch = 64x = conv(x, channels=ch, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv')x = lrelu(x, 0.2)for i in range(self.layer_num // 2):x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm' + str(i))x = lrelu(x, 0.2)ch = ch * 2# Self Attentionx = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)for i in range(self.layer_num // 2, self.layer_num):x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm' + str(i))x = lrelu(x, 0.2)ch = ch * 2x = conv(x, channels=4, stride=1, sn=self.sn, use_bias=False, scope='D_logit')return x

更多细节请参考SAGAN


上面贴的代码是 tensorflow版的没有用spectral normalization。
这个pytorch版使用了spectral normalization。
spectral normalization的具体实现可以看这里

 

【转载】:https://www.jianshu.com/p/0540fb554088

                   https://github.com/heykeetae/Self-Attention-GAN

                   https://github.com/taki0112/Self-Attention-GAN-Tensorflow

这篇关于Self-Attention Generative Adversarial Networks解读+部分代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1123753

相关文章

Python实例题之pygame开发打飞机游戏实例代码

《Python实例题之pygame开发打飞机游戏实例代码》对于python的学习者,能够写出一个飞机大战的程序代码,是不是感觉到非常的开心,:本文主要介绍Python实例题之pygame开发打飞机... 目录题目pygame-aircraft-game使用 Pygame 开发的打飞机游戏脚本代码解释初始化部

Java中Map.Entry()含义及方法使用代码

《Java中Map.Entry()含义及方法使用代码》:本文主要介绍Java中Map.Entry()含义及方法使用的相关资料,Map.Entry是Java中Map的静态内部接口,用于表示键值对,其... 目录前言 Map.Entry作用核心方法常见使用场景1. 遍历 Map 的所有键值对2. 直接修改 Ma

深入解析 Java Future 类及代码示例

《深入解析JavaFuture类及代码示例》JavaFuture是java.util.concurrent包中用于表示异步计算结果的核心接口,下面给大家介绍JavaFuture类及实例代码,感兴... 目录一、Future 类概述二、核心工作机制代码示例执行流程2. 状态机模型3. 核心方法解析行为总结:三

Nacos注册中心和配置中心的底层原理全面解读

《Nacos注册中心和配置中心的底层原理全面解读》:本文主要介绍Nacos注册中心和配置中心的底层原理的全面解读,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录临时实例和永久实例为什么 Nacos 要将服务实例分为临时实例和永久实例?1.x 版本和2.x版本的区别

python获取cmd环境变量值的实现代码

《python获取cmd环境变量值的实现代码》:本文主要介绍在Python中获取命令行(cmd)环境变量的值,可以使用标准库中的os模块,需要的朋友可以参考下... 前言全局说明在执行py过程中,总要使用到系统环境变量一、说明1.1 环境:Windows 11 家庭版 24H2 26100.4061

pandas实现数据concat拼接的示例代码

《pandas实现数据concat拼接的示例代码》pandas.concat用于合并DataFrame或Series,本文主要介绍了pandas实现数据concat拼接的示例代码,具有一定的参考价值,... 目录语法示例:使用pandas.concat合并数据默认的concat:参数axis=0,join=

C#代码实现解析WTGPS和BD数据

《C#代码实现解析WTGPS和BD数据》在现代的导航与定位应用中,准确解析GPS和北斗(BD)等卫星定位数据至关重要,本文将使用C#语言实现解析WTGPS和BD数据,需要的可以了解下... 目录一、代码结构概览1. 核心解析方法2. 位置信息解析3. 经纬度转换方法4. 日期和时间戳解析5. 辅助方法二、L

Python使用Code2flow将代码转化为流程图的操作教程

《Python使用Code2flow将代码转化为流程图的操作教程》Code2flow是一款开源工具,能够将代码自动转换为流程图,该工具对于代码审查、调试和理解大型代码库非常有用,在这篇博客中,我们将深... 目录引言1nVflRA、为什么选择 Code2flow?2、安装 Code2flow3、基本功能演示

C++类和对象之默认成员函数的使用解读

《C++类和对象之默认成员函数的使用解读》:本文主要介绍C++类和对象之默认成员函数的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、默认成员函数有哪些二、各默认成员函数详解默认构造函数析构函数拷贝构造函数拷贝赋值运算符三、默认成员函数的注意事项总结一

IIS 7.0 及更高版本中的 FTP 状态代码

《IIS7.0及更高版本中的FTP状态代码》本文介绍IIS7.0中的FTP状态代码,方便大家在使用iis中发现ftp的问题... 简介尝试使用 FTP 访问运行 Internet Information Services (IIS) 7.0 或更高版本的服务器上的内容时,IIS 将返回指示响应状态的数字代