训练生成手写体数字 对抗神经网络

2024-01-02 01:44

本文主要是介绍训练生成手写体数字 对抗神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

下面是一个使用TensorFlow和Keras的生成对抗网络(GAN)的基本示例,用于生成手写体数字。这个示例基于MNIST数据集。

 

我没有包括所有可能的最佳实践,如模型保存、加载、超参数调整、日志记录等。

首先,确保你安装了所需的库,特别是TensorFlow:

pip install tensorflow

接下来是GAN的代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0 * 2 - 1  # 将像素值缩放到[-1, 1]# GAN参数
img_rows, img_cols, channels = 28, 28, 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100# 生成器
def build_generator():model = Sequential()model.add(Dense(256, input_dim=latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(img_shape), activation='tanh'))model.add(Reshape(img_shape))return model# 判别器
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))model.add(Dense(1, activation='sigmoid'))return model# 编译判别器
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])# 编译生成器
generator = build_generator()# 生成器输入噪声并生成图像
z = tf.keras.Input(shape=(latent_dim,))
img = generator(z)# 对于组合模型,我们只训练生成器
discriminator.trainable = False# 判别器尝试区分真实和生成的图像
valid = discriminator(img)# 组合模型(叠加生成器和判别器)
combined = tf.keras.Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
def train(epochs, batch_size=128, save_interval=50):half_batch = batch_size // 2for epoch in range(epochs):# 随机选择一半的图像idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成一半的新图像noise = np.random.normal(0, 1, (half_batch, latent_dim))gen_imgs = generator.predict(noise)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))valid_y = np.array([1] * batch_size)g_loss = combined.train_on_batch(noise, valid_y)# 打印进度print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")# 如果到了保存间隔,则保存生成的图像样本if epoch % save_interval == 0:save_imgs(epoch)# 保存图像
def save_imgs(epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, latent_dim))gen_imgs = generator.predict(noise)# 缩放图片从[-1, 1]到[0, 1]gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1fig.savefig(f"mnist_{epoch}.png")plt.close()# 训练GAN
train(epochs=10000, batch_size=32, save_interval=1000)

在这个代码中,首先定义了生成器和判别器的架构,然后将它们结合起来形成一个GAN网络。

train函数负责训练过程,它交替地训练判别器和生成器。save_imgs函数用于保存生成的图像,以便我们可以查看GAN在训练过程中的进步。

这是一个非常基础的GAN实现,对于实际应用,你可能需要进行很多调整和优化,包括更复杂的模型架构、更细致的训练过程控制、超参数调整等。

逐行注释的代码,解释了每一步的作用:

# 导入所需的库
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()# 数据预处理,将像素值归一化到[-1, 1]的范围
X_train = X_train / 127.5 - 1.0
# 增加一个维度,使图像有单通道,即形状为(batch_size, 28, 28, 1)
X_train = np.expand_dims(X_train, axis=-1)# 定义生成器模型
def build_generator():model = Sequential()model.add(Dense(256, input_shape=(100,)))  # 输入层,输入维度为100(噪声向量)model.add(LeakyReLU(alpha=0.2))  # 使用LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(1024))  # 第三层,1024个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))  # 输出层,输出与图像像素数相同的单元数model.add(Reshape((28, 28, 1)))  # 将输出重塑为28x28图像return model# 定义判别器模型
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=(28, 28, 1)))  # 输入层,将28x28图像展平model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(256))  # 第三层,256个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(1, activation='sigmoid'))  # 输出层,一个单元输出0到1之间的值return model# 编译判别器和生成器
discriminator = build_discriminator()
# 使用二元交叉熵作为损失函数,Adam优化器,以及准确度评估
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
generator = build_generator()# GAN模型组合
z = tf.keras.Input(shape=(100,))  # 输入层,100维噪声向量
img = generator(z)  # 生成器生成图像
discriminator.trainable = False  # 在训练生成器时冻结判别器的权重
valid = discriminator(img)  # 判别器对生成的图像进行评估
combined = tf.keras.Model(z, valid)  # 组合模型,输入是噪声,输出是判别器的评估结果
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
epochs = 10000  # 训练轮数
batch_size = 32  # 批量大小
save_interval = 1000  # 保存图片的间隔
noise_dim = 100  # 噪声向量的维度
half_batch = batch_size // 2  # 半批量大小
valid = np.ones((half_batch, 1))  # 真实图片标签
fake = np.zeros((half_batch, 1))  # 伪造图片标签for epoch in range(epochs):# 随机选择真实图片idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成噪声noise = np.random.normal(0, 1, (half_batch, noise_dim))# 使用噪声生成伪造图片gen_imgs = generator(noise, training=False)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, valid)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 生成更多噪声noise = np.random.normal(0, 1, (batch_size, noise_dim))# 训练生成器g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))# 如果达到保存间隔,打印损失并保存生成的图片if epoch % save_interval == 0:print("Epoch {}/{} [D loss: {:.4f}, acc.: {:.2f}%] [G loss: {:.4f}]".format(epoch, epochs, d_loss[0], 100 * d_loss[1], g_loss))save_imgs(generator, epoch, noise_dim)# 定义函数以保存生成的手写数字图像
def save_imgs(generator, epoch, noise_dim):r, c = 5, 5  # 生成5x5网格的图片noise = np.random.normal(0, 1, (r * c, noise_dim))  # 生成噪声gen_imgs = generator(noise, training=False)  # 使用噪声生成图片gen_imgs = 0.5 * gen_imgs + 0.5  # 将图片的像素值从[-1, 1]缩放到[0, 1]fig, axs = plt.subplots(r, c)  # 创建子图cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')  # 显示生成的图片axs[i, j].axis('off')  # 关闭坐标轴cnt += 1fig.savefig("mnist_%d.png" % epoch)  # 保存生成的图片plt.close()  # 关闭图形显示窗口# 选择性地保存生成器模型
generator.save('mnist_generator.h5')

