Paddle 实现DCGAN

2024-05-10 23:28
文章标签 实现 paddle dcgan

本文主要是介绍Paddle 实现DCGAN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

传统GAN

传统的GAN可以看我的这篇文章:Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现-CSDN博客

DCGAN

DCGAN是适用于图像生成的GAN,它的特点是:

  • 只采用卷积层和转置卷积层,而不采用全连接层
  • 在每个卷积层或转置卷积层之间,插入一个批归一化层和ReLU激活函数

转置卷积层

转置卷积层执行的是转置卷积或反卷积的操作,即它是常规卷积层的反向操作。它接收一个低分辨率的输入,然后将其通过转置滤波器升采样到更高的分辨率。

对于一个卷积层,它的输出大小公式是:

o = \frac{i + 2p - k}{s} + 1

其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示卷积核大小(kernel_size),s表示步长(stride)。也就是说:输出大小 = (输入大小 - 卷积核大小 + 2 × 填充数) ÷ 步长 + 1

而对于一个转置卷积层,它的输出大小公式是:

o = s(i-1)-2p+k+u

 其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示反卷积核大小(kernel_size),s表示步长(stride),u表示输出填充(output padding)。也就是说:输出大小 = (输入大小 - 1) * 步长 - 2*填充 + 反卷积大小 + 输出填充

在paddle中,转置卷积层可以这么定义:

paddle.nn.Conv2DTranspose(in_channels, out_channels, kernel_size, stride, padding)

像卷积层一样,反卷积层的in_channels表示输入通道数(如形如(3, 32, 32)的图片张量的通道数就是3),out_channels表示输出通道数(如把(64, 32, 32)变成3通道的彩色图像(3, 32, 32))。 

代码实现

这里我们采用NWPU-RESISC45数据集,从中选择“freeway”(高速公路)作为训练数据,让机器生成高速公路的图片。这个训练数据内有700张256x256的图片,但由于我的电脑显存不足,因此将图片大小设置为64x64.

先写dataset.py:

import paddle
import numpy as np
from PIL import Image
import osdef getAllPath(path):return [os.path.join(path, f) for f in os.listdir(path)]class FreewayDataset(paddle.io.Dataset):def __init__(self, transform=None):super().__init__()self.data = []for path in getAllPath('./freeway'):img = Image.open(path)img = img.resize((64, 64))img = np.array(img, dtype=np.float32).transpose((2, 1, 0))if transform is not None:img = transform(img)self.data.append(img)self.data = np.array(self.data, dtype=np.float32)def __getitem__(self, idx):return self.data[idx]def __len__(self):return len(self.data)

然后写训练脚本:

from dataset import FreewayDataset
import paddle
from models import Generator, Discriminator
import numpy as npdataset = FreewayDataset()
dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)netG = Generator()
netD = Discriminator()if 1:try:mydict = paddle.load('generator.params')netG.set_dict(mydict)mydict = paddle.load('discriminator.params')netD.set_dict(mydict)except:print('fail to load model')loss = paddle.nn.BCELoss()optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)# 最大迭代epoch
max_epoch = 1000for epoch in range(max_epoch):now_step = 0for step, data in enumerate(dataloader):############################# (1) 更新鉴别器############################ 清除D的梯度optimizerD.clear_grad()# 传入正样本,并更新梯度pos_img = datalabel = paddle.full([pos_img.shape[0], 1, 1, 1], 1, dtype='float32')pre = netD(pos_img)loss_D_1 = loss(pre, label)loss_D_1.backward()# 通过randn构造随机数,制造负样本,并传入D,更新梯度noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')neg_img = netG(noise)label = paddle.full([pos_img.shape[0], 1, 1, 1], 0, dtype='float32')pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算loss_D_2 = loss(pre, label)loss_D_2.backward()# 更新D网络参数optimizerD.step()optimizerD.clear_grad()loss_D = loss_D_1 + loss_D_2############################# (2) 更新生成器############################ 清除D的梯度optimizerG.clear_grad()noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')fake = netG(noise)label = paddle.full((pos_img.shape[0], 1, 1, 1), 1, dtype=np.float32, )output = netD(fake)# 这个写法没有问题,因为这个loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度loss_G = loss(output, label)loss_G.backward()# 更新G网络参数optimizerG.step()optimizerG.clear_grad()now_step += 1############################ 输出日志###########################if now_step % 10 == 0:print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

 最后编写图片生成脚本:

