C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析

本文主要是介绍C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  在之前的博文中,我们已经队大部分层结构类都进行了分析,在这篇博文中我们准备针对最后两个,也是处于层结构类继承体系中最底层的两个基类layer_base和layer做一下简要分析。由于layer类只是对layer_base的一个简单实例化,因此这里着重分析layer_base类。

  首先,给出layer_base类的基本结构框图:

  一、成员变量

  由于layer_base是这个类体系结构的基类,是构建网络层的基石,因此其内部封装了网络层的基本属性,相应的也有大量对应的成员变量:

  接下来一一对这些成员变量的基本含义做一下大致介绍:

  (1)in_size_、out_size_:保存了当前层的输入数据尺寸和输出数据尺寸。

  (2)parallelize_:布尔类型标志位,用以标记当前工程是否使用TBB多线程加速。

  (3)next_、prev_:两个指向layer_base类型的指针,用以指向当前层的下一层以及当前层的上一层,是维持层间联系的关键纽带。

  (4)a_:保留当前层卷积运算的中间结果。

  (5)output_:经过激活函数处理之后的当前层的最终特征输出。

  (6)prev_delta_:有前一层传播过来的误差灵敏度(梯度下降法过程中使用)。

  (7)W_、b_:当前层的卷积核权重以及偏置。

  (8)dW_、db_:权重的导数和偏置的导数,用以对权重和偏置进行更新。

  (9)Whessian_、bhessian_:海森矩阵的相关变量,具体含义在后续博文中会详细解释。

  (10)prev_delta2_:误差相对于输入的二阶导数,主要用于全连接层中的误差计算。

  二、构造函数

  构造函数的功能十分简单,通过调用set_size()成员函数来完成网络层中各个相关变量的初始化: 

layer_base(layer_size_t in_dim, layer_size_t out_dim, size_t weight_dim, size_t bias_dim) : parallelize_(true), next_(nullptr), prev_(nullptr) 
{set_size(in_dim, out_dim, weight_dim, bias_dim);//初始化神经网络层的参数
}

   需要注意的一点是这里默认将parallelize_标志位初始化为true,即默认使用TBB加速。至于set_size()函数,主要是通过调用vector的成员函数resize()来对各个参数进行初始化。

  三、权重初始化

  权重初始化主要通过set_size()函数完成(注意,这个函数不仅仅在构造函数中有所调用),正如上文所说,这个函数本质上就是在调用resize():

        void set_size(layer_size_t in_dim, layer_size_t out_dim, size_t weight_dim, size_t bias_dim) {in_size_ = in_dim;out_size_ = out_dim;W_.resize(weight_dim);b_.resize(bias_dim);Whessian_.resize(weight_dim);bhessian_.resize(bias_dim);prev_delta2_.resize(in_dim);for (auto& o : output_)     o.resize(out_dim);for (auto& a : a_)          a.resize(out_dim);for (auto& p : prev_delta_) p.resize(in_dim);for (auto& dw : dW_) dw.resize(weight_dim);for (auto& db : db_) db.resize(bias_dim);}

  需要注意的一点就是这里使用了范围for循环来完成这个vector容器中元素的遍历和操作,这算是C++11的一个特点,需要慢慢体会,不过单纯的从遍历的角度讲,这的确比传统的for循环更为方便而安全。

  四、纯虚函数集

  由于layer_base是一个公共基类,有必要定义一些虚函数以及纯虚函数供派生出来的不同类型的子类进行改写。这里作者选择将与激活函数和前向/反向传播算法定义成纯虚函数,原因很明确:不同层的前向/反向传播算法是不同的,并且激活函数也是可有可无: 

        /**********将激活函数、前向传播和反向传播全部声明为纯虚函数,在子类中进行定义**********/virtual activation::function& activation_function() = 0;virtual const vec_t& forward_propagation(const vec_t& in, size_t worker_index) = 0;virtual const vec_t& back_propagation(const vec_t& current_delta, size_t worker_index) = 0;virtual const vec_t& back_propagation_2nd(const vec_t& current_delta2) = 0;

   五、中间状态保存

  由于卷积神经网络的训练时间都较长,因此有必要定义保存中间训练结果的接口以完成断点续传(这个用词可能不太恰当),因此在layer_base中提供了用以保存和加载网络中间训练状态的结构函数save和load:

        /**********保存网络层中的权重和偏置(中间训练结果)**********/virtual void save(std::ostream& os) const {if (is_exploded()) throw nn_error("failed to save weights because of infinite weight");for (auto w : W_) os << w << " ";for (auto b : b_) os << b << " ";}/**********加载中间训练值**********/virtual void load(std::istream& is) {for (auto& w : W_) is >> w;for (auto& b : b_) is >> b;}

   这里主要通过流操作来完成结果的输入输出操作,同样体现出了强力的C++特性。

  六、权值更新

  layer_base对权值更新的操作主要有两个,一是权值和偏置的参数的初始化操作set_size(),这个前文已经介绍过了;二是更新函数update_weight()。update_weight()函数主要是通过调用各个收敛算法(如这里默认使用的gradient_descent_levenberg_marquardt算法)中的update()函数来完成对应权值和偏置的更新操作:

  至于update函数的具体实现细节则取决于所使用的收敛算法,有关这部分内容我会在之后介绍收敛算法(Optimizer结构体)的博文中专门进行详细的介绍。不过从表面的调用形式上可以看出,在BP算法对权值进行更新的过程中,需要用到dW(一阶导数)和海森矩阵(二阶导数)。

  七、属性返回参数

  这部分结构函数几乎是各个网络层的必备函数,方便用户查看对应网络层的具体参数信息和特征输出结果,一般都包含两个方面,return语句和output_to_image类型的视觉转换函数。return语句负责返回网络层的相关成员变量(可以在内部进行一些简单运算),output_to_image()函数则负责将映射核、特征输出结果转换成图像的形式供我们观赏,这些在之前的博文中都有提到过,这里不再赘述。

  八、layer类结构分析

  相对于layer_base类,layer的结构功能则简单了很多,大体上可以分为三类。激活函数实例化,保存/加载函数具体化,定义错误提示信息。

  8.1 激活函数实例化

  由于在layer_base类中将激活函数定义为纯虚函数,作者选择在子类layer中对其进行实例化:

  这里涉及到了Activation类的使用,在这个类中封装了各种各样类型的激活函数,在后续的博文中会专门拿出一两篇的篇幅来对这个类进行分析。

  8.2 保存、加载中间训练值函数具体化

  这里没什么可细说的,通过流操作basic_ostream来进行输入输出:

    /**********辅助的保存、加载操作**********/template <typename Char, typename CharTraits>std::basic_ostream<Char, CharTraits>& operator << (std::basic_ostream<Char, CharTraits>& os, const layer_base& v) {v.save(os);return os;}template <typename Char, typename CharTraits>std::basic_istream<Char, CharTraits>& operator >> (std::basic_istream<Char, CharTraits>& os, layer_base& v) {v.load(os);return os;}

  8.3 错误提示函数定义

  在layer中定义了三种错误类型的信息提示函数:连接不匹配、输入特征维数不匹配、下采样维数不匹配:

  (1)连接不匹配信息提示函数connection_mismatch。这个函数主要是在程序发现当前一层的特征输出维数与后一层的特征输入维数不同时调用,格式化输出错误信息,指明出现问题的具体层。

  (2)输入特征维数不匹配信息提示函数data_mismatch:这个函数主要是在程序发现输入数据的维数与当前层的输入维数不匹配时调用,格式化输出错误信息,指明出现问题的具体层。

  (3)下采样维数不匹配信息提示函数pooling_size_mismatch:这个函数主要是在程序发现当前特征维数不能被下采样窗口尺寸整除时调用,格式化输出错误信息,指明出现问题的具体层。

  需要注意的一点是,以上三个函数只负责格式化输出错误信息提示,具体错误检查机制需要在对应的可能的调用环境中中自行编写进行判断。

  九、注意事项

  1、范围for循环

  在tiny_cnn工程中对容器进行遍历时,全部采用了范围for循环,这点对于之前一直使用传统for循环的童鞋来说刚开始可能有点难以接受,但毕竟范围for循环既安全又简答,以后也要多多使用。

  2、layer_base的函数并没有介绍完全

  上文中对layer_base类中的成员函数并没有百分之百的介绍完全,对于一些小的补丁试的成员函数在后续用到时再进行解释。

  3、激活函数不等于收敛算法

  这里强调一个初学者容易混淆的概念,就是激活函数和收敛算法。首先这两者是完全不同的,举个栗子通俗的说明一下:激活函数包含sigmoid,tanh,Relu;收敛算法则主要指梯度下降法,怎么样,是不是茅塞顿开了。

 



