缺陷检测:PatchCore的代码解读

2024-03-06 17:28

本文主要是介绍缺陷检测:PatchCore的代码解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 前言
    • 补充
  • 代码流程
  • Local Patch Features
  • Coreset Subsampling
  • Detection and Localization
  • 运行结果

前言

在这里插入图片描述

该文章发表在2022的CVPR上,用于缺陷检测,继承自SPADE,背后的关键原理为:测试样本与训练样本之间进行特征匹配,将不匹配的点识别出来。该文章探究了深度特征的多尺度性质。
作者的汇报视频:
Youtube1
Youtube2

PatchCore主要包含三个部分

  • 创建特征的内存库
  • 通过贪心策略减少内存块数据量
  • 使用该内存块检测异常

下面本人先介绍一下整个工程的流程,代码见patchcore-inspection,后面再逐一详细介绍三个部分的代码

补充

使用其他数据集训练PatchCore见缺陷检测:使用PatchCore训练自己的数据集

代码流程

输入参数如下

run_patchcore.py
--gpu 0 --seed 0
--save_patchcore_model
--save_segmentation_images
--log_group IM224_WR50_L2-3_P01_D1024-1024_PS-3_AN-1_S0
--log_project MVTecAD_Results results
patch_core
-b wideresnet50
-le layer2
-le layer3
--faiss_on_gpu
--pretrain_embed_dimension 1024
--target_embed_dimension 1024
--anomaly_scorer_num_nn 1
--patchsize 3
sampler
-p 0.1
approx_greedy_coreset 
dataset
--resize 256
--imagesize 224
-d wood
mvtec E:\datasets\mvtec

首先程序运行的是run_patchcore.pyrun(),先创建所需要的文件夹,再通过list_of_dataloaders = methods["get_dataloaders"](seed)来运行get_dataloaders()函数以获得训练数据和测试数据的dataloader。
上述是准备阶段,接着通过PatchCore.fit(dataloaders["training"])来将训练数据集,将图片转换为特征向量,再通过PatchCore.predict(dataloaders["testing"])对测试数据进行分数预测。接着就是是否保存预测结果和模型。

Local Patch Features

在这里插入图片描述

这一部分主要代码如下:

        def _image_to_features(input_image):with torch.no_grad():input_image = input_image.to(torch.float).to(self.device)return self._embed(input_image)

图片一开始会被reesize为(3, 224, 224),这个大小由MVTecDataset.imageszie这个成员来管理,经过_embed(input_image)将特征提取出来,layer2的输出为(2, 512, 28, 28),layer3的输出为(2, 1024, 14, 14)。再对layer2征层进行裁切得到(2, 512, 3, 3, 28, 28),如下图所示(注意下面的蓝格子和绿格子都代表3×3的特征块),这里的2表示batch_size。
在这里插入图片描述
而layer3的特征层要进一步进行插值(通过_features = F.interpolate( _features.unsqueeze(1), size=(ref_num_patches[0], ref_num_patches[1]), mode="bilinear", align_corners=False, )reshape)得到(2, 1024, 3, 3, 28, 28)的形状,如下图所示:
在这里插入图片描述
然后通过下面两句对特征进行自适应均匀池化,得到(1568, 1024)个特征值,所以一张图片得到的特征数量是784个。

features = self.forward_modules["preprocessing"](features)
features = self.forward_modules["preadapt_aggregator"](features)

在这里插入图片描述

遍历完所有的测试图片,得到的features为(193648, 1024)即(784×247, 1024)

Coreset Subsampling

在这里插入图片描述
主要代码features = self.featuresampler.run(features)

    def run(self, features: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]:"""Subsamples features using Greedy Coreset.Args:features: [N x D]"""if self.percentage == 1:return featuresself._store_type(features)if isinstance(features, np.ndarray):features = torch.from_numpy(features)reduced_features = self._reduce_features(features) # 经过一个全连接层sample_indices = self._compute_greedy_coreset_indices(reduced_features) # 通过贪心策略减少数据量features = features[sample_indices]return self._restore_type(features)

经过一个全连接层将通道数减少至128,然后就是贪心策略的采样,具体如下:
初始阶段,随机选取10个index,通过如下代码计算所有特征点关于这十个点的欧式距离,得到集合D(193648, 10),接着取均值得到平均距离集合d1(193648, 1),选取值最大的那个值作为Mc的一个点x,接着计算x关于所有点的欧式距离d2(193648, 1),将d1和d2进行对应位运算,保留较小的值,从而得到d3,最后取d3中最大的值加入到Mc,并成为新的x,一直循环193648*0.1次。

    def _compute_batchwise_differences(matrix_a: torch.Tensor, matrix_b: torch.Tensor) -> torch.Tensor:"""Computes batchwise Euclidean distances using PyTorch."""a_times_a = matrix_a.unsqueeze(1).bmm(matrix_a.unsqueeze(2)).reshape(-1, 1)b_times_b = matrix_b.unsqueeze(1).bmm(matrix_b.unsqueeze(2)).reshape(1, -1)a_times_b = matrix_a.mm(matrix_b.T)return (-2 * a_times_b + a_times_a + b_times_b).clamp(0, None).sqrt()

