第23周:使用Word2vec实现文本分类

2024-09-07 02:28

本文主要是介绍第23周:使用Word2vec实现文本分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

前言

一、数据预处理

1.1 加载数据

1.2 构建词典

1.3 生成数据批次和迭代器

二、模型构建

2.1 搭建模型

2.2 初始化模型

2.3 定义训练和评估函数

三、训练模型

3.1 拆分数据集并运行模型

3.2 测试指定数据

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
  • 🍖 原作者:[K同学啊]

说在前面

本周任务:1)基础任务---结合Word2vec文本内容(第1列)预测文本标签(第2列);优化网络结果,将准确率提升至89%;绘制出验证集的ACC与Loss图;2)进阶任务---尝试第2周的内容独立实现,尽可能不看本文的代码

我的环境:Python3.8、Pycharm2020、torch1.12.1+cu113

数据来源:[K同学啊]


一、数据预处理

1.1 加载数据

数据示例:

       zip是Python中的一个内置函数,它可以将多个序列(元组、列表等)中对应的元素打包成一个个元组,然后返回这些元组组成的一个迭代器。例如,在代码中zip(texts,labels)就是将texts和labels两个列表中对应位置的元素一一打包成一个元组,返回一个迭代器,每次迭代返回一个元组(x,y),其中x是texts中的一个元素,y是labels中对应的一个元素。这样,每次从迭代器中获取一个元素,就相当于从texts和labels中获取了一组对应的数据。在这里,zip函数主要用于将输入的texts和labels打包成一个可迭代的数据集,然后传给后续的模型训练过程中使用。

代码如下:

import torch
import os,PIL,pathlib,warnings
from torch import nn
import time
import pandas as pd
from torchvision import transforms, datasets
import jiebawarnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)#1.2 加载自定义中文数据
train_data = pd.read_csv('./train.csv', sep='\t', header=None)
print(train_data.head())#1.3 构造数据集迭代器
def custom_data_iter(texts, labels):for x, y in zip(texts, labels):yield x, yx = train_data[0].values[:]
# 多类标签的one-hot展开
y = train_data[1].values[:]

打印输出:

1.2 构建词典

代码如下:

