Pytorch实战3:DCGAN深度卷积对抗生成网络生成动漫头像

2024-02-19 15:38

本文主要是介绍Pytorch实战3:DCGAN深度卷积对抗生成网络生成动漫头像,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

版权申明:本文章为本人原创内容,转载请注明出处,谢谢合作!

实验环境:

1.Pytorch 0.4.0
2.torchvision 0.2.1
3.Python 3.6
4.Win10+Pycharm
本项目是基于DCGAN的,代码是在《深度学习框架PyTorch:入门与实践》第七章的配套代码上做过大量修改过的。项目所用数据集获取:点击获取 提取码:g5qa,感谢知乎用户何之源爬取的数据。 请将下载的压缩包里的图片完整解压至data/face/目录下。整个项目的代码结构如下图:
这里写图片描述
其中data/face里是存放训练图片的,imgs/存放的是最终的训练结果,model.py是DCGAN的结构,train.py是主要的训练文件。

首先是,model.py:
import torch.nn as nn
# 定义生成器网络G
class NetG(nn.Module):def __init__(self, ngf, nz):super(NetG, self).__init__()# layer1输入的是一个100x1x1的随机噪声, 输出尺寸(ngf*8)x4x4self.layer1 = nn.Sequential(nn.ConvTranspose2d(nz, ngf * 8, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(inplace=True))# layer2输出尺寸(ngf*4)x8x8self.layer2 = nn.Sequential(nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(inplace=True))# layer3输出尺寸(ngf*2)x16x16self.layer3 = nn.Sequential(nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(inplace=True))# layer4输出尺寸(ngf)x32x32self.layer4 = nn.Sequential(nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(inplace=True))# layer5输出尺寸 3x96x96self.layer5 = nn.Sequential(nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),nn.Tanh())# 定义NetG的前向传播def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)return out# 定义鉴别器网络D
class NetD(nn.Module):def __init__(self, ndf):super(NetD, self).__init__()# layer1 输入 3 x 96 x 96, 输出 (ndf) x 32 x 32self.layer1 = nn.Sequential(nn.Conv2d(3, ndf, kernel_size=5, stride=3, padding=1, bias=False),nn.BatchNorm2d(ndf),nn.LeakyReLU(0.2, inplace=True))# layer2 输出 (ndf*2) x 16 x 16self.layer2 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True))# layer3 输出 (ndf*4) x 8 x 8self.layer3 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True))# layer4 输出 (ndf*8) x 4 x 4self.layer4 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True))# layer5 输出一个数(概率)self.layer5 = nn.Sequential(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())# 定义NetD的前向传播def forward(self,x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)return out
然后是,train.py:
import argparse
import torch
import torchvision
import torchvision.utils as vutils
import torch.nn as nn
from random import randint
from model import NetD, NetGparser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=64)
parser.add_argument('--imageSize', type=int, default=96)
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--epoch', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--data_path', default='data/', help='folder to train data')
parser.add_argument('--outf', default='imgs/', help='folder to output images and model checkpoints')
opt = parser.parse_args()
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#图像读入与预处理
transforms = torchvision.transforms.Compose([torchvision.transforms.Scale(opt.imageSize),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])dataset = torchvision.datasets.ImageFolder(opt.data_path, transform=transforms)dataloader = torch.utils.data.DataLoader(dataset=dataset,batch_size=opt.batchSize,shuffle=True,drop_last=True,
)netG = NetG(opt.ngf, opt.nz).to(device)
netD = NetD(opt.ndf).to(device)criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerD = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0for epoch in range(1, opt.epoch + 1):for i, (imgs,_) in enumerate(dataloader):# 固定生成器G,训练鉴别器DoptimizerD.zero_grad()## 让D尽可能的把真图片判别为1imgs=imgs.to(device)output = netD(imgs)label.data.fill_(real_label)label=label.to(device)errD_real = criterion(output, label)errD_real.backward()## 让D尽可能把假图片判别为0label.data.fill_(fake_label)noise = torch.randn(opt.batchSize, opt.nz, 1, 1)noise=noise.to(device)fake = netG(noise)  # 生成假图output = netD(fake.detach()) #避免梯度传到G,因为G不用更新errD_fake = criterion(output, label)errD_fake.backward()errD = errD_fake + errD_realoptimizerD.step()# 固定鉴别器D,训练生成器GoptimizerG.zero_grad()# 让D尽可能把G生成的假图判别为1label.data.fill_(real_label)label = label.to(device)output = netD(fake)errG = criterion(output, label)errG.backward()optimizerG.step()print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f'% (epoch, opt.epoch, i, len(dataloader), errD.item(), errG.item()))vutils.save_image(fake.data,'%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),normalize=True)torch.save(netG.state_dict(), '%s/netG_%03d.pth' % (opt.outf, epoch))torch.save(netD.state_dict(), '%s/netD_%03d.pth' % (opt.outf, epoch))

实验结果:

跑完第1个epoch的结果:
这里写图片描述
跑完第25个epoch的结果:
这里写图片描述

这篇关于Pytorch实战3:DCGAN深度卷积对抗生成网络生成动漫头像的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/725119

相关文章

Python办公自动化实战之打造智能邮件发送工具

《Python办公自动化实战之打造智能邮件发送工具》在数字化办公场景中,邮件自动化是提升工作效率的关键技能,本文将演示如何使用Python的smtplib和email库构建一个支持图文混排,多附件,多... 目录前言一、基础配置:搭建邮件发送框架1.1 邮箱服务准备1.2 核心库导入1.3 基础发送函数二、

PowerShell中15个提升运维效率关键命令实战指南

《PowerShell中15个提升运维效率关键命令实战指南》作为网络安全专业人员的必备技能,PowerShell在系统管理、日志分析、威胁检测和自动化响应方面展现出强大能力,下面我们就来看看15个提升... 目录一、PowerShell在网络安全中的战略价值二、网络安全关键场景命令实战1. 系统安全基线核查

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

Java MQTT实战应用

《JavaMQTT实战应用》本文详解MQTT协议,涵盖其发布/订阅机制、低功耗高效特性、三种服务质量等级(QoS0/1/2),以及客户端、代理、主题的核心概念,最后提供Linux部署教程、Sprin... 目录一、MQTT协议二、MQTT优点三、三种服务质量等级四、客户端、代理、主题1. 客户端(Clien

在Spring Boot中集成RabbitMQ的实战记录

《在SpringBoot中集成RabbitMQ的实战记录》本文介绍SpringBoot集成RabbitMQ的步骤,涵盖配置连接、消息发送与接收,并对比两种定义Exchange与队列的方式:手动声明(... 目录前言准备工作1. 安装 RabbitMQ2. 消息发送者(Producer)配置1. 创建 Spr

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现