C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析

本文主要是介绍C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  在之前的博文中,我们已经队大部分层结构类都进行了分析,在这篇博文中我们准备针对最后两个,也是处于层结构类继承体系中最底层的两个基类layer_base和layer做一下简要分析。由于layer类只是对layer_base的一个简单实例化,因此这里着重分析layer_base类。

  首先,给出layer_base类的基本结构框图:

  一、成员变量

  由于layer_base是这个类体系结构的基类,是构建网络层的基石,因此其内部封装了网络层的基本属性,相应的也有大量对应的成员变量:

  接下来一一对这些成员变量的基本含义做一下大致介绍:

  (1)in_size_、out_size_:保存了当前层的输入数据尺寸和输出数据尺寸。

  (2)parallelize_:布尔类型标志位,用以标记当前工程是否使用TBB多线程加速。

  (3)next_、prev_:两个指向layer_base类型的指针,用以指向当前层的下一层以及当前层的上一层,是维持层间联系的关键纽带。

  (4)a_:保留当前层卷积运算的中间结果。

  (5)output_:经过激活函数处理之后的当前层的最终特征输出。

  (6)prev_delta_:有前一层传播过来的误差灵敏度(梯度下降法过程中使用)。

  (7)W_、b_:当前层的卷积核权重以及偏置。

  (8)dW_、db_:权重的导数和偏置的导数,用以对权重和偏置进行更新。

  (9)Whessian_、bhessian_:海森矩阵的相关变量,具体含义在后续博文中会详细解释。

  (10)prev_delta2_:误差相对于输入的二阶导数,主要用于全连接层中的误差计算。

  二、构造函数

  构造函数的功能十分简单,通过调用set_size()成员函数来完成网络层中各个相关变量的初始化: 

