YOLO的Anchor聚类代码

2024-05-01 15:32
文章标签 代码 yolo anchor 聚类

本文主要是介绍YOLO的Anchor聚类代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

代码来源于GiantPandaCV ,作者BBuf

 


#coding=utf-8import xml.etree.ElementTree as ET
import numpy as npdef iou(box, clusters):"""计算一个ground truth边界盒和k个先验框(Anchor)的交并比(IOU)值。参数box: 元组或者数据,代表ground truth的长宽。参数clusters: 形如(k,2)的numpy数组,其中k是聚类Anchor框的个数返回:ground truth和每个Anchor框的交并比。"""x = np.minimum(clusters[:, 0], box[0])y = np.minimum(clusters[:, 1], box[1])if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:raise ValueError("Box has no area")intersection = x * ybox_area = box[0] * box[1]cluster_area = clusters[:, 0] * clusters[:, 1]iou_ = intersection / (box_area + cluster_area - intersection)return iou_def avg_iou(boxes, clusters):"""计算一个ground truth和k个Anchor的交并比的均值。"""return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])])def kmeans(boxes, k, dist=np.median):"""利用IOU值进行K-means聚类参数boxes: 形状为(r, 2)的ground truth框,其中r是ground truth的个数参数k: Anchor的个数参数dist: 距离函数返回值:形状为(k, 2)的k个Anchor框"""# 即是上面提到的rrows = boxes.shape[0]# 距离数组,计算每个ground truth和k个Anchor的距离distances = np.empty((rows, k))# 上一次每个ground truth"距离"最近的Anchor索引last_clusters = np.zeros((rows,))# 设置随机数种子np.random.seed()# 初始化聚类中心,k个簇,从r个ground truth随机选k个clusters = boxes[np.random.choice(rows, k, replace=False)]# 开始聚类while True:# 计算每个ground truth和k个Anchor的距离,用1-IOU(box,anchor)来计算for row in range(rows):distances[row] = 1 - iou(boxes[row], clusters)# 对每个ground truth,选取距离最小的那个Anchor,并存下索引nearest_clusters = np.argmin(distances, axis=1)# 如果当前每个ground truth"距离"最近的Anchor索引和上一次一样,聚类结束if (last_clusters == nearest_clusters).all():break# 更新簇中心为簇里面所有的ground truth框的均值for cluster in range(k):clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)# 更新每个ground truth"距离"最近的Anchor索引last_clusters = nearest_clustersreturn clusters# 加载自己的数据集,只需要所有labelimg标注出来的xml文件即可
def load_dataset(path):dataset = []for xml_file in glob.glob("{}/*xml".format(path)):tree = ET.parse(xml_file)# 图片高度height = int(tree.findtext("./size/height"))# 图片宽度width = int(tree.findtext("./size/width"))for obj in tree.iter("object"):# 偏移量xmin = int(obj.findtext("bndbox/xmin")) / widthymin = int(obj.findtext("bndbox/ymin")) / heightxmax = int(obj.findtext("bndbox/xmax")) / widthymax = int(obj.findtext("bndbox/ymax")) / heightxmin = np.float64(xmin)ymin = np.float64(ymin)xmax = np.float64(xmax)ymax = np.float64(ymax)if xmax == xmin or ymax == ymin:print(xml_file)# 将Anchor的长宽放入dateset,运行kmeans获得Anchordataset.append([xmax - xmin, ymax - ymin])return np.array(dataset)if __name__ == '__main__':ANNOTATIONS_PATH = "F:\Annotations" #xml文件所在文件夹CLUSTERS = 9 #聚类数量,anchor数量INPUTDIM = 416 #输入网络大小data = load_dataset(ANNOTATIONS_PATH)out = kmeans(data, k=CLUSTERS)print('Boxes:')print(np.array(out)*INPUTDIM)print("Accuracy: {:.2f}%".format(avg_iou(data, out) * 100))final_anchors = np.around(out[:, 0] / out[:, 1], decimals=2).tolist()print("Before Sort Ratios:\n {}".format(final_anchors))print("After Sort Ratios:\n {}".format(sorted(final_anchors)))

 

