本文主要是介绍超分辨率(3)--基于RCAN网络实现图像超分辨率重建,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
一.项目介绍
RCAN:Residual Channel Attention Network(残差通道注意网络 )
卷积神经网络(CNN)的深度对于图像超分辨率(SR)是极其关键的因素。然而,我们观察到,更深层次的图像SR网络更难训练。低分辨率的输入和特征包含丰富的低频信息,这些信息在通道间被平等对待,从而阻碍了CNNs的表征能力。为了解决这些问题,我们提出了一种非常深的残差通道注意网络(RCAN)。具体地,我们提出了一种residual in residual(RIR)结构来形成非常深的网络,它由几个具有长跳连接的残差组组成。每个残差组包含一些具有短跳连接的残差块。与此同时,RIR允许大量的低频信息通过多个跳跃连接被绕过,使得主网络专注于学习高频信息。在此基础上,我们提出了一种通道注意机制,通过考虑通道间的相互依赖关系,自适应地重新调整通道特征。大量的实验表明,与比之前最先进的方法相比,我们的RCAN实现了更好的精度和视觉效果。
背景:
- 卷积神经网络(CNN)的深度对于图像超分辨率(SR)是极其关键的因素。然而,作者观察到,更深层次的图像SR网络更难训练。
- 低分辨率图像(LR)的输入和特征包含大量的低频信息,这些信息在通道间被平等对待,从而阻碍了CNNs的表征能力。
解决方案:
- 对于第一个更深的网络更难训练的问题,作者研究发现,通过在网络中引入残差块,这种残差块使得网络达到了1000层,但是仅仅通过叠加残差块来构建更深的网络很难获得更好的提升效果。因此,作者提出了残差嵌套(residual in residual,RIR)结构构造非常深的可训练网络,RIR中的长跳连接和短跳连接有助于绕过大量的低频信息,使主网络学习到更有效的信息。
- 对于第二个LR输入低频和高频信息在通道被平等对待的问题,作者发现注意力可以使可用处理资源的分配偏向于输入中信息量最大的部分,因此引入通道注意(Channel Attention ,CA)机制。
网络架构:
RCAN主要由四个部分组成:浅层特征提取、残差嵌套(RIR)深度特征提取、上采样模块和重建部分。
- RIR组成:G个RG(带长跳连接)
- 每个RG:B个RCAB组成(带短跳连接)
- 每个RCAB组成:Conv + ReLU + Conv + CA
- CA组成:Global pooling + Conv + ReLU + Conv
名词解释:
- Residual Channel Attention Network,RCAN:残差通道注意网络
- residual in residua,RIR:残差嵌套
- residual groups,RG:残差组
- Residual Channel Attention Block,RCAB:残差通道注意块
- Channel Attention,CA:通道注意
- long skip connection,LSC:长跳连接
- short skip connection,SSC:短跳连接
论文地址:
[1807.02758] Image Super-Resolution Using Very Deep Residual Channel Attention Networks (arxiv.org)https://arxiv.org/abs/1807.02758
参考文章:
RCAN论文笔记:Image Super-Resolution Using Very Deep Residual Channel Attention Networks-CSDN博客https://blog.csdn.net/weixin_46773169/article/details/105600346
源码地址:
yulunzhang/RCAN: PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks" (github.com)https://github.com/yulunzhang/RCAN
二.项目流程详解
2.1.数据处理模块
_init_.py
from importlib import import_modulefrom dataloader import MSDataLoader
from torch.utils.data.dataloader import default_collateclass Data:def __init__(self, args):kwargs = {}# 如果不在cpu上训练if not args.cpu:kwargs['collate_fn'] = default_collatekwargs['pin_memory'] = True# 在cpu上训练else:kwargs['collate_fn'] = default_collatekwargs['pin_memory'] = Falseself.loader_train = Noneif not args.test_only:# .lower()将大写字母转换为小写字母module_train = import_module('data.' + args.data_train.lower())# getattr() 函数用于返回一个对象属性值。trainset = getattr(module_train, args.data_train)(args)self.loader_train = MSDataLoader(args,trainset,batch_size=args.batch_size,shuffle=True,**kwargs)# 针对特殊的数据if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']:if not args.benchmark_noise:module_test = import_module('data.benchmark')testset = getattr(module_test, 'Benchmark')(args, train=False)else:module_test = import_module('data.benchmark_noise')testset = getattr(module_test, 'BenchmarkNoise')(args,train=False)else:module_test = import_module('data.' + args.data_test.lower())testset = getattr(module_test, args.data_test)(args, train=False)# 对于自定义的MSDataLoader,主要需要传入的参数为args和datasetself.loader_test = MSDataLoader(args,testset,batch_size=1,shuffle=False,**kwargs)'''
class MSDataLoader(DataLoader):def __init__(self, args, dataset, batch_size=1, shuffle=False,sampler=None, batch_sampler=None,collate_fn=default_collate, pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None):super(MSDataLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle,sampler=sampler, batch_sampler=batch_sampler,num_workers=args.n_threads, collate_fn=collate_fn,pin_memory=pin_memory, drop_last=drop_last,timeout=timeout, worker_init_fn=worker_init_fn)self.scale = args.scaledef __iter__(self):return _MSDataLoaderIter(self)
'''
benchmark.py
import osfrom data import common
from data import srdataimport numpy as np
import scipy.misc as miscimport torch
import torch.utils.data as dataclass Benchmark(srdata.SRData):def __init__(self, args, train=True):super(Benchmark, self).__init__(args, train, benchmark=True)# 扫描磁盘得到数据def _scan(self):list_hr = []list_lr = [[] for _ in self.scale]for entry in os.scandir(self.dir_hr):# os.path.splitext分离文件名字和文件类型# eg: os.path.splitext(abc.txt) 得到的为('abc','txt')# filename取出的是文件名filename = os.path.splitext(entry.name)[0]# filename + self.ext 为文件的完整名字# os.path.join用于拼接文件路径,可以传入多个路径# 此处append的文件路径即为 self.dir_hr + (filename+self.ext)list_hr.append(os.path.join(self.dir_hr, filename + self.ext))for si, s in enumerate(self.scale):list_lr[si].append(os.path.join(self.dir_lr,'X{}/{}x{}{}'.format(s, filename, s, self.ext)))# 对取出的数据进行升序排列list_hr.sort()for l in list_lr:l.sort()return list_hr, list_lr# 设置数据的地址以及数据的类型def _set_filesystem(self, dir_data):self.apath = os.path.join(dir_data, 'benchmark', self.args.data_test)self.dir_hr = os.path.join(self.apath, 'HR')self.dir_lr = os.path.join(self.apath, 'LR_bicubic')self.ext = '.png'
common.py
import randomimport numpy as np
import skimage.io as sio
import skimage.color as sc
import skimage.transform as stimport torch
from torchvision import transformsdef get_patch(img_in, img_tar, patch_size, scale, multi_scale=False):# shape得到图片的高度、宽度、颜色通道# 所以shape[:2}就是获取图片的前两个维度,获得图片的高度和宽度ih, iw = img_in.shape[:2]p = scale if multi_scale else 1tp = p * patch_sizeip = tp // scaleix = random.randrange(0, iw - ip + 1)iy = random.randrange(0, ih - ip + 1)tx, ty = scale * ix, scale * iyimg_in = img_in[iy:iy + ip, ix:ix + ip, :]img_tar = img_tar[ty:ty + tp, tx:tx + tp, :]return img_in, img_tar# 设置channel值
def set_channel(l, n_channel):def _set_channel(img):if img.ndim == 2:# expand_dims(a, axis)中,a为numpy数组,axis为需添加维度的轴# 使数据增加一个维度img = np.expand_dims(img, axis=2)c = img.shape[2]if n_channel == 1 and c == 3:img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)elif n_channel == 3 and c == 1:# numpy.concatenate((a1,a2,...), axis=0)函数。# 能 够一次完成多个数组的拼接。其中a1,a2,...是数组类型的参数img = np.concatenate([img] * n_channel, 2)return imgreturn [_set_channel(_l) for _l in l]# 将np.array类型转为tensor类型
def np2Tensor(l, rgb_range):def _np2Tensor(img):# ascontiguousarray函数将一个内存不连续存储的数组转换为内存连续存储的数组,使得运行速度更快# img.transpose((2,0,1))将图片的维度由(0,1,2)转换为(2,0,1)np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))tensor = torch.from_numpy(np_transpose).float()tensor.mul_(rgb_range / 255)return tensorreturn [_np2Tensor(_l) for _l in l]def add_noise(x, noise='.'):if noise is not '.':noise_type = noise[0]noise_value = int(noise[1:])if noise_type == 'G':noises = np.random.normal(scale=noise_value, size=x.shape)noises = noises.round()elif noise_type == 'S':noises = np.random.poisson(x * noise_value) / noise_valuenoises = noises - noises.mean(axis=0).mean(axis=0)x_noise = x.astype(np.int16) + noises.astype(np.int16)x_noise = x_noise.clip(0, 255).astype(np.uint8)return x_noiseelse:return xdef augment(l, hflip=True, rot=True):hflip = hflip and random.random() < 0.5vflip = rot and random.random() < 0.5rot90 = rot and random.random() < 0.5def _augment(img):if hflip: img = img[:, ::-1, :]if vflip: img = img[::-1, :, :]if rot90: img = img.transpose(1, 0, 2)return imgreturn [_augment(_l) for _l in l]
demo.py
import osfrom data import commonimport numpy as np
import scipy.misc as miscimport torch
import torch.utils.data as dataclass Demo(data.Dataset):def __init__(self, args, train=False):self.args = argsself.name = 'Demo'self.scale = args.scaleself.idx_scale = 0self.train = Falseself.benchmark = Falseself.filelist = []for f in os.listdir(args.dir_demo):if f.find('.png') >= 0 or f.find('.jp') >= 0:self.filelist.append(os.path.join(args.dir_demo, f))self.filelist.sort()def __getitem__(self, idx):filename = os.path.split(self.filelist[idx])[-1]filename, _ = os.path.splitext(filename)lr = misc.imread(self.filelist[idx])lr = common.set_channel([lr], self.args.n_colors)[0]return common.np2Tensor([lr], self.args.rgb_range)[0], -1, filenamedef __len__(self):return len(self.filelist)def set_scale(self, idx_scale):self.idx_scale = idx_scale
srdata.py
import osfrom data import commonimport numpy as np
import scipy.misc as miscimport torch
import torch.utils.data as dataclass SRData(data.Dataset):def __init__(self, args, train=True, benchmark=False):self.args = argsself.train = trainself.split = 'train' if train else 'test'self.benchmark = benchmarkself.scale = args.scaleself.idx_scale = 0self._set_filesystem(args.dir_data)def _load_bin():self.images_hr = np.load(self._name_hrbin())self.images_lr = [np.load(self._name_lrbin(s)) for s in self.scale]if args.ext == 'img' or benchmark:self.images_hr, self.images_lr = self._scan()elif args.ext.find('sep') >= 0:self.images_hr, self.images_lr = self._scan()if args.ext.find('reset') >= 0:print('Preparing seperated binary files')for v in self.images_hr:hr = misc.imread(v)name_sep = v.replace(self.ext, '.npy')np.save(name_sep, hr)for si, s in enumerate(self.scale):for v in self.images_lr[si]:lr = misc.imread(v)name_sep = v.replace(self.ext, '.npy')np.save(name_sep, lr)self.images_hr = [v.replace(self.ext, '.npy') for v in self.images_hr]self.images_lr = [[v.replace(self.ext, '.npy') for v in self.images_lr[i]]for i in range(len(self.scale))]elif args.ext.find('bin') >= 0:try:if args.ext.find('reset') >= 0:raise IOErrorprint('Loading a binary file')_load_bin()except:print('Preparing a binary file')bin_path = os.path.join(self.apath, 'bin')if not os.path.isdir(bin_path):os.mkdir(bin_path)list_hr, list_lr = self._scan()hr = [misc.imread(f) for f in list_hr]np.save(self._name_hrbin(), hr)del hrfor si, s in enumerate(self.scale):lr_scale = [misc.imread(f) for f in list_lr[si]]np.save(self._name_lrbin(s), lr_scale)del lr_scale_load_bin()else:print('Please define data type')def _scan(self):raise NotImplementedErrordef _set_filesystem(self, dir_data):raise NotImplementedErrordef _name_hrbin(self):raise NotImplementedErrordef _name_lrbin(self, scale):raise NotImplementedErrordef __getitem__(self, idx):lr, hr, filename = self._load_file(idx)lr, hr = self._get_patch(lr, hr)lr, hr = common.set_channel([lr, hr], self.args.n_colors)lr_tensor, hr_tensor = common.np2Tensor([lr, hr], self.args.rgb_range)return lr_tensor, hr_tensor, filenamedef __len__(self):return len(self.images_hr)def _get_index(self, idx):return idxdef _load_file(self, idx):idx = self._get_index(idx)lr = self.images_lr[self.idx_scale][idx]hr = self.images_hr[idx]if self.args.ext == 'img' or self.benchmark:filename = hrlr = misc.imread(lr)hr = misc.imread(hr)elif self.args.ext.find('sep') >= 0:filename = hrlr = np.load(lr)hr = np.load(hr)else:filename = str(idx + 1)filename = os.path.splitext(os.path.split(filename)[-1])[0]return lr, hr, filenamedef _get_patch(self, lr, hr):patch_size = self.args.patch_sizescale = self.scale[self.idx_scale]multi_scale = len(self.scale) > 1if self.train:lr, hr = common.get_patch(lr, hr, patch_size, scale, multi_scale=multi_scale)lr, hr = common.augment([lr, hr])lr = common.add_noise(lr, self.args.noise)else:ih, iw = lr.shape[0:2]hr = hr[0:ih * scale, 0:iw * scale]return lr, hrdef set_scale(self, idx_scale):self.idx_scale = idx_scale
div2k.py
import osfrom data import common
from data import srdataimport numpy as np
import scipy.misc as miscimport torch
import torch.utils.data as dataclass DIV2K(srdata.SRData):def __init__(self, args, train=True):super(DIV2K, self).__init__(args, train)self.repeat = args.test_every // (args.n_train // args.batch_size)def _scan(self):list_hr = []list_lr = [[] for _ in self.scale]if self.train:idx_begin = 0idx_end = self.args.n_trainelse:idx_begin = self.args.n_trainidx_end = self.args.offset_val + self.args.n_valfor i in range(idx_begin + 1, idx_end + 1):filename = '{:0>4}'.format(i)list_hr.append(os.path.join(self.dir_hr, filename + self.ext))for si, s in enumerate(self.scale):list_lr[si].append(os.path.join(self.dir_lr,'X{}/{}x{}{}'.format(s, filename, s, self.ext)))return list_hr, list_lrdef _set_filesystem(self, dir_data):self.apath = dir_data + '/DIV2K'self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')self.ext = '.png'def _name_hrbin(self):return os.path.join(self.apath,'bin','{}_bin_HR.npy'.format(self.split))def _name_lrbin(self, scale):return os.path.join(self.apath,'bin','{}_bin_LR_X{}.npy'.format(self.split, scale))def __len__(self):if self.train:return len(self.images_hr) * self.repeatelse:return len(self.images_hr)def _get_index(self, idx):if self.train:return idx % len(self.images_hr)else:return idx
2.2.损失函数设置
_init_.py
import os
from importlib import import_moduleimport matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as pltimport numpy as npimport torch
import torch.nn as nn
import torch.nn.functional as Fclass Loss(nn.modules.loss._Loss):def __init__(self, args, ckp):super(Loss, self).__init__()print('Preparing loss function:')self.n_GPUs = args.n_GPUsself.loss = []# 首先说说 nn.ModuleList 这个类,你可以把任意 nn.Module 的子类# (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,# 方法和 Python 自带的 list 一样,无非是 extend,append 等操作。# 但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,# 同时 module 的 parameters 也会自动添加到整个网络中。self.loss_module = nn.ModuleList()# split(' ')根据括号里的字符分割字符串for loss in args.loss.split('+'):weight, loss_type = loss.split('*')if loss_type == 'MSE':loss_function = nn.MSELoss()elif loss_type == 'L1':loss_function = nn.L1Loss()elif loss_type.find('VGG') >= 0:module = import_module('loss.vgg')loss_function = getattr(module, 'VGG')(loss_type[3:],rgb_range=args.rgb_range)elif loss_type.find('GAN') >= 0:module = import_module('loss.adversarial')loss_function = getattr(module, 'Adversarial')(args,loss_type)self.loss.append({'type': loss_type,'weight': float(weight),'function': loss_function})if loss_type.find('GAN') >= 0:self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})if len(self.loss) > 1:self.loss.append({'type': 'Total', 'weight': 0, 'function': None})for l in self.loss:if l['function'] is not None:print('{:.3f} * {}'.format(l['weight'], l['type']))self.loss_module.append(l['function'])self.log = torch.Tensor()device = torch.device('cpu' if args.cpu else 'cuda')self.loss_module.to(device)if args.precision == 'half': self.loss_module.half()if not args.cpu and args.n_GPUs > 1:self.loss_module = nn.DataParallel(self.loss_module, range(args.n_GPUs))if args.load != '.': self.load(ckp.dir, cpu=args.cpu)def forward(self, sr, hr):losses = []for i, l in enumerate(self.loss):if l['function'] is not None:loss = l['function'](sr, hr)effective_loss = l['weight'] * losslosses.append(effective_loss)self.log[-1, i] += effective_loss.item()elif l['type'] == 'DIS':self.log[-1, i] += self.loss[i - 1]['function'].lossloss_sum = sum(losses)if len(self.loss) > 1:self.log[-1, -1] += loss_sum.item()return loss_sumdef step(self):for l in self.get_loss_module():if hasattr(l, 'scheduler'):l.scheduler.step()def start_log(self):self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))def end_log(self, n_batches):self.log[-1].div_(n_batches)def display_loss(self, batch):n_samples = batch + 1log = []for l, c in zip(self.loss, self.log[-1]):log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))return ''.join(log)def plot_loss(self, apath, epoch):axis = np.linspace(1, epoch, epoch)for i, l in enumerate(self.loss):label = '{} Loss'.format(l['type'])fig = plt.figure()plt.title(label)plt.plot(axis, self.log[:, i].numpy(), label=label)plt.legend()plt.xlabel('Epochs')plt.ylabel('Loss')plt.grid(True)plt.savefig('{}/loss_{}.pdf'.format(apath, l['type']))plt.close(fig)def get_loss_module(self):if self.n_GPUs == 1:return self.loss_moduleelse:return self.loss_module.moduledef save(self, apath):torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))torch.save(self.log, os.path.join(apath, 'loss_log.pt'))def load(self, apath, cpu=False):if cpu:kwargs = {'map_location': lambda storage, loc: storage}else:kwargs = {}self.load_state_dict(torch.load(os.path.join(apath, 'loss.pt'),**kwargs))self.log = torch.load(os.path.join(apath, 'loss_log.pt'))for l in self.get_loss_module():if hasattr(l, 'scheduler'):for _ in range(len(self.log)): l.scheduler.step()
adversarial.py
import utility
from model import common
from loss import discriminatorimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variableclass Adversarial(nn.Module):def __init__(self, args, gan_type):super(Adversarial, self).__init__()self.gan_type = gan_typeself.gan_k = args.gan_kself.discriminator = discriminator.Discriminator(args, gan_type)if gan_type != 'WGAN_GP':self.optimizer = utility.make_optimizer(args, self.discriminator)else:self.optimizer = optim.Adam(self.discriminator.parameters(),betas=(0, 0.9), eps=1e-8, lr=1e-5)self.scheduler = utility.make_scheduler(args, self.optimizer)def forward(self, fake, real):fake_detach = fake.detach()self.loss = 0for _ in range(self.gan_k):self.optimizer.zero_grad()d_fake = self.discriminator(fake_detach)d_real = self.discriminator(real)if self.gan_type == 'GAN':label_fake = torch.zeros_like(d_fake)label_real = torch.ones_like(d_real)loss_d \= F.binary_cross_entropy_with_logits(d_fake, label_fake) \+ F.binary_cross_entropy_with_logits(d_real, label_real)elif self.gan_type.find('WGAN') >= 0:loss_d = (d_fake - d_real).mean()if self.gan_type.find('GP') >= 0:epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)hat.requires_grad = Trued_hat = self.discriminator(hat)gradients = torch.autograd.grad(outputs=d_hat.sum(), inputs=hat,retain_graph=True, create_graph=True, only_inputs=True)[0]gradients = gradients.view(gradients.size(0), -1)gradient_norm = gradients.norm(2, dim=1)gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()loss_d += gradient_penalty# Discriminator updateself.loss += loss_d.item()loss_d.backward()self.optimizer.step()if self.gan_type == 'WGAN':for p in self.discriminator.parameters():p.data.clamp_(-1, 1)self.loss /= self.gan_kd_fake_for_g = self.discriminator(fake)if self.gan_type == 'GAN':loss_g = F.binary_cross_entropy_with_logits(d_fake_for_g, label_real)elif self.gan_type.find('WGAN') >= 0:loss_g = -d_fake_for_g.mean()# Generator lossreturn loss_gdef state_dict(self, *args, **kwargs):state_discriminator = self.discriminator.state_dict(*args, **kwargs)state_optimizer = self.optimizer.state_dict()return dict(**state_discriminator, **state_optimizer)# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
discriminator.py
from model import commonimport torch.nn as nnclass Discriminator(nn.Module):def __init__(self, args, gan_type='GAN'):super(Discriminator, self).__init__()in_channels = 3out_channels = 64depth = 7#bn = not gan_type == 'WGAN_GP'bn = Trueact = nn.LeakyReLU(negative_slope=0.2, inplace=True)m_features = [common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act)]for i in range(depth):in_channels = out_channelsif i % 2 == 1:stride = 1out_channels *= 2else:stride = 2m_features.append(common.BasicBlock(in_channels, out_channels, 3, stride=stride, bn=bn, act=act))self.features = nn.Sequential(*m_features)patch_size = args.patch_size // (2**((depth + 1) // 2))m_classifier = [nn.Linear(out_channels * patch_size**2, 1024),act,nn.Linear(1024, 1)]self.classifier = nn.Sequential(*m_classifier)def forward(self, x):features = self.features(x)output = self.classifier(features.view(features.size(0), -1))return output
vgg.py
from model import commonimport torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variableclass VGG(nn.Module):def __init__(self, conv_index, rgb_range=1):super(VGG, self).__init__()# pretrained = True 表示使用已经训练过的参数vgg_features = models.vgg19(pretrained=True).featuresmodules = [m for m in vgg_features]if conv_index == '22':self.vgg = nn.Sequential(*modules[:8])elif conv_index == '54':self.vgg = nn.Sequential(*modules[:35])vgg_mean = (0.485, 0.456, 0.406)vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)self.vgg.requires_grad = Falsedef forward(self, sr, hr):def _forward(x):x = self.sub_mean(x)x = self.vgg(x)return xvgg_sr = _forward(sr)with torch.no_grad():vgg_hr = _forward(hr.detach())loss = F.mse_loss(vgg_sr, vgg_hr)return loss
2.3.网络模型构建
dataloader.py
import sys
import threading
import queue
import random
import collectionsimport torch
import torch.multiprocessing as multiprocessingfrom torch._C import _set_worker_signal_handlers, _update_worker_pids, \_remove_worker_pids, _error_if_any_worker_fails
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import _DataLoaderIterfrom torch.utils.data.dataloader import ExceptionWrapper
from torch.utils.data.dataloader import _use_shared_memory
from torch.utils.data.dataloader import _worker_manager_loop
from torch.utils.data.dataloader import numpy_type_map
from torch.utils.data.dataloader import default_collate
from torch.utils.data.dataloader import pin_memory_batch
from torch.utils.data.dataloader import _SIGCHLD_handler_set
from torch.utils.data.dataloader import _set_SIGCHLD_handlerif sys.version_info[0] == 2:import Queue as queue
else:import queuedef _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):global _use_shared_memory_use_shared_memory = True_set_worker_signal_handlers()torch.set_num_threads(1)torch.manual_seed(seed)while True:r = index_queue.get()if r is None:breakidx, batch_indices = rtry:idx_scale = 0if len(scale) > 1 and dataset.train:idx_scale = random.randrange(0, len(scale))dataset.set_scale(idx_scale)samples = collate_fn([dataset[i] for i in batch_indices])samples.append(idx_scale)except Exception:data_queue.put((idx, ExceptionWrapper(sys.exc_info())))else:data_queue.put((idx, samples))class _MSDataLoaderIter(_DataLoaderIter):def __init__(self, loader):self.dataset = loader.datasetself.scale = loader.scaleself.collate_fn = loader.collate_fnself.batch_sampler = loader.batch_samplerself.num_workers = loader.num_workersself.pin_memory = loader.pin_memory and torch.cuda.is_available()self.timeout = loader.timeoutself.done_event = threading.Event()self.sample_iter = iter(self.batch_sampler)if self.num_workers > 0:self.worker_init_fn = loader.worker_init_fnself.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]self.worker_queue_idx = 0self.worker_result_queue = multiprocessing.SimpleQueue()self.batches_outstanding = 0self.worker_pids_set = Falseself.shutdown = Falseself.send_idx = 0self.rcvd_idx = 0self.reorder_dict = {}base_seed = torch.LongTensor(1).random_()[0]self.workers = [multiprocessing.Process(target=_ms_loop,args=(self.dataset,self.index_queues[i],self.worker_result_queue,self.collate_fn,self.scale,base_seed + i,self.worker_init_fn,i))for i in range(self.num_workers)]if self.pin_memory or self.timeout > 0:self.data_queue = queue.Queue()if self.pin_memory:maybe_device_id = torch.cuda.current_device()else:# do not initialize cuda context if not necessarymaybe_device_id = Noneself.worker_manager_thread = threading.Thread(target=_worker_manager_loop,args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,maybe_device_id))self.worker_manager_thread.daemon = Trueself.worker_manager_thread.start()else:self.data_queue = self.worker_result_queuefor w in self.workers:w.daemon = True # ensure that the worker exits on process exitw.start()_update_worker_pids(id(self), tuple(w.pid for w in self.workers))_set_SIGCHLD_handler()self.worker_pids_set = True# prime the prefetch loopfor _ in range(2 * self.num_workers):self._put_indices()class MSDataLoader(DataLoader):def __init__(self, args, dataset, batch_size=1, shuffle=False,sampler=None, batch_sampler=None,collate_fn=default_collate, pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None):super(MSDataLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle,sampler=sampler, batch_sampler=batch_sampler,num_workers=args.n_threads, collate_fn=collate_fn,pin_memory=pin_memory, drop_last=drop_last,timeout=timeout, worker_init_fn=worker_init_fn)self.scale = args.scaledef __iter__(self):return _MSDataLoaderIter(self)
main.py
import torchimport utility
import data
import model
import loss
from option import args
from trainer import Trainertorch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)if checkpoint.ok:loader = data.Data(args)model = model.Model(args, checkpoint)loss = loss.Loss(args, checkpoint) if not args.test_only else Nonet = Trainer(args, loader, model, loss, checkpoint)while not t.terminate():t.train()t.test()checkpoint.done()
option.py
import argparse
import templateparser = argparse.ArgumentParser(description='EDSR and MDSR')parser.add_argument('--debug', action='store_true',help='Enables debug mode')
parser.add_argument('--template', default='.',help='You can set various templates in option.py')# Hardware specifications
parser.add_argument('--n_threads', type=int, default=3,help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=1,help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,help='random seed')# Data specifications
parser.add_argument('--dir_data', type=str, default='/home/yulun/data/SR/traindata/DIV2K/bicubic',help='dataset directory')
parser.add_argument('--dir_demo', type=str, default='../test',help='demo image directory')
parser.add_argument('--data_train', type=str, default='DIV2K',help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K',help='test dataset name')
parser.add_argument('--benchmark_noise', action='store_true',help='use noisy benchmark sets')
parser.add_argument('--n_train', type=int, default=800,help='number of training set')
parser.add_argument('--n_val', type=int, default=5,help='number of validation set')
parser.add_argument('--offset_val', type=int, default=800,help='validation index offest')
parser.add_argument('--ext', type=str, default='sep_reset',help='dataset file extension')
parser.add_argument('--scale', default='4',help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=192,help='output patch size')
parser.add_argument('--rgb_range', type=int, default=255,help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,help='number of color channels to use')
parser.add_argument('--noise', type=str, default='.',help='Gaussian noise std.')
parser.add_argument('--chop', action='store_true',help='enable memory-efficient forward')# Model specifications
parser.add_argument('--model', default='RCAN',help='model name')parser.add_argument('--act', type=str, default='relu',help='activation function')
parser.add_argument('--pre_train', type=str, default='.',help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.',help='pre-trained model directory')
parser.add_argument('--n_resblocks', type=int, default=20,help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64,help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=1,help='residual scaling')
parser.add_argument('--shift_mean', default=True,help='subtract pixel mean from the input')
parser.add_argument('--precision', type=str, default='single',choices=('single', 'half'),help='FP precision for test (single | half)')# Training specifications
parser.add_argument('--reset', action='store_true',help='reset the training')
parser.add_argument('--test_every', type=int, default=1000,help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=1000,help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=16,help='input batch size for training')
parser.add_argument('--split_batch', type=int, default=1,help='split the batch into smaller chunks')
parser.add_argument('--self_ensemble', action='store_true',help='use self-ensemble method for test')
parser.add_argument('--test_only', action='store_true',help='set this option to test the model')
parser.add_argument('--gan_k', type=int, default=1,help='k value for adversarial loss')# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4,help='learning rate')
parser.add_argument('--lr_decay', type=int, default=200,help='learning rate decay per N epochs')
parser.add_argument('--decay_type', type=str, default='step',help='learning rate decay type')
parser.add_argument('--gamma', type=float, default=0.5,help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='ADAM',choices=('SGD', 'ADAM', 'RMSprop'),help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,help='SGD momentum')
parser.add_argument('--beta1', type=float, default=0.9,help='ADAM beta1')
parser.add_argument('--beta2', type=float, default=0.999,help='ADAM beta2')
parser.add_argument('--epsilon', type=float, default=1e-8,help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,help='weight decay')# Loss specifications
parser.add_argument('--loss', type=str, default='1*L1',help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e6',help='skipping batch that has large error')# Log specifications
parser.add_argument('--save', type=str, default='test',help='file name to save')
parser.add_argument('--load', type=str, default='.',help='file name to load')
parser.add_argument('--resume', type=int, default=0,help='resume from specific checkpoint')
parser.add_argument('--print_model', action='store_true',help='print model')
parser.add_argument('--save_models', action='store_true',help='save all intermediate models')
parser.add_argument('--print_every', type=int, default=100,help='how many batches to wait before logging training status')
parser.add_argument('--save_results', action='store_true',help='save output results')# options for residual group and feature channel reduction
parser.add_argument('--n_resgroups', type=int, default=10,help='number of residual groups')
parser.add_argument('--reduction', type=int, default=16,help='number of feature maps reduction')
# options for test
parser.add_argument('--testpath', type=str, default='../test/DIV2K_val_LR_our',help='dataset directory for testing')
parser.add_argument('--testset', type=str, default='Set5',help='dataset name for testing')args = parser.parse_args()
template.set_template(args)args.scale = list(map(lambda x: int(x), args.scale.split('+')))if args.epochs == 0:args.epochs = 1e8for arg in vars(args):if vars(args)[arg] == 'True':vars(args)[arg] = Trueelif vars(args)[arg] == 'False':vars(args)[arg] = False
template.py
def set_template(args):# Set the templates hereif args.template.find('jpeg') >= 0:args.data_train = 'DIV2K_jpeg'args.data_test = 'DIV2K_jpeg'args.epochs = 200args.lr_decay = 100if args.template.find('EDSR_paper') >= 0:args.model = 'EDSR'args.n_resblocks = 32args.n_feats = 256args.res_scale = 0.1if args.template.find('MDSR') >= 0:args.model = 'MDSR'args.patch_size = 48args.epochs = 650if args.template.find('DDBPN') >= 0:args.model = 'DDBPN'args.patch_size = 128args.scale = '4'args.data_test = 'Set5'args.batch_size = 20args.epochs = 1000args.lr_decay = 500args.gamma = 0.1args.weight_decay = 1e-4args.loss = '1*MSE'if args.template.find('GAN') >= 0:args.epochs = 200args.lr = 5e-5args.lr_decay = 150
utility.py
import os
import math
import time
import datetime
from functools import reduceimport matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as pltimport numpy as np
import scipy.misc as miscimport torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrsclass timer():def __init__(self):self.acc = 0self.tic()def tic(self):self.t0 = time.time()def toc(self):return time.time() - self.t0def hold(self):self.acc += self.toc()def release(self):ret = self.accself.acc = 0return retdef reset(self):self.acc = 0class checkpoint():def __init__(self, args):self.args = argsself.ok = Trueself.log = torch.Tensor()now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')if args.load == '.':if args.save == '.': args.save = nowself.dir = '../experiment/' + args.saveelse:self.dir = '../experiment/' + args.loadif not os.path.exists(self.dir):args.load = '.'else:self.log = torch.load(self.dir + '/psnr_log.pt')print('Continue from epoch {}...'.format(len(self.log)))if args.reset:os.system('rm -rf ' + self.dir)args.load = '.'def _make_dir(path):if not os.path.exists(path): os.makedirs(path)_make_dir(self.dir)_make_dir(self.dir + '/model')_make_dir(self.dir + '/results')open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'self.log_file = open(self.dir + '/log.txt', open_type)with open(self.dir + '/config.txt', open_type) as f:f.write(now + '\n\n')for arg in vars(args):f.write('{}: {}\n'.format(arg, getattr(args, arg)))f.write('\n')def save(self, trainer, epoch, is_best=False):trainer.model.save(self.dir, epoch, is_best=is_best)trainer.loss.save(self.dir)trainer.loss.plot_loss(self.dir, epoch)self.plot_psnr(epoch)torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))torch.save(trainer.optimizer.state_dict(),os.path.join(self.dir, 'optimizer.pt'))def add_log(self, log):self.log = torch.cat([self.log, log])def write_log(self, log, refresh=False):print(log)self.log_file.write(log + '\n')if refresh:self.log_file.close()self.log_file = open(self.dir + '/log.txt', 'a')def done(self):self.log_file.close()def plot_psnr(self, epoch):axis = np.linspace(1, epoch, epoch)label = 'SR on {}'.format(self.args.data_test)fig = plt.figure()plt.title(label)for idx_scale, scale in enumerate(self.args.scale):plt.plot(axis,self.log[:, idx_scale].numpy(),label='Scale {}'.format(scale))plt.legend()plt.xlabel('Epochs')plt.ylabel('PSNR')plt.grid(True)plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))plt.close(fig)def save_results(self, filename, save_list, scale):filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)postfix = ('SR', 'LR', 'HR')for v, p in zip(save_list, postfix):normalized = v[0].data.mul(255 / self.args.rgb_range)ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()misc.imsave('{}{}.png'.format(filename, p), ndarr)def quantize(img, rgb_range):pixel_range = 255 / rgb_rangereturn img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):diff = (sr - hr).data.div(rgb_range)shave = scaleif diff.size(1) > 1:convert = diff.new(1, 3, 1, 1)convert[0, 0, 0, 0] = 65.738convert[0, 1, 0, 0] = 129.057convert[0, 2, 0, 0] = 25.064diff.mul_(convert).div_(256)diff = diff.sum(dim=1, keepdim=True)'''if benchmark:shave = scaleif diff.size(1) > 1:convert = diff.new(1, 3, 1, 1)convert[0, 0, 0, 0] = 65.738convert[0, 1, 0, 0] = 129.057convert[0, 2, 0, 0] = 25.064diff.mul_(convert).div_(256)diff = diff.sum(dim=1, keepdim=True)else:shave = scale + 6'''valid = diff[:, :, shave:-shave, shave:-shave]mse = valid.pow(2).mean()return -10 * math.log10(mse)def make_optimizer(args, my_model):trainable = filter(lambda x: x.requires_grad, my_model.parameters())if args.optimizer == 'SGD':optimizer_function = optim.SGDkwargs = {'momentum': args.momentum}elif args.optimizer == 'ADAM':optimizer_function = optim.Adamkwargs = {'betas': (args.beta1, args.beta2),'eps': args.epsilon}elif args.optimizer == 'RMSprop':optimizer_function = optim.RMSpropkwargs = {'eps': args.epsilon}kwargs['lr'] = args.lrkwargs['weight_decay'] = args.weight_decayreturn optimizer_function(trainable, **kwargs)def make_scheduler(args, my_optimizer):if args.decay_type == 'step':scheduler = lrs.StepLR(my_optimizer,step_size=args.lr_decay,gamma=args.gamma)elif args.decay_type.find('step') >= 0:milestones = args.decay_type.split('_')milestones.pop(0)milestones = list(map(lambda x: int(x), milestones))scheduler = lrs.MultiStepLR(my_optimizer,milestones=milestones,gamma=args.gamma)return scheduler
trainer.py
import os
import math
from decimal import Decimalimport utilityimport torch
from torch.autograd import Variable
from tqdm import tqdmclass Trainer():def __init__(self, args, loader, my_model, my_loss, ckp):self.args = argsself.scale = args.scaleself.ckp = ckpself.loader_train = loader.loader_trainself.loader_test = loader.loader_testself.model = my_modelself.loss = my_lossself.optimizer = utility.make_optimizer(args, self.model)self.scheduler = utility.make_scheduler(args, self.optimizer)if self.args.load != '.':self.optimizer.load_state_dict(torch.load(os.path.join(ckp.dir, 'optimizer.pt')))for _ in range(len(ckp.log)): self.scheduler.step()self.error_last = 1e8def train(self):self.scheduler.step()self.loss.step()epoch = self.scheduler.last_epoch + 1lr = self.scheduler.get_lr()[0]self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)))self.loss.start_log()self.model.train()timer_data, timer_model = utility.timer(), utility.timer()for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):lr, hr = self.prepare([lr, hr])timer_data.hold()timer_model.tic()self.optimizer.zero_grad()sr = self.model(lr, idx_scale)loss = self.loss(sr, hr)if loss.item() < self.args.skip_threshold * self.error_last:loss.backward()self.optimizer.step()else:print('Skip this batch {}! (Loss: {})'.format(batch + 1, loss.item()))timer_model.hold()if (batch + 1) % self.args.print_every == 0:self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format((batch + 1) * self.args.batch_size,len(self.loader_train.dataset),self.loss.display_loss(batch),timer_model.release(),timer_data.release()))timer_data.tic()self.loss.end_log(len(self.loader_train))self.error_last = self.loss.log[-1, -1]def test(self):epoch = self.scheduler.last_epoch + 1self.ckp.write_log('\nEvaluation:')self.ckp.add_log(torch.zeros(1, len(self.scale)))self.model.eval()timer_test = utility.timer()with torch.no_grad():for idx_scale, scale in enumerate(self.scale):eval_acc = 0self.loader_test.dataset.set_scale(idx_scale)tqdm_test = tqdm(self.loader_test, ncols=80)for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test):filename = filename[0]no_eval = (hr.nelement() == 1)if not no_eval:lr, hr = self.prepare([lr, hr])else:lr = self.prepare([lr])[0]sr = self.model(lr, idx_scale)sr = utility.quantize(sr, self.args.rgb_range)save_list = [sr]if not no_eval:eval_acc += utility.calc_psnr(sr, hr, scale, self.args.rgb_range,benchmark=self.loader_test.dataset.benchmark)save_list.extend([lr, hr])if self.args.save_results:self.ckp.save_results(filename, save_list, scale)self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test)best = self.ckp.log.max(0)self.ckp.write_log('[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(self.args.data_test,scale,self.ckp.log[-1, idx_scale],best[0][idx_scale],best[1][idx_scale] + 1))self.ckp.write_log('Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True)if not self.args.test_only:self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))def prepare(self, l, volatile=False):device = torch.device('cpu' if self.args.cpu else 'cuda')def _prepare(tensor):if self.args.precision == 'half': tensor = tensor.half()return tensor.to(device)return [_prepare(_l) for _l in l]def terminate(self):if self.args.test_only:self.test()return Trueelse:epoch = self.scheduler.last_epoch + 1return epoch >= self.args.epochs
三.测试网络
利用模型将图片四倍放大的结果如下:
输入图片:
输出图片:
输入图片:
输出图片:
这篇关于超分辨率(3)--基于RCAN网络实现图像超分辨率重建的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!