【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch)

2024-03-29 06:12

本文主要是介绍【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

关于

 

近年来,基于卷积网络(CNN)的监督学习已经 在计算机视觉应用中得到了广泛的采用。相比之下,无监督 使用 CNN 进行学习受到的关注较少。在这项工作中,我们希望能有所帮助 缩小了 CNN 在监督学习和无监督学习方面的成功之间的差距。我们介绍一类称为深度卷积生成的 CNN 对抗性网络(DCGAN),具有一定的架构限制,以及 证明他们是无监督学习的有力候选人。训练 在各种图像数据集上,我们展示了令人信服的证据,表明我们的深度卷积对抗对学习了从对象部分到 生成器和鉴别器中的场景。此外,我们使用学到的 新任务的特征 - 证明它们作为一般图像表示的适用性。(https://arxiv.org/pdf/1511.06434.pdf)

工具

 数据集

方法实现

加载必要的库函数和自定义函数

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as Ffrom torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
def get_sample_image(G, n_noise):"""save sample 100 images"""z = torch.randn(100, n_noise).to(DEVICE)y_hat = G(z).view(100, 28, 28) # (100, 28, 28)result = y_hat.cpu().data.numpy()img = np.zeros([280, 280])for j in range(10):img[j*28:(j+1)*28] = np.concatenate([x for x in result[j*10:(j+1)*10]], axis=-1)return img

定义判别模型

class Discriminator(nn.Module):"""Convolutional Discriminator for MNIST"""def __init__(self, in_channel=1, num_classes=1):super(Discriminator, self).__init__()self.conv = nn.Sequential(# 28 -> 14nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),# 14 -> 7nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),# 7 -> 4nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.AvgPool2d(4),)self.fc = nn.Sequential(# reshape input, 128 -> 1nn.Linear(128, 1),nn.Sigmoid(),)def forward(self, x, y=None):y_ = self.conv(x)y_ = y_.view(y_.size(0), -1)y_ = self.fc(y_)return y_

定义生成模型

class Generator(nn.Module):"""Convolutional Generator for MNIST"""def __init__(self, input_size=100, num_classes=784):super(Generator, self).__init__()self.fc = nn.Sequential(nn.Linear(input_size, 4*4*512),nn.ReLU(),)self.conv = nn.Sequential(# input: 4 by 4, output: 7 by 7nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(),# input: 7 by 7, output: 14 by 14nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(),# input: 14 by 14, output: 28 by 28nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),nn.Tanh(),)def forward(self, x, y=None):x = x.view(x.size(0), -1)y_ = self.fc(x)y_ = y_.view(y_.size(0), 512, 4, 4)y_ = self.conv(y_)return y_

 模型超参数定义配置

batch_size = 64criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))max_epoch = 30 # need more than 20 epochs for training generator
step = 0
n_critic = 1 # for training more k steps about Discriminator
n_noise = 100D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake

 模型训练

for epoch in range(max_epoch):for idx, (images, labels) in enumerate(data_loader):# Training Discriminatorx = images.to(DEVICE)x_outputs = D(x)D_x_loss = criterion(x_outputs, D_labels)z = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))D_z_loss = criterion(z_outputs, D_fakes)D_loss = D_x_loss + D_z_lossD.zero_grad()D_loss.backward()D_opt.step()if step % n_critic == 0:# Training Generatorz = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))G_loss = criterion(z_outputs, D_labels)D.zero_grad()G.zero_grad()G_loss.backward()G_opt.step()if step % 500 == 0:print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))if step % 1000 == 0:G.eval()img = get_sample_image(G, n_noise)imsave('./{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')G.train()step += 1

测试生成效果

# generation to image
G.eval()
imshow(get_sample_image(G, n_noise), cmap='gray')

 

模型和状态参量保存

def save_checkpoint(state, file_name='checkpoint.pth.tar'):torch.save(state, file_name)# Saving params.
# torch.save(D.state_dict(), 'D_c.pkl')
# torch.save(G.state_dict(), 'G_c.pkl')
save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_dc.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_dc.pth.tar')

应用

DCGAN作为一个成熟的生成模型,在自然图像,医学图像,医学电生理信号数据分析中,都可以用来实现数据的合成,达到数据增强的目的,同时,如何减少增强数据对于后端任务的不利干扰,也是一个需要关注的方面。

这篇关于【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/857822

相关文章

C#如何去掉文件夹或文件名非法字符

《C#如何去掉文件夹或文件名非法字符》:本文主要介绍C#如何去掉文件夹或文件名非法字符的问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录C#去掉文件夹或文件名非法字符net类库提供了非法字符的数组这里还有个小窍门总结C#去掉文件夹或文件名非法字符实现有输入字

Python实现自动化Word文档样式复制与内容生成

《Python实现自动化Word文档样式复制与内容生成》在办公自动化领域,高效处理Word文档的样式和内容复制是一个常见需求,本文将展示如何利用Python的python-docx库实现... 目录一、为什么需要自动化 Word 文档处理二、核心功能实现:样式与表格的深度复制1. 表格复制(含样式与内容)2

python如何生成指定文件大小

《python如何生成指定文件大小》:本文主要介绍python如何生成指定文件大小的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录python生成指定文件大小方法一(速度最快)方法二(中等速度)方法三(生成可读文本文件–较慢)方法四(使用内存映射高效生成

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

Python中OpenCV与Matplotlib的图像操作入门指南

《Python中OpenCV与Matplotlib的图像操作入门指南》:本文主要介绍Python中OpenCV与Matplotlib的图像操作指南,本文通过实例代码给大家介绍的非常详细,对大家的学... 目录一、环境准备二、图像的基本操作1. 图像读取、显示与保存 使用OpenCV操作2. 像素级操作3.

C/C++的OpenCV 进行图像梯度提取的几种实现

《C/C++的OpenCV进行图像梯度提取的几种实现》本文主要介绍了C/C++的OpenCV进行图像梯度提取的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录预www.chinasem.cn备知识1. 图像加载与预处理2. Sobel 算子计算 X 和 Y

c/c++的opencv图像金字塔缩放实现

《c/c++的opencv图像金字塔缩放实现》本文主要介绍了c/c++的opencv图像金字塔缩放实现,通过对原始图像进行连续的下采样或上采样操作,生成一系列不同分辨率的图像,具有一定的参考价值,感兴... 目录图像金字塔简介图像下采样 (cv::pyrDown)图像上采样 (cv::pyrUp)C++ O

Maven项目中集成数据库文档生成工具的操作步骤

《Maven项目中集成数据库文档生成工具的操作步骤》在Maven项目中,可以通过集成数据库文档生成工具来自动生成数据库文档,本文为大家整理了使用screw-maven-plugin(推荐)的完... 目录1. 添加插件配置到 pom.XML2. 配置数据库信息3. 执行生成命令4. 高级配置选项5. 注意事

MybatisX快速生成增删改查的方法示例

《MybatisX快速生成增删改查的方法示例》MybatisX是基于IDEA的MyBatis/MyBatis-Plus开发插件,本文主要介绍了MybatisX快速生成增删改查的方法示例,文中通过示例代... 目录1 安装2 基本功能2.1 XML跳转2.2 代码生成2.2.1 生成.xml中的sql语句头2