从零实现GPT【1】——BPE

2024-06-23 17:04
文章标签 实现 gpt bpe

本文主要是介绍从零实现GPT【1】——BPE,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • Embedding 的原理
  • 训练
  • 特殊 token 处理和保存
  • 编码
  • 解码
  • 完整代码

BPE,字节对编码

Embedding 的原理

image.png

  • 简单来说就是查表
# 解释embedding
from torch.nn import Embedding
import torch# 标准的正态分布初始化 也可以用均匀分布初始化
emb = Embedding(10, 32)
res = emb(torch.tensor([[0, 1, 2]
]))
print(res.shape)  # torch.Size([1, 3, 32]) [batch, seq_len, dim]
  • 自己实现
# 解释embedding
from torch.nn import Embedding, Parameter, Module
import torchclass MyEmbing(Module):def __init__(self, vocab_size, dim):super().__init__()self.emb_matrix = Parameter(torch.randn(vocab_size, dim))# Parameter标记self.emb_matrix需要被训练def forward(self, ids):return self.emb_matrix[ids]  # 取索引这个操作可以反向传播emb = MyEmbedding(10, 32)
res = emb(torch.tensor([[0, 1, 2]
]))
print(res.shape)  # torch.Size([1, 3, 32]) [batch, seq_len, dim]

训练

  1. 初始化词表,一般是 0-255 个 ASCII 编码
  2. 设置词表大小 Max_size
  3. 循环统计相邻两个字节的频率,取最高的合并后作为新的 token 加入到词表中
  4. 合并新的 token
  5. 重复 c、d,直到词表大小到Max_size 或者 没有更多的相邻 token
class BPETokenizer:def __init__(self):self.b2i = OrderedDict()  # bytes to idself.i2b = OrderedDict()  # id to bytes (b2i的反向映射)self.next_id = 0# special tokenself.sp_s2i = {}  # str to idself.sp_i2s = {}  # id to str# 相邻token统计def _pair_stats(self, tokens, stats):for i in range(len(tokens)-1):new_token = tokens[i]+tokens[i+1]if new_token not in stats:stats[new_token] = 0stats[new_token] += 1# 合并相邻tokendef _merge_pair(self, tokens, new_token):merged_tokens = []i = 0while i < len(tokens):if i+1 < len(tokens) and tokens[i]+tokens[i+1] == new_token:merged_tokens.append(tokens[i]+tokens[i+1])i += 2else:merged_tokens.append(tokens[i])i += 1return merged_tokensdef train(self, text_list, vocab_size):# 单字节是最基础的token,初始化词表for i in range(256):self.b2i[bytes([i])] = iself.next_id = 256# 语料转bytetokens_list = []for text in text_list:tokens = [bytes([b]) for b in text.encode('utf-8')]tokens_list.append(tokens)# 进度条progress = tqdm(total=vocab_size, initial=256)while True:# 词表足够大了,退出训练if self.next_id >= vocab_size:break# 统计相邻token频率stats = {}for tokens in tokens_list:self._pair_stats(tokens, stats)# 没有更多相邻token, 无法生成更多token,退出训练if not stats:break# 合并最高频的相邻token,作为新的token加入词表new_token = max(stats, key=stats.get)new_tokens_list = []for tokens in tokens_list:# self._merge_pair(tokens, new_token) -> listnew_tokens_list.append(self._merge_pair(tokens, new_token))tokens_list = new_tokens_list# new token加入词表self.b2i[new_token] = self.next_idself.next_id += 1# 刷新进度条progress.update(1)self.i2b = {v: k for k, v in self.b2i.items()}

特殊 token 处理和保存

  • 特殊 token 加到词表中
tokenizer = BPETokenizer()# 特殊token
tokenizer.add_special_tokens((['<|im_start|>', '<|im_end|>', '<|endoftext|>', '<|padding|>']))# 特殊token
def add_special_tokens(self, special_tokens):for token in special_tokens:if token not in self.sp_s2i:self.sp_s2i[token] = self.next_idself.sp_i2s[self.next_id] = tokenself.next_id += 1
  • 保存和加载
tokenizer.save('tokenizer.bin')def save(self, file):with open(file, 'wb') as fp:fp.write(pickle.dumps((self.b2i, self.sp_s2i, self.next_id)))# 还原
tokenizer = BPETokenizer()
tokenizer.load('tokenizer.bin')
print('vocab size:', tokenizer.vocab_size())def load(self, file):with open(file, 'rb') as fp:self.b2i, self.sp_s2i, self.next_id = pickle.loads(fp.read())self.i2b = {v: k for k, v in self.b2i.items()}self.sp_i2s = {v: k for k, v in self.sp_s2i.items()}

