tensorflow 网络修剪 剪枝操作

2024-04-03 21:18

本文主要是介绍tensorflow 网络修剪 剪枝操作,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

背景知识

模型剪枝(Model Pruning)是一种模型压缩方法,对深度神经网络的稠密连接引入稀疏性,通过将“不重要”的权值直接置零来减少非零权值数量,其历史可追溯到上世纪 90 年代初。

在 Optimal Brain Damage【2】中,使用对角 Hessian 逼近计算每个权值的重要性,重要性低的权值被置零,然后重新训练网络。

在 Optimal Brain Surgeon【3】中,使用逆 Hessian 矩阵计算每个权值的重要性,重要性低的权值被置零,剩下的权值使用二阶泰勒逼近的 loss 增量更新。

最近比较流行基于幅度的权值剪枝方法【4】,该方法将权值取绝对值,与设定的 threshhold 值进行比较,低于门限的权值被置零。基于幅度的权值剪枝算法计算高效,可以应用到大部分模型和数据集。TensorFlow 也使用了基于幅度的权值剪枝算法。


TF 代码实现

TensorFlow 代码目录 tensorflow/contrib/model_pruning/ 提供了对 TensorFlow  框架的扩展,可在模型训练时实现剪枝。

对每个被选中做剪枝的层增加一个二进制掩模(mask)变量,形状和该层的权值张量形状完全相同。该掩模决定了哪些权值参与前向计算。掩模更新算法则需要为 TensorFlow 训练计算图注入特殊运算符,对当前层权值按绝对值大小排序,对幅度小于一定门限的权值将其对应掩模值设为 0。反向传播梯度也经过掩模,被屏蔽的权值(mask 为 0)在反向传播步骤中无法获得更新量。

研究发现稀疏度不宜从一开始就设置最大,这样容易将重要的权值剪掉造成无法挽回的准确率损失,更好的方法是渐进稀疏度,从初始稀疏度 (一般为 0 )开始,逐步增大到最终稀疏度 ,这期间二进制掩模变量 mask 经历了 n 次更新,每次更新时的门限由当时的稀疏度决定,稀疏度由如下公式计算得到:

随着训练过程,逐步提高稀疏度,直到达到期望的稀疏度 为止。

下图很直观地反映了渐进提高稀疏度的过程。

初始时刻,稀疏度提升较快,而越到后面,稀疏度提升速度会逐渐放缓,这个比较符合直觉,因为初始时有大量冗余的权值,而越到后面保留的权值数量越少,不能再“大刀阔斧”地修剪,而需要更谨慎些,避免“误伤无辜”。

下面 TensorFlow 代码创建了带有 mask 变量的 graph:

from tensorflow.contrib.model_pruning.python import pruningwith tf.variable_scope('conv1') as scope:# 创建权值 variablekernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0)# 创建 conv2d op,权值 variable 增加 maskconv = tf.nn.conv2d(images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')

下面代码给出了带剪枝的模型训练代码结构:

from tensorflow.contrib.model_pruning.python import pruning# 命令行参数解析pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)# 创建剪枝对象pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)# 使用剪枝对象向训练图增加更新 mask 的运算符# 当且仅当训练步骤位于 [begin_pruning_step, end_pruning_step] 之间时,# conditional_mask_update_op 才会更新 maskmask_update_op = pruning_obj.conditional_mask_update_op()# 使用剪枝对象写入 summaries,用于跟踪每层权值 sparsity 变化pruning_obj.add_pruning_summaries()with tf.train.MonitoredTrainingSession() as mon_sess:while not mon_sess.should_stop():mon_sess.run(train_op)# 更新 maskmon_sess.run(mask_update_op)	

其中 FLAGS.pruning_hparams 为一组逗号分隔的键值对,取值如下表所示:

超参名类型默认值说明
begin_pruning_stepinteger0开始剪枝的全局 step
end_pruning_stepinteger-1结束剪枝的全局 step,默认为 -1 标识剪枝一直持续到训练结束
do_not_prunelist of strings[""]一组层名,标记哪些层不做剪枝
threshold_decayfloat0.9衰减因子,用于门限衰减
pruning_frequencyinteger10mask 更新的频率,计数单位为全局 step 数
initial_sparsityfloat0.0初始稀疏度值
target_sparsityfloat0.5目标稀疏度值
sparsity_function_begin_stepinteger0渐进稀疏度函数开始时刻
sparsity_function_end_stepinteger100渐进稀疏度函数结束时刻
sparsity_function_exponentfloat3.0

