【Keras学习笔记】10:IMDb电影评价数据集文本分类

2023-11-25 17:30

本文主要是介绍【Keras学习笔记】10:IMDb电影评价数据集文本分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

读取数据
import keras
from keras import layers
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
%matplotlib inline
Using TensorFlow backend.
data = keras.datasets.imdb
# 最多提取10000个单词,多的不要
(x_train, y_train), (x_test, y_test) = data.load_data(num_words=10000)
Downloading data from https://s3.amazonaws.com/text-datasets/imdb.npz
17465344/17464789 [==============================] - 761s 44us/step
x_train.shape, y_train.shape, x_test.shape, y_test.shape
((25000,), (25000,), (25000,), (25000,))

数据集已经为每个单词做好数字编码了,所以得到的每个样本都是一个整数形式的向量:

# 看一下第一个样本的前10个单词的数字编码
x_train[0][:10]
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
# 标签y是非0即1的,表示负面和正面评价
y_train
array([1, 0, 0, ..., 0, 1, 0], dtype=int64)

不妨恢复一条样本看一下原始形式是什么样子的。

# 这个得到的是一个字典,里面是{单词:数字序号,单词:数字序号,...}
word_index = data.get_word_index()
Downloading data from https://s3.amazonaws.com/text-datasets/imdb_word_index.json
1646592/1641221 [==============================] - 100s 61us/step

现在要根据数字序号去得到单词,所以把这个字典的k-v反转一下。这里用生成器来将其反转,再转换成字典。

index_word = dict((value, key) for key, value in word_index.items())

用生成器将第一个样本转换成单词序列,注意这个数据集的word=>index映射时是从0开始编码的,但前面得到的word_index里保留了0,1,2三个编码,也就是所有编码加了3,,这里将其减掉。另外,有些词在word_index里找不到,不妨在找不到时候就给个?标识。

" ".join(index_word.get(index-3,'?') for index in x_train[0])
"? this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert ? is an amazing actor and now the same being director ? father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for ? and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also ? to the two little boy's that played the ? of norman and paul they were just brilliant children are often left out of the ? list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all"

样本有的很短,有的很长,看一下前10个样本的长度:

[len(seq) for seq in x_train[:10]]
[218, 189, 141, 550, 147, 43, 123, 562, 233, 130]

但一定不会超过读取数据集时定义的最大长度10000:

max(max(seq) for seq in x_train)
9999
文本的向量化

因为有10000个单词,可以使用长度为10000的向量,然后将每个词对应一个索引,如果一个词在一条样本中出现了,就将相应位置设置成1(或者+1),这就是次袋模型。

如果设置成1(而不是+1),那么这个向量是有很多为1的分量,其余位置都是0,在学习视频里老师叫它k-hot编码(没查到有这种叫法,估计又是自己扯的),了解一下就好。

def k_hot(seqs, dim=10000):"""数字编码转k-hot编码"""res = np.zeros((len(seqs), dim))for i, seq in enumerate(seqs):res[i, seq] = 1return res
x_tr = k_hot(x_train) 
x_tr.shape
(25000, 10000)
x_ts = k_hot(x_test)
x_ts.shape
(25000, 10000)
建立模型和训练
model = keras.Sequential()
model.add(layers.Dense(32, input_dim=10000, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 32)                320032    
_________________________________________________________________
dense_5 (Dense)              (None, 32)                1056      
_________________________________________________________________
dense_6 (Dense)              (None, 1)                 33        
=================================================================
Total params: 321,121
Trainable params: 321,121
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['acc']
)
history = model.fit(x_tr, y_train, epochs=15, batch_size=256, validation_data=(x_ts, y_test), verbose=0)
plt.plot(history.epoch, history.history.get('val_acc'), c='g', label='validation acc')
plt.plot(history.epoch, history.history.get('acc'), c='b', label='train acc')
plt.legend()
<matplotlib.legend.Legend at 0x1890b908>

在这里插入图片描述

plt.plot(history.epoch, history.history.get('val_loss'), c='g', label='validation loss')
plt.plot(history.epoch, history.history.get('loss'), c='b', label='train loss')
plt.legend()
<matplotlib.legend.Legend at 0x189b79e8>

在这里插入图片描述

可以看到发生了严重的过拟合,下面尝试引入Dropout和正则化项,同时减小网络的规模。

