如何通过绘制【学习曲线】来判断模型是否【过拟合】

2024-01-07 13:04

本文主要是介绍如何通过绘制【学习曲线】来判断模型是否【过拟合】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

      学习曲线是一种图形化工具,用于展示模型在训练集和验证集(或测试集)上的性能随着训练样本数量的增加而如何变化。它可以帮助我们理解模型是否受益于更多的训练数据,以及模型是否可能存在过拟合或欠拟合问题。学习曲线的x轴通常是训练样本的数量或训练迭代的次数,y轴是模型的性能指标,如准确率或损失函数的值。

- 如果模型在训练集上的性能随着训练样本数量的增加而提高,但在验证集上的性能提高不明显或者甚至下降,那么模型可能存在过拟合问题。
- 如果模型在训练集和验证集上的性能都随着训练样本数量的增加而提高,且两者的性能都还有提升的空间,那么模型可能会从更多的训练数据中受益。
- 如果模型在训练集和验证集上的性能都随着训练样本数量的增加而提高,但两者的性能提升已经很小或者没有提升,那么模型可能存在欠拟合问题,或者已经达到了它的性能上限。

在这里,我们以贝叶斯算法为例:

我们先来导入相应的库:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.naive_bayes import GaussianNB
from sklearn.datasets import load_digits
from sklearn.model_selection import learning_curve #画学习曲线的类
from sklearn.model_selection import ShuffleSplit #设定交叉验证模式的类

接下来定义一个绘制学习曲线的函数:

def plot_learning_curve(estimator,title, X, y, ax, #选择子图ylim=None, #设置纵坐标的取值范围cv=None, #交叉验证n_jobs=None #设定索要使用的线程):train_sizes, train_scores, test_scores = learning_curve(estimator, X, y,cv=cv,n_jobs=n_jobs)    ax.set_title(title)if ylim is not None:ax.set_ylim(*ylim)ax.set_xlabel("Training examples")ax.set_ylabel("Score")ax.grid() #显示网格作为背景,不是必须ax.plot(train_sizes, np.mean(train_scores, axis=1), 'o-', color="r",label="Training score") # 画出训练集学习曲线ax.plot(train_sizes, np.mean(test_scores, axis=1), 'o-', color="g",label="Test score") # 画出验证集学习曲线ax.legend(loc="best")return ax

这段代码使用了`learning_curve`函数,该函数是一个非常有用的工具,用于生成学习曲线的数据。学习曲线可以帮助我们理解随着训练样本数量的增加,模型的性能如何变化。

`learning_curve`函数的参数包括:

- `estimator`:这是用于训练的模型。
- `X`和`y`:这是用于训练的数据和对应的标签。
- `cv`:这是交叉验证的策略。
- `n_jobs`:这是用于计算的线程数。

`learning_curve`函数返回三个值:

- `train_sizes`:这是用于生成学习曲线的训练集的样本数。
- `train_scores`:这是在每个训练集大小下,模型在训练集上的得分。
- `test_scores`:这是在每个训练集大小下,模型在交叉验证集上的得分。

这些返回的值可以用于绘制学习曲线,以帮助我们理解模型随着训练样本数量的增加,其性能如何变化。

接下来再导入手写数据集:

digits = load_digits()
X, y = digits.data, digits.target

再用如下代码绘制子图和学习曲线:

fig, axes = plt.subplots(1, 1, figsize=(10, 6))  # Define the axes variable
cv = ShuffleSplit(n_splits=50, test_size=0.2, random_state=0)
plot_learning_curve(GaussianNB(), "Naive Bayes", X, y,  ax=axes, ylim=[0.7, 1.05], n_jobs=4, cv=cv)
plt.show()

结果分析:可以看出贝叶斯作为一个分类器,效果不是很理想。可以观察到,随着样本量逐渐增大,训练分数逐渐降低,从95%下降到85%,但是测试分数逐渐增高,从75%上升到85%。测试分数在逐渐逼近训练分数,过拟合问题在逐渐减弱。但是,可以想象,接下来即使再增大样本量,测试分数和训练分数也不会变高,只会趋近于某个值。综上所述,朴素贝叶斯是依赖于训练集准确率的下降,测试集准确率上升来解决过拟合问题。

这篇关于如何通过绘制【学习曲线】来判断模型是否【过拟合】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/579998

相关文章

Python使用Matplotlib绘制3D曲面图详解

《Python使用Matplotlib绘制3D曲面图详解》:本文主要介绍Python使用Matplotlib绘制3D曲面图,在Python中,使用Matplotlib库绘制3D曲面图可以通过mpl... 目录准备工作绘制简单的 3D 曲面图绘制 3D 曲面图添加线框和透明度控制图形视角Matplotlib

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Python如何精准判断某个进程是否在运行

《Python如何精准判断某个进程是否在运行》这篇文章主要为大家详细介绍了Python如何精准判断某个进程是否在运行,本文为大家整理了3种方法并进行了对比,有需要的小伙伴可以跟随小编一起学习一下... 目录一、为什么需要判断进程是否存在二、方法1:用psutil库(推荐)三、方法2:用os.system调用

Python实现特殊字符判断并去掉非字母和数字的特殊字符

《Python实现特殊字符判断并去掉非字母和数字的特殊字符》在Python中,可以通过多种方法来判断字符串中是否包含非字母、数字的特殊字符,并将这些特殊字符去掉,本文为大家整理了一些常用的,希望对大家... 目录1. 使用正则表达式判断字符串中是否包含特殊字符去掉字符串中的特殊字符2. 使用 str.isa

Python中判断对象是否为空的方法

《Python中判断对象是否为空的方法》在Python开发中,判断对象是否为“空”是高频操作,但看似简单的需求却暗藏玄机,从None到空容器,从零值到自定义对象的“假值”状态,不同场景下的“空”需要精... 目录一、python中的“空”值体系二、精准判定方法对比三、常见误区解析四、进阶处理技巧五、性能优化

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

C++实现回文串判断的两种高效方法

《C++实现回文串判断的两种高效方法》文章介绍了两种判断回文串的方法:解法一通过创建新字符串来处理,解法二在原字符串上直接筛选判断,两种方法都使用了双指针法,文中通过代码示例讲解的非常详细,需要的朋友... 目录一、问题描述示例二、解法一:将字母数字连接到新的 string思路代码实现代码解释复杂度分析三、