Batch Normalization原理与实战(下)

2024-03-19 01:20

本文主要是介绍Batch Normalization原理与实战(下),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

来自 | 知乎   作者 | 天雨粟

链接 | https://zhuanlan.zhihu.com/p/34879333

编辑 | 深度学习这件小事公众号

本文仅作学术交流,如有侵权,请联系后台删除。

   前言

本文主要从理论与实战视角对深度学习中的Batch Normalization的思路进行讲解、归纳和总结,并辅以代码让小伙伴儿们对Batch Normalization的作用有更加直观的了解。

本文主要分为两大部分,由于篇幅过长,分为上下两篇。本文为第二部分实战板块,主要以MNIST数据集作为整个代码测试的数据,通过比较加入Batch Normalization前后网络的性能来让大家对Batch Normalization的作用与效果有更加直观的感知。

   二、实战板块

经过了上面了理论学习,我们对BN有了理论上的认知。“Talk is cheap, show me the code”。接下来我们就通过实际的代码来对比加入BN前后的模型效果。实战部分使用MNIST数据集作为数据基础,并使用TensorFlow中的Batch Normalization结构来进行BN的实现。

数据准备:MNIST手写数据集

代码地址:https://github.com/NELSONZHAO/zhihu/tree/master/batch_normalization_discussion

注:TensorFlow版本为1.6.0

实战板块主要分为两部分:

  • 网络构建与辅助函数

  • BN测试

1、网络构建与辅助函数

首先我们先定义一下神经网络的类,这个类里面主要包括了以下方法:

  • build_network:前向计算

  • fully_connected:全连接计算

  • train:训练模型

  • test:测试模型

1.1 build_network

我们首先通过构造函数,把权重、激活函数以及是否使用BN这些变量传入,并生成一个training_accuracies来记录训练过程中的模型准确率变化。这里的initial_weights是一个list,list中每一个元素是一个矩阵(二维tuple),存储了每一层的权重矩阵。build_network实现了网络的构建,并调用了fully_connected函数(下面会提)进行计算。要注意的是,由于MNIST是多分类,在这里我们不需要对最后一层进行激活,保留计算的logits就好。

1.2 fully_connected

这里的fully_connected主要用来每一层的线性与非线性计算。通过self.use_batch_norm来控制是否使用BN。

另外,值得注意的是,tf.layers.batch_normalization接口中training参数非常重要,官方文档中描述为:

training: Either a Python boolean, or a TensorFlow boolean scalar tensor (e.g. a placeholder). Whether to return the output in training mode (normalized with statistics of the current batch) or in inference mode (normalized with moving statistics). NOTE: make sure to set this parameter correctly, or else your training/inference will not work properly.

当我们训练时,要设置为True,保证在训练过程中使用的是mini-batch的统计量进行normalization;在Inference阶段,使用False,也就是使用总体样本的无偏估计。

1.3 train

train函数主要用来进行模型的训练。除了要定义label,loss以及optimizer以外,我们还需要注意,官方文档指出在使用BN时的事项:

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op.

因此当self.use_batch_norm为True时,要使用tf.control_dependencies保证模型正常训练。

注意:在训练过程中batch_size选了60(mnist.train.next_batch(60)),这里是因为BN的原paper中用的60。( We trained the network for 50000 steps, with 60 examples per mini-batch.)

1.4 test

test阶段与train类似,只是要设置self.is_training=False,保证Inference阶段BN的正确。

经过上面的步骤,我们的框架基本就搭好了,接下来我们再写一个辅助函数train_and_test以及plot绘图函数就可以开始对BN进行测试啦。train_and_test以及plot函数见GitHub代码中,这里不再赘述

2、BN测试

在这里,我们构造一个4层神经网络,输入层结点数784,三个隐层均为128维,输出层10个结点,如下图所示:

实验中,我们主要控制一下三个变量:

  • 权重矩阵(较小初始化权重,标准差为0.05;较大初始化权重,标准差为10)

  • 学习率(较小学习率:0.01;较大学习率:2)

  • 隐层激活函数(relu,sigmoid)

2.1 小权重,小学习率,ReLU

测试结果如下图:

我们可以得到以下结论:

  • 在训练与预测阶段,加入BN的模型准确率都稍高一点;

  • 加入BN的网络收敛更快(黄线)

  • 没有加入BN的网络训练速度更快(483.61it/s>329.23it/s),这是因为BN增加了神经网络中的计算量

为了更清楚地看到BN收敛速度更快,我们把减少Training batches,设置为3000,得到如下结果:

从上图中我们就可以清晰看到,加入BN的网络在第500个batch的时候已经能够在validation数据集上达到90%的准确率;而没有BN的网络的准确率还在不停波动,并且到第3000个batch的时候才达到90%的准确率。

2.2 小权重,小学习率,Sigmoid

学习率与权重均没变,我们把隐层激活函数换为sigmoid。可以发现,BN收敛速度非常之快,而没有BN的网络前期在不断波动,直到第20000个train batch以后才开始进入平稳的训练状态。

2.3 小权重,大学习率,ReLU

在本次实验中,我们使用了较大的学习率,较大的学习率意味着权重的更新跨度很大,而根据我们前面理论部分的介绍,BN不会受到权重scale的影响,因此其能够使模型保持在一个稳定的训练状态;而没有加入BN的网络则在一开始就由于学习率过大导致训练失败。

2.4 小权重,大学习率,Sigmoid

在保持较大学习率(learning rate=2)的情况下,当我们将激活函数换为sigmoid以后,两个模型都能够达到一个很好的效果,并且在test数据集上的准确率非常接近;但加入BN的网络要收敛地更快,同样的,我们来观察3000次batch的训练准确率。

