训练生成手写体数字 对抗神经网络

2024-01-02 01:44

本文主要是介绍训练生成手写体数字 对抗神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

下面是一个使用TensorFlow和Keras的生成对抗网络(GAN)的基本示例,用于生成手写体数字。这个示例基于MNIST数据集。

 

我没有包括所有可能的最佳实践,如模型保存、加载、超参数调整、日志记录等。

首先,确保你安装了所需的库,特别是TensorFlow:

pip install tensorflow

接下来是GAN的代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0 * 2 - 1  # 将像素值缩放到[-1, 1]# GAN参数
img_rows, img_cols, channels = 28, 28, 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100# 生成器
def build_generator():model = Sequential()model.add(Dense(256, input_dim=latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(img_shape), activation='tanh'))model.add(Reshape(img_shape))return model# 判别器
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))model.add(Dense(1, activation='sigmoid'))return model# 编译判别器
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])# 编译生成器
generator = build_generator()# 生成器输入噪声并生成图像
z = tf.keras.Input(shape=(latent_dim,))
img = generator(z)# 对于组合模型,我们只训练生成器
discriminator.trainable = False# 判别器尝试区分真实和生成的图像
valid = discriminator(img)# 组合模型(叠加生成器和判别器)
combined = tf.keras.Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
def train(epochs, batch_size=128, save_interval=50):half_batch = batch_size // 2for epoch in range(epochs):# 随机选择一半的图像idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成一半的新图像noise = np.random.normal(0, 1, (half_batch, latent_dim))gen_imgs = generator.predict(noise)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))valid_y = np.array([1] * batch_size)g_loss = combined.train_on_batch(noise, valid_y)# 打印进度print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")# 如果到了保存间隔,则保存生成的图像样本if epoch % save_interval == 0:save_imgs(epoch)# 保存图像
def save_imgs(epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, latent_dim))gen_imgs = generator.predict(noise)# 缩放图片从[-1, 1]到[0, 1]gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1fig.savefig(f"mnist_{epoch}.png")plt.close()# 训练GAN
train(epochs=10000, batch_size=32, save_interval=1000)

在这个代码中,首先定义了生成器和判别器的架构,然后将它们结合起来形成一个GAN网络。

train函数负责训练过程,它交替地训练判别器和生成器。save_imgs函数用于保存生成的图像,以便我们可以查看GAN在训练过程中的进步。

这是一个非常基础的GAN实现,对于实际应用,你可能需要进行很多调整和优化,包括更复杂的模型架构、更细致的训练过程控制、超参数调整等。

逐行注释的代码,解释了每一步的作用:

# 导入所需的库
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()# 数据预处理,将像素值归一化到[-1, 1]的范围
X_train = X_train / 127.5 - 1.0
# 增加一个维度,使图像有单通道,即形状为(batch_size, 28, 28, 1)
X_train = np.expand_dims(X_train, axis=-1)# 定义生成器模型
def build_generator():model = Sequential()model.add(Dense(256, input_shape=(100,)))  # 输入层,输入维度为100(噪声向量)model.add(LeakyReLU(alpha=0.2))  # 使用LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(1024))  # 第三层,1024个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))  # 输出层,输出与图像像素数相同的单元数model.add(Reshape((28, 28, 1)))  # 将输出重塑为28x28图像return model# 定义判别器模型
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=(28, 28, 1)))  # 输入层,将28x28图像展平model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(256))  # 第三层,256个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(1, activation='sigmoid'))  # 输出层,一个单元输出0到1之间的值return model# 编译判别器和生成器
discriminator = build_discriminator()
# 使用二元交叉熵作为损失函数,Adam优化器,以及准确度评估
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
generator = build_generator()# GAN模型组合
z = tf.keras.Input(shape=(100,))  # 输入层,100维噪声向量
img = generator(z)  # 生成器生成图像
discriminator.trainable = False  # 在训练生成器时冻结判别器的权重
valid = discriminator(img)  # 判别器对生成的图像进行评估
combined = tf.keras.Model(z, valid)  # 组合模型,输入是噪声,输出是判别器的评估结果
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
epochs = 10000  # 训练轮数
batch_size = 32  # 批量大小
save_interval = 1000  # 保存图片的间隔
noise_dim = 100  # 噪声向量的维度
half_batch = batch_size // 2  # 半批量大小
valid = np.ones((half_batch, 1))  # 真实图片标签
fake = np.zeros((half_batch, 1))  # 伪造图片标签for epoch in range(epochs):# 随机选择真实图片idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成噪声noise = np.random.normal(0, 1, (half_batch, noise_dim))# 使用噪声生成伪造图片gen_imgs = generator(noise, training=False)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, valid)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 生成更多噪声noise = np.random.normal(0, 1, (batch_size, noise_dim))# 训练生成器g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))# 如果达到保存间隔,打印损失并保存生成的图片if epoch % save_interval == 0:print("Epoch {}/{} [D loss: {:.4f}, acc.: {:.2f}%] [G loss: {:.4f}]".format(epoch, epochs, d_loss[0], 100 * d_loss[1], g_loss))save_imgs(generator, epoch, noise_dim)# 定义函数以保存生成的手写数字图像
def save_imgs(generator, epoch, noise_dim):r, c = 5, 5  # 生成5x5网格的图片noise = np.random.normal(0, 1, (r * c, noise_dim))  # 生成噪声gen_imgs = generator(noise, training=False)  # 使用噪声生成图片gen_imgs = 0.5 * gen_imgs + 0.5  # 将图片的像素值从[-1, 1]缩放到[0, 1]fig, axs = plt.subplots(r, c)  # 创建子图cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')  # 显示生成的图片axs[i, j].axis('off')  # 关闭坐标轴cnt += 1fig.savefig("mnist_%d.png" % epoch)  # 保存生成的图片plt.close()  # 关闭图形显示窗口# 选择性地保存生成器模型
generator.save('mnist_generator.h5')

