使用early stopping解决神经网络过拟合问题

2024-05-26 08:48

本文主要是介绍使用early stopping解决神经网络过拟合问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

神经网络训练多少轮是一个很关键的问题,训练轮数少了欠拟合(underfit),训练轮数多了过拟合(overfit),那如何选择训练轮数呢?

Early stopping可以帮助我们解决这个问题,它的作用就是当模型在验证集上的性能不再增加的时候就停止训练,从而达到充分训练的作用,又避免过拟合。

一、在Keras中使用early stopping

完整代码

Keras中有EarlyStopping类,可以直接拿来使用,非常方便

from keras.callbacks import EarlyStoppingearlystop = EarlyStopping(monitor = 'val_loss',mode='min',min_delta = 0,patience = 3,verbose = 1,)
  1. monitor。想要监控的指标,比如在这里我们主要看的是验证集上的loss,当loss不再降低的时候就停止
  2. mode。想要最大值还是最小值,在这里我们使用的min,当时loss越小越好
  3. min_delta。指标的变化超过min_delta才认为产生了变化,否则都认为不再上升或下降
  4. patience。多少轮不发生变化才停止
  5. verbose。设置为1的时候,训练结束会打印出epoch的情况

二、保存最佳模型

完整代码

在early stopping结束后得到模型不一定是最佳模型,所以我们需要把训练过程中表现最好的模型保存下来,以便使用。在这里我们可以使用Keras提供的另一callback来实现:

from keras.callbacks import ModelCheckpointmc = ModelCheckpoint(file_path='./best_model.h5',monitor='val_accuracy',mode='max',verbose=1,save_best_only=True)
  1. filepath,模型存储的路径
  2. monitor,监控的指标
  3. mode,最大还是最小模式
  4. verbose,日志显示控制
  5. save_best_only,是否只存储最好的模型

通过使用这个方法我们就可以把最好的模型存储下来,在使用的时候直接load就可以了。

三、在IMDB数据集上使用Early Stopping

完整代码​​​​​​​

IMDB是一个情感分析数据集,我们首先在这个数据集上使用一个简单的CNN看看效果,然后再使用Early Stopping作为对比。首先看看CNN代码。先对句子embedding, 然后使用一层Conv1D+Maxpooling。

# Build model
sentence = Input(batch_shape=(None, max_words), dtype='int32', name='sentence')
embedding_layer = Embedding(top_words, embedding_dims, input_length=max_words)
sent_embed = embedding_layer(sentence)
conv_layer = Conv1D(filters, kernel_size, padding='valid', activation='relu')
sent_conv = conv_layer(sent_embed)
sent_pooling = GlobalMaxPooling1D()(sent_conv)
sent_repre = Dense(250)(sent_pooling)
sent_repre = Activation('relu')(sent_repre)
sent_repre = Dense(1)(sent_repre)
pred = Activation('sigmoid')(sent_repre)
model = Model(inputs=sentence, outputs=pred)
rmsprop = optimizers.rmsprop(lr=0.0003)
model.compile(loss='binary_crossentropy', optimizer=rmsprop, metrics=['accuracy'])

最终在数据集上的结果如下,在训练集上基本达到了100,而在测试集上还不到90,看起来有点过拟合了

Training Accuracy: 100%
Test Accuracy: 88.50%

我们再看Loss曲线,大约在第8轮的时候,验证集上的Loss达到最低,但是在往后Loss开始升高,这就更加确定发生了过拟合,我们需要提前停止训练,最好在第8轮之后就停下来。

在IMDB数据集上使用Early Stopping

我们再训练过程中加上一个patience=10的earlystop,监控验证集loss。当验证集的loss在近10轮都没有下降的话就停止。

#early stopping
earlystop = EarlyStopping(monitor='val_loss',min_delta=0,patience=10,verbose=1)# fit the model
history = model.fit(x_train, y_train, batch_size=batch_size,epochs=epochs, verbose=1, validation_data=(x_test, y_test), callbacks[earlystop])

结果如下,我们可以看到训练最终在第16轮停止了,停止时在测试集上的准确率为88.40%,并没有高于不使用Early Stopping的情况,但是在训练的第12轮模型的准确达到了89.30%,超过了Baseline。所以我们需要加上存储最好模型的callback。

Epoch 2/50
5000/5000 [==============================] - 5s 951us/step - loss: 0.4851 - acc: 0.7986 - val_loss: 0.4320 - val_acc: 0.8170
Epoch 3/50
5000/5000 [==============================] - 5s 918us/step - loss: 0.3193 - acc: 0.8802 - val_loss: 0.3599 - val_acc: 0.8370
Epoch 4/50
5000/5000 [==============================] - 4s 882us/step - loss: 0.2093 - acc: 0.9322 - val_loss: 0.3392 - val_acc: 0.8530
Epoch 5/50
5000/5000 [==============================] - 4s 880us/step - loss: 0.1209 - acc: 0.9702 - val_loss: 0.4001 - val_acc: 0.8260
Epoch 6/50
5000/5000 [==============================] - 4s 887us/step - loss: 0.0600 - acc: 0.9884 - val_loss: 0.2900 - val_acc: 0.8710
Epoch 7/50
5000/5000 [==============================] - 4s 865us/step - loss: 0.0208 - acc: 0.9986 - val_loss: 0.2978 - val_acc: 0.8840
Epoch 8/50
5000/5000 [==============================] - 4s 883us/step - loss: 0.0053 - acc: 1.0000 - val_loss: 0.3180 - val_acc: 0.8840
Epoch 9/50
5000/5000 [==============================] - 4s 856us/step - loss: 0.0011 - acc: 1.0000 - val_loss: 0.3570 - val_acc: 0.8830
Epoch 10/50
5000/5000 [==============================] - 4s 845us/step - loss: 1.7574e-04 - acc: 1.0000 - val_loss: 0.4035 - val_acc: 0.8800
Epoch 11/50
5000/5000 [==============================] - 4s 869us/step - loss: 2.0190e-05 - acc: 1.0000 - val_loss: 0.4490 - val_acc: 0.8820
Epoch 12/50
5000/5000 [==============================] - 4s 846us/step - loss: 1.6874e-06 - acc: 1.0000 - val_loss: 0.5164 - val_acc: 0.8930
Epoch 13/50
5000/5000 [==============================] - 4s 860us/step - loss: 2.6231e-07 - acc: 1.0000 - val_loss: 0.5429 - val_acc: 0.8840
Epoch 14/50
5000/5000 [==============================] - 4s 870us/step - loss: 1.4614e-07 - acc: 1.0000 - val_loss: 0.5754 - val_acc: 0.8810
Epoch 15/50
5000/5000 [==============================] - 4s 888us/step - loss: 1.2477e-07 - acc: 1.0000 - val_loss: 0.5744 - val_acc: 0.8850
Epoch 16/50
5000/5000 [==============================] - 4s 876us/step - loss: 1.1823e-07 - acc: 1.0000 - val_loss: 0.5909 - val_acc: 0.8840
Epoch 00016: early stopping
Accuracy: 88.40%

存储最好模型

我们使用ModelCheckPoint存储最好的模型,具体如下,通过监控验证集上的准确率,我们把准确率最高的模型存储下来

from keras.callbacks import EarlyStopping, ModelCheckpointmc = ModelCheckpoint(filepath='best_model.h5',monitor='val_acc',mode='max',verbose=1,save_best_only=True)

然后在使用的时候进行load,然后就可以进行预测了

from keras.models import load_model
saved_model = load_model('best_model.h5')
# evaluate the model
_, train_acc = saved_model.evaluate(x_train, y_train, verbose=0)
_, test_acc = saved_model.evaluate(x_test, y_test, verbose=0)
print('Train: %.3f, Test: %.3f' % (train_acc, test_acc))

最终的结果如下

Train: 1.000, Test: 0.893

正确使用Early Stopping加上存储最佳模型可以帮助我们减轻过拟合,从而训练出表现更好的模型。

完整代码​​​​​​​​​​​​​​

这篇关于使用early stopping解决神经网络过拟合问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1003992

相关文章

使用Python实现IP地址和端口状态检测与监控

《使用Python实现IP地址和端口状态检测与监控》在网络运维和服务器管理中,IP地址和端口的可用性监控是保障业务连续性的基础需求,本文将带你用Python从零打造一个高可用IP监控系统,感兴趣的小伙... 目录概述:为什么需要IP监控系统使用步骤说明1. 环境准备2. 系统部署3. 核心功能配置系统效果展

解决IDEA报错:编码GBK的不可映射字符问题

《解决IDEA报错:编码GBK的不可映射字符问题》:本文主要介绍解决IDEA报错:编码GBK的不可映射字符问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录IDEA报错:编码GBK的不可映射字符终端软件问题描述原因分析解决方案方法1:将命令改为方法2:右下jav

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

redis中使用lua脚本的原理与基本使用详解

《redis中使用lua脚本的原理与基本使用详解》在Redis中使用Lua脚本可以实现原子性操作、减少网络开销以及提高执行效率,下面小编就来和大家详细介绍一下在redis中使用lua脚本的原理... 目录Redis 执行 Lua 脚本的原理基本使用方法使用EVAL命令执行 Lua 脚本使用EVALSHA命令

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

MyBatis模糊查询报错:ParserException: not supported.pos 问题解决

《MyBatis模糊查询报错:ParserException:notsupported.pos问题解决》本文主要介绍了MyBatis模糊查询报错:ParserException:notsuppo... 目录问题描述问题根源错误SQL解析逻辑深层原因分析三种解决方案方案一:使用CONCAT函数(推荐)方案二:

使用Python和Pyecharts创建交互式地图

《使用Python和Pyecharts创建交互式地图》在数据可视化领域,创建交互式地图是一种强大的方式,可以使受众能够以引人入胜且信息丰富的方式探索地理数据,下面我们看看如何使用Python和Pyec... 目录简介Pyecharts 简介创建上海地图代码说明运行结果总结简介在数据可视化领域,创建交互式地

Redis 热 key 和大 key 问题小结

《Redis热key和大key问题小结》:本文主要介绍Redis热key和大key问题小结,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、什么是 Redis 热 key?热 key(Hot Key)定义: 热 key 常见表现:热 key 的风险:二、

Java Stream流使用案例深入详解

《JavaStream流使用案例深入详解》:本文主要介绍JavaStream流使用案例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录前言1. Lambda1.1 语法1.2 没参数只有一条语句或者多条语句1.3 一个参数只有一条语句或者多

Java Spring 中 @PostConstruct 注解使用原理及常见场景

《JavaSpring中@PostConstruct注解使用原理及常见场景》在JavaSpring中,@PostConstruct注解是一个非常实用的功能,它允许开发者在Spring容器完全初... 目录一、@PostConstruct 注解概述二、@PostConstruct 注解的基本使用2.1 基本代