C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析

本文主要是介绍C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  在之前的博文中,我们已经队大部分层结构类都进行了分析,在这篇博文中我们准备针对最后两个,也是处于层结构类继承体系中最底层的两个基类layer_base和layer做一下简要分析。由于layer类只是对layer_base的一个简单实例化,因此这里着重分析layer_base类。

  首先,给出layer_base类的基本结构框图:

  一、成员变量

  由于layer_base是这个类体系结构的基类,是构建网络层的基石,因此其内部封装了网络层的基本属性,相应的也有大量对应的成员变量:

  接下来一一对这些成员变量的基本含义做一下大致介绍:

  (1)in_size_、out_size_:保存了当前层的输入数据尺寸和输出数据尺寸。

  (2)parallelize_:布尔类型标志位,用以标记当前工程是否使用TBB多线程加速。

  (3)next_、prev_:两个指向layer_base类型的指针,用以指向当前层的下一层以及当前层的上一层,是维持层间联系的关键纽带。

  (4)a_:保留当前层卷积运算的中间结果。

  (5)output_:经过激活函数处理之后的当前层的最终特征输出。

  (6)prev_delta_:有前一层传播过来的误差灵敏度(梯度下降法过程中使用)。

  (7)W_、b_:当前层的卷积核权重以及偏置。

  (8)dW_、db_:权重的导数和偏置的导数,用以对权重和偏置进行更新。

  (9)Whessian_、bhessian_:海森矩阵的相关变量,具体含义在后续博文中会详细解释。

  (10)prev_delta2_:误差相对于输入的二阶导数,主要用于全连接层中的误差计算。

  二、构造函数

  构造函数的功能十分简单,通过调用set_size()成员函数来完成网络层中各个相关变量的初始化: 

