第11篇 Fast AI深度学习课程——机器翻译

2024-02-27 00:32

本文主要是介绍第11篇 Fast AI深度学习课程——机器翻译,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在上节课程中,我们使用语言模型对IMDB影评进行了情感分析。对于语言模型而言,使用的神经网络是一个seq2seq的网络,即输入和输出均为序列;每输入一个单词,就需输出一个单词,因此输入输出的序列长度是一致的。对于影评分析,是一个由字词序列得到单一分类结果的网络,即为seq2one的网络。本节将介绍由法语到英语的机器翻译,该类型网络也是seq2seq,但与语言模型不同之处在于,其在读入整个字符序列后,再输出另一个字符序列,两个序列长度可不一致,而序列之间的字词也没有一一对应的关系。(关于RNN的分类情况,可参见CS231n的相关内容。)

本节秉承了本系列课程自顶而下的学习思路,在前一节的基础上(前一节主要还是在FastAI的代码基础上进行网络的构建),将从底层开始实现用于机器学习的网络。

一、数据

1. 构建词库

本课所用数据为某网站的法语版和英语版的文章,运行如下命令进行下载:

wget http://www.statmt.org/wmt10/training-giga-fren.tar
tar -xvf training-giga-fren.tar
gunzip giga-fren.release2.fixed.en.gz
gunzip giga-fren.release2.fixed.fr.gz

数据压缩包为2.5G,可能得下个把小时。为简化问题,我们将在数据集的问句集合上进行讨论,具体而言,就是英文句库中以whatwherewhichwhen开头的语句。筛选后所得语句大致为52000条。

对语句进行分词。在对法语进行分词前,可能需要下载spacy的法语支持数据包:

python -m spacy download fr

分词代码如下:

en_tok = Tokenizer.proc_all_mp(partition_by_cores(en_qs))
fr_tok = Tokenizer.proc_all_mp(partition_by_cores(fr_qs), 'fr')

分词后,所得英文序列的90%23个词以内,发文序列90%在28个词以内。接下来构建词库。注意按照词频进行筛选,并补充特殊字词:_bos__pad__unk__eos_

得到词库后,将字词数字化。这一步的转化是用fasttext包的词向量实现的。转化后,每个字词使用300维的向量标识。词向量下载:

wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.fr.zip
wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.zip

得下好长时间。

2. 构建数据模型

首先构建适合seq2seq网络的Dataset。由上节课已知,Dataset实际为一个索引类,其只需实现__getitem__()__len__()函数:

class Seq2SeqDataset(Dataset):def __init__(self, x, y): self.x,self.y = x,ydef __getitem__(self, idx): return A(self.x[idx], self.y[idx])def __len__(self): return len(self.x)
np.random.seed(42)
trn_keep = np.random.rand(len(en_ids_tr))>0.1
en_trn,fr_trn = en_ids_tr[trn_keep],fr_ids_tr[trn_keep]
en_val,fr_val = en_ids_tr[~trn_keep],fr_ids_tr[~trn_keep]
trn_ds = Seq2SeqDataset(fr_trn,en_trn)
val_ds = Seq2SeqDataset(fr_val,en_val)

然后由Dataset构建数据加载器Dataloader,这一部分与上一节大致相同。不同之处在于对文本序列进行补齐时,本例中是在序列末尾补齐,而分类网络是在序列开头补齐。这里的直观理解是:在分类网络里,对于一个批次中的最长文本,那么在读完文本后再做判定是合适的;而对短文本,如果在末端补齐,则填充的无意义字符会极大影响分类结果。在翻译网络中,我们只关心句子结束符之前的内容,这一部分要尽量减少填充字符的影响,因此在句子末尾补齐是合适的。

bs=125
trn_samp = SortishSampler(en_trn, key=lambda x: len(en_trn[x]), bs=bs)
val_samp = SortSampler(en_val, key=lambda x: len(en_val[x]))
trn_dl = DataLoader(trn_ds, bs, transpose=True, transpose_y=True, num_workers=1, pad_idx=1, pre_pad=False, sampler=trn_samp)
val_dl = DataLoader(val_ds, int(bs*1.6), transpose=True, transpose_y=True, num_workers=1, pad_idx=1, pre_pad=False, sampler=val_samp)

