《李航:统计学习方法》笔记之感知机

2024-06-09 17:38

本文主要是介绍《李航:统计学习方法》笔记之感知机,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

感知机学习旨在求出将训练数据集进行线性划分的分类超平面,为此,导入了基于误分类的损失函数,然后利用梯度下降法对损失函数进行极小化,从而求出感知机模型。感知机模型是神经网络和支持向量机的基础。也是现代流行的深度学习网络模型的基础。下面分别从感知机学习的模型、策略和算法三个方面来介绍。

1. 感知机模型

      感知机模型如下:

f(x)= sign(w*x+b)

      其中,x为输入向量,sign为符号函数,括号里面大于等于0,则其值为1,括号里面小于0,则其值为-1。w为权值向量,b为偏置。求感知机模型即求模型参数w和b。感知机预测,即通过学习得到的感知机模型,对于新的输入实例给出其对应的输出类别1或者-1。

2. 感知机策略

      假设训练数据集是线性可分的,感知机学习的目标就是求得一个能够将训练数据集中正负实例完全分开的分类超平面,为了找到分类超平面,即确定感知机模型中的参数w和b,需要定义一个损失函数并通过将损失函数最小化来求w和b。

       这里选择的损失函数是误分类点到分类超平面S的总距离。输入空间中任一点x 0到超平面S的距离为:

其中,||w||为w的L2范数。

        其次,对于误分类点来说,当-y (wx i + b)>0时,y i=-1,当-y i(wx i + b)<0时,y i=+1。所以对误分类点(x i, y i)满足:

-y(wxi +b) > 0

所以误分类点(x i, y i)到分类超平面S的距离是:

3. 感知机算法

       感知机学习问题转化为求解损失函数式(1)的最优化问题,最优化的方法是随机梯度下降法。感知机学习算法是误分类驱动的,具体采用随机梯度下降法。首先,任意选取一个超平面w0,b0,然后用梯度下降法不断极小化目标函数式(1)。极小化的过程不是一次使M中所有误分类点的梯度下降,而是一次随机选取一个误分类点使其梯度下降。

       损失函数L(w,b)的梯度是对w和b求偏导,即:

其中,(0<<=1)是学习率,即学习的步长。综上,感知机学习算法如下:

        这种算法的基本思想是:当一个实例点被误分类,即位于分类超平面错误的一侧时,则调整w和b,使分类超平面向该误分类点的一侧移动,以减少该误分类点与超平面的距离,直到超平面越过该误分类点使其被正确分类为止。

       需要注意的是,这种感知机学习算法得到的模型参数不是唯一的,它会由于采用不同的参数初始值或选取不同的误分类点,而导致解不同。为了得到唯一的分类超平面,需要对分类超平面增加约束条件,线性支持向量机就是这个想法。另外,当训练数据集线性不可分时,感知机学习算法不收敛,迭代结果会发生震荡。而对于线性可分的数据集,算法一定是收敛的,即经过有限次迭代,一定可以得到一个将数据集完全正确划分的分类超平面及感知机模型。

       以上是感知机学习算法的原始形式,下面介绍感知机学习算法的对偶形式,对偶形式的基本想法是,将w和b表示为实例x i和标记y i的线性组合形式,通过求解其系数而求得w和b。对误分类点(x i, y i)通过


所以,感知机学习算法的对偶形式如下:

感知机的原始形式和对偶形式在解决问题的计算上是一致的,但是他们的思想不同,原始形式的基本思想是对于误分类点,调整w和b,使分类超平面向该误分类点的一侧移动,以减少该误分类点与超平面的距离,直到超平面越过该误分类点使其被正确分类为止。 而对偶形式的基本思想是将w和b表示成x和y的线性组合形式,从而求出w和b。 

1)原始形式代码如下:

[cpp]  view plain copy
  1. #include <iostream>  
  2. using namespace std;  
  3.   
  4. int x[3][2] = {  
  5.     {3, 3},  
  6.     {4, 3},  
  7.     {1, 1}  
  8. };  
  9.   
  10. int y[3] = {1, 1, -1};  
  11.   
  12. int w[2] = {0};  
  13. int b = 0;  
  14.   
  15. int L(int y, int* x)  
  16. {  
  17.     int temp = (w[0] * x[0] + w[1] * x[1] + b) * y;  
  18.     if (temp <= 0)  
  19.         return 1;//存在错误点  
  20.     else  
  21.         return 0;  
  22. }  
  23.   
  24. int main(void)  
  25. {  
  26.     int j = 1;  
  27.     while (true)  
  28.     {  
  29.         cout << j++ << " ";  
  30.           
  31.         int i;  
  32.         int num = 0;  
  33.         for (i = 0; i < 3; i++)  
  34.         {  
  35.             if (L(y[i], x[i]) == 1)  
  36.             {  
  37.                 cout << "error point:";  
  38.                 cout << "x" << i <<" w:";  
  39.                 int j;  
  40.                 for (j = 0; j < 2; j++)  
  41.                 {  
  42.                     w[j] += y[i] * x[i][j];  
  43.                     cout << w[j] << " ";  
  44.                 }  
  45.                 b += y[i];  
  46.                 cout << "b:" << b <<endl;  
  47.                 num++;  
  48.                 break;  
  49.             }  
  50.         }  
  51.         if (num == 0)  
  52.             break;  
  53.     }  
  54.     return 0;  
  55. }  

实验结果:

1 error point:x0 w:3 3 b:1

2 error point:x2 w:2 2 b:0

3 error point:x2 w:1 1 b:-1

4 error point:x2 w:0 0 b:-2

5 error point:x0 w:3 3 b:-1

6 error point:x2 w:2 2 b:-2

7 error point:x2 w:1 1 b:-3

 

这跟p30的结果是一样的,不过要注意的是,在极小化的过程中,为了达到书中的结果,选择的误分类点都是第一次遇到的误分类点,而实际上在选择误分类点时应该采用随机的方法来选取,而且每次梯度下降的时候并不是对所有误分类点进行梯度下降,而是只对随机选择的一个误分类点进行梯度下降。结果与误分类点的选择有关。

2)对偶形式,代码如下:

[cpp]  view plain copy
  1. #include <iostream>  
  2. using namespace std;  
  3.   
  4. int x[3][2] = {  
  5.     {3, 3},  
  6.     {4, 3},  
  7.     {1, 1}  
  8. };  
  9.   
  10. int y[3] = {1, 1, -1};  
  11.   
  12. int b = 0;  
  13. int a[3] = {0};  
  14. int G[3][3] = {  
  15.     {18, 21, 6},  
  16.     {21, 25, 7},  
  17.     {6, 7, 2}  
  18. };//Gram matrix  
  19. int L(int j)  
  20. {  
  21.     int temp = 0;  
  22.     for (int i=0 ;i < 3; i++)  
  23.     {  
  24.         temp += a[i] * G[i][j] * y[i];  
  25.     }  
  26.     temp += b;  
  27.     temp *= y[j];  
  28.     if (temp <= 0)  
  29.         return 1;//存在错误点  
  30.     else  
  31.         return 0;  
  32. }  
  33.   
  34. int main(void)  
  35. {  
  36.     int j = 1;  
  37.     while (true)  
  38.     {  
  39.         cout << j++ << " ";  
  40.           
  41.         int i;  
  42.         int num = 0;  
  43.         for (i = 0; i < 3; i++)  
  44.         {  
  45.             if (L(i) == 1)  
  46.             {  
  47.                 cout << "error point:";  
  48.                 cout << "x" << i <<" a:";  
  49.                 int j;  
  50.                 a[i] += 1;  
  51.                 for (j = 0; j < 3; j++)  
  52.                 {  
  53.                     cout << a[j] << " ";  
  54.                 }  
  55.                 b += y[i];  
  56.                 cout << "b:" << b <<endl;  
  57.                 num++;  
  58.                 break;  
  59.             }  
  60.         }  
  61.         if (num == 0)  
  62.             break;  
  63.     }  
  64.     return 0;  
  65. }  

实验结果如下:

1 error point:x0 a:1 0 0 b:1

2 error point:x2 a:1 0 1 b:0

3 error point:x2 a:1 0 2 b:-1

4 error point:x2 a:1 0 3 b:-2

5 error point:x0 a:2 0 3 b:-1

6 error point:x2 a:2 0 4 b:-2

7 error point:x2 a:2 0 5 b:-3



参考文献:
http://blog.csdn.net/qll125596718/article/details/8394186
http://wenku.baidu.com/link?url=80H_Cz1qBitKGe2IJKwsHqgxS72wMROtDYO9N9NwyV5qbvOJrA9-0sUsWjbCsGO5eEGswZVuUuLD9xlq4A-FRhg3Sb6XrfNws7yvma5ec1y

这篇关于《李航:统计学习方法》笔记之感知机的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1045853

相关文章

Java中的getBytes()方法使用详解

《Java中的getBytes()方法使用详解》:本文主要介绍Java中getBytes()方法使用的相关资料,getBytes()方法有多个重载形式,可以根据需要指定字符集来进行转换,文中通过代... 目录前言一、常见重载形式二、示例代码三、getBytes(Charset charset)和getByt

nginx负载均衡及详细配置方法

《nginx负载均衡及详细配置方法》Nginx作为一种高效的Web服务器和反向代理服务器,广泛应用于网站的负载均衡中,:本文主要介绍nginx负载均衡及详细配置,需要的朋友可以参考下... 目录一、 nginx负载均衡策略1.1 基本负载均衡策略1.2 第三方策略1.3 策略对比二、 nginx配置2.1

Java调用Python的四种方法小结

《Java调用Python的四种方法小结》在现代开发中,结合不同编程语言的优势往往能达到事半功倍的效果,本文将详细介绍四种在Java中调用Python的方法,并推荐一种最常用且实用的方法,希望对大家有... 目录一、在Java类中直接执行python语句二、在Java中直接调用Python脚本三、使用Run

Android 12解决push framework.jar无法开机的方法小结

《Android12解决pushframework.jar无法开机的方法小结》:本文主要介绍在Android12中解决pushframework.jar无法开机的方法,包括编译指令、框架层和s... 目录1. android 编译指令1.1 framework层的编译指令1.2 替换framework.ja

在.NET平台使用C#为PDF添加各种类型的表单域的方法

《在.NET平台使用C#为PDF添加各种类型的表单域的方法》在日常办公系统开发中,涉及PDF处理相关的开发时,生成可填写的PDF表单是一种常见需求,与静态PDF不同,带有**表单域的文档支持用户直接在... 目录引言使用 PdfTextBoxField 添加文本输入域使用 PdfComboBoxField

SQLyog中DELIMITER执行存储过程时出现前置缩进问题的解决方法

《SQLyog中DELIMITER执行存储过程时出现前置缩进问题的解决方法》在SQLyog中执行存储过程时出现的前置缩进问题,实际上反映了SQLyog对SQL语句解析的一个特殊行为,本文给大家介绍了详... 目录问题根源正确写法示例永久解决方案为什么命令行不受影响?最佳实践建议问题根源SQLyog的语句分

Pandas中统计汇总可视化函数plot()的使用

《Pandas中统计汇总可视化函数plot()的使用》Pandas提供了许多强大的数据处理和分析功能,其中plot()函数就是其可视化功能的一个重要组成部分,本文主要介绍了Pandas中统计汇总可视化... 目录一、plot()函数简介二、plot()函数的基本用法三、plot()函数的参数详解四、使用pl

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义