编码

  1. 分离特殊 token,用于直接映射特殊 token
  2. 进行编码,特殊 token 直接编码就好,普通 token 继续

while True:

  1. 对于普通 token, 统计相邻 token 频率
  2. 选择合并后的 id 最小的 pair token 合并(也就是优先合并短的)
  3. 重复 c d,直到没有合并的 pair token

不断循环 token,统计相邻 token 的频率,取 id 最小的 pair 进行合并,从而可以拼接成更大的 id

# 编码
ids, tokens = tokenizer.encode('<|im_start|>system\nyou are a helper assistant\n<|im_end|>\n<|im_start|>user\n今天的天气\n<|im_end|><|im_start|>assistant\n')
print('encode:', ids, tokens)
'''
encode: 
[300, 115, 121, 115, 116, 101, 109, 10, 121, 111, 117, 32, 97, 114, 276, 97, 32, 104, 101, 108, 112, 101, 293, 97, 115, 115, 105, 115, 116, 97, 110, 116, 10, 301, 10, 300, 117, 115, 101, 114, 10, 265, 138, 266, 169, 261, 266, 169, 230, 176, 148, 10, 301, 300, 97, 115, 115, 105, 115, 116, 97, 110, 116, 10] [b'<|im_start|>', b's', b'y', b's', b't', b'e', b'm', b'\n', b'y', b'o', b'u', b' ', b'a', b'r', b'e ', b'a', b' ', b'h', b'e', b'l', b'p', b'e', b'r ', b'a', b's', b's', b'i', b's', b't', b'a', b'n', b't', b'\n', b'<|im_end|>', b'\n', b'<|im_start|>', b'u', b's', b'e', b'r', b'\n', b'\xe4\xbb', b'\x8a', b'\xe5\xa4', b'\xa9', b'\xe7\x9a\x84', b'\xe5\xa4', b'\xa9', b'\xe6', b'\xb0', b'\x94', b'\n', b'<|im_end|>', b'<|im_start|>', b'a', b's', b's', b'i', b's', b't', b'a', b'n', b't', b'\n']
''''''
在Python中,Unicode字符通常以"\x"开头,后面跟着两个十六进制数字,或者以"\u"开头,后面跟着四个十六进制数字。
'''def encode(self, text):# 特殊token分离pattern = '('+'|'.join([re.escape(tok) for tok in self.sp_s2i])+')'splits = re.split(pattern, text)  # [ '<|im_start|>', 'user', '<||>' ]# 编码结果enc_ids = []enc_tokens = []for sub_text in splits:if sub_text in self.sp_s2i:  # 特殊token,直接对应idenc_ids.append(self.sp_s2i[sub_text])enc_tokens.append(sub_text.encode('utf-8'))else:tokens = [bytes([b]) for b in sub_text.encode('utf-8')]while True:# 统计相邻token频率stats = {}self._pair_stats(tokens, stats)# 选择合并后id最小的pair合并(也就是优先合并短的)new_token = Nonefor merge_token in stats:if merge_token in self.b2i and (new_token is None or self.b2i[merge_token] < self.b2i[new_token]):new_token = merge_token# 没有可以合并的pair,退出if new_token is None:break# 合并pairtokens = self._merge_pair(tokens, new_token)enc_ids.extend([self.b2i[tok] for tok in tokens])enc_tokens.extend(tokens)return enc_ids, enc_tokens

解码

# 解码
s = tokenizer.decode(ids)
print('decode:', s)
'''
decode: 
<|im_start|>system
you are a helper assistant
<|im_end|>
<|im_start|>user
今天的天气
<|im_end|><|im_start|>assistant
'''def decode(self, ids):bytes_list = []for id in ids:if id in self.sp_i2s:bytes_list.append(self.sp_i2s[id].encode('utf-8'))else:bytes_list.append(self.i2b[id])  # self.i2b 这里已经是字节了 id to byte return b''.join(bytes_list).decode('utf-8', errors='replace')

完整代码

