使用变分编解码器实现自动图像生成

2024-04-30 22:08

本文主要是介绍使用变分编解码器实现自动图像生成,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

深度学习不仅仅在擅长于从现有数据中发现规律,而且它能主动运用规律创造出现实世界没有的实例来。例如给网络输入大量的人脸图片,让它识别人脸特征,然后我们可以指导网络创建出现实世界中不存在的人脸图像,把深度学习应用在创造性生成上是当前AI领域非常热门的应用。

从本节开始,我们将接触神经网络在图像生成方面的应用。有两种专门构建的网络在图像生成上能实现良好效果,一种网络叫变分编解码器,另一种叫生成型对抗性网络。这两种网络不仅仅能有与图片生成,还能用于音乐,声音,以及文本生成,但是在图像生成的效果上表现最好,因此接下来我们看看如何构建相应网络实现生成功能。

图像生成的关键思想是,使用网络构造一个向量空间,空间中每一个向量都可以映射成一张真实图片。在网络中有一个模块,读入该向量后,能够经过一系列运算把向量转换成一张图片所对应的二维向量,这个模块在编解码器网络里称为解码器。

编解码器网络的运行流畅如下:

屏幕快照 2019-02-15 下午5.44.51.png

首先我们把大量图片输入到网络中,网络识别图片并抽取图片中蕴含的规律,它把这些规律进行编码,以向量的形式存储,向量的长度越大,它就能存储越多的图片信息。接着网络的解码器模块解读编码向量,由于向量存储的是所有图片共同展现的人脸特征,而不是某个具体人的人脸特征,因此解码器解读编码向量后,就能根据向量蕴含的人脸特征进行绘图,最终构造出原来训练图片里没有的人脸图案,但这个人脸图案的特征与训练图片里面的人脸特征有相关性。

其实我们在前面章节已经接触过特征向量。在前面讲解单词向量时,所谓的单词向量就是一种特征向量。向量空间中,某个方向,也就是向量里面的某些分量可能记录了训练数据的某一方面的特征,对应人脸图片来说,向量可能有一部分分量用来记录笑容特征,某些分量可能记录了眼睛特征,某些分量可能记录了头发特征,所有这些特征综合起来就可能形成一张人脸。由于解码器能够识别向量中不同分量代表的信息,因此它把向量拆分解读之后,再按照向量分量表达的信息来绘制像素点,最终就可以完成一张人脸图片的绘制。

编解码器网络发明与2013和2014年,它能够把高维数据所展现的特征编码成低维向量,然后再把低维向量转换为原来数据所表示的高维向量。但这种还原并非原封不动的还原,而是把低维向量编码的信息展现出来。编解码网络有点像压缩和解压,把解码器模块把输入数据转变成另一种数据量较小的数据格式,而解码器再把该数据格式还原成输入数据,然而编解码器网络可不是简单的进行数据压缩和解压。

屏幕快照 2019-02-16 下午4.55.15.png

如上图,编解码器网络本质上是在学习输入图片像素点的统计信息,知道了像素点在统计上的分布规律后,它再按照相应的分布规律产生像素点,于是产生的图片与输入图片很像,但因为是根据统计规律随机产生的,因此生成的图片会产生某些变异。当我们把大量图片输入网络进行学习时,网络的编码器统计图片像素点变化的均值和方差,以及变化特征,这些特征编码成中间向量格式,然后解码器读取该向量,用随机方法把还原图片像素点的变化规律。

接下来我们看看代码的实现:

from keras.models import Model
from keras import layers
import numpy as np
import keras#输入图片为28*28的灰度图
img_shape = (28, 28, 1)
batch_size = 16
#将输入图片编码为只含有2个分量的向量
latent_dim = 2input_img = keras.Input(shape = img_shape)
#设计编码器部分
x = layers.Conv2D(32, 3, padding = 'same', activation = 'relu')(input_img)
x = layers.Conv2D(64, 3, padding = 'same', activation = 'relu', strides = (2,2))(x)
x = layers.Conv2D(64, 3, padding = 'same', activation = 'relu')(x)
x = layers.Conv2D(64, 3, padding = 'same', activation = 'relu')(x)shape = K.int_shape(x)
#把x压扁成一维向量
x = layers.Flatten()(x)
x = layers.Dense(32, activation = 'relu')(x)
#统计输入图片像素点统计规律上的均值
z_mean = layers.Dense(latent_dim)(x)
#统计输入图片像素点统计规律上的方差
z_log_var = layers.Dense(latent_dim)(x)'''
均值和方差决定了像素点的变化规律,在统计上发现,大量的事物在数据上的变化都遵守正太分布,一旦掌握
了其数值变化的方差和均值,我们就掌握了它变化的规律。在实现解码器时,我们也认为输入图片的像素点
同样符合正太分布,下面函数根据上面得到的均值和方差构造正太分布,然后从这个分布中进行抽样构成要
还原图片的像素点
'''
def  sampling(args):z_mean, z_log_var = args#构造一个随机值,然后使用它到给定正太分布中生成一个结果,这类似于丢一个骰子然后看点数epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),mean = 0, stddev = 1.)return  z_mean + K.exp(z_log_var) * epsilonz = layers.Lambda(sampling)([z_mean, z_log_var])

上面是编码器的实现,从这里我们看到,深度学习其本质并没有什么神奇的魔力,它本质是对大量的输入数据进行数理统计,由此就能掌握事物的变化规律。我们再看看解码器的实现:

#解码过程是对编码过程的逆运算
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape[1:]), activation = 'relu')(decoder_input)
#恢复为向量压扁前的格式
x = layers.Reshape(shape[1:])(x)
#对编码器的卷积运输进行逆操作
x = layers.Conv2DTranspose(32, 3, padding = 'same', activation = 'relu',strides = (2, 2))(x)
x = layers.Conv2D(1, 3, padding = 'same', activation = 'sigmoid')(x)
decoder = Model(decoder_input, x)z_decoded = decoder(z)

接下来我们设置网络的损失函数:

'''
我们定义网络的损失框架没有提供,因此我们自己动手写
'''
class  CustomVariationLayer(keras.layers.Layer):def  vae_loss(self, x, z_decoded):x = K.flatten(x)z_decoded = K.flatten(z_decoded)#计算生成二维数组与输入图片二维数组对应元素的差方和xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)#计算网络生成像素点统计分布与输入图片像素点变化分布的差异k1_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis = -1)return K.mean(xent_loss + k1_loss)def  call(self, inputs):x = inputs[0]z_decoded = inputs[1]loss = self.vae_loss(x, z_decoded)self.add_loss(loss, inputs = inputs)return xy = CustomVariationLayer()([input_img, z_decoded])

损失函数要计算两部分,一部分是网络解码得到的二维数组与输入图片二维数组对应元素的差方和,第二是网络构造的二维数组,其元素变化规律与输入图片元素变化规律在统计上的差异,也就是我们希望网络生成的二维数组,其元素变化在统计上的均值与方差和输入图片像素点在统计上的均值和方差要尽可能的小。这里涉及到数理统计方面的知识,不了解可以直接忽略掉。

最后我们看看网络的训练过程:

from keras.datasets import mnistvae = Model(input_img, y)
vae.compile(optimizer = 'rmsprop', loss = None)
vae.summary()(x_train, _), (x_test, y_test) = mnist.load_data()x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))vae.fit(x = x_train, y = None,shuffle = True,epochs = 10,batch_size = batch_size,validation_data = (x_test, None))

训练后,我们看看网络对输入图片的还原效果:

import matplotlib.pyp![
](https://upload-images.jianshu.io/upload_images/2849961-5d9959da5b4d37e8.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)
lot as plt
from scipy.stats import norm#一次呈现15*15个数字
n = 15
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))for i, yi in enumerate(grid_x):for j, xi in enumerate(grid_y):z_sample = np.array([[xi, yi]])z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)x_decoded = decoder.predict(z_sample, batch_size = batch_size)digit = x_decoded[0].reshape(digit_size, digit_size)figure[i * digit_size: (i+1) * digit_size,j * digit_size: (j+1) * digit_size] = digitplt.figure(figsize = (10, 10))
plt.imshow(figure, cmap = 'Greys_r')
plt.show()

上面代码运行后可以看到,网络学习了数字图片中像素点的分布规律后,按照规律构造还原会相应的数字图片,还原的图片与输入图片大致相同,但在细节上有些许差异:
1.png

本节的内容比较抽象,不好理解。因为它用到了很多数学知识,没有深厚的数学功底你很难掌握本节内容,这也是现在程序员很难转行到人工智能,特别是深度学习领域的根本原因,因为他们具备的是工程思维,而人工智能要求你具备深厚的数学基础以及科学研究思维,如果你理解不了本节内容不要紧,只要把代码敲一遍,看看结果,具有一个感性认识也就可以了。

更多内容,请点击进入csdn学院

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

这篇关于使用变分编解码器实现自动图像生成的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/950047

相关文章

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

redis中使用lua脚本的原理与基本使用详解

《redis中使用lua脚本的原理与基本使用详解》在Redis中使用Lua脚本可以实现原子性操作、减少网络开销以及提高执行效率,下面小编就来和大家详细介绍一下在redis中使用lua脚本的原理... 目录Redis 执行 Lua 脚本的原理基本使用方法使用EVAL命令执行 Lua 脚本使用EVALSHA命令

Python中pywin32 常用窗口操作的实现

《Python中pywin32常用窗口操作的实现》本文主要介绍了Python中pywin32常用窗口操作的实现,pywin32主要的作用是供Python开发者快速调用WindowsAPI的一个... 目录获取窗口句柄获取最前端窗口句柄获取指定坐标处的窗口根据窗口的完整标题匹配获取句柄根据窗口的类别匹配获取句

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

在 Spring Boot 中实现异常处理最佳实践

《在SpringBoot中实现异常处理最佳实践》本文介绍如何在SpringBoot中实现异常处理,涵盖核心概念、实现方法、与先前查询的集成、性能分析、常见问题和最佳实践,感兴趣的朋友一起看看吧... 目录一、Spring Boot 异常处理的背景与核心概念1.1 为什么需要异常处理?1.2 Spring B

Python位移操作和位运算的实现示例

《Python位移操作和位运算的实现示例》本文主要介绍了Python位移操作和位运算的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 位移操作1.1 左移操作 (<<)1.2 右移操作 (>>)注意事项:2. 位运算2.1

如何在 Spring Boot 中实现 FreeMarker 模板

《如何在SpringBoot中实现FreeMarker模板》FreeMarker是一种功能强大、轻量级的模板引擎,用于在Java应用中生成动态文本输出(如HTML、XML、邮件内容等),本文... 目录什么是 FreeMarker 模板?在 Spring Boot 中实现 FreeMarker 模板1. 环

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

使用Python和Pyecharts创建交互式地图

《使用Python和Pyecharts创建交互式地图》在数据可视化领域,创建交互式地图是一种强大的方式,可以使受众能够以引人入胜且信息丰富的方式探索地理数据,下面我们看看如何使用Python和Pyec... 目录简介Pyecharts 简介创建上海地图代码说明运行结果总结简介在数据可视化领域,创建交互式地

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll