第11篇 Fast AI深度学习课程——机器翻译

2024-02-27 00:32

本文主要是介绍第11篇 Fast AI深度学习课程——机器翻译,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在上节课程中,我们使用语言模型对IMDB影评进行了情感分析。对于语言模型而言,使用的神经网络是一个seq2seq的网络,即输入和输出均为序列;每输入一个单词,就需输出一个单词,因此输入输出的序列长度是一致的。对于影评分析,是一个由字词序列得到单一分类结果的网络,即为seq2one的网络。本节将介绍由法语到英语的机器翻译,该类型网络也是seq2seq,但与语言模型不同之处在于,其在读入整个字符序列后,再输出另一个字符序列,两个序列长度可不一致,而序列之间的字词也没有一一对应的关系。(关于RNN的分类情况,可参见CS231n的相关内容。)

本节秉承了本系列课程自顶而下的学习思路,在前一节的基础上(前一节主要还是在FastAI的代码基础上进行网络的构建),将从底层开始实现用于机器学习的网络。

一、数据

1. 构建词库

本课所用数据为某网站的法语版和英语版的文章,运行如下命令进行下载:

wget http://www.statmt.org/wmt10/training-giga-fren.tar
tar -xvf training-giga-fren.tar
gunzip giga-fren.release2.fixed.en.gz
gunzip giga-fren.release2.fixed.fr.gz

数据压缩包为2.5G,可能得下个把小时。为简化问题,我们将在数据集的问句集合上进行讨论,具体而言,就是英文句库中以whatwherewhichwhen开头的语句。筛选后所得语句大致为52000条。

对语句进行分词。在对法语进行分词前,可能需要下载spacy的法语支持数据包:

python -m spacy download fr

分词代码如下:

en_tok = Tokenizer.proc_all_mp(partition_by_cores(en_qs))
fr_tok = Tokenizer.proc_all_mp(partition_by_cores(fr_qs), 'fr')

分词后,所得英文序列的90%23个词以内,发文序列90%在28个词以内。接下来构建词库。注意按照词频进行筛选,并补充特殊字词:_bos__pad__unk__eos_

得到词库后,将字词数字化。这一步的转化是用fasttext包的词向量实现的。转化后,每个字词使用300维的向量标识。词向量下载:

wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.fr.zip
wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.zip

得下好长时间。

2. 构建数据模型

首先构建适合seq2seq网络的Dataset。由上节课已知,Dataset实际为一个索引类,其只需实现__getitem__()__len__()函数:

class Seq2SeqDataset(Dataset):def __init__(self, x, y): self.x,self.y = x,ydef __getitem__(self, idx): return A(self.x[idx], self.y[idx])def __len__(self): return len(self.x)
np.random.seed(42)
trn_keep = np.random.rand(len(en_ids_tr))>0.1
en_trn,fr_trn = en_ids_tr[trn_keep],fr_ids_tr[trn_keep]
en_val,fr_val = en_ids_tr[~trn_keep],fr_ids_tr[~trn_keep]
trn_ds = Seq2SeqDataset(fr_trn,en_trn)
val_ds = Seq2SeqDataset(fr_val,en_val)

然后由Dataset构建数据加载器Dataloader,这一部分与上一节大致相同。不同之处在于对文本序列进行补齐时,本例中是在序列末尾补齐,而分类网络是在序列开头补齐。这里的直观理解是:在分类网络里,对于一个批次中的最长文本,那么在读完文本后再做判定是合适的;而对短文本,如果在末端补齐,则填充的无意义字符会极大影响分类结果。在翻译网络中,我们只关心句子结束符之前的内容,这一部分要尽量减少填充字符的影响,因此在句子末尾补齐是合适的。

bs=125
trn_samp = SortishSampler(en_trn, key=lambda x: len(en_trn[x]), bs=bs)
val_samp = SortSampler(en_val, key=lambda x: len(en_val[x]))
trn_dl = DataLoader(trn_ds, bs, transpose=True, transpose_y=True, num_workers=1, pad_idx=1, pre_pad=False, sampler=trn_samp)
val_dl = DataLoader(val_ds, int(bs*1.6), transpose=True, transpose_y=True, num_workers=1, pad_idx=1, pre_pad=False, sampler=val_samp)