这样的注释有助于理解代码的每一步,特别是对于初学者来说,可以更好地理解GAN的工作原理和实现细节。

版权所有 © 2023 王一帆。除非另有说明,本作品采用[知识共享 署名-非衍生作品 4.0 国际许可协议](https://creativecommons.org/licenses/by-nd/4.0/)进行许可。

这篇关于训练生成手写体数字 对抗神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/561028

相关文章

k8s admin用户生成token方式

《k8sadmin用户生成token方式》用户使用Kubernetes1.28创建admin命名空间并部署,通过ClusterRoleBinding为jenkins用户授权集群级权限,生成并获取其t... 目录k8s admin用户生成token创建一个admin的命名空间查看k8s namespace 的

Vue3 如何通过json配置生成查询表单

《Vue3如何通过json配置生成查询表单》本文给大家介绍Vue3如何通过json配置生成查询表单,本文结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录功能实现背景项目代码案例功能实现背景通过vue3实现后台管理项目一定含有表格功能,通常离不开表单

Java使用Javassist动态生成HelloWorld类

《Java使用Javassist动态生成HelloWorld类》Javassist是一个非常强大的字节码操作和定义库,它允许开发者在运行时创建新的类或者修改现有的类,本文将简单介绍如何使用Javass... 目录1. Javassist简介2. 环境准备3. 动态生成HelloWorld类3.1 创建CtC

Python从Word文档中提取图片并生成PPT的操作代码

《Python从Word文档中提取图片并生成PPT的操作代码》在日常办公场景中,我们经常需要从Word文档中提取图片,并将这些图片整理到PowerPoint幻灯片中,手动完成这一任务既耗时又容易出错,... 目录引言背景与需求解决方案概述代码解析代码核心逻辑说明总结引言在日常办公场景中,我们经常需要从 W

C#使用Spire.XLS快速生成多表格Excel文件

《C#使用Spire.XLS快速生成多表格Excel文件》在日常开发中,我们经常需要将业务数据导出为结构清晰的Excel文件,本文将手把手教你使用Spire.XLS这个强大的.NET组件,只需几行C#... 目录一、Spire.XLS核心优势清单1.1 性能碾压:从3秒到0.5秒的质变1.2 批量操作的优雅

Python使用python-pptx自动化操作和生成PPT

《Python使用python-pptx自动化操作和生成PPT》这篇文章主要为大家详细介绍了如何使用python-pptx库实现PPT自动化,并提供实用的代码示例和应用场景,感兴趣的小伙伴可以跟随小编... 目录使用python-pptx操作PPT文档安装python-pptx基础概念创建新的PPT文档查看

在ASP.NET项目中如何使用C#生成二维码

《在ASP.NET项目中如何使用C#生成二维码》二维码(QRCode)已广泛应用于网址分享,支付链接等场景,本文将以ASP.NET为示例,演示如何实现输入文本/URL,生成二维码,在线显示与下载的完整... 目录创建前端页面(Index.cshtml)后端二维码生成逻辑(Index.cshtml.cs)总结

Python实现数据可视化图表生成(适合新手入门)

《Python实现数据可视化图表生成(适合新手入门)》在数据科学和数据分析的新时代,高效、直观的数据可视化工具显得尤为重要,下面:本文主要介绍Python实现数据可视化图表生成的相关资料,文中通过... 目录前言为什么需要数据可视化准备工作基本图表绘制折线图柱状图散点图使用Seaborn创建高级图表箱线图热

基于Python实现数字限制在指定范围内的五种方式

《基于Python实现数字限制在指定范围内的五种方式》在编程中,数字范围限制是常见需求,无论是游戏开发中的角色属性值、金融计算中的利率调整,还是传感器数据处理中的异常值过滤,都需要将数字控制在合理范围... 目录引言一、基础条件判断法二、数学运算巧解法三、装饰器模式法四、自定义类封装法五、NumPy数组处理

SQLServer中生成雪花ID(Snowflake ID)的实现方法

《SQLServer中生成雪花ID(SnowflakeID)的实现方法》:本文主要介绍在SQLServer中生成雪花ID(SnowflakeID)的实现方法,文中通过示例代码介绍的非常详细,... 目录前言认识雪花ID雪花ID的核心特点雪花ID的结构(64位)雪花ID的优势雪花ID的局限性雪花ID的应用场景