FGSM(Fast Gradient Sign Method)算法源码解析

2023-10-29 00:36

本文主要是介绍FGSM(Fast Gradient Sign Method)算法源码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

论文链接:https://arxiv.org/abs/1412.6572
源码出处:https://github.com/Harry24k/adversarial-attacks-pytorch/tree/master


源码

import torch
import torch.nn as nnfrom ..attack import Attackclass FGSM(Attack):r"""FGSM in the paper 'Explaining and harnessing adversarial examples'[https://arxiv.org/abs/1412.6572]Distance Measure : LinfArguments:model (nn.Module): model to attack.eps (float): maximum perturbation. (Default: 8/255)Shape:- images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].- labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.- output: :math:`(N, C, H, W)`.Examples::>>> attack = torchattacks.FGSM(model, eps=8/255)>>> adv_images = attack(images, labels)"""def __init__(self, model, eps=8/255):super().__init__("FGSM", model)self.eps = epsself.supported_mode = ['default', 'targeted']def forward(self, images, labels):r"""Overridden."""self._check_inputs(images)images = images.clone().detach().to(self.device)labels = labels.clone().detach().to(self.device)if self.targeted:target_labels = self.get_target_label(images, labels)loss = nn.CrossEntropyLoss()images.requires_grad = Trueoutputs = self.get_logits(images)# Calculate lossif self.targeted:cost = -loss(outputs, target_labels)else:cost = loss(outputs, labels)# Update adversarial imagesgrad = torch.autograd.grad(cost, images,retain_graph=False, create_graph=False)[0]adv_images = images + self.eps*grad.sign()adv_images = torch.clamp(adv_images, min=0, max=1).detach()return adv_images

解析

FGSM的全称是Fast Gradient Sign Method(快速梯度下降法),在白盒环境下,通过求出损失cost对输入的导数,然后用符号函数sign()得到其具体的梯度方向,接着乘以一个步长eps,得到的“扰动”加在原来的输入 上就得到了在FGSM攻击下的样本。
可以仔细回忆一下,在神经网络的反向传播当中,我们在训练过程时就是沿着梯度下降的方向来更新更新 w , b w,b w,b的值。这样做可以使得网络往损失cost减小的方向收敛。简单来说,梯度方向代表了损失cost增大速度最快的方向,FGSM算法假设目标损失函数 J ( x , y ) J(x,y) J(x,y) x x x之间是近似线性的,即 J ( x , y ) ≈ w T x J(x ,y)≈w^Tx J(x,y)wTx,所以沿着梯度方向改变输入 x x x可以增大损失,从而达到使模型分类错误的目的。具体做法是在图像上加一个扰动 η \eta η η = ϵ s i g n ( ▽ x J ( θ , x , y ) ) \eta= \epsilon sign(\bigtriangledown_{x}J(\theta,x,y)) η=ϵsign(xJ(θ,x,y)),其中 ▽ x \bigtriangledown_{x} x即梯度, ϵ \epsilon ϵ即步长,也就是每个像素扰动的最大值。

forward()函数就是攻击过程,输入图像images和标签y,即可返回对抗图像adv_images
images = images.clone().detach().to(self.device)clone()将图像克隆到一块新的内存区(pytorch默认同样的tensor共享一块内存区);detach()是将克隆的新的tensor从当前计算图中分离下来,作为叶节点,从而可以计算其梯度;to()作用就是将其载入设备。
target_labels = self.get_target_label(images, labels):是有目标攻击的情况,由于该论文并没有探讨有目标攻击,这里就先不做解释。
loss = nn.CrossEntropyLoss():设置损失函数为交叉熵损失。
images.requires_grad = True:将这个参数设置为True,pytorch就会在程序运行过程中自动生成计算图,供计算梯度使用。
outputs = self.get_logits(images):获得图像的在模型中的输出值。
cost = loss(outputs, labels):计算损失
grad = torch.autograd.grad(cost, images, retain_graph=False, create_graph=False)[0]costimages求导,得到梯度grad
adv_images = images + self.eps*grad.sign():根据公式在原图像上增加一个扰动,得到对抗图像。
adv_images = torch.clamp(adv_images, min=0, max=1).detach():将images中大于1的部分设为1,小于0的部分设为0,防止越界。