由数据加载器构建ModelData。事实上,Model Data就是整合训练集、验证集、可选的测试集,并提供可用于临时存储的路径。

md = ModelData(PATH, trn_dl, val_dl)

二、网络架构

翻译网络的结构如下图所示。整个流程为:将一种语言的语句通过一个Encoder网络,获得最终的一个表征语句句法结构等特征的隐藏状态向量,以之为下一个Decoder网络的初始隐藏状态,并以_bos_为初始输入,按照训练语言模型时的方式,一步一词地生成另一语言的完整语句。

图 1. 翻译网络的架构
class Seq2SeqRNN(nn.Module):def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2):super().__init__()self.nl,self.nh,self.out_sl = nl,nh,out_slself.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)self.emb_enc_drop = nn.Dropout(0.15)self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25)self.out_enc = nn.Linear(nh, em_sz_dec, bias=False)self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)self.out_drop = nn.Dropout(0.35)self.out = nn.Linear(em_sz_dec, len(itos_dec))self.out.weight.data = self.emb_dec.weight.datadef forward(self, inp):sl,bs = inp.size()h = self.initHidden(bs)emb = self.emb_enc_drop(self.emb_enc(inp))enc_out, h = self.gru_enc(emb, h)h = self.out_enc(h)dec_inp = V(torch.zeros(bs).long())res = []for i in range(self.out_sl):emb = self.emb_dec(dec_inp).unsqueeze(0)outp, h = self.gru_dec(emb, h)outp = self.out(self.out_drop(outp[0]))res.append(outp)dec_inp = V(outp.data.max(1)[1])if (dec_inp==1).all(): breakreturn torch.stack(res)def initHidden(self, bs): return V(torch.zeros(self.nl, bs, self.nh))

