sklearn中SVM的可视化

2024-06-21 13:32
文章标签 可视化 svm sklearn

本文主要是介绍sklearn中SVM的可视化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 第一部分:如何绘制三维散点图和分类平面
    • 第二部分:sklearn中的SVM参数介绍
    • 第三部分:源代码and数据

最近遇到一个简单的二分类任务,本来可用一维的线性分类器来解决,但是为了获得更好的泛化性能,我选取了三个特征,变成了一个三维空间的二分类任务。目的就是使两类样本之间的间隔再大一些,为了满足这种需求,自然而然的想到使用SVM作为分类器,并且该任务是线性可分,自然的选用LinearSVM——核函数为线性函数。为了充分理解SVM,我还对SVM的分类平面、支持向量、bad case均进行可视化,通过本文可以了解:

1 .如何用matplotlib绘制三维散点图
2 .sklearn中SVM的核函数

先看图:

这里写图片描述

其中蓝色的是负样本,红色为正样本,带绿色圈圈的是支持向量,蓝色平面就是分类平面;

再看测试集,其中绿色圈圈,圈出来的是分类错误的样本:
这里写图片描述

第一部分:如何绘制三维散点图和分类平面

这里采用sklearn里面的SVC作为我的分类器,由于分类任务较为简单,几乎是线性可分,所以采用线性核函数,通过以下语句构建SVM并训练:

cls = svm.SVC(kernel='linear', C=1.5)
cls.fit(x_train, y_train)

其中C为惩罚因子,C越大,模型越不能容忍错误,则会使模型更容易过拟合,反之C越小,模型对错误样本容忍性很强,可能导致模型欠拟合。关于系数C的理论可以参考博客,点这里呀

训练好我们的cls之后就是如何绘制分类平面了,这就得知道我们分类平面的表达式是什么。SVM分类平面通式为Wφ(X)+b = 0 ,当采用线性核函数时,分类平面简化为:WX+b=0 (φ(X)=X),其中W,X为向量,b为标量,想进一步了解核函数作用的朋友可以参考博客,点这里呀

本文用例X是一个三维向量,因此W也应该是一个三维的向量,W和b 分别可从cls的coef_ , intercept_这两个属性中获取,具体如下:

w = cls.coef_  
b = cls.intercept_

则绘制分类平面步骤:

	ax = plt.subplot(111, projection='3d')x = np.arange(0,1,0.01)y = np.arange(0,1,0.11)x, y = np.meshgrid(x, y)z = (w[0,0]*x + w[0,1]*y + b) / (-w[0,2])surf = ax.plot_surface(x, y, z, rstride=1, cstride=1)

首先,创建一个3d的画布,其次要构建分类平面表达式 z = (w[0,0]x + w[0,1]y + b) / (-w[0,2])
其实是这样演变的:
Wφ(X)+b = 0 $ \Rightarrow $ WX+b=0 $ \Rightarrow $ w1
x1+w2
x2+w3*x3 + b = 0

有了分类平面,我们还想知道,支持向量是哪些,那么可以通过cls中的support_ 属性获取支持向量的idx,然后依据idx去训练集中找到我们的支持向量

第二部分:sklearn中的SVM参数介绍

SVM中最关键的就是核函数的选择,上一部分中仅仅采用了最简单的线性核函数(其实等于没用核函数,哈哈哈),SVM中常用的核函数有高斯核(rbf,径向基)、多项式核以及sigmoid核。在这里就简单介绍sklearn中SVM的这些核函数具体使用方法。

1.高斯核(rbf) 表达式:$ K(x,z)=exp(−γ||x−z||^{2})$
涉及参数 γ,默认值为 1/特征维度
创建一个高斯核的SVM分类器:
cls = svm.SVC(kenerl = ‘rbf’,gamma = 0.5 )

2.sigmoid核函数表达式: K ( x , z ) = t a n h ( γ x ∙ z + r ) K(x,z)=tanh(γx∙z+r) K(x,z)=tanhγxz+r)
涉及两个参数:γ,r
γ通过gamma设置,默认值为1/特征维度; r通过coef0设置,默认值为0;
创建一个sigmoid核函数的SVM:
cls = svm.SVC(kenerl = ‘sigmoid’,gamma = 0.3,coef0=0)

3.多项式核表达式: K ( x , z ) = ( γ x ∙ z + r ) d K(x,z)=(γx∙z+r)^{d} K(x,z)=γxz+r)d
涉及三个参数:γ,r,d
γ通过gamma设置,默认值为1/特征维度; r通过coef0设置,默认值为0;,d通过degree设置,默认值为3
创建一个二阶多项式核的SVM:
cls = svm.SVC(kenerl = ‘poly’,gamma = 0.3,coef0=0,dgree=2 )

在SVC中还有一个参数可以控制样本的权重,用以解决unbalance问题,class_weight,具体参考官方文档:
http://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html
推荐博客:
https://www.cnblogs.com/pinard/p/6117515.html

最后留下一个疑问,倘若我采用非线性核,如高斯核函数,我应该如何绘制分类平面?

我的思路是这样的,分类平面表达式:Wφ(X) + b =0, 当采用非线性核的时候,我们如何能知道这个映射函数φ(·)呢?

第三部分:源代码and数据

代码+数据文件可从:
1.CSDN下载:https://download.csdn.net/download/u011995719/10557270
2.百度云: https://pan.baidu.com/s/1s5Xu_h2nlTSum7jeoKniGQ 密码: gtc8

代码:

# coding: utf-8
import numpy as np
import csv
from sklearn import svm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D"""
采用 sklearn 中的svm 
"""
train_path = './train_set.csv'
test_path = './test_set.csv'def load_data(data_path):X,Y = [],[]csv_reader = csv.reader(open(data_path,'r'))for row in csv_reader:a = row[0][1:-1].split()X.append(np.array(a))Y.append(np.array(row[1]))return X, Ydef find_badcase(X, Y):bad_list = []y = cls.predict(X)for i in range(len(X)):if y[i] != Y[i]:bad_list.append(i)return bad_listif __name__=="__main__":# load datax_train, y_train = load_data(train_path)x_test, y_test = load_data(test_path)# trainingcls = svm.SVC(kernel='linear', C=1.5)cls.fit(x_train, y_train)# accuracyprint('Test score: %.4f' % cls.score(x_test, y_test))print('Train score: %.4f' % cls.score(x_train, y_train))# print bad case idbad_idx = find_badcase(x_test,y_test)n_Support_vector = cls.n_support_  # 支持向量个数sv_idx = cls.support_  # 支持向量索引w = cls.coef_  # 方向向量Wb = cls.intercept_# plot# 绘制分类平面ax = plt.subplot(111, projection='3d')x = np.arange(0,1,0.01)y = np.arange(0,1,0.11)x, y = np.meshgrid(x, y)z = (w[0,0]*x + w[0,1]*y + b) / (-w[0,2])surf = ax.plot_surface(x, y, z, rstride=1, cstride=1)# 绘制三维散点图x_array = np.array(x_train, dtype=float)y_array = np.array(y_train, dtype=int)pos = x_array[np.where(y_array==1)]neg = x_array[np.where(y_array==-1)]ax.scatter(pos[:,0], pos[:,1], pos[:,2], c='r', label='pos')ax.scatter(neg[:,0], neg[:,1], neg[:,2], c='b', label='neg')# 绘制支持向量X = np.array(x_train,dtype=float)for i in range(len(sv_idx)):ax.scatter(X[sv_idx[i],0], X[sv_idx[i],1], X[sv_idx[i],2],s=50,c='',marker='o', edgecolors='g')# 绘制 bad case# x_test = np.array(x_test,dtype=float)# for i in range(len(bad_idx)):#     j = bad_idx[i]#     ax.scatter(x_test[j,0], x_test[j,1], x_test[j,2],s=60,#                c='',marker='o', edgecolors='g')ax.set_zlabel('Z')    # 坐标轴ax.set_ylabel('Y')ax.set_xlabel('X')ax.set_zlim([0, 1])plt.legend(loc='upper left')ax.view_init(35,300)plt.show()

再次强调:请问如何绘制非线性核的分类平面呢??

这篇关于sklearn中SVM的可视化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1081356

相关文章

Python数据分析与可视化的全面指南(从数据清洗到图表呈现)

《Python数据分析与可视化的全面指南(从数据清洗到图表呈现)》Python是数据分析与可视化领域中最受欢迎的编程语言之一,凭借其丰富的库和工具,Python能够帮助我们快速处理、分析数据并生成高质... 目录一、数据采集与初步探索二、数据清洗的七种武器1. 缺失值处理策略2. 异常值检测与修正3. 数据

使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)

《使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)》字体设计和矢量图形处理是编程中一个有趣且实用的领域,通过Python的matplotlib库,我们可以轻松将字体轮廓... 目录背景知识字体轮廓的表示实现步骤1. 安装依赖库2. 准备数据3. 解析路径指令4. 绘制图形关键

8种快速易用的Python Matplotlib数据可视化方法汇总(附源码)

《8种快速易用的PythonMatplotlib数据可视化方法汇总(附源码)》你是否曾经面对一堆复杂的数据,却不知道如何让它们变得直观易懂?别慌,Python的Matplotlib库是你数据可视化的... 目录引言1. 折线图(Line Plot)——趋势分析2. 柱状图(Bar Chart)——对比分析3

使用Vue-ECharts实现数据可视化图表功能

《使用Vue-ECharts实现数据可视化图表功能》在前端开发中,经常会遇到需要展示数据可视化的需求,比如柱状图、折线图、饼图等,这类需求不仅要求我们准确地将数据呈现出来,还需要兼顾美观与交互体验,所... 目录前言为什么选择 vue-ECharts?1. 基于 ECharts,功能强大2. 更符合 Vue

Git可视化管理工具(SourceTree)使用操作大全经典

《Git可视化管理工具(SourceTree)使用操作大全经典》本文详细介绍了SourceTree作为Git可视化管理工具的常用操作,包括连接远程仓库、添加SSH密钥、克隆仓库、设置默认项目目录、代码... 目录前言:连接Gitee or github,获取代码:在SourceTree中添加SSH密钥:Cl

Pandas中统计汇总可视化函数plot()的使用

《Pandas中统计汇总可视化函数plot()的使用》Pandas提供了许多强大的数据处理和分析功能,其中plot()函数就是其可视化功能的一个重要组成部分,本文主要介绍了Pandas中统计汇总可视化... 目录一、plot()函数简介二、plot()函数的基本用法三、plot()函数的参数详解四、使用pl

使用Python实现矢量路径的压缩、解压与可视化

《使用Python实现矢量路径的压缩、解压与可视化》在图形设计和Web开发中,矢量路径数据的高效存储与传输至关重要,本文将通过一个Python示例,展示如何将复杂的矢量路径命令序列压缩为JSON格式,... 目录引言核心功能概述1. 路径命令解析2. 路径数据压缩3. 路径数据解压4. 可视化代码实现详解1

Python 交互式可视化的利器Bokeh的使用

《Python交互式可视化的利器Bokeh的使用》Bokeh是一个专注于Web端交互式数据可视化的Python库,本文主要介绍了Python交互式可视化的利器Bokeh的使用,具有一定的参考价值,感... 目录1. Bokeh 简介1.1 为什么选择 Bokeh1.2 安装与环境配置2. Bokeh 基础2

基于Python打造一个可视化FTP服务器

《基于Python打造一个可视化FTP服务器》在日常办公和团队协作中,文件共享是一个不可或缺的需求,所以本文将使用Python+Tkinter+pyftpdlib开发一款可视化FTP服务器,有需要的小... 目录1. 概述2. 功能介绍3. 如何使用4. 代码解析5. 运行效果6.相关源码7. 总结与展望1

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1