由数据加载器构建ModelData。事实上,Model Data就是整合训练集、验证集、可选的测试集,并提供可用于临时存储的路径。

md = ModelData(PATH, trn_dl, val_dl)

二、网络架构

翻译网络的结构如下图所示。整个流程为:将一种语言的语句通过一个Encoder网络,获得最终的一个表征语句句法结构等特征的隐藏状态向量,以之为下一个Decoder网络的初始隐藏状态,并以_bos_为初始输入,按照训练语言模型时的方式,一步一词地生成另一语言的完整语句。

图 1. 翻译网络的架构
class Seq2SeqRNN(nn.Module):def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2):super().__init__()self.nl,self.nh,self.out_sl = nl,nh,out_slself.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)self.emb_enc_drop = nn.Dropout(0.15)self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25)self.out_enc = nn.Linear(nh, em_sz_dec, bias=False)self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)self.out_drop = nn.Dropout(0.35)self.out = nn.Linear(em_sz_dec, len(itos_dec))self.out.weight.data = self.emb_dec.weight.datadef forward(self, inp):sl,bs = inp.size()h = self.initHidden(bs)emb = self.emb_enc_drop(self.emb_enc(inp))enc_out, h = self.gru_enc(emb, h)h = self.out_enc(h)dec_inp = V(torch.zeros(bs).long())res = []for i in range(self.out_sl):emb = self.emb_dec(dec_inp).unsqueeze(0)outp, h = self.gru_dec(emb, h)outp = self.out(self.out_drop(outp[0]))res.append(outp)dec_inp = V(outp.data.max(1)[1])if (dec_inp==1).all(): breakreturn torch.stack(res)def initHidden(self, bs): return V(torch.zeros(self.nl, bs, self.nh))

