教你用TensorFlow和自编码器模型生成手写数字(附代码)

2024-04-14 01:38

本文主要是介绍教你用TensorFlow和自编码器模型生成手写数字(附代码),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


来源:机器之心

本文长度为1876字,建议阅读4分钟

本文介绍了如何使用 TensorFlow 实现变分自编码器(VAE)模型,并通过简单的手写数字生成案例一步步引导读者实现这一强大的生成模型。


自编码器是一种能够用来学习对输入数据高效编码的神经网络。若给定一些输入,神经网络首先会使用一系列的变换来将数据映射到低维空间,这部分神经网络就被称为编码器。


然后,网络会使用被编码的低维数据去尝试重建输入,这部分网络称之为解码器。我们可以使用编码器将数据压缩为神经网络可以理解的类型。然而自编码器很少用做这个目的,因为通常存在比它更为有效的手工编写的算法(例如 jpg 压缩)。


此外,自编码器还被经常用来执行降噪任务,它能够学会如何重建原始图像。


什么是变分自编码器?


有很多与自编码器相关的有趣应用。


其中之一被称为变分自编码器(variational autoencoder)。使用变分自编码器不仅可以压缩数据--还能生成自编码器曾经遇到过的新对象。


使用通用自编码器的时候,我们根本不知道网络所生成的编码具体是什么。虽然我们可以对比不同的编码对象,但是要理解它内部编码的方式几乎是不可能的。这也就意味着我们不能使用编码器来生成新的对象。我们甚至连输入应该是什么样子的都不知道。


而我们用相反的方法使用变分自编码器。我们不会尝试着去关注隐含向量所服从的分布,只需要告诉网络我们想让这个分布转换为什么样子就行了。


通常情况,我们会限制网络来生成具有单位正态分布性质的隐含向量。然后,在尝试生成数据的时候,我们只需要从这种分布中进行采样,然后把样本喂给解码器就行,解码器会返回新的对象,看上去就和我们用来训练网络的对象一样。


下面我们将介绍如何使用 Python 和 TensorFlow 实现这一过程,我们要教会我们的网络来画 MNIST 字符。


第一步加载训练数据


首先我们来执行一些基本的导入操作。TensorFlow 具有非常便利的函数来让我们能够很容易地访问 MNIST 数据集。

import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt
%matplotlib inlinefrom tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')

定义输入数据和输出数据


MNIST 图像的维度是 28*28 像素,只有单色通道。我们的输入数据 X_in 是一批一批的 MNIST 字符,网络会学习如何重建它们。然后在一个占位符 Y 中输出它们,输出和输入具有相同的维度。


Y_flat 将会在后面计算损失函数的时候用到,keep_prob 将会在应用 dropout 的时候用到(作为一种正则化的方法)。在训练的过程中,它的值会设为 0.8,当生成新数据的时候,我们不使用 dropout,所以它的值会变成 1。


lrelu 函数需要自及定义,因为 TensorFlow 中并没有预定义一个 Leaky ReLU 函数。

tf.reset_default_graph()batch_size = 64X_in = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='X')
Y    = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='Y')
Y_flat = tf.reshape(Y, shape=[-1, 28 * 28])
keep_prob = tf.placeholder(dtype=tf.float32, shape=(), name='keep_prob')dec_in_channels = 1n_latent = 8reshaped_dim = [-1, 7, 7, dec_in_channels]
inputs_decoder = 49 * dec_in_channels / 2def lrelu(x, alpha=0.3):    return tf.maximum(x, tf.multiply(x, alpha))


定义编码器


因为我们的输入是图像,所以使用一些卷积变换会更加合理。最值得注意的是我们在编码器中创建了两个向量,因为编码器应该创建服从高斯分布的对象。


  • 一个是均值向量

  • 一个是标准差向量



在后面你会看到,我们是如何「强制」编码器来保证它确实生成 了服从正态分布的数据点,我们可以把将会被输入到解码器中的编码值表示为 z。在计算损失函数的时候,我们会需要我们所选分布的均值和标准差。