指数项,=1 则为线性增长,>1 则初始快后续慢

转者注: 详细代码可参照https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10    (例子)

           和 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/model_pruning/python/pruning.py  (pruning源码)                                                                                                                                                                                                                

实践

TensorFlow model pruning 自带 CIFAR10 例程,实现了一个稀疏 CNN 模型,其中卷积层和 local 层的权值均做了稀疏化。

(1) 准备 TensorFlow r1.7 环境

硬件环境:GTX 1080

软件环境:CUDA 9.0 + cuDNN 7, Bazel 0.11.1

git clone https://github.com/tensorflow/tensorflow.git
cd tensorflow/
git checkout r1.7

(2) 编译、运行 tensorflow/contrib/model_pruning/

cd tensorflow/contrib/model_pruning/
bazel build -c opt examples/cifar10:cifar10_{train,val}
cd ../../
bazel-bin/contrib/model_pruning/examples/cifar10/cifar10_train -prune_hparams=name=cifar10_pruning,begin_pruning_step=10000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000

(3) 查看训练过程

运行 TensorBoard:

tensorboard --logdir /tmp/cifar10_train/

打开浏览器,输入 localhost:6006

可以看到随着训练步骤增加,conv1 和 conv2 的 sparsity 在不断增长。总的 loss 变化如下图所示:

(4) 查看计算图

切换到 GRAPHS 页面,双击 conv2 节点,可以看到在原有计算图基础上新增了 mask 和 threshold 节点用来做 model pruning。

(5) 模型评估

利用以下命令对训练模型进行评估:

bazel-bin/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval

训练 15 万次迭代的结果(仅供参考)

SparsityAccuracy after 150K steps
0%86%
50%86%
75%
90%
95%77%

论文【1】中一些结论

随着稀疏度提高,模型质量逐渐下降,其表现为分类准确率降低。下表为 InceptionV3 模型不同稀疏度的情况【1】:

从表中看到,50% 系数模型和基准模型(0% 稀疏度)表现一致,而 87.5% 稀疏度模型的 top-5 准确率相比基准模型只有 2% 降低,但模型非零权值数量减少为原来 1/8。

我们前面文章《用于移动和嵌入式视觉应用的 MobileNets》介绍过轻量 CNN 模型 MobileNet,是一类特别为移动视觉应用设计的高效卷积神经网络。MobileNet 基于 depthwise separable 卷积,将通道内滤波通道间线性组合分解为两个独立步骤,显著减少了参数数量。MobileNet 网络架构包括一个标准卷积层用于处理输入图片,一大堆 depthwise separable conv,最后为 average pooling 和全连接层。 width multiplier 是 MobileNet 的一个调节参数,能实现模型准确率和模型权值数量、计算量的 trade-off。

我们既可以通过设置更小的 width multiplier 实现尺寸更小的模型(准确率会降低),也可以通过对原始 MobileNet 做稀疏化得到尺寸更小的模型(准确率同样会降低),那么这两种方法哪种更有效呢?论文【1】 给出了结果:

基本结论为:大而稀疏的模型(large-sparse)表现优于小而稠密的模型(small-dense)。

例如,75% 稀疏度模型( 1.09 M 权值,top-1 accuracy 为 67.7%)优于稠密 0.5 MobileNet( 1.32 M 权值,top-1 accuracy 为 63.7%)。

类似地, 90% 稀疏度模型(0.46 M 权值,top-1 accuracy 61.8%)优于稠密 0.25 MobileNet(0.46 M 权值,top-1 accuracy 50.6%)。

通过这些结论,为轻量级模型设计提供了新的思路,同时也为专用硬件加速器设计提供了参考。


参考文献

【1】 To prune, or not to prune: exploring the efficacy of pruning for model compression, arXiv:1710.01878

【2】Yann LeCun et.al. Optimal brain damage. NIPS, 1990