思考

FGSM算法假设目标损失函数 J ( x , y ) J(x,y) J(x,y) x x x之间是近似线性的,但是这个线性假设不一定正确,如果J JJ和x xx不是线性的,那么在 ( 0 , ϵ s i g n ( ▽ x J ( θ , x , y ) ) ) (0,\epsilon sign(\bigtriangledown_{x}J(\theta,x,y))) (0,ϵsign(xJ(θ,x,y)))之间是否存在某个扰动,使得 J J J增加的也很大,此时 x x x的修改量就可以小于 ϵ \epsilon ϵ。于是,有学者就提出迭代的方式来找各个像素点的扰动,也就是BIM算法。

这篇关于FGSM(Fast Gradient Sign Method)算法源码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/297045

相关文章

Spring组件实例化扩展点之InstantiationAwareBeanPostProcessor使用场景解析

《Spring组件实例化扩展点之InstantiationAwareBeanPostProcessor使用场景解析》InstantiationAwareBeanPostProcessor是Spring... 目录一、什么是InstantiationAwareBeanPostProcessor?二、核心方法解

深入解析 Java Future 类及代码示例

《深入解析JavaFuture类及代码示例》JavaFuture是java.util.concurrent包中用于表示异步计算结果的核心接口,下面给大家介绍JavaFuture类及实例代码,感兴... 目录一、Future 类概述二、核心工作机制代码示例执行流程2. 状态机模型3. 核心方法解析行为总结:三

springboot项目中使用JOSN解析库的方法

《springboot项目中使用JOSN解析库的方法》JSON,全程是JavaScriptObjectNotation,是一种轻量级的数据交换格式,本文给大家介绍springboot项目中使用JOSN... 目录一、jsON解析简介二、Spring Boot项目中使用JSON解析1、pom.XML文件引入依

Python中文件读取操作漏洞深度解析与防护指南

《Python中文件读取操作漏洞深度解析与防护指南》在Web应用开发中,文件操作是最基础也最危险的功能之一,这篇文章将全面剖析Python环境中常见的文件读取漏洞类型,成因及防护方案,感兴趣的小伙伴可... 目录引言一、静态资源处理中的路径穿越漏洞1.1 典型漏洞场景1.2 os.path.join()的陷

C#代码实现解析WTGPS和BD数据

《C#代码实现解析WTGPS和BD数据》在现代的导航与定位应用中,准确解析GPS和北斗(BD)等卫星定位数据至关重要,本文将使用C#语言实现解析WTGPS和BD数据,需要的可以了解下... 目录一、代码结构概览1. 核心解析方法2. 位置信息解析3. 经纬度转换方法4. 日期和时间戳解析5. 辅助方法二、L

Mybatis Plus JSqlParser解析sql语句及JSqlParser安装步骤

《MybatisPlusJSqlParser解析sql语句及JSqlParser安装步骤》JSqlParser是一个用于解析SQL语句的Java库,它可以将SQL语句解析为一个Java对象树,允许... 目录【一】jsqlParser 是什么【二】JSqlParser 的安装步骤【三】使用场景【1】sql语

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

Java 关键字transient与注解@Transient的区别用途解析

《Java关键字transient与注解@Transient的区别用途解析》在Java中,transient是一个关键字,用于声明一个字段不会被序列化,这篇文章给大家介绍了Java关键字transi... 在Java中,transient 是一个关键字,用于声明一个字段不会被序列化。当一个对象被序列化时,被

Java JSQLParser解析SQL的使用指南

《JavaJSQLParser解析SQL的使用指南》JSQLParser是一个Java语言的SQL语句解析工具,可以将SQL语句解析成为Java类的层次结构,还支持改写SQL,下面我们就来看看它的具... 目录一、引言二、jsQLParser常见类2.1 Class Diagram2.2 Statement

python进行while遍历的常见错误解析

《python进行while遍历的常见错误解析》在Python中选择合适的遍历方式需要综合考虑可读性、性能和具体需求,本文就来和大家讲解一下python中while遍历常见错误以及所有遍历方法的优缺点... 目录一、超出数组范围问题分析错误复现解决方法关键区别二、continue使用问题分析正确写法关键点三