RNN学习:利用LSTM,GRU层解决航空公司评论数据预测问题

2024-03-29 13:48

本文主要是介绍RNN学习:利用LSTM,GRU层解决航空公司评论数据预测问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

RNN学习:利用LSTM,GRU解决航空公司评论数据预测问题

文章目录

  • RNN学习:利用LSTM,GRU解决航空公司评论数据预测问题
    • 1.RNN的介绍
      • 1.1 LSTM的简单介绍
      • 1.2 GRU的简单介绍
    • 2.数据集的介绍
    • 3.读取数据并作预处理
    • 4.模型的搭建
    • 结语

1.RNN的介绍

​ RNN,即循环神经网络,即一般的神经网络同层节点与节点之间并无连接,比如CNN隐藏单元之间并没有连接,那么这相对于一些序列问题上的处理就会效果很差。如翻译单词,全文的意思必须是根据所有单词来进行判断。或判断说话人情绪,评论好坏,最终的输出要和前面所有的输入发生关系,所以这里学者们提出循环神经网络,让上一个节点会对下一个节点传递状态向量,每个节点之间输出两个值,一个是我们要的输出,还有一个就是状态向量,该向量输入下一个下一个节点,最终输出为二维数据(None,units)units为RNN的隐藏单元数。

在这里插入图片描述

1.1 LSTM的简单介绍

​ 刚才我们说明了RNN会不断的向下一个节点传递状态,但是经过长时间的多次传递,最终传递的状态可能会引起梯度爆炸或梯度消失等问题,为了解决这个问题,学者们又提出了LSTM层来解决这个问题,LSTM层的内部存在一些门,他会通过训练门的参数控制了上一状态我们需要遗忘多少,并且在这一层状态的更新。

在这里插入图片描述

可以看到在这一个单元中上一层的输出ht-1和状态Ct-1都传递了进来从而经过我们的门来控制该单元遗忘并更新状态。

1.2 GRU的简单介绍

GRU是LSTM结构的一种变体,他可以做到与LSTM性能相当的情况下,计算量会比LSTM减少,他的网络结构如下

在这里插入图片描述

可以看到他作为LSTM的变体,与LSTM的相似之处,他也会有前一次的状态(但是不会有前一层的输出传入)向他传入并且通过训练控制前一次状态在本单元的遗忘与更新。

2.数据集的介绍

本次使用的是Twitter 美国航空公司情绪:2015年2月美国航空公司的Twitter数据,分类为正面,负面和中性推文(https://www.kaggle.com/crowdflower/twitter-airline-sentiment)

在这里插入图片描述

整个数据集使用CSV格式存储,这种文件格式是一种经常用来数据科学存储数据的纯文本文件,可以用EXCEL直接打开。

可以看到该数据集上有该评价的好坏有neural,positive,negative三种,关于评价的具体 文本是text下的,我们在此次任务中只会用到评价文本(作为数据),情绪好坏作为我们的标签,也就是真值。

3.读取数据并作预处理

​ 首先先清楚我们的目标在预处理过程,是想要提取一个序列(这个序列是由我们的评论转换的),和一个标签(标签也要数字化),那么我们接下来就开始从CSV格式文件中提取文本和标签并分别将他们转化成序列和数字。

import tensorflow as tf
keras=tf.keras
layers=keras.layers
import numpy as np
import pandas as pd
import re
data=pd.read_csv('../input/twitter-airline-sentiment/Tweets.csv')
data.head()#文件内部数据太多使用这个默认查看前五行

在这里插入图片描述

然后我们此次只需要提取每个人评论的text,和评论观点的倾向,所以我们提取以下两列

data=data[['airline_sentiment','text']]
data.head()

在这里插入图片描述

我们成功提取每个评论的情绪,和文本,接下来我们先将情绪用数字表示,可以先查看有多少种情绪

data.airline_sentiment.value_counts()#使用该方法可以查看每个值的个数
negative    9178
neutral     3099
positive    2363
Name: airline_sentiment, dtype: int64

可以看到这里有三个倾向的情绪,消极,中立,积极,那么也就是说这是一个多分类单标签问题,那么我们直接对每个情绪进行编码然后转化即可

sentiment_to_index={'positive':0,'neutral':1,'negative':2}
def to_index(sentiment):#写函数来转化return sentiment_to_index.get(sentiment)
data['sentiment']=data.airline_sentiment.apply(to_index)
del data['airline_sentiment'] #删除原有的一列
data.head()

在这里插入图片描述

可以看到我们的标签被成功的转化成对应的数字标签。

并且我们还要注意一点,消极的评论远远多余积极的评论,我们在训练分类问题上最好是将每个类别上的数据的数量都保持一致,防止模型对于某些分类的特征过分学习。也就是说我们在这里使用消极和中立的数量都必须被降为和积极一样,那么这里我们就直接使用切片,对于series数据切片使用iloc函数

data_positive=data[data.airline_sentiment=='positive']
data_negative=data[data.airline_sentiment=='negative']
data_neutral=data[data.airline_sentiment=='neutral']
data_negative=data_negative.iloc[:len(data_positive)]  
data_neutral=data_neutral.iloc[:len(data_positive)]
len(data_negative),len(data_neutral),len(data_positive)(2363, 2363, 2363) #可以看到我们将三个数据全部转化为相同个数

那么接下来我们合并我们的这些数据并且使用sample方法随机打乱(sample的用法是从原有数据随机抽出一部分数据,但是如果我们把抽出数据的规模等于所有数据,就相当于打乱)

data=pd.concat([data_negative,data_positive,data_neutral])
data=data.sample(len(data))  #smaple的意思是从dataframe中随机抽取指定数量的数据
data.head()

那么接下来我们就将每个文本转化为一个序列,怎么转化呢,其实很简单,那就是将每个句子里的单词映射成一个数字,那么整个句子就成为了一个数字序列,那么如何来完成了,接下来我们开始贴代码

token=re.compile('[A-Za-z]+|[!?,.()]')
#我们设置匹配的时候不要特殊字符,只要标点符号和字母,并且大小写不会影响单词原意,我这里也直接将所有大写转化成小写
def constractor_text(text):res_text=token.findall(text)res_text=[word.lower() for word in res_text]return res_text
#上面是使用re库提供的一个正则匹配方法在除去特殊符号其他均匹配情况下效果显著
new_data=data.text.apply(constractor_text)
data['text']=new_data
data.head()

在这里插入图片描述

那么接下里我们将单词全部映射成一个个数字其实想法很简单,先做一个集合将所有单词添加进集合吗,由于集合本身的特性,会自动删除重复的,然后我们将该集合中的单词转化成字典,就可以将单词转化成序列了,这里也简单的贴代码

word_list=list(word_set) #因为集合并没有下标这个概念,所以为了后面的方便我们转化成列表
word_dict=dict((k,v+1) for v,k in enumerate(word_list))
word_dict#同时为了防止填充单词之后填充0影响结果我们将所有数据,转化
{'win': 1,'DEFINITELY': 2,'gfc': 3,'OI': 4,'pearl': 5,'briughy': 6,'necessity': 7,'flyingwithUS': 8,'agreement': 9,...

这里需要非常注意的一点就是,每个评论的数据都是有一定长度的,但最后为了规范化我们一定是要将所有评论长度都处理到相同长度,那么我们填充的数字一般用0来填充,所以我们在字典里不能对0进行赋值,防止影响结果,所以我这里将所有单词对应的编号加一。可以看到我们单词编号从一开始。

好的有了单词的转换表,那么我们接下来编写函数将句子转换成序列

def word_to_vector(text):vector=[word_dict.get(word,0) for word in text]return vector
data['text']=data.text.apply(word_to_vector)
data_text=data['text']
data_text.head()8263    [3228, 11239, 9075, 694, 1133, 4364, 10324, 10...
4953    [1721, 11079, 870, 10, 11285, 9390, 10642, 724...
5489    [1721, 443, 6165, 4999, 4859, 4806, 7367, 7013...
2452    [3436, 10200, 6758, 10, 310, 1660, 8275, 10324...
8219    [3228, 11460, 10774, 10324, 1291, 6804, 516, 7...
Name: text, dtype: object

这里我们可以看到每个句子就都被转换为对应的序列,那么我们接下里将所有向量处理成完全一样的长度,

maxlen=max(len(x) for x in data_text)
max_word=len(word_set)+1
data_text=keras.preprocessing.sequence.pad_sequences(data_text.values,maxlen=maxlen)
data_text.shape
(7089, 40)

可以看到每个序列都被填充到了长度为40,那么我们接下来提取标签然后制作dataset,划分测试集与训练集

label=data.sentiment.values
test_count=int(7089*0.2)
train_count=7089-test_count
test_data=train_data.take(test_count)
train_data=train_data.skip(test_count)
train_data=train_data.shuffle(train_count).repeat().batch(64)
test_data=test_data.batch(64)

划分完毕后我们总算是完成了我们数据的预处理,接下来开始我们模型的搭建。

4.模型的搭建

我们输入的是一个长度为40的序列,但这样并不适合我们模型对他的处理,对此已经有提出词嵌入方法,WORD2VEC的方法,即将每个单词转化成固定维度的向量,向量之间差的大小,表示每个单词之间关系的大小(我理解为单词之间的相似性),这里我们可以用RGB表示颜色的方式来理解,每个颜色的值都可以用一个三维向量来表示,对于单词就是我们设置一个几十个维度的词向量,假设所有词都可以用这个高维向量来表示,那么具体怎么转换,有多种方法,我们这里使用keras提供的embelding层来将所有单词转换成我们设定维度的向量

model=keras.Sequential()
#Embedding层可以吧文本映射为一个密集向量
model.add(layers.Embedding(max_word,50,input_length=maxlen))
#然后我们多次未见的主角GRU登场,用它来处理这种序列数据效果是十分好的
model.add(layers.GRU(64))#LSTM的参数是一个隐藏单元数
model.add(layers.Dense(3,activation='softmax'))
#最后输出这是一个三分类的问题,所以我们激活函数用softmax
model.compile(optimizer=keras.optimizers.Adam(0.0001),loss='sparse_categorical_crossentropy',metrics=['acc'])
#设置模型的优化器这里没什么好说的

5.训练结果分析与网络调整

model.fit(train_data,steps_per_epoch=train_count//64,epochs=10,validation_data=test_data,validation_steps=test_count//64)

这里我们开始训练查看结果却发现

在这里插入图片描述

网络已经达到严重过拟合,测试集准确率极高,但验证集却非常低,两者相差达到20%,那么为了抑制过拟合我这里采取两种方法一是增加网络深度,添加Dropout层抑制过拟合

model=keras.Sequential()
#Embedding层可以吧文本映射为一个密集向量
model.add(layers.Embedding(max_word,50,input_length=maxlen))
model.add(layers.GRU(64))#LSTM的参数是一个隐藏单元数
model.add(layers.Dropout(0.2))
model.add(layers.Dense(32,activation='relu'))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(16,activation='relu'))
model.add(layers.Dense(3,activation='softmax'))

,二是我将数据增加一倍,(等于是复制了一遍数据再,打乱),最终数据翻倍达到14000多条那么我们再次开始训练,查看结果

Epoch 15/15
177/177 [==============================] - 6s 34ms/step - loss: 0.0542 - acc: 0.9852 - val_loss: 0.1551 - val_acc: 0.9592

可以看到在训练最后,过拟合被抑制了,模型无论在训练集,测试集都达到了极高的正确率

结语

本篇博客简单介绍了RNN网络,并且非常具体的展示了如何从CSV文件读取数据,预处理并制作成模型可以接收的数据,在最后利用GRU搭建模型,并且对于训练结果产生过拟合如何去抑制方面做了处理,如果有任何建议或者问题欢迎评论区指出,谢谢!

这篇关于RNN学习:利用LSTM,GRU层解决航空公司评论数据预测问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858763

相关文章

Python在二进制文件中进行数据搜索的实战指南

《Python在二进制文件中进行数据搜索的实战指南》在二进制文件中搜索特定数据是编程中常见的任务,尤其在日志分析、程序调试和二进制数据处理中尤为重要,下面我们就来看看如何使用Python实现这一功能吧... 目录简介1. 二进制文件搜索概述2. python二进制模式文件读取(rb)2.1 二进制模式与文本

JAVA Calendar设置上个月时,日期不存在或错误提示问题及解决

《JAVACalendar设置上个月时,日期不存在或错误提示问题及解决》在使用Java的Calendar类设置上个月的日期时,如果遇到不存在的日期(如4月31日),默认会自动调整到下个月的相应日期(... 目录Java Calendar设置上个月时,日期不存在或错误提示java进行日期计算时如果出现不存在的

Mybatis对MySQL if 函数的不支持问题解读

《Mybatis对MySQLif函数的不支持问题解读》接手项目后,为了实现多租户功能,引入了Mybatis-plus,发现之前运行正常的SQL语句报错,原因是Mybatis不支持MySQL的if函... 目录MyBATis对mysql if 函数的不支持问题描述经过查询网上搜索资料找到原因解决方案总结Myb

C#实现将XML数据自动化地写入Excel文件

《C#实现将XML数据自动化地写入Excel文件》在现代企业级应用中,数据处理与报表生成是核心环节,本文将深入探讨如何利用C#和一款优秀的库,将XML数据自动化地写入Excel文件,有需要的小伙伴可以... 目录理解XML数据结构与Excel的对应关系引入高效工具:使用Spire.XLS for .NETC

Nginx错误拦截转发 error_page的问题解决

《Nginx错误拦截转发error_page的问题解决》Nginx通过配置错误页面和请求处理机制,可以在请求失败时展示自定义错误页面,提升用户体验,下面就来介绍一下Nginx错误拦截转发error_... 目录1. 准备自定义错误页面2. 配置 Nginx 错误页面基础配置示例:3. 关键配置说明4. 生效

Java调用DeepSeek API的8个高频坑与解决方法

《Java调用DeepSeekAPI的8个高频坑与解决方法》现在大模型开发特别火,DeepSeek因为中文理解好、反应快、还便宜,不少Java开发者都用它,本文整理了最常踩的8个坑,希望对... 目录引言一、坑 1:Token 过期未处理,鉴权异常引发服务中断问题本质典型错误代码解决方案:实现 Token

springboot3.x使用@NacosValue无法获取配置信息的解决过程

《springboot3.x使用@NacosValue无法获取配置信息的解决过程》在SpringBoot3.x中升级Nacos依赖后,使用@NacosValue无法动态获取配置,通过引入SpringC... 目录一、python问题描述二、解决方案总结一、问题描述springboot从2android.x

MySQL数据目录迁移的完整过程

《MySQL数据目录迁移的完整过程》文章详细介绍了将MySQL数据目录迁移到新硬盘的整个过程,包括新硬盘挂载、创建新的数据目录、迁移数据(推荐使用两遍rsync方案)、修改MySQL配置文件和重启验证... 目录1,新硬盘挂载(如果有的话)2,创建新的 mysql 数据目录3,迁移 MySQL 数据(推荐两

Python数据验证神器Pydantic库的使用和实践中的避坑指南

《Python数据验证神器Pydantic库的使用和实践中的避坑指南》Pydantic是一个用于数据验证和设置的库,可以显著简化API接口开发,文章通过一个实际案例,展示了Pydantic如何在生产环... 目录1️⃣ 崩溃时刻:当你的API接口又双叒崩了!2️⃣ 神兵天降:3行代码解决验证难题3️⃣ 深度

MySQL快速复制一张表的四种核心方法(包括表结构和数据)

《MySQL快速复制一张表的四种核心方法(包括表结构和数据)》本文详细介绍了四种复制MySQL表(结构+数据)的方法,并对每种方法进行了对比分析,适用于不同场景和数据量的复制需求,特别是针对超大表(1... 目录一、mysql 复制表(结构+数据)的 4 种核心方法(面试结构化回答)方法 1:CREATE