layer_base(layer_size_t in_dim, layer_size_t out_dim, size_t weight_dim, size_t bias_dim) : parallelize_(true), next_(nullptr), prev_(nullptr) 
{set_size(in_dim, out_dim, weight_dim, bias_dim);//初始化神经网络层的参数
}

   需要注意的一点是这里默认将parallelize_标志位初始化为true,即默认使用TBB加速。至于set_size()函数,主要是通过调用vector的成员函数resize()来对各个参数进行初始化。

  三、权重初始化

  权重初始化主要通过set_size()函数完成(注意,这个函数不仅仅在构造函数中有所调用),正如上文所说,这个函数本质上就是在调用resize():

        void set_size(layer_size_t in_dim, layer_size_t out_dim, size_t weight_dim, size_t bias_dim) {in_size_ = in_dim;out_size_ = out_dim;W_.resize(weight_dim);b_.resize(bias_dim);Whessian_.resize(weight_dim);bhessian_.resize(bias_dim);prev_delta2_.resize(in_dim);for (auto& o : output_)     o.resize(out_dim);for (auto& a : a_)          a.resize(out_dim);for (auto& p : prev_delta_) p.resize(in_dim);for (auto& dw : dW_) dw.resize(weight_dim);for (auto& db : db_) db.resize(bias_dim);}

  需要注意的一点就是这里使用了范围for循环来完成这个vector容器中元素的遍历和操作,这算是C++11的一个特点,需要慢慢体会,不过单纯的从遍历的角度讲,这的确比传统的for循环更为方便而安全。

  四、纯虚函数集

  由于layer_base是一个公共基类,有必要定义一些虚函数以及纯虚函数供派生出来的不同类型的子类进行改写。这里作者选择将与激活函数和前向/反向传播算法定义成纯虚函数,原因很明确:不同层的前向/反向传播算法是不同的,并且激活函数也是可有可无: 

        /**********将激活函数、前向传播和反向传播全部声明为纯虚函数,在子类中进行定义**********/virtual activation::function& activation_function() = 0;virtual const vec_t& forward_propagation(const vec_t& in, size_t worker_index) = 0;virtual const vec_t& back_propagation(const vec_t& current_delta, size_t worker_index) = 0;virtual const vec_t& back_propagation_2nd(const vec_t& current_delta2) = 0;

   五、中间状态保存

  由于卷积神经网络的训练时间都较长,因此有必要定义保存中间训练结果的接口以完成断点续传(这个用词可能不太恰当),因此在layer_base中提供了用以保存和加载网络中间训练状态的结构函数save和load:

        /**********保存网络层中的权重和偏置(中间训练结果)**********/virtual void save(std::ostream& os) const {if (is_exploded()) throw nn_error("failed to save weights because of infinite weight");for (auto w : W_) os << w << " ";for (auto b : b_) os << b << " ";}/**********加载中间训练值**********/virtual void load(std::istream& is) {for (auto& w : W_) is >> w;for (auto& b : b_) is >> b;}

   这里主要通过流操作来完成结果的输入输出操作,同样体现出了强力的C++特性。

  六、权值更新

  layer_base对权值更新的操作主要有两个,一是权值和偏置的参数的初始化操作set_size(),这个前文已经介绍过了;二是更新函数update_weight()。update_weight()函数主要是通过调用各个收敛算法(如这里默认使用的gradient_descent_levenberg_marquardt算法)中的update()函数来完成对应权值和偏置的更新操作:

  至于update函数的具体实现细节则取决于所使用的收敛算法,有关这部分内容我会在之后介绍收敛算法(Optimizer结构体)的博文中专门进行详细的介绍。不过从表面的调用形式上可以看出,在BP算法对权值进行更新的过程中,需要用到dW(一阶导数)和海森矩阵(二阶导数)。

  七、属性返回参数

  这部分结构函数几乎是各个网络层的必备函数,方便用户查看对应网络层的具体参数信息和特征输出结果,一般都包含两个方面,return语句和output_to_image类型的视觉转换函数。return语句负责返回网络层的相关成员变量(可以在内部进行一些简单运算),output_to_image()函数则负责将映射核、特征输出结果转换成图像的形式供我们观赏,这些在之前的博文中都有提到过,这里不再赘述。

  八、layer类结构分析

  相对于layer_base类,layer的结构功能则简单了很多,大体上可以分为三类。激活函数实例化,保存/加载函数具体化,定义错误提示信息。

  8.1 激活函数实例化

  由于在layer_base类中将激活函数定义为纯虚函数,作者选择在子类layer中对其进行实例化:

  这里涉及到了Activation类的使用,在这个类中封装了各种各样类型的激活函数,在后续的博文中会专门拿出一两篇的篇幅来对这个类进行分析。

  8.2 保存、加载中间训练值函数具体化

  这里没什么可细说的,通过流操作basic_ostream来进行输入输出:

    /**********辅助的保存、加载操作**********/template <typename Char, typename CharTraits>std::basic_ostream<Char, CharTraits>& operator << (std::basic_ostream<Char, CharTraits>& os, const layer_base& v) {v.save(os);return os;}template <typename Char, typename CharTraits>std::basic_istream<Char, CharTraits>& operator >> (std::basic_istream<Char, CharTraits>& os, layer_base& v) {v.load(os);return os;}

  8.3 错误提示函数定义

  在layer中定义了三种错误类型的信息提示函数:连接不匹配、输入特征维数不匹配、下采样维数不匹配:

  (1)连接不匹配信息提示函数connection_mismatch。这个函数主要是在程序发现当前一层的特征输出维数与后一层的特征输入维数不同时调用,格式化输出错误信息,指明出现问题的具体层。

  (2)输入特征维数不匹配信息提示函数data_mismatch:这个函数主要是在程序发现输入数据的维数与当前层的输入维数不匹配时调用,格式化输出错误信息,指明出现问题的具体层。

  (3)下采样维数不匹配信息提示函数pooling_size_mismatch:这个函数主要是在程序发现当前特征维数不能被下采样窗口尺寸整除时调用,格式化输出错误信息,指明出现问题的具体层。

  需要注意的一点是,以上三个函数只负责格式化输出错误信息提示,具体错误检查机制需要在对应的可能的调用环境中中自行编写进行判断。

  九、注意事项

  1、范围for循环

  在tiny_cnn工程中对容器进行遍历时,全部采用了范围for循环,这点对于之前一直使用传统for循环的童鞋来说刚开始可能有点难以接受,但毕竟范围for循环既安全又简答,以后也要多多使用。

  2、layer_base的函数并没有介绍完全

  上文中对layer_base类中的成员函数并没有百分之百的介绍完全,对于一些小的补丁试的成员函数在后续用到时再进行解释。

  3、激活函数不等于收敛算法

  这里强调一个初学者容易混淆的概念,就是激活函数和收敛算法。首先这两者是完全不同的,举个栗子通俗的说明一下:激活函数包含sigmoid,tanh,Relu;收敛算法则主要指梯度下降法,怎么样,是不是茅塞顿开了。

 