import paddle
from models import Generator
import matplotlib.pyplot as plt# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):noise = paddle.randn([1, 100, 1, 1], 'float32')img = netG(noise)img = img.numpy()[0].transpose((2, 1, 0))  # img.numpy():张量转np数组img[img < 0] = 0  # 将img中所有小于0的元素赋值为0# 显示图片ax.imshow(img)ax.axis('off')  # 不显示坐标轴# 显示图像
plt.show()

经过数次训练,最终的效果如下:

这样看来,至少有点高速公路的感觉了。 

参考

通过DCGAN实现人脸图像生成-使用文档-PaddlePaddle深度学习平台

卷积层和反卷积层输出特征图大小计算_输出特征图大小的计算方法-CSDN博客 

这篇关于Paddle 实现DCGAN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/977863

相关文章

SpringBoot+RustFS 实现文件切片极速上传的实例代码

《SpringBoot+RustFS实现文件切片极速上传的实例代码》本文介绍利用SpringBoot和RustFS构建高性能文件切片上传系统,实现大文件秒传、断点续传和分片上传等功能,具有一定的参考... 目录一、为什么选择 RustFS + SpringBoot?二、环境准备与部署2.1 安装 RustF

Nginx部署HTTP/3的实现步骤

《Nginx部署HTTP/3的实现步骤》本文介绍了在Nginx中部署HTTP/3的详细步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录前提条件第一步:安装必要的依赖库第二步:获取并构建 BoringSSL第三步:获取 Nginx

MyBatis Plus实现时间字段自动填充的完整方案

《MyBatisPlus实现时间字段自动填充的完整方案》在日常开发中,我们经常需要记录数据的创建时间和更新时间,传统的做法是在每次插入或更新操作时手动设置这些时间字段,这种方式不仅繁琐,还容易遗漏,... 目录前言解决目标技术栈实现步骤1. 实体类注解配置2. 创建元数据处理器3. 服务层代码优化填充机制详

Python实现Excel批量样式修改器(附完整代码)

《Python实现Excel批量样式修改器(附完整代码)》这篇文章主要为大家详细介绍了如何使用Python实现一个Excel批量样式修改器,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一... 目录前言功能特性核心功能界面特性系统要求安装说明使用指南基本操作流程高级功能技术实现核心技术栈关键函

Java实现字节字符转bcd编码

《Java实现字节字符转bcd编码》BCD是一种将十进制数字编码为二进制的表示方式,常用于数字显示和存储,本文将介绍如何在Java中实现字节字符转BCD码的过程,需要的小伙伴可以了解下... 目录前言BCD码是什么Java实现字节转bcd编码方法补充总结前言BCD码(Binary-Coded Decima

SpringBoot全局域名替换的实现

《SpringBoot全局域名替换的实现》本文主要介绍了SpringBoot全局域名替换的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录 项目结构⚙️ 配置文件application.yml️ 配置类AppProperties.Ja

Python实现批量CSV转Excel的高性能处理方案

《Python实现批量CSV转Excel的高性能处理方案》在日常办公中,我们经常需要将CSV格式的数据转换为Excel文件,本文将介绍一个基于Python的高性能解决方案,感兴趣的小伙伴可以跟随小编一... 目录一、场景需求二、技术方案三、核心代码四、批量处理方案五、性能优化六、使用示例完整代码七、小结一、

Java实现将HTML文件与字符串转换为图片

《Java实现将HTML文件与字符串转换为图片》在Java开发中,我们经常会遇到将HTML内容转换为图片的需求,本文小编就来和大家详细讲讲如何使用FreeSpire.DocforJava库来实现这一功... 目录前言核心实现:html 转图片完整代码场景 1:转换本地 HTML 文件为图片场景 2:转换 H

C#使用Spire.Doc for .NET实现HTML转Word的高效方案

《C#使用Spire.Docfor.NET实现HTML转Word的高效方案》在Web开发中,HTML内容的生成与处理是高频需求,然而,当用户需要将HTML页面或动态生成的HTML字符串转换为Wor... 目录引言一、html转Word的典型场景与挑战二、用 Spire.Doc 实现 HTML 转 Word1

C#实现一键批量合并PDF文档

《C#实现一键批量合并PDF文档》这篇文章主要为大家详细介绍了如何使用C#实现一键批量合并PDF文档功能,文中的示例代码简洁易懂,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言效果展示功能实现1、添加文件2、文件分组(书签)3、定义页码范围4、自定义显示5、定义页面尺寸6、PDF批量合并7、其他方法