#1.4 构建词典
from gensim.models.word2vec import Word2Vec
import numpy as np# 训练 Word2Vec 浅层神经网络模型
w2v = Word2Vec(vector_size=100,  # 是指特征向量的维度,默认为100。min_count=3)  # 可以对字典做截断. 词频少于min_count次数的单词会被丢弃掉, 默认值为5。w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=20)

Word2Vec可以直接训练模型,一步到位。这里分了三步
第一步构建一个空模型
第二步使用 build vocab 方法根据输入的文本数据 x构建词典。build vocab 方法会统计输入文本中每个词汇出现的次数,并按照词频从高到低的顺序将词汇加入词典中。
第三步使用 train 方法对模型进行训练,total examples 参数指定了训练时使用的文本数量,这里使用的是 w2v.corpus count 属性,表示输入文本的数量如果一步到位的话代码为:w2v = Word2Vec(x, vector_size=100, min_count=3, epochs=20)

#将文本转化为向量
def average_vec(text):vec = np.zeros(100).reshape((1, 100))for word in text:try:vec += w2v.wv[word].reshape((1, 100))except KeyError:continuereturn vec#将词向量保存为Ndarray
x_vec = np.concatenate([average_vec(z) for z in x])
w2v.save('w2v_model.pkl')train_iter = custom_data_iter(x_vec, y)
print(len(x), len(x_vec))label_name = list(set(train_data[1].values[:]))
print(label_name)

打印输出如下:

12100 12100
['FilmTele-Play', 'Weather-Query', 'Radio-Listen', 'Travel-Query', 'Alarm-Update', 'Calendar-Query', 'Music-Play', 'Other', 'Video-Play', 'TVProgram-Play', 'HomeAppliance-Control', 'Audio-Play']

这段代码定义了一个函数 average_vec(text),它接受一个包含多个词的列表text 作为输入,并返回这些词对应词向量的平均值。该函数
首先初始化一个形状为(1,100)的全零 numpy 数组来表示平均向量然后遍历 text 中的每个词,并尝试从 Word2Vec 模型 w2v 中使用 wv 属性获取其对应的词向量。如果在模型中找到了该词,函数将其向量加到 vec中。如果未找到该词,函数会继续选代下一个词
最后,函数返回平均向量 vec然后使用列表推导式将 average_vec()函数应用于列表x中的每个元素。得到的平均向量列表使用 np.concatenate()连接成一个numpy数组xvec,该数组表示x中所有元素的平均向量。xvec的形状为(n,100),其中n是x中元素的数

1.3 生成数据批次和迭代器

代码如下:

#生成数据批次和迭代器
text_pipeline = lambda x: average_vec(x)
label_pipeline = lambda x: label_name.index(x)print(text_pipeline("你在干嘛"))
print(label_pipeline("Travel-Query"))from torch.utils.data import DataLoader
def collate_batch(batch):label_list, text_list = [], []for (_text, _label) in batch:#标签列表label_list.append(label_pipeline(_label))#文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)text_list.append(processed_text)label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)return text_list.to(device), label_list.to(device)dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

打印输出如下:

[[ 0.71341366  1.81265011  1.44437733  0.81895351 -2.10842876 -0.53163828
   1.69011965  0.61772476  0.45250437 -1.19413337 -0.93801782 -4.51869714
   2.08006459 -0.19431476  0.54761063  0.53993282  2.45709884 -1.7272636
   3.3405721  -2.00952683  2.7074931   0.44549108 -0.3798939   0.50284129
  -2.03344245 -1.0066061  -1.57383255 -1.02822024  1.22481698 -0.74399903
   2.72032912  0.68474213 -1.08696781 -0.43206174  0.17515172  0.04883668
   0.68131649  3.37725095 -1.73957334  0.44227505  0.35449219  0.9353995
  -0.53143035  0.5939152   0.15114589 -0.67918842  1.19383969 -0.40012862
  -2.7421315   2.3960007   0.93965465 -2.33946571 -1.03136044  0.44977702
  -0.20926718 -0.48943431  1.56342356 -1.81069714  0.2234989   1.05807498
   1.99193773 -0.18156157  2.24787551 -0.63780972 -0.12800559 -0.43717601
  -2.1173833   1.23210199  2.40076267  0.39000577 -0.50040299  0.29607797
   1.25565214 -0.45914613  0.40915862 -0.72103182 -4.15503209  0.32175705
   1.13466016 -1.11661778 -0.90987498  0.02924608 -2.1390073   2.00657488
  -2.04405907 -2.21540118  2.36201783 -2.28765213 -1.62947962 -0.23354006
  -0.26953844 -2.08598122  0.30332083  1.65787105  0.44275794 -2.15785465
  -0.49007402  0.6538553  -2.73823986  0.34911314]]
3

二、模型构建

2.1 搭建模型

代码如下:

#2.1 搭建模型
class TextClassificationModel(nn.Module):def __init__(self, num_class):super(TextClassificationModel, self).__init__()self.fc = nn.Linear(100, num_class)def forward(self, text):return self.fc(text)

注意⚠️:这里使用的是最简单的网络,可根据自己的需求替换成其他网络,这里就不需要嵌入层了。

2.2 初始化模型

代码如下:

#2.2 初始化模型
num_class = len(label_name)
vocab_size = 100000
em_size = 12
model = TextClassificationModel(num_class).to(device)

2.3 定义训练和评估函数

代码如下:

#2.3 定义训练和评估函数
def train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time = time.time()for idx, (text, label) in enumerate(dataloader):predicted_label = model(text)optimizer.zero_grad()  # grad属性归零loss = criterion(predicted_label, label)  # 计算网络输出和真实值之间的差距,label为真loss.backward()  # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度裁剪optimizer.step()  # 每一步自动更新# 记录acc与losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(epoch,idx,len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text, label) in enumerate(dataloader):predicted_label = model(text)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc / total_count, train_loss / total_count

torch.nn.utils.clip_grad norm_(model.parameters(), 0.1)是一个PyTorch函数,用于在训练神经网络时限制梯度的大小。这种操作被称为梯度裁剪(gradientclipping),可以防止梯度爆炸问题,从而提高神经网络的稳定性和性能。
在这个函数中:

  • model.parameters()表示模型的所有参数。对于一个神经网络,参数通常包括权重和偏置项。
  • 0.1 是一个指定的阈值,表示梯度的最大范数(L2范数)。如果计算出的梯度范数超过这个阈值,梯度会被缩放,使其范数等于阈值,

梯度裁剪的主要目的是防止梯度爆炸。梯度爆炸通常发生在训练深度神经网络时,尤其是在处理长序列数据的循环神经网络(RNN)中。当梯度爆炸时,参数更新可能会变得非常大,导致模型无法收敛或出现数值不稳定。通过限制梯度的大小,梯度裁剪有助于解决这些问题,使模型训练变得更加稳定。

三、训练模型

3.1 拆分数据集并运行模型

代码如下:

#三、训练模型
#3.1 拆分数据集并运行模型
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数设定
EPOCHS = 10  # epoch
LR = 4  # learningRate
BATCH_SIZE = 64  # batch size for training# 设置损失函数、选择优化器、设置学习率调整函数
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
total_accu = None# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)#3.2 正式训练
for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:d} | time:{:4.2f}s |'' valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch, time.time() - epoch_start_time,val_acc, val_loss,lr))print('-' * 69)test_acc, test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