这篇关于YOLO的Anchor聚类代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/952028

相关文章

Python实例题之pygame开发打飞机游戏实例代码

《Python实例题之pygame开发打飞机游戏实例代码》对于python的学习者,能够写出一个飞机大战的程序代码,是不是感觉到非常的开心,:本文主要介绍Python实例题之pygame开发打飞机... 目录题目pygame-aircraft-game使用 Pygame 开发的打飞机游戏脚本代码解释初始化部

Java中Map.Entry()含义及方法使用代码

《Java中Map.Entry()含义及方法使用代码》:本文主要介绍Java中Map.Entry()含义及方法使用的相关资料,Map.Entry是Java中Map的静态内部接口,用于表示键值对,其... 目录前言 Map.Entry作用核心方法常见使用场景1. 遍历 Map 的所有键值对2. 直接修改 Ma

深入解析 Java Future 类及代码示例

《深入解析JavaFuture类及代码示例》JavaFuture是java.util.concurrent包中用于表示异步计算结果的核心接口,下面给大家介绍JavaFuture类及实例代码,感兴... 目录一、Future 类概述二、核心工作机制代码示例执行流程2. 状态机模型3. 核心方法解析行为总结:三

python获取cmd环境变量值的实现代码

《python获取cmd环境变量值的实现代码》:本文主要介绍在Python中获取命令行(cmd)环境变量的值,可以使用标准库中的os模块,需要的朋友可以参考下... 前言全局说明在执行py过程中,总要使用到系统环境变量一、说明1.1 环境:Windows 11 家庭版 24H2 26100.4061

pandas实现数据concat拼接的示例代码

《pandas实现数据concat拼接的示例代码》pandas.concat用于合并DataFrame或Series,本文主要介绍了pandas实现数据concat拼接的示例代码,具有一定的参考价值,... 目录语法示例:使用pandas.concat合并数据默认的concat:参数axis=0,join=

C#代码实现解析WTGPS和BD数据

《C#代码实现解析WTGPS和BD数据》在现代的导航与定位应用中,准确解析GPS和北斗(BD)等卫星定位数据至关重要,本文将使用C#语言实现解析WTGPS和BD数据,需要的可以了解下... 目录一、代码结构概览1. 核心解析方法2. 位置信息解析3. 经纬度转换方法4. 日期和时间戳解析5. 辅助方法二、L

Python使用Code2flow将代码转化为流程图的操作教程

《Python使用Code2flow将代码转化为流程图的操作教程》Code2flow是一款开源工具,能够将代码自动转换为流程图,该工具对于代码审查、调试和理解大型代码库非常有用,在这篇博客中,我们将深... 目录引言1nVflRA、为什么选择 Code2flow?2、安装 Code2flow3、基本功能演示

IIS 7.0 及更高版本中的 FTP 状态代码

《IIS7.0及更高版本中的FTP状态代码》本文介绍IIS7.0中的FTP状态代码,方便大家在使用iis中发现ftp的问题... 简介尝试使用 FTP 访问运行 Internet Information Services (IIS) 7.0 或更高版本的服务器上的内容时,IIS 将返回指示响应状态的数字代

MySQL 添加索引5种方式示例详解(实用sql代码)

《MySQL添加索引5种方式示例详解(实用sql代码)》在MySQL数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中,下面给大家分享MySQL添加索引5种方式示例详解(实用sql代码),... 在mysql数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中。索引可以在创建表时定义,也可

使用C#删除Excel表格中的重复行数据的代码详解

《使用C#删除Excel表格中的重复行数据的代码详解》重复行是指在Excel表格中完全相同的多行数据,删除这些重复行至关重要,因为它们不仅会干扰数据分析,还可能导致错误的决策和结论,所以本文给大家介绍... 目录简介使用工具C# 删除Excel工作表中的重复行语法工作原理实现代码C# 删除指定Excel单元