layer_base(layer_size_t in_dim, layer_size_t out_dim, size_t weight_dim, size_t bias_dim) : parallelize_(true), next_(nullptr), prev_(nullptr) 
{set_size(in_dim, out_dim, weight_dim, bias_dim);//初始化神经网络层的参数
}

   需要注意的一点是这里默认将parallelize_标志位初始化为true,即默认使用TBB加速。至于set_size()函数,主要是通过调用vector的成员函数resize()来对各个参数进行初始化。

  三、权重初始化

  权重初始化主要通过set_size()函数完成(注意,这个函数不仅仅在构造函数中有所调用),正如上文所说,这个函数本质上就是在调用resize():

        void set_size(layer_size_t in_dim, layer_size_t out_dim, size_t weight_dim, size_t bias_dim) {in_size_ = in_dim;out_size_ = out_dim;W_.resize(weight_dim);b_.resize(bias_dim);Whessian_.resize(weight_dim);bhessian_.resize(bias_dim);prev_delta2_.resize(in_dim);for (auto& o : output_)     o.resize(out_dim);for (auto& a : a_)          a.resize(out_dim);for (auto& p : prev_delta_) p.resize(in_dim);for (auto& dw : dW_) dw.resize(weight_dim);for (auto& db : db_) db.resize(bias_dim);}

  需要注意的一点就是这里使用了范围for循环来完成这个vector容器中元素的遍历和操作,这算是C++11的一个特点,需要慢慢体会,不过单纯的从遍历的角度讲,这的确比传统的for循环更为方便而安全。

  四、纯虚函数集

  由于layer_base是一个公共基类,有必要定义一些虚函数以及纯虚函数供派生出来的不同类型的子类进行改写。这里作者选择将与激活函数和前向/反向传播算法定义成纯虚函数,原因很明确:不同层的前向/反向传播算法是不同的,并且激活函数也是可有可无: 

        /**********将激活函数、前向传播和反向传播全部声明为纯虚函数,在子类中进行定义**********/virtual activation::function& activation_function() = 0;virtual const vec_t& forward_propagation(const vec_t& in, size_t worker_index) = 0;virtual const vec_t& back_propagation(const vec_t& current_delta, size_t worker_index) = 0;virtual const vec_t& back_propagation_2nd(const vec_t& current_delta2) = 0;

   五、中间状态保存

  由于卷积神经网络的训练时间都较长,因此有必要定义保存中间训练结果的接口以完成断点续传(这个用词可能不太恰当),因此在layer_base中提供了用以保存和加载网络中间训练状态的结构函数save和load:

        /**********保存网络层中的权重和偏置(中间训练结果)**********/virtual void save(std::ostream& os) const {if (is_exploded()) throw nn_error("failed to save weights because of infinite weight");for (auto w : W_) os << w << " ";for (auto b : b_) os << b << " ";}/**********加载中间训练值**********/virtual void load(std::istream& is) {for (auto& w : W_) is >> w;for (auto& b : b_) is >> b;}

   这里主要通过流操作来完成结果的输入输出操作,同样体现出了强力的C++特性。

  六、权值更新

  layer_base对权值更新的操作主要有两个,一是权值和偏置的参数的初始化操作set_size(),这个前文已经介绍过了;二是更新函数update_weight()。update_weight()函数主要是通过调用各个收敛算法(如这里默认使用的gradient_descent_levenberg_marquardt算法)中的update()函数来完成对应权值和偏置的更新操作:

  至于update函数的具体实现细节则取决于所使用的收敛算法,有关这部分内容我会在之后介绍收敛算法(Optimizer结构体)的博文中专门进行详细的介绍。不过从表面的调用形式上可以看出,在BP算法对权值进行更新的过程中,需要用到dW(一阶导数)和海森矩阵(二阶导数)。

  七、属性返回参数

  这部分结构函数几乎是各个网络层的必备函数,方便用户查看对应网络层的具体参数信息和特征输出结果,一般都包含两个方面,return语句和output_to_image类型的视觉转换函数。return语句负责返回网络层的相关成员变量(可以在内部进行一些简单运算),output_to_image()函数则负责将映射核、特征输出结果转换成图像的形式供我们观赏,这些在之前的博文中都有提到过,这里不再赘述。

  八、layer类结构分析

  相对于layer_base类,layer的结构功能则简单了很多,大体上可以分为三类。激活函数实例化,保存/加载函数具体化,定义错误提示信息。

  8.1 激活函数实例化

  由于在layer_base类中将激活函数定义为纯虚函数,作者选择在子类layer中对其进行实例化:

  这里涉及到了Activation类的使用,在这个类中封装了各种各样类型的激活函数,在后续的博文中会专门拿出一两篇的篇幅来对这个类进行分析。

  8.2 保存、加载中间训练值函数具体化

  这里没什么可细说的,通过流操作basic_ostream来进行输入输出:

    /**********辅助的保存、加载操作**********/template <typename Char, typename CharTraits>std::basic_ostream<Char, CharTraits>& operator << (std::basic_ostream<Char, CharTraits>& os, const layer_base& v) {v.save(os);return os;}template <typename Char, typename CharTraits>std::basic_istream<Char, CharTraits>& operator >> (std::basic_istream<Char, CharTraits>& os, layer_base& v) {v.load(os);return os;}

  8.3 错误提示函数定义

  在layer中定义了三种错误类型的信息提示函数:连接不匹配、输入特征维数不匹配、下采样维数不匹配:

  (1)连接不匹配信息提示函数connection_mismatch。这个函数主要是在程序发现当前一层的特征输出维数与后一层的特征输入维数不同时调用,格式化输出错误信息,指明出现问题的具体层。

  (2)输入特征维数不匹配信息提示函数data_mismatch:这个函数主要是在程序发现输入数据的维数与当前层的输入维数不匹配时调用,格式化输出错误信息,指明出现问题的具体层。

  (3)下采样维数不匹配信息提示函数pooling_size_mismatch:这个函数主要是在程序发现当前特征维数不能被下采样窗口尺寸整除时调用,格式化输出错误信息,指明出现问题的具体层。

  需要注意的一点是,以上三个函数只负责格式化输出错误信息提示,具体错误检查机制需要在对应的可能的调用环境中中自行编写进行判断。

  九、注意事项

  1、范围for循环

  在tiny_cnn工程中对容器进行遍历时,全部采用了范围for循环,这点对于之前一直使用传统for循环的童鞋来说刚开始可能有点难以接受,但毕竟范围for循环既安全又简答,以后也要多多使用。

  2、layer_base的函数并没有介绍完全

  上文中对layer_base类中的成员函数并没有百分之百的介绍完全,对于一些小的补丁试的成员函数在后续用到时再进行解释。

  3、激活函数不等于收敛算法

  这里强调一个初学者容易混淆的概念,就是激活函数和收敛算法。首先这两者是完全不同的,举个栗子通俗的说明一下:激活函数包含sigmoid,tanh,Relu;收敛算法则主要指梯度下降法,怎么样,是不是茅塞顿开了。

 



