YOLOv9训练损失、精度、mAP绘图功能 | 支持多结果对比,多结果绘在一个图片(消融实验、科研必备)

本文主要是介绍YOLOv9训练损失、精度、mAP绘图功能 | 支持多结果对比,多结果绘在一个图片(消融实验、科研必备),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、本文介绍

本文给大家带来的是YOLOv9系列的绘图功能,我将向大家介绍YOLO系列的绘图功能。我们在进行实验时,经常需要比较多个结果,针对这一问题,我写了点代码来解决这个问题,它可以根据训练结果绘制损失(loss)和mAP(平均精度均值)的对比图。这个工具不仅支持多个文件的对比分析,还允许大家在现有代码的基础上进行修,从而达到数据可视化的功能,大家也可以将对比图放在论文中进行对比也是非常不错的选择。

先展示一下效果图-> 

专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

损失对比图片->

目录

一、本文介绍

二、绘图工具核心代码 

三、使用讲解 

四、本文总结


二、绘图工具核心代码 

import os
import pandas as pd
import matplotlib.pyplot as pltdef plot_metrics_and_loss(experiment_names, metrics_info, loss_info, metrics_subplot_layout, loss_subplot_layout,metrics_figure_size=(15, 10), loss_figure_size=(15, 10), base_directory='runs/train'):# Plot metricsplt.figure(figsize=metrics_figure_size)for i, (metric_name, title) in enumerate(metrics_info):plt.subplot(*metrics_subplot_layout, i + 1)for name in experiment_names:file_path = os.path.join(base_directory, name, 'results.csv')data = pd.read_csv(file_path)column_name = [col for col in data.columns if col.strip() == metric_name][0]plt.plot(data[column_name], label=name)plt.xlabel('Epoch')plt.title(title)plt.legend()plt.tight_layout()metrics_filename = 'metrics_curves.png'plt.savefig(metrics_filename)plt.show()# Plot lossplt.figure(figsize=loss_figure_size)for i, (loss_name, title) in enumerate(loss_info):plt.subplot(*loss_subplot_layout, i + 1)for name in experiment_names:file_path = os.path.join(base_directory, name, 'results.csv')data = pd.read_csv(file_path)column_name = [col for col in data.columns if col.strip() == loss_name][0]plt.plot(data[column_name], label=name)plt.xlabel('Epoch')plt.title(title)plt.legend()plt.tight_layout()loss_filename = 'loss_curves.png'plt.savefig(loss_filename)plt.show()return metrics_filename, loss_filename# Metrics to plot
metrics_info = [('metrics/precision', 'Precision'),('metrics/recall', 'Recall'),('metrics/mAP_0.5', 'mAP at IoU=0.5'),('metrics/mAP_0.5:0.95', 'mAP for IoU Range 0.5-0.95')
]# Loss to plot
loss_info = [('train/box_loss', 'Training Box Loss'),('train/cls_loss', 'Training Classification Loss'),('train/obj_loss', 'Training OBJ Loss'),('val/box_loss', 'Validation Box Loss'),('val/cls_loss', 'Validation Classification Loss'),('val/obj_loss', 'Validation obj Loss')
]# Plot the metrics and loss from multiple experiments
metrics_filename, loss_filename = plot_metrics_and_loss(experiment_names=['exp40', 'exp38'],metrics_info=metrics_info,loss_info=loss_info,metrics_subplot_layout=(2, 2),loss_subplot_layout=(2, 3)
)


三、使用讲解 

使用方式非常简单,我们首先创建一个文件,将核心代码粘贴进去,其中experiment_names这个参数就代表我们的每个训练结果的名字, 我们只需要修改这个即可,我这里就是五个结果进行对比,修改完成之后大家运行该文件即可。

五、热力图代码 

