当前位置: 首页 > news >正文

深度学习之生成唐诗案例(Pytorch版)

主要思路:

对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

 示例的唐诗大概有40000多首,

首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。

Dataset_Dataloader.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoaderdef deal_tangshi():with open("poems.txt", "r", encoding="utf-8") as fr:lines = fr.read().strip().split("\n")tangshis = []for line in lines:splits = line.split(":")if len(splits) != 2:continuetangshis.append("S" + splits[1] + "E")word2idx = {"S": 0, "E": 1}word2idx_count = 2tangshi_ids = []for tangshi in tangshis:for word in tangshi:if word not in word2idx:word2idx[word] = word2idx_countword2idx_count += 1idx2word = {idx: w for w, idx in word2idx.items()}for tangshi in tangshis:tangshi_ids.extend([word2idx[w] for w in tangshi])return word2idx, idx2word, tangshis, word2idx_count, tangshi_idsword2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()class TangShiDataset(Dataset):def __init__(self, tangshi_ids, num_chars):# 语料数据self.tangshi_ids = tangshi_ids# 语料长度self.num_chars = num_chars# 词的数量self.word_count = len(self.tangshi_ids)# 句子数量self.number = self.word_count // self.num_charsdef __len__(self):return self.numberdef __getitem__(self, idx):# 修正索引值到: [0, self.word_count - 1]start = min(max(idx, 0), self.word_count - self.num_chars - 2)x = self.tangshi_ids[start: start + self.num_chars]y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]return torch.tensor(x), torch.tensor(y)def __test_Dataset():dataset = TangShiDataset(tangshi_ids, 8)x, y = dataset[0]print(x, y)if __name__ == '__main__':# deal_tangshi()__test_Dataset()
TangShiModel.py:唐诗的模型
import torch
import torch.nn as nn
from Dataset_Dataloader import *
import torch.nn.functional as Fclass TangShiRNN(nn.Module):def __init__(self, vocab_size):super().__init__()# 初始化词嵌入层self.ebd = nn.Embedding(vocab_size, 128)# 循环网络层self.rnn = nn.RNN(128, 128, 1)# 输出层self.out = nn.Linear(128, vocab_size)def forward(self, inputs, hidden):embed = self.ebd(inputs)# 正则化层embed = F.dropout(embed, p=0.2)output, hidden = self.rnn(embed.transpose(0, 1), hidden)# 正则化层embed = F.dropout(output, p=0.2)output = self.out(output.squeeze())return output, hiddendef init_hidden(self):return torch.zeros(1, 64, 128)

 main.py:

import timeimport torchfrom Dataset_Dataloader import *
from TangShiModel import *
import torch.optim as optim
from tqdm import tqdmdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def train():dataset = TangShiDataset(tangshi_ids, 128)epochs = 100model = TangShiRNN(word2idx_count).to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)for idx in range(epochs):dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)start_time = time.time()total_loss = 0total_num = 0total_correct = 0total_correct_num = 0hidden = model.init_hidden()for x, y in tqdm(dataloader):x = x.to(device)y = y.to(device)# 隐藏状态hidden = model.init_hidden()hidden = hidden.to(device)# 模型计算output, hidden = model(x, hidden)# print(output.shape)# print(y.shape)# 计算损失loss = criterion(output.permute(1, 2, 0), y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_loss += loss.sum().item()total_num += len(y)total_correct_num += y.shape[0] * y.shape[1]# print(output.shape)total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %(idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")if __name__ == '__main__':train()

predict.py:

import torch
import torch.nn as nn
from Dataset_Dataloader import *
from TangShiModel import *device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def predict():model = TangShiRNN(word2idx_count)model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))model.eval()hidden = torch.zeros(1, 1, 128)start_word = input("输入第一个字:")flag = Nonetangshi_strs = []while True:if not flag:outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)tangshi_strs.append("S")flag = Trueelse:tangshi_strs.append(start_word)outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)top_i = torch.argmax(outputs, dim=-1)if top_i.item() == word2idx["E"]:breakprint(top_i)start_word = idx2word[top_i.item()]print(tangshi_strs)if __name__ == '__main__':predict()

完整代码如下:

https://github.com/STZZ-1992/tangshi-generator.giticon-default.png?t=N7T8https://github.com/STZZ-1992/tangshi-generator.git

相关文章:

深度学习之生成唐诗案例(Pytorch版)

主要思路: 对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。 示例的唐诗大概有40000多首, 首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。…...

算法设计与分析算法实现——删数问题

通过棋盘输入一个高精度的正整数n(n的有效位数<=240)去掉其中任意s个数字后,剩下的数字按原左右次序将组成一个新的正整数。变成对给定的n和s,寻找一种方案,使得剩下的数字组成的新数最小。 输入:n,s 输出:最后剩下的最小数 输入实例: 178543 4 输出示例: 13 首先…...

