简单易上手的生成对抗网络

2024-09-02 06:12

本文主要是介绍简单易上手的生成对抗网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

模型原理

生成对抗网络是指一类采用对抗训练方式进行学习的深度生成模型,包含的判别网络生成网络都可以根据不同的生成任务使用不同的网络结构。

生成器: 通过机器生成数据,最终目的是骗过判别器。
判别器: 判断这张图像是真实的还是机器生成的,目的是找出生成器做的假数据。

构建GAN模型的基本逻辑: 现实问题需求→建立实现功能的GAN框架(编程)→训练GAN(生成网络、对抗网络)→成熟的GAN模型→应用。

GAN训练过程:
生成器生成假数据,然后将生成的假数据和真数据都输入判别器,判别器要判断出哪些是真的哪些是假的。判别器第一次判别出来的肯定有很大的误差,然后我们根据误差来优化判别器。现在判别器水平提高了,生成器生成的数据很难再骗过判别器了,所以我们得反过来优化生成器,之后生成器水平提高了,然后反过来继续训练判别器,判别器水平又提高了,再反过来训练生成器,就这样循环往复,直到达到纳什均衡。

GAN的发展历程

  1. GAN的基本思想起源于2014年,由伊恩·古德费洛等人首次提出。
  2. DCGAN,它在生成器和判别器中都使用了卷积层,取得了更好的图像生成效果。
  3. ConditionalGAN,通过引入条件信息指导生成器生成特定类型的数据。 Wasserstein
  4. WGAN使用Wasserstein距离作为损失函数,为GAN的训练提供了更稳定的优化方法,提高了生成样本的质量。

代码实现

DCGAN模型:

generator = Sequential()
generator.add(Dense(7 * 7 * 128, input_shape=[100]))
generator.add(Reshape([7, 7, 128]))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",activation="relu"))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",activation="tanh"))discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, padding="same",activation=LeakyReLU(0.3),input_shape=[28, 28, 1]))
discriminator.add(Dropout(0.5))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same",activation=LeakyReLU(0.3)))
discriminator.add(Dropout(0.5))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation="sigmoid"))

模型训练:

GAN =Sequential([generator,discriminator])
discriminator.compile(optimizer='adam',loss='binary_crossentropy')
discriminator.trainable = FalseGAN.compile(optimizer='adam',loss='binary_crossentropy')epochs = 150 
batch_size = 100
noise_shape=100with tf.device('/gpu:0'):for epoch in range(epochs):print(f"Currently on Epoch {epoch+1}")for i in range(X_train.shape[0]//batch_size):if (i+1)%50 == 0:print(f"\tCurrently on batch number {i+1} of {X_train.shape[0]//batch_size}")noise=np.random.normal(size=[batch_size,noise_shape])gen_image = generator.predict_on_batch(noise)train_dataset = X_train[i*batch_size:(i+1)*batch_size]train_label=np.ones(shape=(batch_size,1))discriminator.trainable = Trued_loss_real=discriminator.train_on_batch(train_dataset,train_label)train_label=np.zeros(shape=(batch_size,1))d_loss_fake=discriminator.train_on_batch(gen_image,train_label)noise=np.random.normal(size=[batch_size,noise_shape])train_label=np.ones(shape=(batch_size,1))discriminator.trainable = False #while training the generator as combined model,discriminator training should be turned offd_g_loss_batch =GAN.train_on_batch(noise, train_label)if epoch % 10 == 0:samples = 10x_fake = generator.predict(np.random.normal(loc=0, scale=1, size=(samples, 100)))for k in range(samples):plt.subplot(2, 5, k+1)plt.imshow(x_fake[k].reshape(28, 28), cmap='gray')plt.xticks([])plt.yticks([])plt.tight_layout()plt.show()print('Training is complete')

使用np.random.normal生成的噪声被作为输入给发生器:

noise=np.random.normal(loc=0, scale=1, size=(100,noise_shape))
gen_image = generator.predict(noise)
plt.imshow(noise)
plt.title('DCGAN Noise')

这篇关于简单易上手的生成对抗网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1129214

相关文章

Java使用Javassist动态生成HelloWorld类

《Java使用Javassist动态生成HelloWorld类》Javassist是一个非常强大的字节码操作和定义库,它允许开发者在运行时创建新的类或者修改现有的类,本文将简单介绍如何使用Javass... 目录1. Javassist简介2. 环境准备3. 动态生成HelloWorld类3.1 创建CtC

Debian 13升级后网络转发等功能异常怎么办? 并非错误而是管理机制变更

《Debian13升级后网络转发等功能异常怎么办?并非错误而是管理机制变更》很多朋友反馈,更新到Debian13后网络转发等功能异常,这并非BUG而是Debian13Trixie调整... 日前 Debian 13 Trixie 发布后已经有众多网友升级到新版本,只不过升级后发现某些功能存在异常,例如网络转

Python 基于http.server模块实现简单http服务的代码举例

《Python基于http.server模块实现简单http服务的代码举例》Pythonhttp.server模块通过继承BaseHTTPRequestHandler处理HTTP请求,使用Threa... 目录测试环境代码实现相关介绍模块简介类及相关函数简介参考链接测试环境win11专业版python

Python从Word文档中提取图片并生成PPT的操作代码

《Python从Word文档中提取图片并生成PPT的操作代码》在日常办公场景中,我们经常需要从Word文档中提取图片,并将这些图片整理到PowerPoint幻灯片中,手动完成这一任务既耗时又容易出错,... 目录引言背景与需求解决方案概述代码解析代码核心逻辑说明总结引言在日常办公场景中,我们经常需要从 W

C#使用Spire.XLS快速生成多表格Excel文件

《C#使用Spire.XLS快速生成多表格Excel文件》在日常开发中,我们经常需要将业务数据导出为结构清晰的Excel文件,本文将手把手教你使用Spire.XLS这个强大的.NET组件,只需几行C#... 目录一、Spire.XLS核心优势清单1.1 性能碾压:从3秒到0.5秒的质变1.2 批量操作的优雅

Python使用python-pptx自动化操作和生成PPT

《Python使用python-pptx自动化操作和生成PPT》这篇文章主要为大家详细介绍了如何使用python-pptx库实现PPT自动化,并提供实用的代码示例和应用场景,感兴趣的小伙伴可以跟随小编... 目录使用python-pptx操作PPT文档安装python-pptx基础概念创建新的PPT文档查看

python连接sqlite3简单用法完整例子

《python连接sqlite3简单用法完整例子》SQLite3是一个内置的Python模块,可以通过Python的标准库轻松地使用,无需进行额外安装和配置,:本文主要介绍python连接sqli... 目录1. 连接到数据库2. 创建游标对象3. 创建表4. 插入数据5. 查询数据6. 更新数据7. 删除

Jenkins的安装与简单配置过程

《Jenkins的安装与简单配置过程》本文简述Jenkins在CentOS7.3上安装流程,包括Java环境配置、RPM包安装、修改JENKINS_HOME路径及权限、启动服务、插件安装与系统管理设置... 目录www.chinasem.cnJenkins安装访问并配置JenkinsJenkins配置邮件通知

Python开发简易网络服务器的示例详解(新手入门)

《Python开发简易网络服务器的示例详解(新手入门)》网络服务器是互联网基础设施的核心组件,它本质上是一个持续运行的程序,负责监听特定端口,本文将使用Python开发一个简单的网络服务器,感兴趣的小... 目录网络服务器基础概念python内置服务器模块1. HTTP服务器模块2. Socket服务器模块

在ASP.NET项目中如何使用C#生成二维码

《在ASP.NET项目中如何使用C#生成二维码》二维码(QRCode)已广泛应用于网址分享,支付链接等场景,本文将以ASP.NET为示例,演示如何实现输入文本/URL,生成二维码,在线显示与下载的完整... 目录创建前端页面(Index.cshtml)后端二维码生成逻辑(Index.cshtml.cs)总结