深度学习之生成唐诗案例(Pytorch版)

2023-11-21 12:20

本文主要是介绍深度学习之生成唐诗案例(Pytorch版),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

主要思路:

对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

 示例的唐诗大概有40000多首,

首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。

Dataset_Dataloader.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoaderdef deal_tangshi():with open("poems.txt", "r", encoding="utf-8") as fr:lines = fr.read().strip().split("\n")tangshis = []for line in lines:splits = line.split(":")if len(splits) != 2:continuetangshis.append("S" + splits[1] + "E")word2idx = {"S": 0, "E": 1}word2idx_count = 2tangshi_ids = []for tangshi in tangshis:for word in tangshi:if word not in word2idx:word2idx[word] = word2idx_countword2idx_count += 1idx2word = {idx: w for w, idx in word2idx.items()}for tangshi in tangshis:tangshi_ids.extend([word2idx[w] for w in tangshi])return word2idx, idx2word, tangshis, word2idx_count, tangshi_idsword2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()class TangShiDataset(Dataset):def __init__(self, tangshi_ids, num_chars):# 语料数据self.tangshi_ids = tangshi_ids# 语料长度self.num_chars = num_chars# 词的数量self.word_count = len(self.tangshi_ids)# 句子数量self.number = self.word_count // self.num_charsdef __len__(self):return self.numberdef __getitem__(self, idx):# 修正索引值到: [0, self.word_count - 1]start = min(max(idx, 0), self.word_count - self.num_chars - 2)x = self.tangshi_ids[start: start + self.num_chars]y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]return torch.tensor(x), torch.tensor(y)def __test_Dataset():dataset = TangShiDataset(tangshi_ids, 8)x, y = dataset[0]print(x, y)if __name__ == '__main__':# deal_tangshi()__test_Dataset()
TangShiModel.py:唐诗的模型
import torch
import torch.nn as nn
from Dataset_Dataloader import *
import torch.nn.functional as Fclass TangShiRNN(nn.Module):def __init__(self, vocab_size):super().__init__()# 初始化词嵌入层self.ebd = nn.Embedding(vocab_size, 128)# 循环网络层self.rnn = nn.RNN(128, 128, 1)# 输出层self.out = nn.Linear(128, vocab_size)def forward(self, inputs, hidden):embed = self.ebd(inputs)# 正则化层embed = F.dropout(embed, p=0.2)output, hidden = self.rnn(embed.transpose(0, 1), hidden)# 正则化层embed = F.dropout(output, p=0.2)output = self.out(output.squeeze())return output, hiddendef init_hidden(self):return torch.zeros(1, 64, 128)

 main.py:

import timeimport torchfrom Dataset_Dataloader import *
from TangShiModel import *
import torch.optim as optim
from tqdm import tqdmdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def train():dataset = TangShiDataset(tangshi_ids, 128)epochs = 100model = TangShiRNN(word2idx_count).to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)for idx in range(epochs):dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)start_time = time.time()total_loss = 0total_num = 0total_correct = 0total_correct_num = 0hidden = model.init_hidden()for x, y in tqdm(dataloader):x = x.to(device)y = y.to(device)# 隐藏状态hidden = model.init_hidden()hidden = hidden.to(device)# 模型计算output, hidden = model(x, hidden)# print(output.shape)# print(y.shape)# 计算损失loss = criterion(output.permute(1, 2, 0), y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_loss += loss.sum().item()total_num += len(y)total_correct_num += y.shape[0] * y.shape[1]# print(output.shape)total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %(idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")if __name__ == '__main__':train()

predict.py:

import torch
import torch.nn as nn
from Dataset_Dataloader import *
from TangShiModel import *device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def predict():model = TangShiRNN(word2idx_count)model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))model.eval()hidden = torch.zeros(1, 1, 128)start_word = input("输入第一个字:")flag = Nonetangshi_strs = []while True:if not flag:outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)tangshi_strs.append("S")flag = Trueelse:tangshi_strs.append(start_word)outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)top_i = torch.argmax(outputs, dim=-1)if top_i.item() == word2idx["E"]:breakprint(top_i)start_word = idx2word[top_i.item()]print(tangshi_strs)if __name__ == '__main__':predict()

完整代码如下:

https://github.com/STZZ-1992/tangshi-generator.giticon-default.png?t=N7T8https://github.com/STZZ-1992/tangshi-generator.git

这篇关于深度学习之生成唐诗案例(Pytorch版)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/402341

相关文章

Django HTTPResponse响应体中返回openpyxl生成的文件过程

《DjangoHTTPResponse响应体中返回openpyxl生成的文件过程》Django返回文件流时需通过Content-Disposition头指定编码后的文件名,使用openpyxl的sa... 目录Django返回文件流时使用指定文件名Django HTTPResponse响应体中返回openp

深度解析Spring Security 中的 SecurityFilterChain核心功能

《深度解析SpringSecurity中的SecurityFilterChain核心功能》SecurityFilterChain通过组件化配置、类型安全路径匹配、多链协同三大特性,重构了Spri... 目录Spring Security 中的SecurityFilterChain深度解析一、Security

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

RabbitMQ消费端单线程与多线程案例讲解

《RabbitMQ消费端单线程与多线程案例讲解》文章解析RabbitMQ消费端单线程与多线程处理机制,说明concurrency控制消费者数量,max-concurrency控制最大线程数,prefe... 目录 一、基础概念详细解释:举个例子:✅ 单消费者 + 单线程消费❌ 单消费者 + 多线程消费❌ 多

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

python生成随机唯一id的几种实现方法

《python生成随机唯一id的几种实现方法》在Python中生成随机唯一ID有多种方法,根据不同的需求场景可以选择最适合的方案,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习... 目录方法 1:使用 UUID 模块(推荐)方法 2:使用 Secrets 模块(安全敏感场景)方法

MySql基本查询之表的增删查改+聚合函数案例详解

《MySql基本查询之表的增删查改+聚合函数案例详解》本文详解SQL的CURD操作INSERT用于数据插入(单行/多行及冲突处理),SELECT实现数据检索(列选择、条件过滤、排序分页),UPDATE... 目录一、Create1.1 单行数据 + 全列插入1.2 多行数据 + 指定列插入1.3 插入否则更

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

PostgreSQL的扩展dict_int应用案例解析

《PostgreSQL的扩展dict_int应用案例解析》dict_int扩展为PostgreSQL提供了专业的整数文本处理能力,特别适合需要精确处理数字内容的搜索场景,本文给大家介绍PostgreS... 目录PostgreSQL的扩展dict_int一、扩展概述二、核心功能三、安装与启用四、字典配置方法

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析