基于Vue+SpringBoot的超市账单管理系统 开源项目

项目编号&#xff1a; S 032 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S032&#xff0c;文末获取源码。} 项目编号&#xff1a;S032&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统设计3.1 总体设计3.2 前端设计3…...

【Linux 内核分析课程作业 1】mmap 实现一个 key-valueMap

作业一 功能要求利用 mmap(虚拟内存映射文件) 机制实现一个带持久化能力的 key-valueMap 系统&#xff0c;至少支持单机单进程访问。(可能用到的 linux API: mmap、msync、mremap、munmap、ftruncate、fallocate 等) 电子版提交方式&#xff1a; 2023 年 11 月 20 日 18:00 前通…...

docker compose使用教程(docker-compose教程)

文章目录 Docker Compose 使用教程安装Docker ComposeLinuxWindows 和 macOS Docker Compose 基础Compose 文件结构配置服务网络与卷 Docker Compose 命令启动服务停止服务查看服务状态查看日志缩放服务 多环境部署健康检查与依赖管理Docker Compose 最佳实践常见问题解析如何覆…...

印刷企业实施MES管理系统需要哪些硬件设施

随着科技的飞速发展&#xff0c;印刷行业正面临着前所未有的挑战和机遇。为了提高生产效率&#xff0c;降低成本&#xff0c;并增强市场竞争力&#xff0c;越来越多的印刷企业开始实施制造执行系统&#xff08;MES&#xff09;管理系统。本文将重点讨论印刷企业在实施MES管理系…...

Java JSON字符串替换其中对应的值

代码&#xff1a; public static void main(String[] args) { // String theData crmScene.getData();String theData "[{\"type\":1,\"values\":[\"审批中\",\"未交付\"],\"name\":\"status\"}]"…...

Android VSYNC发展历程

0 前言 安卓直到android-4.1.1_r1才首次引入VSYNC实现&#xff0c;然后逐步演进到android-4.4才得以完善&#xff0c;并在android-11、12后继续大改。 1 尚未引入 android-4.0.4_r2.1之前尚未引入VSYNC[1]&#xff0c;SurfaceFlinger被实现为一个线程&#xff0c;通过睡眠来实…...

外呼系统作用和优势有哪些okcc,ai源码

随着外呼系统诞生&#xff0c;普通中小企业也开始广泛使用&#xff0c;系统给他们带来更多的服务方式和提升业绩的可能。然而&#xff0c;许多企业对外呼系统的理解相对片面和简单&#xff0c;认为它是一个成本中心&#xff0c;需要继续投入人力和使用。事实上&#xff0c;外呼…...

智元机器人岗位内推

Hi there &#x1f44b; 智元机器人招聘信息 官网&#xff1a; https://www.agibot.com/ 内推联系 邮箱&#xff1a;jiejietopgmail.com微信&#xff1a;yij1799 高级C软件工程师&#xff08;上海&#xff09; 岗位职责&#xff1a; 开发自研机器人操作系统&#xff0c;…...

el-popover和el-tooltip样式修改(普通的组件样式修改方法,对popover是不生效的)

第一步:‘popper-class’=‘popperClass’ //添加类名 <el-table-column label="审核状态" align="center"><template slot-scope="scope"><el-popoverpopper-class="addformPanel"placement="top"width=&…...

【AI实用技巧】GPT写sql统计语句

编写sql的统计语句是一项复杂的任务&#xff0c;特别是涉及多表的情况下。但有了GPT的帮助&#xff0c;一切变得轻松愉快。 AI7号 - 最强人工智能&#xff08;GPT&#xff09;中文版https://ai7.pro/s/9v2um 举例说明 有表结构如下&#xff1a; users(user_id, name) bills(…...

LeetCode(31)无重复字符的最长子串【滑动窗口】【中等】

目录 1.题目2.答案3.提交结果截图 链接&#xff1a; 无重复字符的最长子串 1.题目 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的 最长子串 的长度。 示例 1: 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是 "abc"&…...

天猫超市电商营销系统:无代码开发实现API连接集成

无代码开发实现天猫超市与电商系统的高效连接 天猫超市&#xff0c;作为天猫推出的网络零售超市&#xff0c;为广大网购消费者提供了一站式的购物服务。而通过无代码开发的方式&#xff0c;天猫超市能够实现与各种电商系统的连接和集成&#xff0c;这种方式无需进行繁琐的API开…...

element表格分页+数据过滤筛选

目录 前言效果展示分页效果展示搜索效果展示 代码分析分页功能过滤数据功能 全部代码 前言 在el-element的标签里的tableData数据过多时&#xff0c;会导致表格页面操作卡顿。为解决这一问题&#xff0c;有以下解决方法&#xff1a; 分页加载&#xff1a; 将大量数据进行分页&…...