打印输出如下:

|epoch1|  50/ 152 batches|train_acc0.748 train_loss0.02063
|epoch1| 100/ 152 batches|train_acc0.834 train_loss0.01444
|epoch1| 150/ 152 batches|train_acc0.831 train_loss0.01594
---------------------------------------------------------------------
| epoch 1 | time:9.10s | valid_acc 0.794 valid_loss 0.016 | lr 4.000000
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.844 train_loss0.01494
|epoch2| 100/ 152 batches|train_acc0.846 train_loss0.01505
|epoch2| 150/ 152 batches|train_acc0.849 train_loss0.01442
---------------------------------------------------------------------
| epoch 2 | time:1.71s | valid_acc 0.834 valid_loss 0.017 | lr 4.000000
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.856 train_loss0.01387
|epoch3| 100/ 152 batches|train_acc0.848 train_loss0.01376
|epoch3| 150/ 152 batches|train_acc0.863 train_loss0.01322
---------------------------------------------------------------------
| epoch 3 | time:1.93s | valid_acc 0.843 valid_loss 0.017 | lr 4.000000
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.865 train_loss0.01316
|epoch4| 100/ 152 batches|train_acc0.856 train_loss0.01294
|epoch4| 150/ 152 batches|train_acc0.847 train_loss0.01460
---------------------------------------------------------------------
| epoch 4 | time:1.48s | valid_acc 0.830 valid_loss 0.017 | lr 4.000000
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.882 train_loss0.00931
|epoch5| 100/ 152 batches|train_acc0.900 train_loss0.00714
|epoch5| 150/ 152 batches|train_acc0.902 train_loss0.00734
---------------------------------------------------------------------
| epoch 5 | time:1.46s | valid_acc 0.881 valid_loss 0.009 | lr 0.400000
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.902 train_loss0.00681
|epoch6| 100/ 152 batches|train_acc0.903 train_loss0.00632
|epoch6| 150/ 152 batches|train_acc0.906 train_loss0.00641
---------------------------------------------------------------------
| epoch 6 | time:1.65s | valid_acc 0.876 valid_loss 0.009 | lr 0.400000
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.907 train_loss0.00579
|epoch7| 100/ 152 batches|train_acc0.904 train_loss0.00610
|epoch7| 150/ 152 batches|train_acc0.908 train_loss0.00561
---------------------------------------------------------------------
| epoch 7 | time:1.50s | valid_acc 0.885 valid_loss 0.008 | lr 0.040000
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.912 train_loss0.00535
|epoch8| 100/ 152 batches|train_acc0.915 train_loss0.00571
|epoch8| 150/ 152 batches|train_acc0.901 train_loss0.00604
---------------------------------------------------------------------
| epoch 8 | time:1.52s | valid_acc 0.884 valid_loss 0.008 | lr 0.040000
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.909 train_loss0.00564
|epoch9| 100/ 152 batches|train_acc0.907 train_loss0.00553
|epoch9| 150/ 152 batches|train_acc0.912 train_loss0.00565
---------------------------------------------------------------------
| epoch 9 | time:1.49s | valid_acc 0.884 valid_loss 0.008 | lr 0.004000
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.912 train_loss0.00535
|epoch10| 100/ 152 batches|train_acc0.911 train_loss0.00558
|epoch10| 150/ 152 batches|train_acc0.905 train_loss0.00580
---------------------------------------------------------------------
| epoch 10 | time:1.47s | valid_acc 0.884 valid_loss 0.008 | lr 0.000400
---------------------------------------------------------------------
模型准确率为:0.8839