当我们把training batch限制到3000以后,可以发现加入BN后,尽管我们使用较大的学习率,其仍然能够在大约500个batch以后在validation上达到90%的准确率;但不加入BN的准确率前期在一直大幅度波动,到大约1000个batch以后才达到90%的准确率。

2.5 大权重,小学习率,ReLU

当我们使用较大权重时,不加入BN的网络在一开始就失效;而加入BN的网络能够克服如此bad的权重初始化,并达到接近80%的准确率。

2.6 大权重,小学习率,Sigmoid

同样使用较大的权重初始化,当我们激活函数为sigmoid时,不加入BN的网络在一开始的准确率有所上升,但随着训练的进行网络逐渐失效,最终准确率仅有30%;而加入BN的网络依旧出色地克服如此bad的权重初始化,并达到接近85%的准确率

2.7 大权重,大学习率,ReLU

当权重与学习率都很大时,BN网络开始还会训练一段时间,但随后就直接停止训练;而没有BN的神经网络开始就失效。

2.8 大权重,大学习率,Sigmoid

可以看到,加入BN对较大的权重与较大学习率都具有非常好的鲁棒性,最终模型能够达到93%的准确率;而未加入BN的网络则经过一段时间震荡后开始失效。

8个模型的准确率统计如下:

   总结

至此,关于Batch Normalization的理论与实战部分就介绍道这里。总的来说,BN通过将每一层网络的输入进行normalization,保证输入分布的均值与方差固定在一定范围内,减少了网络中的Internal Covariate Shift问题,并在一定程度上缓解了梯度消失,加速了模型收敛;并且BN使得网络对参数、激活函数更加具有鲁棒性,降低了神经网络模型训练和调参的复杂度;最后BN训练过程中由于使用mini-batch的mean/variance作为总体样本统计量估计,引入了随机噪声,在一定程度上对模型起到了正则化的效果。

参考资料:

[1] Ioffe S, Szegedy C. Batch normalization: accelerating deepnetwork training by reducing internal covariate shift[C]// InternationalConference on International Conference on Machine Learning. JMLR.org,2015:448-456.

[2] 吴恩达Cousera Deep Learning课程

[3] 详解深度学习中的Normalization,不只是BN

[4] 深度学习中 Batch Normalization为什么效果好?

[5] Udacity DeepLearning Nanodegree

[6] Implementing Batch Normalization in Tensorflow

—完—

这篇关于Batch Normalization原理与实战(下)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/824343

相关文章

SQL Server跟踪自动统计信息更新实战指南

《SQLServer跟踪自动统计信息更新实战指南》本文详解SQLServer自动统计信息更新的跟踪方法,推荐使用扩展事件实时捕获更新操作及详细信息,同时结合系统视图快速检查统计信息状态,重点强调修... 目录SQL Server 如何跟踪自动统计信息更新:深入解析与实战指南 核心跟踪方法1️⃣ 利用系统目录

java中pdf模版填充表单踩坑实战记录(itextPdf、openPdf、pdfbox)

《java中pdf模版填充表单踩坑实战记录(itextPdf、openPdf、pdfbox)》:本文主要介绍java中pdf模版填充表单踩坑的相关资料,OpenPDF、iText、PDFBox是三... 目录准备Pdf模版方法1:itextpdf7填充表单(1)加入依赖(2)代码(3)遇到的问题方法2:pd

Spring Security 单点登录与自动登录机制的实现原理

《SpringSecurity单点登录与自动登录机制的实现原理》本文探讨SpringSecurity实现单点登录(SSO)与自动登录机制,涵盖JWT跨系统认证、RememberMe持久化Token... 目录一、核心概念解析1.1 单点登录(SSO)1.2 自动登录(Remember Me)二、代码分析三、

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

在IntelliJ IDEA中高效运行与调试Spring Boot项目的实战步骤

《在IntelliJIDEA中高效运行与调试SpringBoot项目的实战步骤》本章详解SpringBoot项目导入IntelliJIDEA的流程,教授运行与调试技巧,包括断点设置与变量查看,奠定... 目录引言:为良驹配上好鞍一、为何选择IntelliJ IDEA?二、实战:导入并运行你的第一个项目步骤1

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分

Spring Boot3.0新特性全面解析与应用实战

《SpringBoot3.0新特性全面解析与应用实战》SpringBoot3.0作为Spring生态系统的一个重要里程碑,带来了众多令人兴奋的新特性和改进,本文将深入解析SpringBoot3.0的... 目录核心变化概览Java版本要求提升迁移至Jakarta EE重要新特性详解1. Native Ima

Spring Boot 与微服务入门实战详细总结

《SpringBoot与微服务入门实战详细总结》本文讲解SpringBoot框架的核心特性如快速构建、自动配置、零XML与微服务架构的定义、演进及优缺点,涵盖开发环境准备和HelloWorld实战... 目录一、Spring Boot 核心概述二、微服务架构详解1. 微服务的定义与演进2. 微服务的优缺点三

SpringBoot集成MyBatis实现SQL拦截器的实战指南

《SpringBoot集成MyBatis实现SQL拦截器的实战指南》这篇文章主要为大家详细介绍了SpringBoot集成MyBatis实现SQL拦截器的相关知识,文中的示例代码讲解详细,有需要的小伙伴... 目录一、为什么需要SQL拦截器?二、MyBATis拦截器基础2.1 核心接口:Interceptor

从入门到进阶讲解Python自动化Playwright实战指南

《从入门到进阶讲解Python自动化Playwright实战指南》Playwright是针对Python语言的纯自动化工具,它可以通过单个API自动执行Chromium,Firefox和WebKit... 目录Playwright 简介核心优势安装步骤观点与案例结合Playwright 核心功能从零开始学习