【DA-CLIP】test.py解读,调用DA-CLIP和IRSDE模型复原计算复原图与GT图SSIM、PSNR、LPIPS

2024-04-05 02:12

本文主要是介绍【DA-CLIP】test.py解读,调用DA-CLIP和IRSDE模型复原计算复原图与GT图SSIM、PSNR、LPIPS,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文件路径daclip-uir-main/universal-image-restoration/config/daclip-sde/test.py

代码有部分修改

导包

import argparse
import logging
import os.path
import sys
import time
from collections import OrderedDict
import torchvision.utils as tvutilsimport numpy as np
import torch
from IPython import embed
import lpipsimport options as option
from models import create_modelsys.path.insert(0, "../../")
import open_clip
import utils as util
from data import create_dataloader, create_dataset
from data.util import bgr2ycbcr

注意open_clip使用的是项目里的代码,而非环境里装的那个。data、util、option同样是项目里有的包

声明

#### options
parser = argparse.ArgumentParser()
parser.add_argument("-opt", type=str, default='options/test.yml', help="Path to options YMAL file.")
opt = option.parse(parser.parse_args().opt, is_train=False)opt = option.dict_to_nonedict(opt)

配置文件 

设置配置文件相对地址options/test.yml

在该配置文件中配置GT和LQ图像文件地址

datasets:test1:name: Testmode: LQGTdataroot_GT: C:\Users\86136\Desktop\LQ_test\shadow\GTdataroot_LQ: C:\Users\86136\Desktop\LQ_test\shadow\LQ

设置results_root结果地址,每次计算结束这个地址保存要求记录的计算结果

该目录下Test文件夹将保存一张GT一张LQ一张复原图像  。

不设置也会默认在项目内 daclip-uir-main\results\daclip-sde\universal-ir

#### path
path:pretrain_model_G: E:\daclip\pretrained\universal-ir.pthdaclip: E:\daclip\pretrained\daclip_ViT-B-32.ptresults_root: C:\Users\86136\Desktop\daclip-uir-main\results\daclip-sde\universal-irlog: 

 

#### mkdir and logger
util.mkdirs((pathfor key, path in opt["path"].items()if not key == "experiments_root"and "pretrain_model" not in keyand "resume" not in key)
)# os.system("rm ./result")
# os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result")

 报错执行代码没有删除再创建权限?我把相关os操作注释了,全部保存到result对我影响不大

加载创建数据对

#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt["datasets"].items()):test_set = create_dataset(dataset_opt)test_loader = create_dataloader(test_set, dataset_opt)logger.info("Number of test images in [{:s}]: {:d}".format(dataset_opt["name"], len(test_set)))test_loaders.append(test_loader)

 自定义包含复原IR-SDE模型的外层类model,参考app.py

# load pretrained model by default
model = create_model(opt)
device = model.device

 加载DA-CLIP、IR-SDE

# clip_model, _preprocess = clip.load("ViT-B/32", device=device)
if opt['path']['daclip'] is not None:clip_model, preprocess = open_clip.create_model_from_pretrained('daclip_ViT-B-32', pretrained=opt['path']['daclip'])
else:clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
clip_model = clip_model.to(device)

else是直接使用CLIP的ViT-B-32模型进行测试的代码。与我测DA-CLIP无关。

想使用的话 目测要预先下载对应模型权重并手动修改pretrained为文件地址,否则报错hf无法连接

sde = util.IRSDE(max_sigma=opt["sde"]["max_sigma"], T=opt["sde"]["T"], schedule=opt["sde"]["schedule"], eps=opt["sde"]["eps"], device=device)
sde.set_model(model.model)
lpips_fn = lpips.LPIPS(net='alex').to(device)scale = opt['degradation']['scale']

加载IR-SDE、LPIPS

如果不指定crop_border后续crop_border=scale

处理并计算