注意forward()函数。其中``Decoder的输入dec_inp初始化为0,即bos的索引值;Decoder的初始隐藏状态为Encoder的输出;outp`表示在词库中所有词上的概率。

值得说明的要点如下:

1. Encoder的内嵌矩阵

使用fast.text提供的词向量矩阵作为内嵌矩阵。由于fast.text的词向量矩阵的标准差为0.3,为得到大致满足高斯分布的内嵌矩阵,需要乘以系数3

2. 如何确定目标语句完结

首先统计一个目标语言的最长语句长度。然后以这个长度为终值做循环,直至结束或输出_pad_

三、损失函数

损失函数使用的是交互熵函数。由于生成的翻译语句可能和目标语句长度不一致,所以可能需要做填充。所使用的Pytorchpad函数,其需要六个参数,分别指明了在次序列方向、批索引方向的填充的头尾起始位置以及长度。

def seq2seq_loss(input, target):sl,bs = target.size()sl_in,bs_in,nc = input.size()if sl>sl_in: input = F.pad(input, (0,0,0,0,0,sl-sl_in))input = input[:sl]return F.cross_entropy(input.view(-1,nc), target.view(-1))#, ignore_index=1)

四、一些技巧

1. 双向训练设置

一般设置Encoderbidirectional=True,而不对Decoder做双向设置。这样,网络会同时在输入序列的倒序序列上训练得到相应的隐藏状态。

2. 初始阶段的强制校正

考虑训练初始时,网络对两种语言还未学习到有效信息,此时Decoder的每一步输出的单词都是随机的,从而导致后续输出远偏离于真值。而如果此时强制以正确的目标语句进行Decoder状态的推进,可有效提高网络收敛的速度。(这实际和GAN的策略很接近。)实际应用中,设置pr_force参数,当预测出的词的概率低于pr_force时,就采取强制校正措施。然后逐渐缩小pr_force,减弱强制校正的力度。

在前向传播中加入强制校正还是挺直观的,修改Seq2SeqRNNforward()函数:

    def forward(self, inp, y=None):sl,bs = inp.size()h = self.initHidden(bs)emb = self.emb_enc_drop(self.emb_enc(inp))enc_out, h = self.gru_enc(emb, h)h = self.out_enc(h)dec_inp = V(torch.zeros(bs).long())res = []for i in range(self.out_sl):emb = self.emb_dec(dec_inp).unsqueeze(0)outp, h = self.gru_dec(emb, h)outp = self.out(self.out_drop(outp[0]))res.append(outp)dec_inp = V(outp.data.max(1)[1])if (dec_inp==1).all(): breakif (y is not None) and (random.random()<self.pr_force):if i>=len(y): breakdec_inp = y[i]return torch.stack(res)

注意和pr_force相关的那一行。

那么如何加入使得pr_force逐步减小的机制呢?实际上控制epoch之间的循环的是fit()函数,在其定义中,调用了stepper.step(),该函数实现了模型的前向传播、损失函数的计算、梯度的反向传播等。因此只需定义一个新的stepper,重写其step()函数,实现pr_force的逐步减小即可。

class Seq2SeqStepper(Stepper):def step(self, xs, y, epoch):self.m.pr_force = (10-epoch)*0.1 if epoch<10 else 0return super.step(xs, y, epoch)

在调用learner.fit()时,指明stepper=Seq2SeqStepper

3. 注意力模型

Encoder不仅输出了最后一步的隐藏状态,还保存了前面步骤的隐藏状态。如果能够在输出目标语言的某个字词时,在源语言的语句中找到与之最相关的部分,然后对该相关部分的隐藏状态进行加权求和,并传递到Decoder中,那么Decoder所获取的信息就更全面,应当能够改善翻译效果。而这种加权信息,可以通过一个小型网络得到。

def forward(self, inp, y=None, ret_attn=False):sl,bs = inp.size()h = self.initHidden(bs)emb = self.emb_enc_drop(self.emb_enc(inp))enc_out, h = self.gru_enc(emb, h)h = self.out_enc(h)dec_inp = V(torch.zeros(bs).long())res,attns = [],[]w1e = enc_out @ self.W1for i in range(self.out_sl):w2h = self.l2(h[-1])u = F.tanh(w1e + w2h)a = F.softmax(u @ self.V, 0)attns.append(a)Xa = (a.unsqueeze(2) * enc_out).sum(0)emb = self.emb_dec(dec_inp)wgt_enc = self.l3(torch.cat([emb, Xa], 1))outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h)outp = self.out(self.out_drop(outp[0]))res.append(outp)dec_inp = V(outp.data.max(1)[1])if (dec_inp==1).all(): breakif (y is not None) and (random.random()<self.pr_force):if i>=len(y): breakdec_inp = y[i]res = torch.stack(res)if ret_attn: res = res,torch.stack(attns)return res

五、更广泛的应用实例

附注

  • 若一个python代码包的git库中,包含setup.pyrequirements.txt,那么可通过如下命令进行安装:pip install git+https://github.com/facebookresearch/fastText.git
  • 一个小技巧:对于网络,可以使用to_gpu()函数替代model.cuda()方法,这样在没有GPU时,会自动使用CPU进行计算。在调试时,可通过设置fastai.core.GPUFalse,以提供方便。

一些有用的链接

  • 课程wiki: 本节课程的一些相关资源,包括课程笔记、课上提到的博客地址等。

  • 注意力模型在机器翻译中的应用: 首次引入注意力模型的论文。

  • 注意力模型的博客: 博客很有意思,还支持用户交互。

这篇关于第11篇 Fast AI深度学习课程——机器翻译的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/750660

相关文章

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

Spring AI使用tool Calling和MCP的示例详解

《SpringAI使用toolCalling和MCP的示例详解》SpringAI1.0.0.M6引入ToolCalling与MCP协议,提升AI与工具交互的扩展性与标准化,支持信息检索、行动执行等... 目录深入探索 Spring AI聊天接口示例Function CallingMCPSTDIOSSE结束语

三频BE12000国补到手2549元! ROG 魔盒Pro WIFI7电竞AI路由器上架

《三频BE12000国补到手2549元!ROG魔盒ProWIFI7电竞AI路由器上架》近日,华硕带来了ROG魔盒ProWIFI7电竞AI路由器(ROGSTRIXGR7Pro),目前新... 华硕推出了ROG 魔盒Pro WIFI7电竞AI路由器(ROG STRIX GR7 Phttp://www.cppcn

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

SpringBoot开发中十大常见陷阱深度解析与避坑指南

《SpringBoot开发中十大常见陷阱深度解析与避坑指南》在SpringBoot的开发过程中,即使是经验丰富的开发者也难免会遇到各种棘手的问题,本文将针对SpringBoot开发中十大常见的“坑... 目录引言一、配置总出错?是不是同时用了.properties和.yml?二、换个位置配置就失效?搞清楚加

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和