计算psnr ssim niqe fid mae lpips等指标的代码

2024-04-10 21:28

本文主要是介绍计算psnr ssim niqe fid mae lpips等指标的代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • 以下代码仅供参考,路径处理最好自己改一下
# Author: Wu
# Created: 2023/8/15
# module containing metrics functions
# using package in https://github.com/chaofengc/IQA-PyTorch
import torch
from PIL import Image
import numpy as np
from piqa import PSNR, SSIM
import pyiqa
import argparse
import os
from collections import defaultdict
first = True
first2 = True
lpips_metric = None
niqe_metric = None
config = None
def read_img(img_path, ref_image=None):img = Image.open(img_path).convert('RGB')# resize gt to size of inputif ref_image is not None: w,h = img.size_,_, h_ref, w_ref = ref_image.shapeif w_ref!=w or h_ref!=h:img = img.resize((w_ref, h_ref), Image.ANTIALIAS)img = (np.asarray(img)/255.0)img = torch.from_numpy(img).float()img = img.permute(2,0,1)img = img.to(torch.device(f'cuda:{config.device}')).unsqueeze(0)return img.contiguous()def get_NIQE(enhanced_image, gt_path=None):niqe_metric = pyiqa.create_metric('niqe', device=enhanced_image.device).to(torch.device(f'cuda:{config.device}'))return  niqe_metric(enhanced_image)
def get_FID(enhanced_image_path, gt_path):fid_metric = pyiqa.create_metric('fid').to(torch.device(f'cuda:{config.device}'))score = fid_metric(enhanced_image_path, gt_path)return score
def get_psnr(enhanced_image, gt_path):gtimg = Image.open(gt_path).convert('RGB')gtimg = gtimg.resize((1200, 900), Image.ANTIALIAS)gtimg = (np.asarray(gtimg)/255.0)gtimg = torch.from_numpy(gtimg).float()gtimg = gtimg.permute(2,0,1)gtimg = gtimg.to(torch.device(f'cuda:{config.device}')).unsqueeze(0).contiguous()criterion = PSNR().to(torch.device(f'cuda:{config.device}'))return criterion(enhanced_image, gtimg).cpu().item()
def get_ssim(enhanced_image, gt_path):gtimg = Image.open(gt_path).convert('RGB')gtimg = gtimg.resize((1200, 900), Image.ANTIALIAS)gtimg = (np.asarray(gtimg)/255.0)gtimg = torch.from_numpy(gtimg).float()gtimg = gtimg.permute(2,0,1)gtimg = gtimg.to(torch.device(f'cuda:{config.device}')).unsqueeze(0).contiguous()criterion = SSIM().to(torch.device(f'cuda:{config.device}'))return criterion(enhanced_image, gtimg).cpu().item()
def get_lpips(enhanced_image, gt_path):gtimg = Image.open(gt_path).convert('RGB')gtimg = gtimg.resize((1200, 900), Image.ANTIALIAS)gtimg = (np.asarray(gtimg)/255.0)gtimg = torch.from_numpy(gtimg).float()gtimg = gtimg.permute(2,0,1)gtimg = gtimg.to(torch.device(f'cuda:{config.device}')).unsqueeze(0).contiguous()iqa_metric = pyiqa.create_metric('lpips', device=enhanced_image.device)return iqa_metric(enhanced_image, gtimg).cpu().item()
def get_MAE(enhanced_image, gt_path):gtimg = Image.open(gt_path).convert('RGB')gtimg = gtimg.resize((1200, 900), Image.ANTIALIAS)gtimg = (np.asarray(gtimg)/255.0)gtimg = torch.from_numpy(gtimg).float()gtimg = gtimg.permute(2,0,1)gtimg = gtimg.to(torch.device(f'cuda:{config.device}')).unsqueeze(0).contiguous()return torch.mean(torch.abs(enhanced_image-gtimg)).cpu().item()def get_metric(enhanced_image, gt_path, metrics):if gt_path is not None:gtimg = read_img(gt_path, enhanced_image)else:gtimg = Noneres = dict()if 'psnr' in metrics:psnr = PSNR().to(torch.device(f'cuda:{config.device}'))res['psnr'] = psnr(enhanced_image, gtimg).cpu().item()if 'ssim' in metrics:ssim = SSIM().to(torch.device(f'cuda:{config.device}'))res['ssim'] = ssim(enhanced_image, gtimg).cpu().item()if 'mae' in metrics:res['mae'] = torch.mean(torch.abs(enhanced_image-gtimg)).cpu().item()if 'niqe' in metrics:global first2global niqe_metricif first2:first2 = Falseniqe_metric = pyiqa.create_metric('niqe', device=enhanced_image.device)res['niqe'] = niqe_metric(enhanced_image).cpu().item()if 'lpips' in metrics:global firstglobal lpips_metricif first:first = Falselpips_metric = pyiqa.create_metric('lpips', device=enhanced_image.device)res['lpips'] = lpips_metric(enhanced_image, gtimg).cpu().item()return resdef get_metrics_dataset(pred_path, gt_path, dataset='lol'):if dataset == 'fivek':input_file_path_list = []gt_file_path_list = []file_list = os.listdir(os.path.join(gt_path))for filename in file_list:input_file_path_list.append(os.path.join(pred_path, filename))gt_file_path_list.append(os.path.join(gt_path,  filename))elif dataset == 'lol':input_file_path_list = []gt_file_path_list = []file_list = os.listdir(os.path.join(gt_path))for filename in file_list:input_file_path_list.append(os.path.join(pred_path, filename.replace('normal', 'low')))gt_file_path_list.append(os.path.join(gt_path,  filename))elif dataset == 'EE':input_file_path_list = []gt_file_path_list = []file_list = os.listdir(os.path.join(pred_path))for filename in file_list:input_file_path_list.append(os.path.join(pred_path, filename))suffix = filename.split('_')[-1]new_filename = filename[:-len(suffix)-1]+'.jpg'gt_file_path_list.append(os.path.join(gt_path,  new_filename))elif dataset == 'upair':input_file_path_list = []gt_file_path_list = []file_list = os.listdir(os.path.join(pred_path))for filename in file_list:input_file_path_list.append(os.path.join(pred_path, filename))gt_file_path_list.append(None)else:print(f'{dataset} not supported')exit()return input_file_path_list, gt_file_path_listif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--gt', type=str, default="/data1/wjh/LOL_v2/Real_captured/eval/gt")parser.add_argument('--pred', type=str, default="/data1/wjh/ECNet/baseline/gt_referenced/output")parser.add_argument('--dataset', type=str, default="lol")parser.add_argument('--device', type=str, default="0")parser.add_argument('--psnr', action='store_true')parser.add_argument('--ssim', action='store_true')parser.add_argument('--fid', action='store_true')parser.add_argument('--niqe', action='store_true')parser.add_argument('--lpips', action='store_true')parser.add_argument('--mae', action='store_true')config = parser.parse_args()print(config)gt_path = config.gtpred_path = config.pred# os.environ['CUDA_VISIBLE_DEVICES']=config.deviceassert os.path.exists(gt_path), 'gt_path not exits'assert os.path.exists(pred_path), 'pred_path not exits'metrics_names = []for metrics_name in ['psnr', 'ssim', 'niqe', 'lpips', 'mae']:if vars(config)[metrics_name]:metrics_names.append(metrics_name)# compute metricsmetrics_dict = defaultdict(list)metrics = dict()with torch.no_grad():# load img pathinput_file_paths,  gt_file_paths = get_metrics_dataset(pred_path, gt_path, config.dataset)# read img and compute metricsfor input_file_path, gt_file_path in zip(input_file_paths, gt_file_paths):# print(input_file_path)pred = read_img(input_file_path)metrics = get_metric(pred, gt_file_path, metrics_names)for metrics_name in metrics:metrics_dict[metrics_name].append(metrics[metrics_name])for metrics_name in metrics:print(f'{metrics_name}: {np.mean(metrics_dict[metrics_name])}')if config.fid:fid_score = get_FID(pred_path, gt_path)print(F'fid: {fid_score}')