小程序判断是否授权位置信息和手动授权

文章目录 概要微信小程序的&#xff0c;使用的是高德地图 概要 当用户来到页面之后就会弹出是否授权弹框&#xff0c;但是如果第一次关闭之后&#xff0c;下一次再过来的话页面的授权弹框就不出现了&#xff0c;针对于这种情况写了一个方法 微信小程序的&#xff0c;使用的是…...

2023年亚太杯数学建模亚太赛A题思路解析+代码+论文

下文包含&#xff1a;2023年亚太杯数学建模亚太赛A题思路解析代码参考论文等及如何准备数学建模竞赛&#xff08;23号比赛开始后逐步更新&#xff09; C君将会第一时间发布选题建议、所有题目的思路解析、相关代码、参考文献、参考论文等多项资料&#xff0c;帮助大家取得好成…...

【Android】画面卡顿优化列表流畅度六(终篇)

上一篇&#xff1a; 【Android】画面卡顿优化列表流畅度五之下拉刷新上拉加载更多组件RefreshLayout修改 场景回顾&#xff1a; 业务经过一年半左右的运行后&#xff0c;出现了明显的列表卡顿情况&#xff1b;于是开始着手进行列表卡顿优化。目前的情况是&#xff1a; 网络图…...

一文了解:离散型制造业轻量化MES解决方案

离散型制造业的特点 离散型生产行业主要是通过对原材料物理形状的改变、组装&#xff0c;成为产品&#xff0c;使其增值。典型的离散型行业包括&#xff1a;机械、电子、航空、汽车等行业。这些企业既有按订单生产&#xff08;MTO&#xff09;&#xff0c;也有按照库存生产&am…...

《云计算:云端协同,智慧互联》

《云计算&#xff1a;云端协同&#xff0c;智慧互联》 云计算&#xff0c;这个科技领域中的热门词汇&#xff0c;正在逐渐改变我们的生活方式。它像一座座无形的桥梁&#xff0c;将世界各地的设备、数据、应用紧密连接在一起&#xff0c;实现了云端协同&#xff0c;智慧互联的愿…...

[2025CVPR]DeepVideo-R1:基于难度感知回归GRPO的视频强化微调框架详解

突破视频大语言模型推理瓶颈,在多个视频基准上实现SOTA性能 一、核心问题与创新亮点 1.1 GRPO在视频任务中的两大挑战 ​安全措施依赖问题​ GRPO使用min和clip函数限制策略更新幅度,导致: 梯度抑制:当新旧策略差异过大时梯度消失收敛困难:策略无法充分优化# 传统GRPO的梯…...

React 第五十五节 Router 中 useAsyncError的使用详解

前言 useAsyncError 是 React Router v6.4 引入的一个钩子&#xff0c;用于处理异步操作&#xff08;如数据加载&#xff09;中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误&#xff1a;捕获在 loader 或 action 中发生的异步错误替…...

RocketMQ延迟消息机制

两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数&#xff0c;对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后&#xf…...

uni-app学习笔记二十二---使用vite.config.js全局导入常用依赖

在前面的练习中&#xff0c;每个页面需要使用ref&#xff0c;onShow等生命周期钩子函数时都需要像下面这样导入 import {onMounted, ref} from "vue" 如果不想每个页面都导入&#xff0c;需要使用node.js命令npm安装unplugin-auto-import npm install unplugin-au…...

Python爬虫实战:研究feedparser库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...

关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案

问题描述&#xff1a;iview使用table 中type: "index",分页之后 &#xff0c;索引还是从1开始&#xff0c;试过绑定后台返回数据的id, 这种方法可行&#xff0c;就是后台返回数据的每个页面id都不完全是按照从1开始的升序&#xff0c;因此百度了下&#xff0c;找到了…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下&#xff0c;江苏艾立泰以一场跨国资源接力的创新实践&#xff0c;重新定义了绿色供应链的边界。 跨国回收网络&#xff1a;废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点&#xff0c;将海外废弃包装箱通过标准…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术&#xff0c;它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton)&#xff1a;由层级结构的骨头组成&#xff0c;类似于人体骨骼蒙皮 (Mesh Skinning)&#xff1a;将模型网格顶点绑定到骨骼上&#xff0c;使骨骼移动…...

浪潮交换机配置track检测实现高速公路收费网络主备切换NQA

浪潮交换机track配置 项目背景高速网络拓扑网络情况分析通信线路收费网络路由 收费汇聚交换机相应配置收费汇聚track配置 项目背景 在实施省内一条高速公路时遇到的需求&#xff0c;本次涉及的主要是收费汇聚交换机的配置&#xff0c;浪潮网络设备在高速项目很少&#xff0c;通…...