tensorflow 网络修剪 剪枝操作

2024-04-03 21:18

本文主要是介绍tensorflow 网络修剪 剪枝操作,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

背景知识

模型剪枝(Model Pruning)是一种模型压缩方法,对深度神经网络的稠密连接引入稀疏性,通过将“不重要”的权值直接置零来减少非零权值数量,其历史可追溯到上世纪 90 年代初。

在 Optimal Brain Damage【2】中,使用对角 Hessian 逼近计算每个权值的重要性,重要性低的权值被置零,然后重新训练网络。

在 Optimal Brain Surgeon【3】中,使用逆 Hessian 矩阵计算每个权值的重要性,重要性低的权值被置零,剩下的权值使用二阶泰勒逼近的 loss 增量更新。

最近比较流行基于幅度的权值剪枝方法【4】,该方法将权值取绝对值,与设定的 threshhold 值进行比较,低于门限的权值被置零。基于幅度的权值剪枝算法计算高效,可以应用到大部分模型和数据集。TensorFlow 也使用了基于幅度的权值剪枝算法。


TF 代码实现

TensorFlow 代码目录 tensorflow/contrib/model_pruning/ 提供了对 TensorFlow  框架的扩展,可在模型训练时实现剪枝。

对每个被选中做剪枝的层增加一个二进制掩模(mask)变量,形状和该层的权值张量形状完全相同。该掩模决定了哪些权值参与前向计算。掩模更新算法则需要为 TensorFlow 训练计算图注入特殊运算符,对当前层权值按绝对值大小排序,对幅度小于一定门限的权值将其对应掩模值设为 0。反向传播梯度也经过掩模,被屏蔽的权值(mask 为 0)在反向传播步骤中无法获得更新量。

研究发现稀疏度不宜从一开始就设置最大,这样容易将重要的权值剪掉造成无法挽回的准确率损失,更好的方法是渐进稀疏度,从初始稀疏度 (一般为 0 )开始,逐步增大到最终稀疏度 ,这期间二进制掩模变量 mask 经历了 n 次更新,每次更新时的门限由当时的稀疏度决定,稀疏度由如下公式计算得到:

随着训练过程,逐步提高稀疏度,直到达到期望的稀疏度 为止。

下图很直观地反映了渐进提高稀疏度的过程。

初始时刻,稀疏度提升较快,而越到后面,稀疏度提升速度会逐渐放缓,这个比较符合直觉,因为初始时有大量冗余的权值,而越到后面保留的权值数量越少,不能再“大刀阔斧”地修剪,而需要更谨慎些,避免“误伤无辜”。

下面 TensorFlow 代码创建了带有 mask 变量的 graph:

from tensorflow.contrib.model_pruning.python import pruningwith tf.variable_scope('conv1') as scope:# 创建权值 variablekernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0)# 创建 conv2d op,权值 variable 增加 maskconv = tf.nn.conv2d(images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')

下面代码给出了带剪枝的模型训练代码结构:

from tensorflow.contrib.model_pruning.python import pruning# 命令行参数解析pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)# 创建剪枝对象pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)# 使用剪枝对象向训练图增加更新 mask 的运算符# 当且仅当训练步骤位于 [begin_pruning_step, end_pruning_step] 之间时,# conditional_mask_update_op 才会更新 maskmask_update_op = pruning_obj.conditional_mask_update_op()# 使用剪枝对象写入 summaries,用于跟踪每层权值 sparsity 变化pruning_obj.add_pruning_summaries()with tf.train.MonitoredTrainingSession() as mon_sess:while not mon_sess.should_stop():mon_sess.run(train_op)# 更新 maskmon_sess.run(mask_update_op)	

其中 FLAGS.pruning_hparams 为一组逗号分隔的键值对,取值如下表所示:

超参名类型默认值说明
begin_pruning_stepinteger0开始剪枝的全局 step
end_pruning_stepinteger-1结束剪枝的全局 step,默认为 -1 标识剪枝一直持续到训练结束
do_not_prunelist of strings[""]一组层名,标记哪些层不做剪枝
threshold_decayfloat0.9衰减因子,用于门限衰减
pruning_frequencyinteger10mask 更新的频率,计数单位为全局 step 数
initial_sparsityfloat0.0初始稀疏度值
target_sparsityfloat0.5目标稀疏度值
sparsity_function_begin_stepinteger0渐进稀疏度函数开始时刻
sparsity_function_end_stepinteger100渐进稀疏度函数结束时刻
sparsity_function_exponentfloat3.0

指数项,=1 则为线性增长,>1 则初始快后续慢

转者注: 详细代码可参照https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10    (例子)

           和 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/model_pruning/python/pruning.py  (pruning源码)                                                                                                                                                                                                                

实践

TensorFlow model pruning 自带 CIFAR10 例程,实现了一个稀疏 CNN 模型,其中卷积层和 local 层的权值均做了稀疏化。

(1) 准备 TensorFlow r1.7 环境

硬件环境:GTX 1080

软件环境:CUDA 9.0 + cuDNN 7, Bazel 0.11.1

git clone https://github.com/tensorflow/tensorflow.git
cd tensorflow/
git checkout r1.7

(2) 编译、运行 tensorflow/contrib/model_pruning/

cd tensorflow/contrib/model_pruning/
bazel build -c opt examples/cifar10:cifar10_{train,val}
cd ../../
bazel-bin/contrib/model_pruning/examples/cifar10/cifar10_train -prune_hparams=name=cifar10_pruning,begin_pruning_step=10000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000

(3) 查看训练过程

运行 TensorBoard:

tensorboard --logdir /tmp/cifar10_train/

打开浏览器,输入 localhost:6006

可以看到随着训练步骤增加,conv1 和 conv2 的 sparsity 在不断增长。总的 loss 变化如下图所示:

(4) 查看计算图

切换到 GRAPHS 页面,双击 conv2 节点,可以看到在原有计算图基础上新增了 mask 和 threshold 节点用来做 model pruning。

(5) 模型评估

利用以下命令对训练模型进行评估:

bazel-bin/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval

训练 15 万次迭代的结果(仅供参考)

SparsityAccuracy after 150K steps
0%86%
50%86%
75%
90%
95%77%

论文【1】中一些结论

随着稀疏度提高,模型质量逐渐下降,其表现为分类准确率降低。下表为 InceptionV3 模型不同稀疏度的情况【1】:

从表中看到,50% 系数模型和基准模型(0% 稀疏度)表现一致,而 87.5% 稀疏度模型的 top-5 准确率相比基准模型只有 2% 降低,但模型非零权值数量减少为原来 1/8。

我们前面文章《用于移动和嵌入式视觉应用的 MobileNets》介绍过轻量 CNN 模型 MobileNet,是一类特别为移动视觉应用设计的高效卷积神经网络。MobileNet 基于 depthwise separable 卷积,将通道内滤波通道间线性组合分解为两个独立步骤,显著减少了参数数量。MobileNet 网络架构包括一个标准卷积层用于处理输入图片,一大堆 depthwise separable conv,最后为 average pooling 和全连接层。 width multiplier 是 MobileNet 的一个调节参数,能实现模型准确率和模型权值数量、计算量的 trade-off。

我们既可以通过设置更小的 width multiplier 实现尺寸更小的模型(准确率会降低),也可以通过对原始 MobileNet 做稀疏化得到尺寸更小的模型(准确率同样会降低),那么这两种方法哪种更有效呢?论文【1】 给出了结果:

基本结论为:大而稀疏的模型(large-sparse)表现优于小而稠密的模型(small-dense)。

例如,75% 稀疏度模型( 1.09 M 权值,top-1 accuracy 为 67.7%)优于稠密 0.5 MobileNet( 1.32 M 权值,top-1 accuracy 为 63.7%)。

类似地, 90% 稀疏度模型(0.46 M 权值,top-1 accuracy 61.8%)优于稠密 0.25 MobileNet(0.46 M 权值,top-1 accuracy 50.6%)。

通过这些结论,为轻量级模型设计提供了新的思路,同时也为专用硬件加速器设计提供了参考。