注意forward()函数。其中``Decoder的输入dec_inp初始化为0,即bos的索引值;Decoder的初始隐藏状态为Encoder的输出;outp`表示在词库中所有词上的概率。

值得说明的要点如下:

1. Encoder的内嵌矩阵

使用fast.text提供的词向量矩阵作为内嵌矩阵。由于fast.text的词向量矩阵的标准差为0.3,为得到大致满足高斯分布的内嵌矩阵,需要乘以系数3

2. 如何确定目标语句完结

首先统计一个目标语言的最长语句长度。然后以这个长度为终值做循环,直至结束或输出_pad_

三、损失函数

损失函数使用的是交互熵函数。由于生成的翻译语句可能和目标语句长度不一致,所以可能需要做填充。所使用的Pytorchpad函数,其需要六个参数,分别指明了在次序列方向、批索引方向的填充的头尾起始位置以及长度。

def seq2seq_loss(input, target):sl,bs = target.size()sl_in,bs_in,nc = input.size()if sl>sl_in: input = F.pad(input, (0,0,0,0,0,sl-sl_in))input = input[:sl]return F.cross_entropy(input.view(-1,nc), target.view(-1))#, ignore_index=1)

四、一些技巧

1. 双向训练设置

一般设置Encoderbidirectional=True,而不对Decoder做双向设置。这样,网络会同时在输入序列的倒序序列上训练得到相应的隐藏状态。

2. 初始阶段的强制校正

考虑训练初始时,网络对两种语言还未学习到有效信息,此时Decoder的每一步输出的单词都是随机的,从而导致后续输出远偏离于真值。而如果此时强制以正确的目标语句进行Decoder状态的推进,可有效提高网络收敛的速度。(这实际和GAN的策略很接近。)实际应用中,设置pr_force参数,当预测出的词的概率低于pr_force时,就采取强制校正措施。然后逐渐缩小pr_force,减弱强制校正的力度。

在前向传播中加入强制校正还是挺直观的,修改Seq2SeqRNNforward()函数:

    def forward(self, inp, y=None):sl,bs = inp.size()h = self.initHidden(bs)emb = self.emb_enc_drop(self.emb_enc(inp))enc_out, h = self.gru_enc(emb, h)h = self.out_enc(h)dec_inp = V(torch.zeros(bs).long())res = []for i in range(self.out_sl):emb = self.emb_dec(dec_inp).unsqueeze(0)outp, h = self.gru_dec(emb, h)outp = self.out(self.out_drop(outp[0]))res.append(outp)dec_inp = V(outp.data.max(1)[1])if (dec_inp==1).all(): breakif (y is not None) and (random.random()<self.pr_force):if i>=len(y): breakdec_inp = y[i]return torch.stack(res)

注意和pr_force相关的那一行。

那么如何加入使得pr_force逐步减小的机制呢?实际上控制epoch之间的循环的是fit()函数,在其定义中,调用了stepper.step(),该函数实现了模型的前向传播、损失函数的计算、梯度的反向传播等。因此只需定义一个新的stepper,重写其step()函数,实现pr_force的逐步减小即可。

class Seq2SeqStepper(Stepper):def step(self, xs, y, epoch):self.m.pr_force = (10-epoch)*0.1 if epoch<10 else 0return super.step(xs, y, epoch)

在调用learner.fit()时,指明stepper=Seq2SeqStepper

3. 注意力模型

Encoder不仅输出了最后一步的隐藏状态,还保存了前面步骤的隐藏状态。如果能够在输出目标语言的某个字词时,在源语言的语句中找到与之最相关的部分,然后对该相关部分的隐藏状态进行加权求和,并传递到Decoder中,那么Decoder所获取的信息就更全面,应当能够改善翻译效果。而这种加权信息,可以通过一个小型网络得到。

def forward(self, inp, y=None, ret_attn=False):sl,bs = inp.size()h = self.initHidden(bs)emb = self.emb_enc_drop(self.emb_enc(inp))enc_out, h = self.gru_enc(emb, h)h = self.out_enc(h)dec_inp = V(torch.zeros(bs).long())res,attns = [],[]w1e = enc_out @ self.W1for i in range(self.out_sl):w2h = self.l2(h[-1])u = F.tanh(w1e + w2h)a = F.softmax(u @ self.V, 0)attns.append(a)Xa = (a.unsqueeze(2) * enc_out).sum(0)emb = self.emb_dec(dec_inp)wgt_enc = self.l3(torch.cat([emb, Xa], 1))outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h)outp = self.out(self.out_drop(outp[0]))res.append(outp)dec_inp = V(outp.data.max(1)[1])if (dec_inp==1).all(): breakif (y is not None) and (random.random()<self.pr_force):if i>=len(y): breakdec_inp = y[i]res = torch.stack(res)if ret_attn: res = res,torch.stack(attns)return res

五、更广泛的应用实例

附注

  • 若一个python代码包的git库中,包含setup.pyrequirements.txt,那么可通过如下命令进行安装:pip install git+https://github.com/facebookresearch/fastText.git
  • 一个小技巧:对于网络,可以使用to_gpu()函数替代model.cuda()方法,这样在没有GPU时,会自动使用CPU进行计算。在调试时,可通过设置fastai.core.GPUFalse,以提供方便。

一些有用的链接

  • 课程wiki: 本节课程的一些相关资源,包括课程笔记、课上提到的博客地址等。

  • 注意力模型在机器翻译中的应用: 首次引入注意力模型的论文。

  • 注意力模型的博客: 博客很有意思,还支持用户交互。

这篇关于第11篇 Fast AI深度学习课程——机器翻译的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/750660

相关文章

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置

Maven中生命周期深度解析与实战指南

《Maven中生命周期深度解析与实战指南》这篇文章主要为大家详细介绍了Maven生命周期实战指南,包含核心概念、阶段详解、SpringBoot特化场景及企业级实践建议,希望对大家有一定的帮助... 目录一、Maven 生命周期哲学二、default生命周期核心阶段详解(高频使用)三、clean生命周期核心阶

Java+AI驱动实现PDF文件数据提取与解析

《Java+AI驱动实现PDF文件数据提取与解析》本文将和大家分享一套基于AI的体检报告智能评估方案,详细介绍从PDF上传、内容提取到AI分析、数据存储的全流程自动化实现方法,感兴趣的可以了解下... 目录一、核心流程:从上传到评估的完整链路二、第一步:解析 PDF,提取体检报告内容1. 引入依赖2. 封装

深度剖析SpringBoot日志性能提升的原因与解决

《深度剖析SpringBoot日志性能提升的原因与解决》日志记录本该是辅助工具,却为何成了性能瓶颈,SpringBoot如何用代码彻底破解日志导致的高延迟问题,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言第一章:日志性能陷阱的底层原理1.1 日志级别的“双刃剑”效应1.2 同步日志的“吞吐量杀手”