for test_loader in test_loaders:test_set_name = test_loader.dataset.opt["name"]  # path opt['']logger.info("\nTesting [{:s}]...".format(test_set_name))test_start_time = time.time()dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)util.mkdir(dataset_dir)test_results = OrderedDict()test_results["psnr"] = []test_results["ssim"] = []test_results["psnr_y"] = []test_results["ssim_y"] = []test_results["lpips"] = []test_times = []for i, test_data in enumerate(test_loader):single_img_psnr = []single_img_ssim = []single_img_psnr_y = []single_img_ssim_y = []need_GT = False if test_loader.dataset.opt["dataroot_GT"] is None else Trueimg_path = test_data["GT_path"][0] if need_GT else test_data["LQ_path"][0]img_name = os.path.splitext(os.path.basename(img_path))[0]#### input dataset_LQLQ, GT = test_data["LQ"], test_data["GT"]img4clip = test_data["LQ_clip"].to(device)with torch.no_grad(), torch.cuda.amp.autocast():image_context, degra_context = clip_model.encode_image(img4clip, control=True)image_context = image_context.float()degra_context = degra_context.float()noisy_state = sde.noise_state(LQ)model.feed_data(noisy_state, LQ, GT, text_context=degra_context, image_context=image_context)tic = time.time()model.test(sde, save_states=False)toc = time.time()test_times.append(toc - tic)visuals = model.get_current_visuals()SR_img = visuals["Output"]output = util.tensor2img(SR_img.squeeze())  # uint8LQ_ = util.tensor2img(visuals["Input"].squeeze())  # uint8GT_ = util.tensor2img(visuals["GT"].squeeze())  # uint8suffix = opt["suffix"]if suffix:save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png")else:save_img_path = os.path.join(dataset_dir, img_name + ".png")util.save_img(output, save_img_path)# remove it if you only want to save output imagesLQ_img_path = os.path.join(dataset_dir, img_name + "_LQ.png")GT_img_path = os.path.join(dataset_dir, img_name + "_HQ.png")util.save_img(LQ_, LQ_img_path)util.save_img(GT_, GT_img_path)if need_GT:gt_img = GT_ / 255.0sr_img = output / 255.0crop_border = opt["crop_border"] if opt["crop_border"] else scaleif crop_border == 0:cropped_sr_img = sr_imgcropped_gt_img = gt_imgelse:cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border]cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border]psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)lp_score = lpips_fn(GT.to(device) * 2 - 1, SR_img.to(device) * 2 - 1).squeeze().item()test_results["psnr"].append(psnr)test_results["ssim"].append(ssim)test_results["lpips"].append(lp_score)if len(gt_img.shape) == 3:if gt_img.shape[2] == 3:  # RGB imagesr_img_y = bgr2ycbcr(sr_img, only_y=True)gt_img_y = bgr2ycbcr(gt_img, only_y=True)if crop_border == 0:cropped_sr_img_y = sr_img_ycropped_gt_img_y = gt_img_yelse:cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)test_results["psnr_y"].append(psnr_y)test_results["ssim_y"].append(ssim_y)logger.info("img{:3d}:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}; LPIPS: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.".format(i, img_name, psnr, ssim, lp_score, psnr_y, ssim_y))else:logger.info("img:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format(img_name, psnr, ssim))test_results["psnr_y"].append(psnr)test_results["ssim_y"].append(ssim)else:logger.info(img_name)ave_lpips = sum(test_results["lpips"]) / len(test_results["lpips"])ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"])ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"])logger.info("----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n".format(test_set_name, ave_psnr, ave_ssim))if test_results["psnr_y"] and test_results["ssim_y"]:ave_psnr_y = sum(test_results["psnr_y"]) / len(test_results["psnr_y"])ave_ssim_y = sum(test_results["ssim_y"]) / len(test_results["ssim_y"])logger.info("----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n".format(ave_psnr_y, ave_ssim_y))logger.info("----average LPIPS\t: {:.6f}\n".format(ave_lpips))print(f"average test time: {np.mean(test_times):.4f}")

开头往log记录了相应配置文件内容,不需要可以注释。

遍历测试数据集(test_loaders)计算各种评价指标,如峰值信噪比(PSNR)、结构相似性(SSIM)和感知损失(LPIPS)。

在处理过程中,代码首先会创建一个目录来保存测试结果。

然后,对于每个测试图像,代码会加载对应的图像(如果可用),并使用一个名为clip_model的模型对图像进行编码。

接下来,代码会使用一个名为sde的随机微分方程模型和名为model的深度学习模型来处理带有噪声的图像,并生成复原图像(SR_img)。额可能作者拿了以前做超分的代码没改变量名

在这个过程中,text_contextimage_context被用作模型的输入,

图像都会被保存到之前创建的目录中。

此外,代码还会计算并记录每个图像的PSNR、SSIM和LPIPS分数,并在最后打印出这些分数的平均值。 代码中还包含了一些用于图像处理的实用函数,如util.tensor2img用于将张量转换为图像,util.save_img用于保存图像,以及util.calculate_psnrutil.calculate_ssim用于计算PSNR和SSIM分数。psnr_y和ssim_y 不用可以把相关代码注释。

最后,代码还计算了平均测试时间,并将其打印出来。

结果

log处理的单张图像报错的信息 0是该处理的图像排序序号,即正在处理第0张图