参考文献

【1】 To prune, or not to prune: exploring the efficacy of pruning for model compression, arXiv:1710.01878

【2】Yann LeCun et.al. Optimal brain damage. NIPS, 1990

【3】B.Hassibi et.al. Optimal brain surgeeon and general network pruning. ICNN, 1993

【4】Song Han et.al. Learning both weights and connections for efficient neural network. NIPS, 2015


这篇关于tensorflow 网络修剪 剪枝操作的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/873998

相关文章

python panda库从基础到高级操作分析

《pythonpanda库从基础到高级操作分析》本文介绍了Pandas库的核心功能,包括处理结构化数据的Series和DataFrame数据结构,数据读取、清洗、分组聚合、合并、时间序列分析及大数据... 目录1. Pandas 概述2. 基本操作:数据读取与查看3. 索引操作:精准定位数据4. Group

Python操作PDF文档的主流库使用指南

《Python操作PDF文档的主流库使用指南》PDF因其跨平台、格式固定的特性成为文档交换的标准,然而,由于其复杂的内部结构,程序化操作PDF一直是个挑战,本文主要为大家整理了Python操作PD... 目录一、 基础操作1.PyPDF2 (及其继任者 pypdf)2.PyMuPDF / fitz3.Fre

Python对接支付宝支付之使用AliPay实现的详细操作指南

《Python对接支付宝支付之使用AliPay实现的详细操作指南》支付宝没有提供PythonSDK,但是强大的github就有提供python-alipay-sdk,封装里很多复杂操作,使用这个我们就... 目录一、引言二、准备工作2.1 支付宝开放平台入驻与应用创建2.2 密钥生成与配置2.3 安装ali

MySQL 强制使用特定索引的操作

《MySQL强制使用特定索引的操作》MySQL可通过FORCEINDEX、USEINDEX等语法强制查询使用特定索引,但优化器可能不采纳,需结合EXPLAIN分析执行计划,避免性能下降,注意版本差异... 目录1. 使用FORCE INDEX语法2. 使用USE INDEX语法3. 使用IGNORE IND

Python使用openpyxl读取Excel的操作详解

《Python使用openpyxl读取Excel的操作详解》本文介绍了使用Python的openpyxl库进行Excel文件的创建、读写、数据操作、工作簿与工作表管理,包括创建工作簿、加载工作簿、操作... 目录1 概述1.1 图示1.2 安装第三方库2 工作簿 workbook2.1 创建:Workboo

Ubuntu 24.04启用root图形登录的操作流程

《Ubuntu24.04启用root图形登录的操作流程》Ubuntu默认禁用root账户的图形与SSH登录,这是为了安全,但在某些场景你可能需要直接用root登录GNOME桌面,本文以Ubuntu2... 目录一、前言二、准备工作三、设置 root 密码四、启用图形界面 root 登录1. 修改 GDM 配

JSONArray在Java中的应用操作实例

《JSONArray在Java中的应用操作实例》JSONArray是org.json库用于处理JSON数组的类,可将Java对象(Map/List)转换为JSON格式,提供增删改查等操作,适用于前后端... 目录1. jsONArray定义与功能1.1 JSONArray概念阐释1.1.1 什么是JSONA

Java操作Word文档的全面指南

《Java操作Word文档的全面指南》在Java开发中,操作Word文档是常见的业务需求,广泛应用于合同生成、报表输出、通知发布、法律文书生成、病历模板填写等场景,本文将全面介绍Java操作Word文... 目录简介段落页头与页脚页码表格图片批注文本框目录图表简介Word编程最重要的类是org.apach

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

Python实现对阿里云OSS对象存储的操作详解

《Python实现对阿里云OSS对象存储的操作详解》这篇文章主要为大家详细介绍了Python实现对阿里云OSS对象存储的操作相关知识,包括连接,上传,下载,列举等功能,感兴趣的小伙伴可以了解下... 目录一、直接使用代码二、详细使用1. 环境准备2. 初始化配置3. bucket配置创建4. 文件上传到os