[异常检测]Deep One-Class Classfication(Deep-SVDD) 论文阅读源码分析

本文主要是介绍[异常检测]Deep One-Class Classfication(Deep-SVDD) 论文阅读源码分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

论文链接:http://proceedings.mlr.press/v80/ruff18a

文章目录

    • 论文阅读
      • 摘要
      • 相关工作
      • 模型结构
      • 实验
    • 源码分析
      • 数据集制作
      • 训练

论文阅读

摘要

“尽管deep learning在很多machine learning任务中取得成功,相对较少的deep learning方法被用在异常检测任务中。一些原本被用来做其他任务的深度模型例如生成模型或压缩模型尝试被用在异常检测任务中,但是没有基于异常检测目标训练的网络”

在18年的时候,anomaly detection这个任务还没有什么深度学习的方法,摘要中提到基于异常检测的目标进行训练是整篇论文的核心,也为后面的DL-based AD任务提供了一些思路。

相关工作

相关工作部分论文提到了基于传统方法和深度方法两个方向

  • One-class SVM & support vector data description(SVDD)

    One-class SVM
    m i n 1 2 ∣ ∣ w ∣ ∣ F k 2 − ρ + 1 v n ∑ i = 1 n ξ i s . t . ⟨ w , ϕ ( x i ) ⟩ ≥ ρ − ξ i , ξ i ≥ 0 min \frac{1}{2}||w||^2_{F_k}- \rho + \frac{1}{vn}\sum_{i=1}^n\xi_i \\ s.t. \left \langle w,\phi(x_i) \right \rangle \geq \rho - \xi_i, \xi_i \geq 0 min21wFk2ρ+vn1i=1nξis.t.w,ϕ(xi)ρξi,ξi0
    SVDD
    m i n R 2 + 1 v n ∑ i ξ i s . t . ∣ ∣ ϕ ( x i ) − c ∣ ∣ 2 ≤ R 2 + ξ i min \ R^2 + \frac{1}{vn}\sum_i\xi_i \\ s.t. ||\phi(x_i) - c||^2 \leq R^2 + \xi_i min R2+vn1iξis.t.ϕ(xi)c2R2+ξi
    其中 ϕ ( x i ) \phi(x_i) ϕ(xi)都是特征映射,OC-SVM是找一个间隔超平面使得数据与原点距离最大;SVDD是找到一个超球面,使得正常数据全部落在该超球面内,具体可参考网上的分析。

  • Deep learning method(Autoencoder GAN)

    基于深度学习的方法包括利用深度网络进行特征提取再与传统异常检测算法结合的,也有纯深度学习方法的,包括利用自编码器进行特征压缩和样本重构,将重构误差作为异常分数。
    s c o r e = ∣ ∣ x − x ^ ∣ ∣ 2 score = ||x - \hat{x}||^2 score=xx^2
    也有利用GAN网络在latent space对测试样本进行分析,这里不再赘述。

    这些方法的弊端统一有两点:1. 不是以anomaly detection为目标的(都是一些其他任务的副产物) 2. 压缩的dimension不好确定,是否能找到一个很好的低维latent space是决定性因素。

模型结构

在这里插入图片描述

  • 对于图像级数据,利用CNN进行特征提取,进行 X − F \mathcal{X} - \mathcal{F} XF的映射,在 F \mathcal{F} F空间中,尽可能希望大量normal样本被聚集在一个hypersphere中,而anomaly 样本在该hypersphere外。(思想与SVDD一致),这里的hypersphere中由两个参数确定(c, R),具体调整在后面会详解。
  • 论文提出了两个objective: soft-boundary & one-class

1 soft-boundary Deep SVDD

objective function:
m i n R 2 + 1 ν n ∑ i = 1 n m a x { 0 , ∣ ∣ ϕ ( x i ; W ) − c ∣ ∣ 2 − R 2 } + λ 2 ∑ l = 1 L ∣ ∣ W l ∣ ∣ F 2 min \ R^2 + \frac{1}{{\nu} n}\sum_{i=1}^nmax \left \{ 0,||\phi(x_i;W)-c||^2 -R^2 \right \} + \frac{\lambda}{2}\sum_{l=1}^L||W^l||_F^2 min R2+νn1i=1nmax{0,ϕ(xi;W)c2R2}+2λl=1LWlF2
软边界Deep-SVDD的核心思想是通过假设训练数据中不全是normal数据(因此强迫所有样本聚于超球面内是不合适的,但是从源码来看,训练的时候还是用的one-class数据,不过这样的假设更软,更适合实际场景), 因此利用 ν ∈ ( 0 , 1 ] \nu \in (0,1] ν(0,1]在超球体体积和超出边界的程度进行一个trade-off control. (这个参数控制着允许超出球面的样本比例)。

2 one-class Deep SVDD

objective function:
m i n 1 n ∑ i = 1 n ∣ ∣ ϕ ( x i ; W ) − c ∣ ∣ 2 + λ 2 ∑ l = 1 L ∣ ∣ W l ∣ ∣ F 2 min \frac{1}{n}\sum_{i=1}^{n}||\phi(x_i;W) - c||^2 + \frac{\lambda}{2}\sum_{l=1}^L||W^l||_F^2 minn1i=1nϕ(xi;W)c2+2λl=1LWlF2

一类Deep SVDD是在一个较强的假设,训练数据大部分都为normal数据,也就是处在一个one-classification的情境下,不用优化参数R,超球体的体积通过函数第一项进行隐式的压缩。该损失相较于soft-boundry是一种更硬的目标,要求所有样本离超球体中心更近,因此这种方法一定要在训练集大部分为normal下进行,否则测试阶段的异常样本也会被认为正常(因为把所有数据一起压缩了)。

3 测试

定义异常分数
s ( x ) = ∣ ∣ ϕ ( x ; W ∗ ) − c ∣ ∣ 2 s(x) = ||\phi(x;W^*) - c||^2 s(x)=ϕ(x;W)c2
对于soft-boundary来说可以利用
∣ ∣ ϕ ( x ; W ∗ ) − c ∣ ∣ 2 − R > 0 ||\phi(x;W^*) - c||^2 - R > 0 ϕ(x;W)c2R>0
进行判断。

4 优化

Adam作为优化器优化网络权重。当用soft-boundary损失时,因为要优化 W W W R R R, 两个参数优化是不同的尺度,因此利用alternating minimization/block coordinate descent approach,先将R固定住,训练k epoch 的W参数,之后的每个epoch先更新W,再利用最新的W,训练R。

5 性质

论文提到了该模型的四个性质,前三个性质为可能产生平凡解的情况,最后一个为参数 ν \nu ν的性质。

球心c的选择

当所有权重为零,得到的网络输出 ϕ ( x n ; W ) = ϕ ( x m ; W ) = c 0 \phi(x_n;W) = \phi(x_m;W) = c_0 ϕ(xn;W)=ϕ(xm;W)=c0时,即当权重为零时,所有样本映射至一个常数,则此时的loss为零,即达到一个最优解,但该解是无意义的,因此必须避免这种情况。

偏置项

传统的深度网络中往往有偏置项,但是该模型中不能有偏置项。

如果某个hidden layer有bias,则该层输出为 z l ( x ) = σ l ( W l z l − 1 ( x ) + b l ) z^l(x) = \sigma^l(W^lz^{l-1}(x) + b^l) zl(x)=σl(Wlzl1(x)+bl),则当权重全为零时,所有样本的输出同样相同即 z l ( x i ) = z l ( x j ) = σ l ( b l ) z^l(x_i) = z^l(x_j) = \sigma^l(b^l) zl(xi)=zl(xj)=σl(bl),最后的结果跟1一样,同样会输出到同一个点,使得R = 0。因此避免该情况就需要所有层不包含bias项,包括batch normal层($ x = \frac{x-mean}{\sqrt{var}}*\gamma + \varepsilon $)。

激活函数的选择

论文提到如果选择的激活函数包含上下界,则也会陷入hypersphere collapse。例如sigmoid函数,如果一个神经元对应所有样本输入都为正,则模型会朝着将其他神经元输出置零,该神经元输出继续增大,直到接近上界。那么也就和前述一样,所有样本输出为常数。

针对上述三点,论文的解决方案为:采用将网络进行初始化后的所有样本输出均值作为球心;所有层不添加bias;激活函数选择单侧无界函数(论文里使用的是leaky relu)

v-property

该性质与soft-boundary中的 ν \nu ν相关,较为复杂,暂没看懂,后续会专门结合OC-SVM中的性质一起总结。

实验

论文在MNIST 和 CIFAR10上进行实验,每个实验分别选择一类为normal类,其他为anomaly类。比较的方法包括传统方法OS-SVM/ 核密度估计 KDE/ isolation forest,深度方法 DCAE/ AnoGAN。

在提出的方法的实验中,具体实现是先用DCAE进行预训练,将encode部分作为网络权重,learning_rate为两阶段分别是1e-4, 1e-5; R的学习每5个epoch利用line search进行。

实验结果:
在这里插入图片描述

源码分析

源码地址:https://github.com/lukasruff/Deep-SVDD-PyTorch

数据集制作

利用L1-norm进行 全局对比度标准化global contrast normalization, 计算图像中所有像素的均值和标准差,然后每个像素分别减去权值并除以标准差。

def global_contrast_normalization(x, scale = 'l1'):assert scale in ('l1', 'l2')n_features = int(np.prod(x.shape)) # 所有维度特征数量和 n * w * hmean = torch.mean(x)x -= meanif scale == 'l1':x_scale = torch.mean(torch.abs(x)) if scale == 'l2':x_scale = torch.sqrt(torch.sum( x**2 )) / n_features # 这个是对比度x /= x_scalereturn x

训练集选择只选择某一类

def get_target_label_index(labels, targets):return np.argwhere(np.isin(labels, targets)).flatten().tolist() # 返回某一类的索引

训练

论文特意定义了一个trainer类。正式训练之前,将初始化网络的输出均值作为目标函数中的c

def init_center_c(self, train_loader, net, eps = 0.1):n_samples = 0c = torch.zeros(net.rep_dim, device = self.device)net.eval()with torch.no_grad():for data in train_loader:inputs, _ = datainputs = inputs.to(self.device)outputs = net(inputs)n_samples += outputs.shape[0]c += torch.sum(outputs, dim = 0)c /= n_samplesc[(abs(c) < eps) & (c < 0)] = -eps # 为了避免初始化后输出过小(前面提到输出的常数值,用relu的话不能输出零)c[(abs(c) > eps) & (c > 0)] = epsreturn c

loss部分参考源码自己写了一个函数

def Compute_loss(objective, outputs, r, c, v):assert objective in ('soft-boundry', 'one-class')dist = torch.sum((outputs-c)**2, dim = 1) # 行求和 N * 1if objective == 'soft-boundry':scores = dist - r**2 # 也是测试用的异常分数loss = r ** 2 + (1/v) * torch.mean(torch.max(scores, torch.zeros_like(scores)))else:loss = torch.mean(dist)return dist, loss

R的更新。论文中写的是line search,代码中给出的方法是:

def Get_radius(dist, v): # 参数\nureturn np.quantile(np.sqrt(dist.clone().data.cpu().numpy()), 1-v)
if (self.objective == 'soft-boundary') and (epoch >= self.warm_up_n_epochs):self.R.data = torch.tensor(get_radius(dist, self.nu), device=self.device) # 一个batch中的距离计算出来的

这篇关于[异常检测]Deep One-Class Classfication(Deep-SVDD) 论文阅读源码分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/742995

相关文章

Python脚本轻松实现检测麦克风功能

《Python脚本轻松实现检测麦克风功能》在进行音频处理或开发需要使用麦克风的应用程序时,确保麦克风功能正常是非常重要的,本文将介绍一个简单的Python脚本,能够帮助我们检测本地麦克风的功能,需要的... 目录轻松检测麦克风功能脚本介绍一、python环境准备二、代码解析三、使用方法四、知识扩展轻松检测麦

Java异常捕获及处理方式详解

《Java异常捕获及处理方式详解》异常处理是Java编程中非常重要的一部分,它允许我们在程序运行时捕获并处理错误或不预期的行为,而不是让程序直接崩溃,本文将介绍Java中如何捕获异常,以及常用的异常处... 目录前言什么是异常?Java异常的基本语法解释:1. 捕获异常并处理示例1:捕获并处理单个异常解释:

Android 缓存日志Logcat导出与分析最佳实践

《Android缓存日志Logcat导出与分析最佳实践》本文全面介绍AndroidLogcat缓存日志的导出与分析方法,涵盖按进程、缓冲区类型及日志级别过滤,自动化工具使用,常见问题解决方案和最佳实... 目录android 缓存日志(Logcat)导出与分析全攻略为什么要导出缓存日志?按需过滤导出1. 按

Python自定义异常的全面指南(入门到实践)

《Python自定义异常的全面指南(入门到实践)》想象你正在开发一个银行系统,用户转账时余额不足,如果直接抛出ValueError,调用方很难区分是金额格式错误还是余额不足,这正是Python自定义异... 目录引言:为什么需要自定义异常一、异常基础:先搞懂python的异常体系1.1 异常是什么?1.2

Linux中的HTTPS协议原理分析

《Linux中的HTTPS协议原理分析》文章解释了HTTPS的必要性:HTTP明文传输易被篡改和劫持,HTTPS通过非对称加密协商对称密钥、CA证书认证和混合加密机制,有效防范中间人攻击,保障通信安全... 目录一、什么是加密和解密?二、为什么需要加密?三、常见的加密方式3.1 对称加密3.2非对称加密四、

MySQL中读写分离方案对比分析与选型建议

《MySQL中读写分离方案对比分析与选型建议》MySQL读写分离是提升数据库可用性和性能的常见手段,本文将围绕现实生产环境中常见的几种读写分离模式进行系统对比,希望对大家有所帮助... 目录一、问题背景介绍二、多种解决方案对比2.1 原生mysql主从复制2.2 Proxy层中间件:ProxySQL2.3

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

python panda库从基础到高级操作分析

《pythonpanda库从基础到高级操作分析》本文介绍了Pandas库的核心功能,包括处理结构化数据的Series和DataFrame数据结构,数据读取、清洗、分组聚合、合并、时间序列分析及大数据... 目录1. Pandas 概述2. 基本操作:数据读取与查看3. 索引操作:精准定位数据4. Group

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

MySQL 内存使用率常用分析语句

《MySQL内存使用率常用分析语句》用户整理了MySQL内存占用过高的分析方法,涵盖操作系统层确认及数据库层bufferpool、内存模块差值、线程状态、performance_schema性能数据... 目录一、 OS层二、 DB层1. 全局情况2. 内存占js用详情最近连续遇到mysql内存占用过高导致