如果觉得这篇文章对您有所启发,欢迎关注我的公众号,我会尽可能积极和大家交流,谢谢。


这篇关于C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1115504

相关文章

MySQL 内存使用率常用分析语句

《MySQL内存使用率常用分析语句》用户整理了MySQL内存占用过高的分析方法,涵盖操作系统层确认及数据库层bufferpool、内存模块差值、线程状态、performance_schema性能数据... 目录一、 OS层二、 DB层1. 全局情况2. 内存占js用详情最近连续遇到mysql内存占用过高导致

idea的终端(Terminal)cmd的命令换成linux的命令详解

《idea的终端(Terminal)cmd的命令换成linux的命令详解》本文介绍IDEA配置Git的步骤:安装Git、修改终端设置并重启IDEA,强调顺序,作为个人经验分享,希望提供参考并支持脚本之... 目录一编程、设置前二、前置条件三、android设置四、设置后总结一、php设置前二、前置条件

python中列表应用和扩展性实用详解

《python中列表应用和扩展性实用详解》文章介绍了Python列表的核心特性:有序数据集合,用[]定义,元素类型可不同,支持迭代、循环、切片,可执行增删改查、排序、推导式及嵌套操作,是常用的数据处理... 目录1、列表定义2、格式3、列表是可迭代对象4、列表的常见操作总结1、列表定义是处理一组有序项目的

python使用try函数详解

《python使用try函数详解》Pythontry语句用于异常处理,支持捕获特定/多种异常、else/final子句确保资源释放,结合with语句自动清理,可自定义异常及嵌套结构,灵活应对错误场景... 目录try 函数的基本语法捕获特定异常捕获多个异常使用 else 子句使用 finally 子句捕获所

C++11范围for初始化列表auto decltype详解

《C++11范围for初始化列表autodecltype详解》C++11引入auto类型推导、decltype类型推断、统一列表初始化、范围for循环及智能指针,提升代码简洁性、类型安全与资源管理效... 目录C++11新特性1. 自动类型推导auto1.1 基本语法2. decltype3. 列表初始化3

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

C++11右值引用与Lambda表达式的使用

《C++11右值引用与Lambda表达式的使用》C++11引入右值引用,实现移动语义提升性能,支持资源转移与完美转发;同时引入Lambda表达式,简化匿名函数定义,通过捕获列表和参数列表灵活处理变量... 目录C++11新特性右值引用和移动语义左值 / 右值常见的左值和右值移动语义移动构造函数移动复制运算符

SQL Server 中的 WITH (NOLOCK) 示例详解

《SQLServer中的WITH(NOLOCK)示例详解》SQLServer中的WITH(NOLOCK)是一种表提示,等同于READUNCOMMITTED隔离级别,允许查询在不获取共享锁的情... 目录SQL Server 中的 WITH (NOLOCK) 详解一、WITH (NOLOCK) 的本质二、工作

springboot自定义注解RateLimiter限流注解技术文档详解

《springboot自定义注解RateLimiter限流注解技术文档详解》文章介绍了限流技术的概念、作用及实现方式,通过SpringAOP拦截方法、缓存存储计数器,结合注解、枚举、异常类等核心组件,... 目录什么是限流系统架构核心组件详解1. 限流注解 (@RateLimiter)2. 限流类型枚举 (

Java Thread中join方法使用举例详解

《JavaThread中join方法使用举例详解》JavaThread中join()方法主要是让调用改方法的thread完成run方法里面的东西后,在执行join()方法后面的代码,这篇文章主要介绍... 目录前言1.join()方法的定义和作用2.join()方法的三个重载版本3.join()方法的工作原