from collections import OrderedDict
import pickle
import re
from tqdm import tqdm# Byte-Pair Encoding tokenizationclass BPETokenizer:def __init__(self):self.b2i = OrderedDict()  # bytes to idself.i2b = OrderedDict()  # id to bytes (b2i的反向映射)self.next_id = 0# special tokenself.sp_s2i = {}  # str to idself.sp_i2s = {}  # id to str# 相邻token统计def _pair_stats(self, tokens, stats):for i in range(len(tokens)-1):new_token = tokens[i]+tokens[i+1]if new_token not in stats:stats[new_token] = 0stats[new_token] += 1# 合并相邻tokendef _merge_pair(self, tokens, new_token):merged_tokens = []i = 0while i < len(tokens):if i+1 < len(tokens) and tokens[i]+tokens[i+1] == new_token:merged_tokens.append(tokens[i]+tokens[i+1])i += 2else:merged_tokens.append(tokens[i])i += 1return merged_tokensdef train(self, text_list, vocab_size):# 单字节是最基础的token,初始化词表for i in range(256):self.b2i[bytes([i])] = iself.next_id = 256# 语料转bytetokens_list = []for text in text_list:tokens = [bytes([b]) for b in text.encode('utf-8')]tokens_list.append(tokens)# 进度条progress = tqdm(total=vocab_size, initial=256)while True:# 词表足够大了,退出训练if self.next_id >= vocab_size:break# 统计相邻token频率stats = {}for tokens in tokens_list:self._pair_stats(tokens, stats)# 没有更多相邻token, 无法生成更多token,退出训练if not stats:break# 合并最高频的相邻token,作为新的token加入词表new_token = max(stats, key=stats.get)new_tokens_list = []for tokens in tokens_list:# self._merge_pair(tokens, new_token) -> listnew_tokens_list.append(self._merge_pair(tokens, new_token))tokens_list = new_tokens_list# new token加入词表self.b2i[new_token] = self.next_idself.next_id += 1# 刷新进度条progress.update(1)self.i2b = {v: k for k, v in self.b2i.items()}# 词表大小def vocab_size(self):return self.next_id# 词表def vocab(self):v = {}v.update(self.i2b)v.update({id: token.encode('utf-8')for id, token in self.sp_i2s.items()})return v# 特殊tokendef add_special_tokens(self, special_tokens):for token in special_tokens:if token not in self.sp_s2i:self.sp_s2i[token] = self.next_idself.sp_i2s[self.next_id] = tokenself.next_id += 1def encode(self, text):# 特殊token分离pattern = '('+'|'.join([re.escape(tok) for tok in self.sp_s2i])+')'splits = re.split(pattern, text)  # [ '<|im_start|>', 'user', '<||>' ]# 编码结果enc_ids = []enc_tokens = []for sub_text in splits:if sub_text in self.sp_s2i:  # 特殊token,直接对应idenc_ids.append(self.sp_s2i[sub_text])enc_tokens.append(sub_text.encode('utf-8'))else:tokens = [bytes([b]) for b in sub_text.encode('utf-8')]while True:# 统计相邻token频率stats = {}self._pair_stats(tokens, stats)# 选择合并后id最小的pair合并(也就是优先合并短的)new_token = Nonefor merge_token in stats:if merge_token in self.b2i and (new_token is None or self.b2i[merge_token] < self.b2i[new_token]):new_token = merge_token# 没有可以合并的pair,退出if new_token is None:break# 合并pairtokens = self._merge_pair(tokens, new_token)enc_ids.extend([self.b2i[tok] for tok in tokens])enc_tokens.extend(tokens)return enc_ids, enc_tokensdef decode(self, ids):bytes_list = []for id in ids:if id in self.sp_i2s:bytes_list.append(self.sp_i2s[id].encode('utf-8'))else:bytes_list.append(self.i2b[id])  # self.i2b 这里已经是字节了 id to bytereturn b''.join(bytes_list).decode('utf-8', errors='replace')def save(self, file):with open(file, 'wb') as fp:fp.write(pickle.dumps((self.b2i, self.sp_s2i, self.next_id)))def load(self, file):with open(file, 'rb') as fp:self.b2i, self.sp_s2i, self.next_id = pickle.loads(fp.read())self.i2b = {v: k for k, v in self.b2i.items()}self.sp_i2s = {v: k for k, v in self.sp_s2i.items()}if __name__ == '__main__':# 加载语料cn = open('dataset/train-cn.txt', 'r').read()en = open('dataset/train-en.txt', 'r').read()# 训练tokenizer = BPETokenizer()tokenizer.train(text_list=[cn, en], vocab_size=300)# 特殊tokentokenizer.add_special_tokens((['<|im_start|>', '<|im_end|>', '<|endoftext|>', '<|padding|>']))# 保存tokenizer.save('tokenizer.bin')# 还原tokenizer = BPETokenizer()tokenizer.load('tokenizer.bin')print('vocab size:', tokenizer.vocab_size())# 编码ids, tokens = tokenizer.encode('<|im_start|>system\nyou are a helper assistant\n<|im_end|>\n<|im_start|>user\n今天的天气\n<|im_end|><|im_start|>assistant\n')print('encode:', ids, tokens)# 解码s = tokenizer.decode(ids)print('decode:', s)# 打印词典print('vocab:', tokenizer.vocab())