使用方式我会单独更一篇,这个热力图代码的进阶版,这里只是先放一下。 

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch, yaml, cv2, os, shutil
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
from tqdm import trange
from PIL import Image
from ultralytics.nn.tasks import DetectionModel as Model
from ultralytics.utils.torch_utils import intersect_dicts
from ultralytics.utils.ops import xywh2xyxy
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradientsdef letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):# Resize and pad image while meeting stride-multiple constraintsshape = im.shape[:2]  # current shape [height, width]if isinstance(new_shape, int):new_shape = (new_shape, new_shape)# Scale ratio (new / old)r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])if not scaleup:  # only scale down, do not scale up (for better val mAP)r = min(r, 1.0)# Compute paddingratio = r, r  # width, height ratiosnew_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh paddingif auto:  # minimum rectangledw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh paddingelif scaleFill:  # stretchdw, dh = 0.0, 0.0new_unpad = (new_shape[1], new_shape[0])ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratiosdw /= 2  # divide padding into 2 sidesdh /= 2if shape[::-1] != new_unpad:  # resizeim = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))left, right = int(round(dw - 0.1)), int(round(dw + 0.1))im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add borderreturn im, ratio, (dw, dh)class yolov8_heatmap:def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):device = torch.device(device)ckpt = torch.load(weight)model_names = ckpt['model'].namescsd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32model = Model(cfg, ch=3, nc=len(model_names)).to(device)csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor'])  # intersectmodel.load_state_dict(csd, strict=False)  # loadmodel.eval()print(f'Transferred {len(csd)}/{len(model.state_dict())} items')target_layers = [eval(layer)]method = eval(method)colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int)self.__dict__.update(locals())def post_process(self, result):logits_ = result[:, 4:]boxes_ = result[:, :4]sorted, indices = torch.sort(logits_.max(1)[0], descending=True)return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy()def draw_detections(self, box, color, name, img):xmin, ymin, xmax, ymax = list(map(int, list(box)))cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)return imgdef __call__(self, img_path, save_path):# remove dir if existif os.path.exists(save_path):shutil.rmtree(save_path)# make dir if not existos.makedirs(save_path, exist_ok=True)# img processimg = cv2.imread(img_path)img = letterbox(img)[0]img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = np.float32(img) / 255.0tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)# init ActivationsAndGradientsgrads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)# get ActivationsAndResultresult = grads(tensor)activations = grads.activations[0].cpu().detach().numpy()# postprocess to yolo outputpost_result, pre_post_boxes, post_boxes = self.post_process(result[0])for i in trange(int(post_result.size(0) * self.ratio)):if float(post_result[i].max()) < self.conf_threshold:breakself.model.zero_grad()# get max probability for this predictionif self.backward_type == 'class' or self.backward_type == 'all':score = post_result[i].max()score.backward(retain_graph=True)if self.backward_type == 'box' or self.backward_type == 'all':for j in range(4):score = pre_post_boxes[i, j]score.backward(retain_graph=True)# process heatmapif self.backward_type == 'class':gradients = grads.gradients[0]elif self.backward_type == 'box':gradients = grads.gradients[0] + grads.gradients[1] + grads.gradients[2] + grads.gradients[3]else:gradients = grads.gradients[0] + grads.gradients[1] + grads.gradients[2] + grads.gradients[3] + grads.gradients[4]b, k, u, v = gradients.size()weights = self.method.get_cam_weights(self.method, None, None, None, activations, gradients.detach().numpy())weights = weights.reshape((b, k, 1, 1))saliency_map = np.sum(weights * activations, axis=1)saliency_map = np.squeeze(np.maximum(saliency_map, 0))saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()if (saliency_map_max - saliency_map_min) == 0:continuesaliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)# add heatmap and box to imagecam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)cam_image = self.draw_detections(post_boxes[i], self.colors[int(post_result[i, :].argmax())], f'{self.model_names[int(post_result[i, :].argmax())]} {float(post_result[i].max()):.2f}', cam_image)cam_image = Image.fromarray(cam_image)cam_image.save(f'{save_path}/{i}.png')def get_params():params = {'weight': 'yolov8n.pt','cfg': 'ultralytics/cfg/models/v8/yolov8n.yaml','device': 'cuda:0','method': 'GradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM'layer': 'model.model[9]','backward_type': 'all', # class, box, all'conf_threshold': 0.6, # 0.6'ratio': 0.02 # 0.02-0.1}return paramsif __name__ == '__main__':model = yolov8_heatmap(**get_params())model(r'ultralytics/assets/bus.jpg', 'result')