3.2 测试指定数据

代码如下:

# 测试指定的数据
def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text), dtype=torch.float32)print(text.shape)output = model(text)return output.argmax(1).item()ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")print("该文本的类别是: %s" % label_name[predict(ex_text_str, text_pipeline)])

打印输出如下:

torch.Size([1, 100])
该文本的类别是: Travel-Query


总结

结合Word2vec文本内容(第1列)预测文本标签(第2列),准确率能达到88%以上,后续将继续优化网络结构和进行结果的可视化。

这篇关于第23周:使用Word2vec实现文本分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1143796

相关文章

SpringBoot集成redisson实现延时队列教程

《SpringBoot集成redisson实现延时队列教程》文章介绍了使用Redisson实现延迟队列的完整步骤,包括依赖导入、Redis配置、工具类封装、业务枚举定义、执行器实现、Bean创建、消费... 目录1、先给项目导入Redisson依赖2、配置redis3、创建 RedissonConfig 配

Python的Darts库实现时间序列预测

《Python的Darts库实现时间序列预测》Darts一个集统计、机器学习与深度学习模型于一体的Python时间序列预测库,本文主要介绍了Python的Darts库实现时间序列预测,感兴趣的可以了解... 目录目录一、什么是 Darts?二、安装与基本配置安装 Darts导入基础模块三、时间序列数据结构与

Python使用FastAPI实现大文件分片上传与断点续传功能

《Python使用FastAPI实现大文件分片上传与断点续传功能》大文件直传常遇到超时、网络抖动失败、失败后只能重传的问题,分片上传+断点续传可以把大文件拆成若干小块逐个上传,并在中断后从已完成分片继... 目录一、接口设计二、服务端实现(FastAPI)2.1 运行环境2.2 目录结构建议2.3 serv

C#实现千万数据秒级导入的代码

《C#实现千万数据秒级导入的代码》在实际开发中excel导入很常见,现代社会中很容易遇到大数据处理业务,所以本文我就给大家分享一下千万数据秒级导入怎么实现,文中有详细的代码示例供大家参考,需要的朋友可... 目录前言一、数据存储二、处理逻辑优化前代码处理逻辑优化后的代码总结前言在实际开发中excel导入很

Spring Security简介、使用与最佳实践

《SpringSecurity简介、使用与最佳实践》SpringSecurity是一个能够为基于Spring的企业应用系统提供声明式的安全访问控制解决方案的安全框架,本文给大家介绍SpringSec... 目录一、如何理解 Spring Security?—— 核心思想二、如何在 Java 项目中使用?——

SpringBoot+RustFS 实现文件切片极速上传的实例代码

《SpringBoot+RustFS实现文件切片极速上传的实例代码》本文介绍利用SpringBoot和RustFS构建高性能文件切片上传系统,实现大文件秒传、断点续传和分片上传等功能,具有一定的参考... 目录一、为什么选择 RustFS + SpringBoot?二、环境准备与部署2.1 安装 RustF

Nginx部署HTTP/3的实现步骤

《Nginx部署HTTP/3的实现步骤》本文介绍了在Nginx中部署HTTP/3的详细步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录前提条件第一步:安装必要的依赖库第二步:获取并构建 BoringSSL第三步:获取 Nginx

springboot中使用okhttp3的小结

《springboot中使用okhttp3的小结》OkHttp3是一个JavaHTTP客户端,可以处理各种请求类型,比如GET、POST、PUT等,并且支持高效的HTTP连接池、请求和响应缓存、以及异... 在 Spring Boot 项目中使用 OkHttp3 进行 HTTP 请求是一个高效且流行的方式。

MyBatis Plus实现时间字段自动填充的完整方案

《MyBatisPlus实现时间字段自动填充的完整方案》在日常开发中,我们经常需要记录数据的创建时间和更新时间,传统的做法是在每次插入或更新操作时手动设置这些时间字段,这种方式不仅繁琐,还容易遗漏,... 目录前言解决目标技术栈实现步骤1. 实体类注解配置2. 创建元数据处理器3. 服务层代码优化填充机制详

Python实现Excel批量样式修改器(附完整代码)

《Python实现Excel批量样式修改器(附完整代码)》这篇文章主要为大家详细介绍了如何使用Python实现一个Excel批量样式修改器,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一... 目录前言功能特性核心功能界面特性系统要求安装说明使用指南基本操作流程高级功能技术实现核心技术栈关键函