Detection and Localization

在这里插入图片描述
这里主要是使用faiss.GpuIndexFlatL2,具体原理本人也不太清楚。预测阶段主要代码如下

    def _predict_dataloader(self, dataloader):"""This function provides anomaly scores/maps for full dataloaders."""_ = self.forward_modules.eval()scores = []masks = []labels_gt = []masks_gt = []with tqdm.tqdm(dataloader, desc="Inferring...", leave=False) as data_iterator:for image in data_iterator:if isinstance(image, dict):labels_gt.extend(image["is_anomaly"].numpy().tolist())masks_gt.extend(image["mask"].numpy().tolist())image = image["image"]_scores, _masks = self._predict(image)for score, mask in zip(_scores, _masks):scores.append(score)masks.append(mask)return scores, masks, labels_gt, masks_gt

运行结果

运行结果会保存在patchcore-inspection\bin\results\MVTecAD_Results下,结果如下图所示(大家应该图片都是256*256,我这里稍微修改了一点代码,以原尺寸大小输出图片)
在这里插入图片描述
在这里插入图片描述

这篇关于缺陷检测:PatchCore的代码解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/780799

相关文章

Linux系统性能检测命令详解

《Linux系统性能检测命令详解》本文介绍了Linux系统常用的监控命令(如top、vmstat、iostat、htop等)及其参数功能,涵盖进程状态、内存使用、磁盘I/O、系统负载等多维度资源监控,... 目录toppsuptimevmstatIOStatiotopslabtophtopdstatnmon

解读GC日志中的各项指标用法

《解读GC日志中的各项指标用法》:本文主要介绍GC日志中的各项指标用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、基础 GC 日志格式(以 G1 为例)1. Minor GC 日志2. Full GC 日志二、关键指标解析1. GC 类型与触发原因2. 堆

Java设计模式---迭代器模式(Iterator)解读

《Java设计模式---迭代器模式(Iterator)解读》:本文主要介绍Java设计模式---迭代器模式(Iterator),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录1、迭代器(Iterator)1.1、结构1.2、常用方法1.3、本质1、解耦集合与遍历逻辑2、统一

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

MySQL之InnoDB存储页的独立表空间解读

《MySQL之InnoDB存储页的独立表空间解读》:本文主要介绍MySQL之InnoDB存储页的独立表空间,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、独立表空间【1】表空间大小【2】区【3】组【4】段【5】区的类型【6】XDES Entry区结构【

MySQL数据库的内嵌函数和联合查询实例代码

《MySQL数据库的内嵌函数和联合查询实例代码》联合查询是一种将多个查询结果组合在一起的方法,通常使用UNION、UNIONALL、INTERSECT和EXCEPT关键字,下面:本文主要介绍MyS... 目录一.数据库的内嵌函数1.1聚合函数COUNT([DISTINCT] expr)SUM([DISTIN

Java实现自定义table宽高的示例代码

《Java实现自定义table宽高的示例代码》在桌面应用、管理系统乃至报表工具中,表格(JTable)作为最常用的数据展示组件,不仅承载对数据的增删改查,还需要配合布局与视觉需求,而JavaSwing... 目录一、项目背景详细介绍二、项目需求详细介绍三、相关技术详细介绍四、实现思路详细介绍五、完整实现代码

Go语言代码格式化的技巧分享

《Go语言代码格式化的技巧分享》在Go语言的开发过程中,代码格式化是一个看似细微却至关重要的环节,良好的代码格式化不仅能提升代码的可读性,还能促进团队协作,减少因代码风格差异引发的问题,Go在代码格式... 目录一、Go 语言代码格式化的重要性二、Go 语言代码格式化工具:gofmt 与 go fmt(一)

C++ 检测文件大小和文件传输的方法示例详解

《C++检测文件大小和文件传输的方法示例详解》文章介绍了在C/C++中获取文件大小的三种方法,推荐使用stat()函数,并详细说明了如何设计一次性发送压缩包的结构体及传输流程,包含CRC校验和自动解... 目录检测文件的大小✅ 方法一:使用 stat() 函数(推荐)✅ 用法示例:✅ 方法二:使用 fsee