def encoder(X_in, keep_prob):activation = lrelu    with tf.variable_scope("encoder", reuse=None):X = tf.reshape(X_in, shape=[-1, 28, 28, 1])x = tf.layers.conv2d(X, filters=64, kernel_size=4, strides=2, padding='same', activation=activation)x = tf.nn.dropout(x, keep_prob)x = tf.layers.conv2d(x, filters=64, kernel_size=4, strides=2, padding='same', activation=activation)x = tf.nn.dropout(x, keep_prob)x = tf.layers.conv2d(x, filters=64, kernel_size=4, strides=1, padding='same', activation=activation)x = tf.nn.dropout(x, keep_prob)x = tf.contrib.layers.flatten(x)mn = tf.layers.dense(x, units=n_latent)sd       = 0.5 * tf.layers.dense(x, units=n_latent)            epsilon = tf.random_normal(tf.stack([tf.shape(x)[0], n_latent])) z  = mn + tf.multiply(epsilon, tf.exp(sd))        return z, mn, sd


定义解码器


解码器不会关心输入值是不是从我们定义的某个特定分布中采样得到的。它仅仅会尝试重建输入图像。最后,我们使用了一系列的转置卷积(transpose convolution)。

def decoder(sampled_z, keep_prob):    with tf.variable_scope("decoder", reuse=None):x = tf.layers.dense(sampled_z, units=inputs_decoder, activation=lrelu)x = tf.layers.dense(x, units=inputs_decoder * 2 + 1, activation=lrelu)x = tf.reshape(x, reshaped_dim)x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=2, padding='same', activation=tf.nn.relu)x = tf.nn.dropout(x, keep_prob)x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=1, padding='same', activation=tf.nn.relu)x = tf.nn.dropout(x, keep_prob)x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=1, padding='same', activation=tf.nn.relu)x = tf.contrib.layers.flatten(x)x = tf.layers.dense(x, units=28*28, activation=tf.nn.sigmoid)img = tf.reshape(x, shape=[-1, 28, 28])        return img


现在,我们将两部分连在一起。

sampled, mn, sd = encoder(X_in, keep_prob)
dec = decoder(sampled, keep_prob)

计算损失函数,并实施一个高斯隐藏分布


为了计算图像重构的损失函数,我们简单地使用了平方差(这有时候会使图像变得有些模糊)。这个损失函数还结合了 KL 散度,这确保了我们的隐藏值将会从一个标准分布中采样。关于这个主题,如果想要了解更多,可以看一下这篇文章(https://jaan.io/what-is-variational-autoencoder-vae-tutorial/)

unreshaped = tf.reshape(dec, [-1, 28*28])
img_loss = tf.reduce_sum(tf.squared_difference(unreshaped, Y_flat), 1)
latent_loss = -0.5 * tf.reduce_sum(1.0 + 2.0 * sd - tf.square(mn) - tf.exp(2.0 * sd), 1)
loss = tf.reduce_mean(img_loss + latent_loss)
optimizer = tf.train.AdamOptimizer(0.0005).minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())


训练网络


现在我们终于可以训练我们的 VAE 了!


每隔 200 步,我们会看一下当前的重建是什么样子的。大约在处理了 2000 次迭代后,大多数重建看上去是挺合理的。

for i in range(30000):batch = [np.reshape(b, [28, 28]) for b in mnist.train.next_batch(batch_size=batch_size)[0]]sess.run(optimizer, feed_dict = {X_in: batch, Y: batch, keep_prob: 0.8})        if not i % 200:ls, d, i_ls, d_ls, mu, sigm = sess.run([loss, dec, img_loss, dst_loss, mn, sd], feed_dict = {X_in: batch, Y: batch, keep_prob: 1.0})plt.imshow(np.reshape(batch[0], [28, 28]), cmap='gray')plt.show()plt.imshow(d[0], cmap='gray')plt.show()print(i, ls, np.mean(i_ls), np.mean(d_ls))


生成新数据


最惊人的是我们现在可以生成新的字符了。最后,我们仅仅是从一个单位正态分布里面采集了一个值,输入到解码器。生成的大多数字符都和人类手写的是一样的。