这篇关于计算psnr ssim niqe fid mae lpips等指标的代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/892186

相关文章

Python并行处理实战之如何使用ProcessPoolExecutor加速计算

《Python并行处理实战之如何使用ProcessPoolExecutor加速计算》Python提供了多种并行处理的方式,其中concurrent.futures模块的ProcessPoolExecu... 目录简介完整代码示例代码解释1. 导入必要的模块2. 定义处理函数3. 主函数4. 生成数字列表5.

Python实例题之pygame开发打飞机游戏实例代码

《Python实例题之pygame开发打飞机游戏实例代码》对于python的学习者,能够写出一个飞机大战的程序代码,是不是感觉到非常的开心,:本文主要介绍Python实例题之pygame开发打飞机... 目录题目pygame-aircraft-game使用 Pygame 开发的打飞机游戏脚本代码解释初始化部

Java中Map.Entry()含义及方法使用代码

《Java中Map.Entry()含义及方法使用代码》:本文主要介绍Java中Map.Entry()含义及方法使用的相关资料,Map.Entry是Java中Map的静态内部接口,用于表示键值对,其... 目录前言 Map.Entry作用核心方法常见使用场景1. 遍历 Map 的所有键值对2. 直接修改 Ma

深入解析 Java Future 类及代码示例

