pytorch版本的bert模型代码(MLM)

2024-05-11 22:04

本文主要是介绍pytorch版本的bert模型代码(MLM),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

魔改bert就必须要知道Bert的结构:

主要解答与BertForMaskedLM(MLM)有关的类:

下面是MLM的分类头:

class BertLMPredictionHead(nn.Module):def __init__(self, config, bert_model_embedding_weights):super(BertLMPredictionHead, self).__init__()self.transform = BertPredictionHeadTransform(config)# The output weights are the same as the input embeddings, but there is# an output-only bias for each token.self.decoder = nn.Linear(bert_model_embedding_weights.size(1),bert_model_embedding_weights.size(0),bias=False)self.decoder.weight = bert_model_embedding_weightsself.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))"""上面是创建一个线性映射层, 把transformer block输出的[batch_size, seq_len, embed_dim]映射为[batch_size, seq_len, vocab_size], 也就是把最后一个维度映射成字典中字的数量, 获取MaskedLM的预测结果, 注意这里其实也可以直接矩阵成embedding矩阵的转置, 但一般情况下我们要随机初始化新的一层参数"""def forward(self, hidden_states):hidden_states = self.transform(hidden_states)hidden_states = self.decoder(hidden_states) + self.biasreturn hidden_statesclass BertOnlyMLMHead(nn.Module):def __init__(self, config, bert_model_embedding_weights):super(BertOnlyMLMHead, self).__init__()self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)def forward(self, sequence_output):prediction_scores = self.predictions(sequence_output)return prediction_scores

另一个类:

class BertModel(BertPreTrainedModel):"""BERT model ("Bidirectional Embedding Representations from a Transformer").Params:config: a BertConfig class instance with the configuration to build a new modelInputs:`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts`extract_features.py`, `run_classifier.py` and `run_squad.py`)`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the tokentypes indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds toa `sentence B` token (see BERT paper for more details).`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indicesselected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the maxinput sequence length in the current batch. It's the mask that we typically use for attention whena batch has varying length sentences.`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.Outputs: Tuple of (encoded_layers, pooled_output)`encoded_layers`: controled by `output_all_encoded_layers` argument:- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the endof each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), eachencoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states correspondingto the last attention block of shape [batch_size, sequence_length, hidden_size],`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of aclassifier pretrained on top of the hidden state associated to the first character of theinput (`CLS`) to train on the Next-Sentence task (see BERT's paper).Example usage:```python# Already been converted into WordPiece token idsinput_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)model = modeling.BertModel(config=config)all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)```"""def __init__(self, config):super(BertModel, self).__init__(config)self.embeddings = BertEmbeddings(config)self.encoder = BertEncoder(config)self.pooler = BertPooler(config)self.apply(self.init_bert_weights)def forward(self, input_ids, positional_enc, token_type_ids=None, attention_mask=None,output_all_encoded_layers=True, get_attention_matrices=False):if attention_mask is None:# torch.LongTensor# attention_mask = torch.ones_like(input_ids)attention_mask = (input_ids > 0)# attention_mask [batch_size, length]if token_type_ids is None:token_type_ids = torch.zeros_like(input_ids)# We create a 3D attention mask from a 2D tensor mask.# Sizes are [batch_size, 1, 1, to_seq_length]# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]# this attention mask is more simple than the triangular masking of causal attention# used in OpenAI GPT, we just need to prepare the broadcast dimension here.extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)# 注意力矩阵mask: [batch_size, 1, 1, seq_length]# Since attention_mask is 1.0 for positions we want to attend and 0.0 for# masked positions, this operation will create a tensor which is 0.0 for# positions we want to attend and -10000.0 for masked positions.# Since we are adding it to the raw scores before the softmax, this is# effectively the same as removing these entirely.extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibilityextended_attention_mask = (1.0 - extended_attention_mask) * -10000.0# 给注意力矩阵里padding的无效区域加一个很大的负数的偏置, 为了使softmax之后这些无效区域仍然为0, 不参与后续计算# embedding层embedding_output = self.embeddings(input_ids, positional_enc, token_type_ids)# 经过所有定义的transformer block之后的输出encoded_layers, all_attention_matrices = self.encoder(embedding_output,extended_attention_mask,output_all_encoded_layers=output_all_encoded_layers,get_attention_matrices=get_attention_matrices)# 可输出所有层的注意力矩阵用于可视化if get_attention_matrices:return all_attention_matrices# [-1]为最后一个transformer block的隐藏层的计算结果sequence_output = encoded_layers[-1]# pooled_output为隐藏层中#CLS#对应的token的一条向量pooled_output = self.pooler(sequence_output)if not output_all_encoded_layers:encoded_layers = encoded_layers[-1]return encoded_layers, pooled_output