如果觉得这篇文章对您有所启发,欢迎关注我的公众号,我会尽可能积极和大家交流,谢谢。


这篇关于C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1115504

相关文章

MySQL数据库双机热备的配置方法详解

《MySQL数据库双机热备的配置方法详解》在企业级应用中,数据库的高可用性和数据的安全性是至关重要的,MySQL作为最流行的开源关系型数据库管理系统之一,提供了多种方式来实现高可用性,其中双机热备(M... 目录1. 环境准备1.1 安装mysql1.2 配置MySQL1.2.1 主服务器配置1.2.2 从

C++中unordered_set哈希集合的实现

《C++中unordered_set哈希集合的实现》std::unordered_set是C++标准库中的无序关联容器,基于哈希表实现,具有元素唯一性和无序性特点,本文就来详细的介绍一下unorder... 目录一、概述二、头文件与命名空间三、常用方法与示例1. 构造与析构2. 迭代器与遍历3. 容量相关4

C++中悬垂引用(Dangling Reference) 的实现

《C++中悬垂引用(DanglingReference)的实现》C++中的悬垂引用指引用绑定的对象被销毁后引用仍存在的情况,会导致访问无效内存,下面就来详细的介绍一下产生的原因以及如何避免,感兴趣... 目录悬垂引用的产生原因1. 引用绑定到局部变量,变量超出作用域后销毁2. 引用绑定到动态分配的对象,对象

Linux kill正在执行的后台任务 kill进程组使用详解

《Linuxkill正在执行的后台任务kill进程组使用详解》文章介绍了两个脚本的功能和区别,以及执行这些脚本时遇到的进程管理问题,通过查看进程树、使用`kill`命令和`lsof`命令,分析了子... 目录零. 用到的命令一. 待执行的脚本二. 执行含子进程的脚本,并kill2.1 进程查看2.2 遇到的

MyBatis常用XML语法详解

《MyBatis常用XML语法详解》文章介绍了MyBatis常用XML语法,包括结果映射、查询语句、插入语句、更新语句、删除语句、动态SQL标签以及ehcache.xml文件的使用,感兴趣的朋友跟随小... 目录1、定义结果映射2、查询语句3、插入语句4、更新语句5、删除语句6、动态 SQL 标签7、ehc

详解SpringBoot+Ehcache使用示例

《详解SpringBoot+Ehcache使用示例》本文介绍了SpringBoot中配置Ehcache、自定义get/set方式,并实际使用缓存的过程,文中通过示例代码介绍的非常详细,对大家的学习或者... 目录摘要概念内存与磁盘持久化存储:配置灵活性:编码示例引入依赖:配置ehcache.XML文件:配置

从基础到高级详解Go语言中错误处理的实践指南

《从基础到高级详解Go语言中错误处理的实践指南》Go语言采用了一种独特而明确的错误处理哲学,与其他主流编程语言形成鲜明对比,本文将为大家详细介绍Go语言中错误处理详细方法,希望对大家有所帮助... 目录1 Go 错误处理哲学与核心机制1.1 错误接口设计1.2 错误与异常的区别2 错误创建与检查2.1 基础

Nginx分布式部署流程分析

《Nginx分布式部署流程分析》文章介绍Nginx在分布式部署中的反向代理和负载均衡作用,用于分发请求、减轻服务器压力及解决session共享问题,涵盖配置方法、策略及Java项目应用,并提及分布式事... 目录分布式部署NginxJava中的代理代理分为正向代理和反向代理正向代理反向代理Nginx应用场景

k8s按需创建PV和使用PVC详解

《k8s按需创建PV和使用PVC详解》Kubernetes中,PV和PVC用于管理持久存储,StorageClass实现动态PV分配,PVC声明存储需求并绑定PV,通过kubectl验证状态,注意回收... 目录1.按需创建 PV(使用 StorageClass)创建 StorageClass2.创建 PV

Python版本信息获取方法详解与实战

《Python版本信息获取方法详解与实战》在Python开发中,获取Python版本号是调试、兼容性检查和版本控制的重要基础操作,本文详细介绍了如何使用sys和platform模块获取Python的主... 目录1. python版本号获取基础2. 使用sys模块获取版本信息2.1 sys模块概述2.1.1