《深入解析JavaFuture类及代码示例》JavaFuture是java.util.concurrent包中用于表示异步计算结果的核心接口,下面给大家介绍JavaFuture类及实例代码,感兴... 目录一、Future 类概述二、核心工作机制代码示例执行流程2. 状态机模型3. 核心方法解析行为总结:三

python获取cmd环境变量值的实现代码

《python获取cmd环境变量值的实现代码》:本文主要介绍在Python中获取命令行(cmd)环境变量的值,可以使用标准库中的os模块,需要的朋友可以参考下... 前言全局说明在执行py过程中,总要使用到系统环境变量一、说明1.1 环境:Windows 11 家庭版 24H2 26100.4061

pandas实现数据concat拼接的示例代码

《pandas实现数据concat拼接的示例代码》pandas.concat用于合并DataFrame或Series,本文主要介绍了pandas实现数据concat拼接的示例代码,具有一定的参考价值,... 目录语法示例:使用pandas.concat合并数据默认的concat:参数axis=0,join=

C#代码实现解析WTGPS和BD数据

《C#代码实现解析WTGPS和BD数据》在现代的导航与定位应用中,准确解析GPS和北斗(BD)等卫星定位数据至关重要,本文将使用C#语言实现解析WTGPS和BD数据,需要的可以了解下... 目录一、代码结构概览1. 核心解析方法2. 位置信息解析3. 经纬度转换方法4. 日期和时间戳解析5. 辅助方法二、L

Python使用Code2flow将代码转化为流程图的操作教程

《Python使用Code2flow将代码转化为流程图的操作教程》Code2flow是一款开源工具,能够将代码自动转换为流程图,该工具对于代码审查、调试和理解大型代码库非常有用,在这篇博客中,我们将深... 目录引言1nVflRA、为什么选择 Code2flow?2、安装 Code2flow3、基本功能演示

IIS 7.0 及更高版本中的 FTP 状态代码

《IIS7.0及更高版本中的FTP状态代码》本文介绍IIS7.0中的FTP状态代码,方便大家在使用iis中发现ftp的问题... 简介尝试使用 FTP 访问运行 Internet Information Services (IIS) 7.0 或更高版本的服务器上的内容时,IIS 将返回指示响应状态的数字代

MySQL 添加索引5种方式示例详解(实用sql代码)

《MySQL添加索引5种方式示例详解(实用sql代码)》在MySQL数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中,下面给大家分享MySQL添加索引5种方式示例详解(实用sql代码),... 在mysql数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中。索引可以在创建表时定义,也可