如果觉得这篇文章对您有所启发,欢迎关注我的公众号,我会尽可能积极和大家交流,谢谢。


这篇关于C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1115504

相关文章

PHP轻松处理千万行数据的方法详解

《PHP轻松处理千万行数据的方法详解》说到处理大数据集,PHP通常不是第一个想到的语言,但如果你曾经需要处理数百万行数据而不让服务器崩溃或内存耗尽,你就会知道PHP用对了工具有多强大,下面小编就... 目录问题的本质php 中的数据流处理:为什么必不可少生成器:内存高效的迭代方式流量控制:避免系统过载一次性

C++右移运算符的一个小坑及解决

《C++右移运算符的一个小坑及解决》文章指出右移运算符处理负数时左侧补1导致死循环,与除法行为不同,强调需注意补码机制以正确统计二进制1的个数... 目录我遇到了这么一个www.chinasem.cn函数由此可以看到也很好理解总结我遇到了这么一个函数template<typename T>unsigned

C#实现千万数据秒级导入的代码

《C#实现千万数据秒级导入的代码》在实际开发中excel导入很常见,现代社会中很容易遇到大数据处理业务,所以本文我就给大家分享一下千万数据秒级导入怎么实现,文中有详细的代码示例供大家参考,需要的朋友可... 目录前言一、数据存储二、处理逻辑优化前代码处理逻辑优化后的代码总结前言在实际开发中excel导入很

SpringBoot+RustFS 实现文件切片极速上传的实例代码

《SpringBoot+RustFS实现文件切片极速上传的实例代码》本文介绍利用SpringBoot和RustFS构建高性能文件切片上传系统,实现大文件秒传、断点续传和分片上传等功能,具有一定的参考... 目录一、为什么选择 RustFS + SpringBoot?二、环境准备与部署2.1 安装 RustF

MySQL的JDBC编程详解

《MySQL的JDBC编程详解》:本文主要介绍MySQL的JDBC编程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录前言一、前置知识1. 引入依赖2. 认识 url二、JDBC 操作流程1. JDBC 的写操作2. JDBC 的读操作总结前言本文介绍了mysq

Python实现Excel批量样式修改器(附完整代码)

《Python实现Excel批量样式修改器(附完整代码)》这篇文章主要为大家详细介绍了如何使用Python实现一个Excel批量样式修改器,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一... 目录前言功能特性核心功能界面特性系统要求安装说明使用指南基本操作流程高级功能技术实现核心技术栈关键函

Redis 的 SUBSCRIBE命令详解

《Redis的SUBSCRIBE命令详解》Redis的SUBSCRIBE命令用于订阅一个或多个频道,以便接收发送到这些频道的消息,本文给大家介绍Redis的SUBSCRIBE命令,感兴趣的朋友跟随... 目录基本语法工作原理示例消息格式相关命令python 示例Redis 的 SUBSCRIBE 命令用于订

使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解

《使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解》本文详细介绍了如何使用Python通过ncmdump工具批量将.ncm音频转换为.mp3的步骤,包括安装、配置ffmpeg环... 目录1. 前言2. 安装 ncmdump3. 实现 .ncm 转 .mp34. 执行过程5. 执行结

Python中 try / except / else / finally 异常处理方法详解

《Python中try/except/else/finally异常处理方法详解》:本文主要介绍Python中try/except/else/finally异常处理方法的相关资料,涵... 目录1. 基本结构2. 各部分的作用tryexceptelsefinally3. 执行流程总结4. 常见用法(1)多个e

C++统计函数执行时间的最佳实践

《C++统计函数执行时间的最佳实践》在软件开发过程中,性能分析是优化程序的重要环节,了解函数的执行时间分布对于识别性能瓶颈至关重要,本文将分享一个C++函数执行时间统计工具,希望对大家有所帮助... 目录前言工具特性核心设计1. 数据结构设计2. 单例模式管理器3. RAII自动计时使用方法基本用法高级用法