这篇关于从零实现GPT【1】——BPE的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1087748

相关文章

Spring Boot 实现 IP 限流的原理、实践与利弊解析

《SpringBoot实现IP限流的原理、实践与利弊解析》在SpringBoot中实现IP限流是一种简单而有效的方式来保障系统的稳定性和可用性,本文给大家介绍SpringBoot实现IP限... 目录一、引言二、IP 限流原理2.1 令牌桶算法2.2 漏桶算法三、使用场景3.1 防止恶意攻击3.2 控制资源

springboot下载接口限速功能实现

《springboot下载接口限速功能实现》通过Redis统计并发数动态调整每个用户带宽,核心逻辑为每秒读取并发送限定数据量,防止单用户占用过多资源,确保整体下载均衡且高效,本文给大家介绍spring... 目录 一、整体目标 二、涉及的主要类/方法✅ 三、核心流程图解(简化) 四、关键代码详解1️⃣ 设置

Nginx 配置跨域的实现及常见问题解决

《Nginx配置跨域的实现及常见问题解决》本文主要介绍了Nginx配置跨域的实现及常见问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来... 目录1. 跨域1.1 同源策略1.2 跨域资源共享(CORS)2. Nginx 配置跨域的场景2.1

Python中提取文件名扩展名的多种方法实现

《Python中提取文件名扩展名的多种方法实现》在Python编程中,经常会遇到需要从文件名中提取扩展名的场景,Python提供了多种方法来实现这一功能,不同方法适用于不同的场景和需求,包括os.pa... 目录技术背景实现步骤方法一:使用os.path.splitext方法二:使用pathlib模块方法三

CSS实现元素撑满剩余空间的五种方法

《CSS实现元素撑满剩余空间的五种方法》在日常开发中,我们经常需要让某个元素占据容器的剩余空间,本文将介绍5种不同的方法来实现这个需求,并分析各种方法的优缺点,感兴趣的朋友一起看看吧... css实现元素撑满剩余空间的5种方法 在日常开发中,我们经常需要让某个元素占据容器的剩余空间。这是一个常见的布局需求

HTML5 getUserMedia API网页录音实现指南示例小结

《HTML5getUserMediaAPI网页录音实现指南示例小结》本教程将指导你如何利用这一API,结合WebAudioAPI,实现网页录音功能,从获取音频流到处理和保存录音,整个过程将逐步... 目录1. html5 getUserMedia API简介1.1 API概念与历史1.2 功能与优势1.3

Java实现删除文件中的指定内容

《Java实现删除文件中的指定内容》在日常开发中,经常需要对文本文件进行批量处理,其中,删除文件中指定内容是最常见的需求之一,下面我们就来看看如何使用java实现删除文件中的指定内容吧... 目录1. 项目背景详细介绍2. 项目需求详细介绍2.1 功能需求2.2 非功能需求3. 相关技术详细介绍3.1 Ja

使用Python和OpenCV库实现实时颜色识别系统

《使用Python和OpenCV库实现实时颜色识别系统》:本文主要介绍使用Python和OpenCV库实现的实时颜色识别系统,这个系统能够通过摄像头捕捉视频流,并在视频中指定区域内识别主要颜色(红... 目录一、引言二、系统概述三、代码解析1. 导入库2. 颜色识别函数3. 主程序循环四、HSV色彩空间详解

PostgreSQL中MVCC 机制的实现

《PostgreSQL中MVCC机制的实现》本文主要介绍了PostgreSQL中MVCC机制的实现,通过多版本数据存储、快照隔离和事务ID管理实现高并发读写,具有一定的参考价值,感兴趣的可以了解一下... 目录一 MVCC 基本原理python1.1 MVCC 核心概念1.2 与传统锁机制对比二 Postg

SpringBoot整合Flowable实现工作流的详细流程

《SpringBoot整合Flowable实现工作流的详细流程》Flowable是一个使用Java编写的轻量级业务流程引擎,Flowable流程引擎可用于部署BPMN2.0流程定义,创建这些流程定义的... 目录1、流程引擎介绍2、创建项目3、画流程图4、开发接口4.1 Java 类梳理4.2 查看流程图4