参考:https://blog.51cto.com/u_15060462/4254056

这篇关于pytorch版本的bert模型代码(MLM)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/980779

相关文章

uniapp小程序中实现无缝衔接滚动效果代码示例

《uniapp小程序中实现无缝衔接滚动效果代码示例》:本文主要介绍uniapp小程序中实现无缝衔接滚动效果的相关资料,该方法可以实现滚动内容中字的不同的颜色更改,并且可以根据需要进行艺术化更改和自... 组件滚动通知只能实现简单的滚动效果,不能实现滚动内容中的字进行不同颜色的更改,下面实现一个无缝衔接的滚动

利用Python实现可回滚方案的示例代码

《利用Python实现可回滚方案的示例代码》很多项目翻车不是因为不会做,而是走错了方向却没法回头,技术选型失败的风险我们都清楚,但真正能提前规划“回滚方案”的人不多,本文从实际项目出发,教你如何用Py... 目录描述题解答案(核心思路)题解代码分析第一步:抽象缓存接口第二步:实现两个版本第三步:根据 Fea

Java计算经纬度距离的示例代码

《Java计算经纬度距离的示例代码》在Java中计算两个经纬度之间的距离,可以使用多种方法(代码示例均返回米为单位),文中整理了常用的5种方法,感兴趣的小伙伴可以了解一下... 目录1. Haversine公式(中等精度,推荐通用场景)2. 球面余弦定理(简单但精度较低)3. Vincenty公式(高精度,

QT6中绘制UI的两种方法详解与示例代码

《QT6中绘制UI的两种方法详解与示例代码》Qt6提供了两种主要的UI绘制技术:​​QML(QtMeta-ObjectLanguage)​​和​​C++Widgets​​,这两种技术各有优势,适用于不... 目录一、QML 技术详解1.1 QML 简介1.2 QML 的核心概念1.3 QML 示例:简单按钮

Java进行日期解析与格式化的实现代码

《Java进行日期解析与格式化的实现代码》使用Java搭配ApacheCommonsLang3和Natty库,可以实现灵活高效的日期解析与格式化,本文将通过相关示例为大家讲讲具体的实践操作,需要的可以... 目录一、背景二、依赖介绍1. Apache Commons Lang32. Natty三、核心实现代

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

使用Python自动化生成PPT并结合LLM生成内容的代码解析

《使用Python自动化生成PPT并结合LLM生成内容的代码解析》PowerPoint是常用的文档工具,但手动设计和排版耗时耗力,本文将展示如何通过Python自动化提取PPT样式并生成新PPT,同时... 目录核心代码解析1. 提取 PPT 样式到 jsON关键步骤:代码片段:2. 应用 JSON 样式到

SpringBoot实现二维码生成的详细步骤与完整代码

《SpringBoot实现二维码生成的详细步骤与完整代码》如今,二维码的应用场景非常广泛,从支付到信息分享,二维码都扮演着重要角色,SpringBoot是一个非常流行的Java基于Spring框架的微... 目录一、环境搭建二、创建 Spring Boot 项目三、引入二维码生成依赖四、编写二维码生成代码五

Android NDK版本迭代与FFmpeg交叉编译完全指南

《AndroidNDK版本迭代与FFmpeg交叉编译完全指南》在Android开发中,使用NDK进行原生代码开发是一项常见需求,特别是当我们需要集成FFmpeg这样的多媒体处理库时,本文将深入分析A... 目录一、android NDK版本迭代分界线二、FFmpeg交叉编译关键注意事项三、完整编译脚本示例四

查看MySQL数据库版本的四种方法

《查看MySQL数据库版本的四种方法》查看MySQL数据库的版本信息可以通过多种方法实现,包括使用命令行工具、SQL查询语句和图形化管理工具等,以下是详细的步骤和示例代码,需要的朋友可以参考下... 目录方法一:使用命令行工具1. 使用 mysql 命令示例:方法二:使用 mysqladmin 命令示例:方