四、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv9改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

这篇关于YOLOv9训练损失、精度、mAP绘图功能 | 支持多结果对比,多结果绘在一个图片(消融实验、科研必备)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/932975

相关文章

Java使用Thumbnailator库实现图片处理与压缩功能

《Java使用Thumbnailator库实现图片处理与压缩功能》Thumbnailator是高性能Java图像处理库,支持缩放、旋转、水印添加、裁剪及格式转换,提供易用API和性能优化,适合Web应... 目录1. 图片处理库Thumbnailator介绍2. 基本和指定大小图片缩放功能2.1 图片缩放的

深度解析Spring Security 中的 SecurityFilterChain核心功能

《深度解析SpringSecurity中的SecurityFilterChain核心功能》SecurityFilterChain通过组件化配置、类型安全路径匹配、多链协同三大特性,重构了Spri... 目录Spring Security 中的SecurityFilterChain深度解析一、Security

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

详解MySQL中JSON数据类型用法及与传统JSON字符串对比

《详解MySQL中JSON数据类型用法及与传统JSON字符串对比》MySQL从5.7版本开始引入了JSON数据类型,专门用于存储JSON格式的数据,本文将为大家简单介绍一下MySQL中JSON数据类型... 目录前言基本用法jsON数据类型 vs 传统JSON字符串1. 存储方式2. 查询方式对比3. 索引

Java实现预览与打印功能详解

《Java实现预览与打印功能详解》在Java中,打印功能主要依赖java.awt.print包,该包提供了与打印相关的一些关键类,比如PrinterJob和PageFormat,它们构成... 目录Java 打印系统概述打印预览与设置使用 PageFormat 和 PrinterJob 类设置页面格式与纸张

MySQL 8 中的一个强大功能 JSON_TABLE示例详解

《MySQL8中的一个强大功能JSON_TABLE示例详解》JSON_TABLE是MySQL8中引入的一个强大功能,它允许用户将JSON数据转换为关系表格式,从而可以更方便地在SQL查询中处理J... 目录基本语法示例示例查询解释应用场景不适用场景1. ‌jsON 数据结构过于复杂或动态变化‌2. ‌性能要

Kotlin Map映射转换问题小结

《KotlinMap映射转换问题小结》文章介绍了Kotlin集合转换的多种方法,包括map(一对一转换)、mapIndexed(带索引)、mapNotNull(过滤null)、mapKeys/map... 目录Kotlin 集合转换:map、mapIndexed、mapNotNull、mapKeys、map

SpringBoot中六种批量更新Mysql的方式效率对比分析

《SpringBoot中六种批量更新Mysql的方式效率对比分析》文章比较了MySQL大数据量批量更新的多种方法,指出REPLACEINTO和ONDUPLICATEKEY效率最高但存在数据风险,MyB... 目录效率比较测试结构数据库初始化测试数据批量修改方案第一种 for第二种 case when第三种

Qt使用QSqlDatabase连接MySQL实现增删改查功能

《Qt使用QSqlDatabase连接MySQL实现增删改查功能》这篇文章主要为大家详细介绍了Qt如何使用QSqlDatabase连接MySQL实现增删改查功能,文中的示例代码讲解详细,感兴趣的小伙伴... 目录一、创建数据表二、连接mysql数据库三、封装成一个完整的轻量级 ORM 风格类3.1 表结构

基于Python实现一个图片拆分工具

《基于Python实现一个图片拆分工具》这篇文章主要为大家详细介绍了如何基于Python实现一个图片拆分工具,可以根据需要的行数和列数进行拆分,感兴趣的小伙伴可以跟随小编一起学习一下... 简单介绍先自己选择输入的图片,默认是输出到项目文件夹中,可以自己选择其他的文件夹,选择需要拆分的行数和列数,可以通过