【3】B.Hassibi et.al. Optimal brain surgeeon and general network pruning. ICNN, 1993

【4】Song Han et.al. Learning both weights and connections for efficient neural network. NIPS, 2015


这篇关于tensorflow 网络修剪 剪枝操作的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/873998

相关文章

Python中文件读取操作漏洞深度解析与防护指南

《Python中文件读取操作漏洞深度解析与防护指南》在Web应用开发中,文件操作是最基础也最危险的功能之一,这篇文章将全面剖析Python环境中常见的文件读取漏洞类型,成因及防护方案,感兴趣的小伙伴可... 目录引言一、静态资源处理中的路径穿越漏洞1.1 典型漏洞场景1.2 os.path.join()的陷

Python使用Code2flow将代码转化为流程图的操作教程

《Python使用Code2flow将代码转化为流程图的操作教程》Code2flow是一款开源工具,能够将代码自动转换为流程图,该工具对于代码审查、调试和理解大型代码库非常有用,在这篇博客中,我们将深... 目录引言1nVflRA、为什么选择 Code2flow?2、安装 Code2flow3、基本功能演示

Python中OpenCV与Matplotlib的图像操作入门指南

《Python中OpenCV与Matplotlib的图像操作入门指南》:本文主要介绍Python中OpenCV与Matplotlib的图像操作指南,本文通过实例代码给大家介绍的非常详细,对大家的学... 目录一、环境准备二、图像的基本操作1. 图像读取、显示与保存 使用OpenCV操作2. 像素级操作3.

python操作redis基础

《python操作redis基础》Redis(RemoteDictionaryServer)是一个开源的、基于内存的键值对(Key-Value)存储系统,它通常用作数据库、缓存和消息代理,这篇文章... 目录1. Redis 简介2. 前提条件3. 安装 python Redis 客户端库4. 连接到 Re

Java Stream.reduce()方法操作实际案例讲解

《JavaStream.reduce()方法操作实际案例讲解》reduce是JavaStreamAPI中的一个核心操作,用于将流中的元素组合起来产生单个结果,:本文主要介绍JavaStream.... 目录一、reduce的基本概念1. 什么是reduce操作2. reduce方法的三种形式二、reduce

MySQL表空间结构详解表空间到段页操作

《MySQL表空间结构详解表空间到段页操作》在MySQL架构和存储引擎专题中介绍了使用不同存储引擎创建表时生成的表空间数据文件,在本章节主要介绍使用InnoDB存储引擎创建表时生成的表空间数据文件,对... 目录️‍一、什么是表空间结构1.1 表空间与表空间文件的关系是什么?️‍二、用户数据在表空间中是怎么

Linux网络配置之网桥和虚拟网络的配置指南

《Linux网络配置之网桥和虚拟网络的配置指南》这篇文章主要为大家详细介绍了Linux中配置网桥和虚拟网络的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 一、网桥的配置在linux系统中配置一个新的网桥主要涉及以下几个步骤:1.为yum仓库做准备,安装组件epel-re

Python对PDF书签进行添加,修改提取和删除操作

《Python对PDF书签进行添加,修改提取和删除操作》PDF书签是PDF文件中的导航工具,通常包含一个标题和一个跳转位置,本教程将详细介绍如何使用Python对PDF文件中的书签进行操作... 目录简介使用工具python 向 PDF 添加书签添加书签添加嵌套书签Python 修改 PDF 书签Pytho

python如何下载网络文件到本地指定文件夹

《python如何下载网络文件到本地指定文件夹》这篇文章主要为大家详细介绍了python如何实现下载网络文件到本地指定文件夹,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下...  在python中下载文件到本地指定文件夹可以通过以下步骤实现,使用requests库处理HTTP请求,并结合o

Mysql数据库中数据的操作CRUD详解

《Mysql数据库中数据的操作CRUD详解》:本文主要介绍Mysql数据库中数据的操作(CRUD),详细描述对Mysql数据库中数据的操作(CRUD),包括插入、修改、删除数据,还有查询数据,包括... 目录一、插入数据(insert)1.插入数据的语法2.注意事项二、修改数据(update)1.语法2.有