模型优化
from keras import regularizers
model = keras.Sequential()
model.add(layers.Dense(8, input_dim=10000, activation='relu', kernel_regularizer=regularizers.l2(0.005)))
model.add(layers.Dropout(rate=0.4)) # keeep_prob=0.6
model.add(layers.Dense(8, activation='relu', kernel_regularizer=regularizers.l2(0.005)))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_13 (Dense)             (None, 8)                 80008     
_________________________________________________________________
dropout_3 (Dropout)          (None, 8)                 0         
_________________________________________________________________
dense_14 (Dense)             (None, 8)                 72        
_________________________________________________________________
dense_15 (Dense)             (None, 1)                 9         
=================================================================
Total params: 80,089
Trainable params: 80,089
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['acc']
)
history = model.fit(x_tr, y_train, epochs=15, batch_size=256, validation_data=(x_ts, y_test), verbose=0)
plt.plot(history.epoch, history.history.get('val_acc'), c='g', label='validation acc')
plt.plot(history.epoch, history.history.get('acc'), c='b', label='train acc')
plt.legend()
<matplotlib.legend.Legend at 0x1b2a4f28>

在这里插入图片描述

plt.plot(history.epoch, history.history.get('val_loss'), c='g', label='validation loss')
plt.plot(history.epoch, history.history.get('loss'), c='b', label='train loss')
plt.legend()
<matplotlib.legend.Legend at 0x1b319208>

在这里插入图片描述

好了很多。

这篇关于【Keras学习笔记】10:IMDb电影评价数据集文本分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/424249

相关文章

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

MyBatisPlus如何优化千万级数据的CRUD

《MyBatisPlus如何优化千万级数据的CRUD》最近负责的一个项目,数据库表量级破千万,每次执行CRUD都像走钢丝,稍有不慎就引起数据库报警,本文就结合这个项目的实战经验,聊聊MyBatisPl... 目录背景一、MyBATis Plus 简介二、千万级数据的挑战三、优化 CRUD 的关键策略1. 查

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

mysql中的数据目录用法及说明

《mysql中的数据目录用法及说明》:本文主要介绍mysql中的数据目录用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、版本3、数据目录4、总结1、背景安装mysql之后,在安装目录下会有一个data目录,我们创建的数据库、创建的表、插入的

Navicat数据表的数据添加,删除及使用sql完成数据的添加过程

《Navicat数据表的数据添加,删除及使用sql完成数据的添加过程》:本文主要介绍Navicat数据表的数据添加,删除及使用sql完成数据的添加过程,具有很好的参考价值,希望对大家有所帮助,如有... 目录Navicat数据表数据添加,删除及使用sql完成数据添加选中操作的表则出现如下界面,查看左下角从左

MySQL中的索引结构和分类实战案例详解

《MySQL中的索引结构和分类实战案例详解》本文详解MySQL索引结构与分类,涵盖B树、B+树、哈希及全文索引,分析其原理与优劣势,并结合实战案例探讨创建、管理及优化技巧,助力提升查询性能,感兴趣的朋... 目录一、索引概述1.1 索引的定义与作用1.2 索引的基本原理二、索引结构详解2.1 B树索引2.2

SpringBoot中4种数据水平分片策略

《SpringBoot中4种数据水平分片策略》数据水平分片作为一种水平扩展策略,通过将数据分散到多个物理节点上,有效解决了存储容量和性能瓶颈问题,下面小编就来和大家分享4种数据分片策略吧... 目录一、前言二、哈希分片2.1 原理2.2 SpringBoot实现2.3 优缺点分析2.4 适用场景三、范围分片

Redis分片集群、数据读写规则问题小结

《Redis分片集群、数据读写规则问题小结》本文介绍了Redis分片集群的原理,通过数据分片和哈希槽机制解决单机内存限制与写瓶颈问题,实现分布式存储和高并发处理,但存在通信开销大、维护复杂及对事务支持... 目录一、分片集群解android决的问题二、分片集群图解 分片集群特征如何解决的上述问题?(与哨兵模

浅析如何保证MySQL与Redis数据一致性

《浅析如何保证MySQL与Redis数据一致性》在互联网应用中,MySQL作为持久化存储引擎,Redis作为高性能缓存层,两者的组合能有效提升系统性能,下面我们来看看如何保证两者的数据一致性吧... 目录一、数据不一致性的根源1.1 典型不一致场景1.2 关键矛盾点二、一致性保障策略2.1 基础策略:更新数

Oracle 数据库数据操作如何精通 INSERT, UPDATE, DELETE

《Oracle数据库数据操作如何精通INSERT,UPDATE,DELETE》在Oracle数据库中,对表内数据进行增加、修改和删除操作是通过数据操作语言来完成的,下面给大家介绍Oracle数... 目录思维导图一、插入数据 (INSERT)1.1 插入单行数据,指定所有列的值语法:1.2 插入单行数据,指