randoms = [np.random.normal(0, 1, n_latent) for _ in range(10)]
imgs = sess.run(dec, feed_dict = {sampled: randoms, keep_prob: 1.0})
imgs = [np.reshape(imgs[i], [28, 28]) for i in range(len(imgs))]for img in imgs:plt.figure(figsize=(1,1))plt.axis('off')plt.imshow(img, cmap='gray')


一些自动生成的字符


总结


这是关于 VAE 应用一个相当简单的例子。但是可以想象一下更多的可能性!神经网络可以学习谱写音乐,它们可以自动地创建对书籍、游戏的描述。借用创新思维,VAE 可以为一些新颖的项目开创空间。


全部 VAE 代码:

https://github.com/FelixMohr/Deep-learning-with-Python/blob/master/VAE.ipynb

原文链接:

https://medium.com/towards-data-science/teaching-a-variational-autoencoder-vae-to-draw-mnist-characters-978675c95776


编辑:王璇

这篇关于教你用TensorFlow和自编码器模型生成手写数字(附代码)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/901773

相关文章

MySQL进行数据库审计的详细步骤和示例代码

《MySQL进行数据库审计的详细步骤和示例代码》数据库审计通过触发器、内置功能及第三方工具记录和监控数据库活动,确保安全、完整与合规,Java代码实现自动化日志记录,整合分析系统提升监控效率,本文给大... 目录一、数据库审计的基本概念二、使用触发器进行数据库审计1. 创建审计表2. 创建触发器三、Java

python生成随机唯一id的几种实现方法

《python生成随机唯一id的几种实现方法》在Python中生成随机唯一ID有多种方法,根据不同的需求场景可以选择最适合的方案,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习... 目录方法 1:使用 UUID 模块(推荐)方法 2:使用 Secrets 模块(安全敏感场景)方法

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

MySQL数据库的内嵌函数和联合查询实例代码

《MySQL数据库的内嵌函数和联合查询实例代码》联合查询是一种将多个查询结果组合在一起的方法,通常使用UNION、UNIONALL、INTERSECT和EXCEPT关键字,下面:本文主要介绍MyS... 目录一.数据库的内嵌函数1.1聚合函数COUNT([DISTINCT] expr)SUM([DISTIN

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

Java实现自定义table宽高的示例代码

《Java实现自定义table宽高的示例代码》在桌面应用、管理系统乃至报表工具中,表格(JTable)作为最常用的数据展示组件,不仅承载对数据的增删改查,还需要配合布局与视觉需求,而JavaSwing... 目录一、项目背景详细介绍二、项目需求详细介绍三、相关技术详细介绍四、实现思路详细介绍五、完整实现代码

Go语言代码格式化的技巧分享

《Go语言代码格式化的技巧分享》在Go语言的开发过程中,代码格式化是一个看似细微却至关重要的环节,良好的代码格式化不仅能提升代码的可读性,还能促进团队协作,减少因代码风格差异引发的问题,Go在代码格式... 目录一、Go 语言代码格式化的重要性二、Go 语言代码格式化工具:gofmt 与 go fmt(一)

HTML5实现的移动端购物车自动结算功能示例代码

《HTML5实现的移动端购物车自动结算功能示例代码》本文介绍HTML5实现移动端购物车自动结算,通过WebStorage、事件监听、DOM操作等技术,确保实时更新与数据同步,优化性能及无障碍性,提升用... 目录1. 移动端购物车自动结算概述2. 数据存储与状态保存机制2.1 浏览器端的数据存储方式2.1.

基于 HTML5 Canvas 实现图片旋转与下载功能(完整代码展示)

《基于HTML5Canvas实现图片旋转与下载功能(完整代码展示)》本文将深入剖析一段基于HTML5Canvas的代码,该代码实现了图片的旋转(90度和180度)以及旋转后图片的下载... 目录一、引言二、html 结构分析三、css 样式分析四、JavaScript 功能实现一、引言在 Web 开发中,