24-04-03 17:28:24.697 - INFO: img  0:_MG_2374_no_shadow - PSNR: 27.779773 dB; SSIM: 0.863140; LPIPS: 0.078669; PSNR_Y: 29.135256 dB; SSIM_Y: 0.869278.

 

可以给复原结果图加个后缀方便区分。

这篇关于【DA-CLIP】test.py解读,调用DA-CLIP和IRSDE模型复原计算复原图与GT图SSIM、PSNR、LPIPS的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/877402

相关文章

Linux jq命令的使用解读

《Linuxjq命令的使用解读》jq是一个强大的命令行工具,用于处理JSON数据,它可以用来查看、过滤、修改、格式化JSON数据,通过使用各种选项和过滤器,可以实现复杂的JSON处理任务... 目录一. 简介二. 选项2.1.2.2-c2.3-r2.4-R三. 字段提取3.1 普通字段3.2 数组字段四.

MySQL之搜索引擎使用解读

《MySQL之搜索引擎使用解读》MySQL存储引擎是数据存储和管理的核心组件,不同引擎(如InnoDB、MyISAM)采用不同机制,InnoDB支持事务与行锁,适合高并发场景;MyISAM不支持事务,... 目录mysql的存储引擎是什么MySQL存储引擎的功能MySQL的存储引擎的分类查看存储引擎1.命令

Spring的基础事务注解@Transactional作用解读

《Spring的基础事务注解@Transactional作用解读》文章介绍了Spring框架中的事务管理,核心注解@Transactional用于声明事务,支持传播机制、隔离级别等配置,结合@Tran... 目录一、事务管理基础1.1 Spring事务的核心注解1.2 注解属性详解1.3 实现原理二、事务事

MyBatis/MyBatis-Plus同事务循环调用存储过程获取主键重复问题分析及解决

《MyBatis/MyBatis-Plus同事务循环调用存储过程获取主键重复问题分析及解决》MyBatis默认开启一级缓存,同一事务中循环调用查询方法时会重复使用缓存数据,导致获取的序列主键值均为1,... 目录问题原因解决办法如果是存储过程总结问题myBATis有如下代码获取序列作为主键IdMappe

Linux五种IO模型的使用解读

《Linux五种IO模型的使用解读》文章系统解析了Linux的五种IO模型(阻塞、非阻塞、IO复用、信号驱动、异步),重点区分同步与异步IO的本质差异,强调同步由用户发起,异步由内核触发,通过对比各模... 目录1.IO模型简介2.五种IO模型2.1 IO模型分析方法2.2 阻塞IO2.3 非阻塞IO2.4

使用Go调用第三方API的方法详解

《使用Go调用第三方API的方法详解》在现代应用开发中,调用第三方API是非常常见的场景,比如获取天气预报、翻译文本、发送短信等,Go作为一门高效并发的编程语言,拥有强大的标准库和丰富的第三方库,可以... 目录引言一、准备工作二、案例1:调用天气查询 API1. 注册并获取 API Key2. 代码实现3

MySQL8.0临时表空间的使用及解读

《MySQL8.0临时表空间的使用及解读》MySQL8.0+引入会话级(temp_N.ibt)和全局(ibtmp1)InnoDB临时表空间,用于存储临时数据及事务日志,自动创建与回收,重启释放,管理高... 目录一、核心概念:为什么需要“临时表空间”?二、InnoDB 临时表空间的两种类型1. 会话级临时表

Python实现精确小数计算的完全指南

《Python实现精确小数计算的完全指南》在金融计算、科学实验和工程领域,浮点数精度问题一直是开发者面临的重大挑战,本文将深入解析Python精确小数计算技术体系,感兴趣的小伙伴可以了解一下... 目录引言:小数精度问题的核心挑战一、浮点数精度问题分析1.1 浮点数精度陷阱1.2 浮点数误差来源二、基础解决

Python文本相似度计算的方法大全

《Python文本相似度计算的方法大全》文本相似度是指两个文本在内容、结构或语义上的相近程度,通常用0到1之间的数值表示,0表示完全不同,1表示完全相同,本文将深入解析多种文本相似度计算方法,帮助您选... 目录前言什么是文本相似度?1. Levenshtein 距离(编辑距离)核心公式实现示例2. Jac

Java调用Python脚本实现HelloWorld的示例详解

《Java调用Python脚本实现HelloWorld的示例详解》作为程序员,我们经常会遇到需要在Java项目中调用Python脚本的场景,下面我们来看看如何从基础到进阶,一步步实现Java与Pyth... 目录一、环境准备二、基础调用:使用 Runtime.exec()2.1 实现步骤2.2 代码解析三、