这样的注释有助于理解代码的每一步,特别是对于初学者来说,可以更好地理解GAN的工作原理和实现细节。

版权所有 © 2023 王一帆。除非另有说明,本作品采用[知识共享 署名-非衍生作品 4.0 国际许可协议](https://creativecommons.org/licenses/by-nd/4.0/)进行许可。

这篇关于训练生成手写体数字 对抗神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/561028

相关文章

在ASP.NET项目中如何使用C#生成二维码

《在ASP.NET项目中如何使用C#生成二维码》二维码(QRCode)已广泛应用于网址分享,支付链接等场景,本文将以ASP.NET为示例,演示如何实现输入文本/URL,生成二维码,在线显示与下载的完整... 目录创建前端页面(Index.cshtml)后端二维码生成逻辑(Index.cshtml.cs)总结

Python实现数据可视化图表生成(适合新手入门)

《Python实现数据可视化图表生成(适合新手入门)》在数据科学和数据分析的新时代,高效、直观的数据可视化工具显得尤为重要,下面:本文主要介绍Python实现数据可视化图表生成的相关资料,文中通过... 目录前言为什么需要数据可视化准备工作基本图表绘制折线图柱状图散点图使用Seaborn创建高级图表箱线图热

基于Python实现数字限制在指定范围内的五种方式

《基于Python实现数字限制在指定范围内的五种方式》在编程中,数字范围限制是常见需求,无论是游戏开发中的角色属性值、金融计算中的利率调整,还是传感器数据处理中的异常值过滤,都需要将数字控制在合理范围... 目录引言一、基础条件判断法二、数学运算巧解法三、装饰器模式法四、自定义类封装法五、NumPy数组处理

SQLServer中生成雪花ID(Snowflake ID)的实现方法

《SQLServer中生成雪花ID(SnowflakeID)的实现方法》:本文主要介绍在SQLServer中生成雪花ID(SnowflakeID)的实现方法,文中通过示例代码介绍的非常详细,... 目录前言认识雪花ID雪花ID的核心特点雪花ID的结构(64位)雪花ID的优势雪花ID的局限性雪花ID的应用场景

Django HTTPResponse响应体中返回openpyxl生成的文件过程

《DjangoHTTPResponse响应体中返回openpyxl生成的文件过程》Django返回文件流时需通过Content-Disposition头指定编码后的文件名,使用openpyxl的sa... 目录Django返回文件流时使用指定文件名Django HTTPResponse响应体中返回openp

python生成随机唯一id的几种实现方法

《python生成随机唯一id的几种实现方法》在Python中生成随机唯一ID有多种方法,根据不同的需求场景可以选择最适合的方案,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习... 目录方法 1:使用 UUID 模块(推荐)方法 2:使用 Secrets 模块(安全敏感场景)方法

Python实现自动化Word文档样式复制与内容生成

《Python实现自动化Word文档样式复制与内容生成》在办公自动化领域,高效处理Word文档的样式和内容复制是一个常见需求,本文将展示如何利用Python的python-docx库实现... 目录一、为什么需要自动化 Word 文档处理二、核心功能实现:样式与表格的深度复制1. 表格复制(含样式与内容)2

python如何生成指定文件大小

《python如何生成指定文件大小》:本文主要介绍python如何生成指定文件大小的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录python生成指定文件大小方法一(速度最快)方法二(中等速度)方法三(生成可读文本文件–较慢)方法四(使用内存映射高效生成

Maven项目中集成数据库文档生成工具的操作步骤

《Maven项目中集成数据库文档生成工具的操作步骤》在Maven项目中,可以通过集成数据库文档生成工具来自动生成数据库文档,本文为大家整理了使用screw-maven-plugin(推荐)的完... 目录1. 添加插件配置到 pom.XML2. 配置数据库信息3. 执行生成命令4. 高级配置选项5. 注意事

MybatisX快速生成增删改查的方法示例

《MybatisX快速生成增删改查的方法示例》MybatisX是基于IDEA的MyBatis/MyBatis-Plus开发插件,本文主要介绍了MybatisX快速生成增删改查的方法示例,文中通过示例代... 目录1 安装2 基本功能2.1 XML跳转2.2 代码生成2.2.1 生成.xml中的sql语句头2