当前位置: 首页 > news >正文

NLP文本自动生成介绍及Char-RNN中文文本自动生成训练demo

前言

文本自动生成是自然语言处理领域的一个重要研究方向,实现文本自动生成也是人工智能走向成熟的一个重要标志。文本自动生成技术极具应用前景。
例如,文本自动生成技术可以应用于智能问答与对话、机器翻译等系统,实现更加智能和自然的人机交互;也可以通过文本自动生成系统替代编辑实现新闻的自动撰写与发布,最终将有可能颠覆新闻出版行业;该项技术甚至可以用来帮助学者进行学术论文撰写,进而改变科研创作模式。

按照不同的输入划分,文本自动生成可包括文本到文本的生成(text-to-text generation)、意义到文本的生成(meaning-to-text generation)、数据到文本的生成(data-to-text generation)以及图像到文本的生成(image-to-text generation)等。上述每项技术均极具挑战性,在自然语言处理与人工智能领域均有相当多的前沿研究,近几年业界已产生了若干具有国际影响力的成果与应用。

本文主要简单介绍文本生成中最为成熟的领域的——文本到文本的生成的一些常用算法,最后实操部分则是使用中文数据训练Char-RNN
模型生成中文文本。

文本生成算法

首先,啥是文本生成,简单来说,就是输入一段文本,经过自然语言模型之后,生成一段新的文本,如下图所示
在这里插入图片描述
这便是文本自动补全场景下的文本生成,这种应用场景多见于智能问答与对话;如果是机器翻译场景下的文本生成,那模型输入则是需要被翻译的文本,模型输出是翻译后的文本语言;同样的,文本摘要,则更好理解,就是输入一段文本,模型输出这段文本的概括文本。

在很长一段时间里,文本生成主要都基于Seq2Seq模型,所谓的Seq2Seq模型就是使用上一个时刻的值来预测下一个时刻的值,两个常用的模型是GRU和LSTM。然而,用 RNN 生成的文本远非也会有一些问题,比如,RNN模型有时候会输出一些莫名其妙的文本,有时还包括一些基本的拼写错误,而其中一个时刻的错误输出,则会让整段文本的输出变得不可用。此外,在推理过程中的无法并行化也是RNN模型在处理序列数据时的一个致命缺陷。
Seq2seq model
后来,为了解决RNN模型的缺陷,谷歌在2017年发布了一篇经典文章"Attention Is All You Need", 文章中提出了Transformer模型。Transformer是包含了自注意力机制、全连接层的同样带有编码器和解码器的全新的网络结构,同时由于Transformer模型中没有包含RNN网络结构,使其可以并行运算,大大提升了模型的训练和推理时间。当然,模型的参数量也比RNN提升了数倍,模型拟合能力也得到了大大的提升。

: Transformer-model
再后来随着深度学习领域的发展,业界提出了更多更大的模型来解决NLP领域的问题,当然,这里也包括了文本生成这一领域。随着BERT、GPT-2、GPT-3等等大模型的提出,使得文本生成的开发可以使用少量场景数据在大模型的基础上做fine-tuned,这样也可以得到远超过简单的Seq2Seq模型的效果。
在这里插入图片描述
当然,文本生成领域内容太多,所涉及的算法也很复杂,笔者这里提到的只是一些常规的模型和技术方法,对其他的模型和算法感兴趣的可以参考文末的参考文章继续深入阅读。

中文char-rnn文本生成训练demo

char-rnn之于文本生成领域的地位,与手写mnist之一图像分类领域地位一样,可以说,就是一个入门级别的模型,就是使用RNN模型,输入一个字符,输出一个字符,最后要么达到字数限制,要么输出结束字符,这样就完成文本生成的任务。

这里的中文char-rnn,训练数据使用的中文小说,使用结巴分词将语料进行预处理,然后将分词的结果再进行Embedding编码。

数据预处理

whole = open('text/白夜行.txt', encoding='utf-8').read()
all_words = list(jieba.cut(whole, cut_all=False))  # jieba分词
words = sorted(list(set(all_words)))
word_indices = dict((word, words.index(word)) for word in words)maxlen = 30
epoch_num = 100 
class TextTensorDataset(Dataset):def __init__(self, all_words, maxlen, word_indices):sentences = []next_word = []for i in range(0, len(all_words) - maxlen):sentences.append(all_words[i: i + maxlen])next_word.append(all_words[i + maxlen])print('提取的句子总数:', len(sentences))self.inputs = np.zeros((len(sentences), maxlen), dtype='float32') # 先将每个inputs切成30个词的句子列表,然后将句子中的词转化成index索引self.labels = np.zeros((len(sentences)), dtype='float32')for i, sentence in enumerate(sentences):for t, word in enumerate(sentence):self.inputs[i, t] = word_indices[word]self.labels[i] = word_indices[next_word[i]]def __getitem__(self, item):# x = np.expand_dims(self.inputs[item], axis=0)# y = np.expand_dims(self.labels[item], axis=0)return self.inputs[item], self.labels[item]def __len__(self):return len(self.inputs)

模型定义

class LSTM(torch.nn.Module):def __init__(self, hidden_size1, hidden_size2, vocab_size, input_size, num_layers):super().__init__()self.embed = torch.nn.Embedding(vocab_size, input_size, max_norm = 1)self.lstm1 = torch.nn.LSTM(input_size, hidden_size1, num_layers, batch_first=True, bidirectional=True)self.lstm2 = torch.nn.LSTM(hidden_size1*2, hidden_size2, num_layers, batch_first=True, bidirectional=True)self.dropout = torch.nn.Dropout(0.1)self.line = torch.nn.Linear(hidden_size2 * maxlen * 2, vocab_size)self.softmax = torch.nn.Softmax(dim=1)def forward(self, x):x = self.embed(x)      output1, _ = self.lstm1(x) output, _ = self.lstm2(output1) out_d_reshaped = output.reshape(output.shape[0], (output.shape[1] * output.shape[2]) )line_o = self.line(out_d_reshaped)pred = self.softmax(line_o)#print(pred.shape)return pred

模型使用了两个双向的LSTM,然后再接了一个全连接层,整体都比较简单,没有什么可以详细阐述的

模型训练

hidden_size1, hidden_size2, vocab_size, input_size, num_layers = 256, 128, len(words), 128, 2model = LSTM(hidden_size1, hidden_size2, vocab_size, input_size, num_layers).to(device)loss_function = torch.nn.NLLLoss().to(device)optimizer = torch.optim.RMSprop(model.parameters(), lr=3e-3)mydataset = TextTensorDataset(all_words, maxlen, word_indices)train_loader = DataLoader(mydataset, batch_size=1024, shuffle=True)# training
model.train()
h_state = Nonefor epoch in range(epoch_num):total_loss = 0items = 0for batch_x, batch_label in (train_loader):x = Variable(torch.LongTensor(batch_x.numpy())).cuda()#torch.Size([1024, 30, 1])pred = model(x)pred = torch.log(pred.view(-1, vocab_size) + 1e-20)    #print('pred shape ',  pred.shape)target = Variable(batch_label.view(-1)).cuda()#print('target shape ',  target.shape)loss = loss_function(pred, target.long())        optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()items += 1print('Epoch {}, Step {} Train Loss {}'.format(epoch, items, loss.item() ) )#save model every 10 epochesif epoch % 10 == 0:if not os.path.exists("./new_trained"):os.makedirs("./new_trained")directory = './new_trained/rnn_novel'+str(epoch)+'.pkl'torch.save(model, directory)

预测代码

def write_words(model, word_num, begin_sentence):gg = begin_sentence[:30]print(''.join(gg), end='/// ')for _ in range(word_num):sampled = np.zeros((1, maxlen)) for t, char in enumerate(gg):sampled[0, t] = word_indices[char]x = Variable(torch.LongTensor(sampled)).cuda()preds = model(x)next_word = words[np.argmax(preds.data.cpu().numpy())]gg.append(next_word)gg = gg[1:]sys.stdout.write(next_word)sys.stdout.flush()begin_sentence = whole[50003: 50100]
print("初始句:", begin_sentence[:30])
begin_sentence = list(jieba.cut(begin_sentence, cut_all=False))write_words(model, 300, begin_sentence)

这里为了方便简单,在模型完成训练之后,即刻进行模型预测,模型预测的效果如下:

参考

运用深度学习进行文本生成
torch.nn.Embedding使用详解
【pytorch】关于Embedding和GRU、LSTM的使用详解
Pytorch损失函数torch.nn.NLLLoss()详解
Text Generation
Char RNN原理介绍以及文本生成实践
MODERN METHODS OF TEXT GENERATION
文本自动生成研究进展与趋势

相关文章:

NLP文本自动生成介绍及Char-RNN中文文本自动生成训练demo

前言 文本自动生成是自然语言处理领域的一个重要研究方向,实现文本自动生成也是人工智能走向成熟的一个重要标志。文本自动生成技术极具应用前景。 例如,文本自动生成技术可以应用于智能问答与对话、机器翻译等系统,实现更加智能和自然的人机…...

Teradata 离场,企业数据分析平台如何应对变革?

近日大数据分析和数仓软件巨头 Teradata(TD)宣布基于中国商业环境的评估,退出在中国的直接运营。TD 是全球最大的专注于大数据分析、数仓和整合营销管理解决方案的供应商之一,其早在 1997 年就进入中国,巅峰期占据半数…...

QWebEngineView-官翻

文章目录特性公共成员函数重实现公共成员函数公有槽函数信号静态公有成员函数保护成员函数重实现保护成员函数额外继承成员详细描述特性文档编制成员函数文档QWebEngineView::**QWebEngineView**([QWidget](../../W/QWidget.md) **parent* Q_NULLPTR)[virtual] QWebEngineView…...

网络安全高级攻击

对分类器的高层次攻击可以分为以下三种类型:对抗性输入:这是专门设计的输入,旨在确保被误分类,以躲避检测。对抗性输入包含专门用来躲避防病毒程序的恶意文档和试图逃避垃圾邮件过滤器的电子邮件。数据中毒攻击:这涉及…...

优思学院:六西格玛中的水平对比方法是什么?

水平对比,就是比较不同事物之间的差异。 这个概念在六西格玛管理中也很重要,也就是我们经常说的标杆管理,经常被用来寻找行业中最好的做法,以帮助组织改进自身的绩效。 在六西格玛管理中,水平对比有三种常见的应用方式…...

UVa 690 Pipeline Scheduling 流水线调度 二进制表示状态 DFS 剪枝

题目链接:Pipeline Scheduling 题目描述: 给定一张5n(1≤n≤20)5\times n(1\le n\le20)5n(1≤n≤20)的资源需求表,第iii行第jjj列的值为’X’表示进程在jjj时刻需要使用使用资源iii,如果为’.则表示不需要使用。你的任务是安排十个…...

【ArcGIS Pro二次开发】(6):工程(Project)的基本操作

在ArcGIS Pro中我们对工程的基本操作一般包括打开、新建、保存等。下面演示在二次开发中如何用代码进行以上操作。 新建一个项目,命名为【ProjectManager】,添加8个按钮,命名为【CreateEmptyProject、CreateProjectByDefault、OpenExProjest…...

Qt OpenGL(四十)——Qt OpenGL 核心模式-雷达扫描效果

提示:本系列文章的索引目录在下面文章的链接里(点击下面可以跳转查看): Qt OpenGL 核心模式版本文章目录 Qt OpenGL(四十)——Qt OpenGL 核心模式-雷达扫描效果 一、场景 上一篇文章介绍了在雷达坐标系中绘制飞行的飞机,其实雷达坐标系应该还有一个效果,就是扫描的效…...

群智能优化算法求解标准测试函数F1~F23之种群动态分布图(视频)

群智能优化算法求解标准测试函数F1的种群动态分布图群智能优化算法求解标准测试函数F2的种群动态分布图群智能优化算法求解标准测试函数F3的种群动态分布图群智能优化算法求解标准测试函数F4的种群动态分布图群智能优化算法求解标准测试函数F5的种群动态分布图群智能优化算法求…...

vue-axios封装与使用

一、简介 Axios 是一个基于 promise 网络请求库,作用于node.js 和浏览器中。 这是一个使用率很高的前端网络请求库,几乎所有的前端项目都会使用,本文主要介绍的是如何在vue项目中使用axios,并对其进行全面的封装。 注意&#x…...

重要节点排序方法

文章目录研究背景提前约定基于节点近邻的排序方法度中心性(degree centrality, DC)半局部中心性(semilocal centrality, SLC)k-壳分解法基于路径排序的方法离心中心性 (Eccentricity, ECC)接近中心性 (closeness centrality, CC)K…...

【2.20】动态规划 +项目 + 存储引擎

01背包问题 现有一容量为w的背包,有3个物品,每个物品重量不同,价值不同,问,怎样装才能价值最大化? 明确dp数组含义和下标含义:dp[j]表示当前背包的最大价值。j表示背包容量。递推公式&#xf…...

触摸屏单个按键远程控制led

一、硬件 arduino2块 淘晶驰串口屏7寸增强型带外壳1块,不支持视频音频 nRF24L0模块2块 扩展板2块 跳线若干 面包板1块 led灯1个 电阻二极管若干 下载线两个 usb转串口1个 二、实验内容 一个arduino作为触摸屏的控制器,接收触摸屏双向开关的信号,同时通过nRF24L01发送“open”…...

JVM12 class文件

1. Class 文件结构 1.1. Class 字节码文件结构 类型名称说明长度数量魔数u4magic魔数,识别Class文件格式4个字节1版本号u2minor_version副版本号(小版本)2个字节1u2major_version主版本号(大版本)2个字节1常量池集合u2constant_pool_count常量池计数器2个字节1cp_infoconstan…...

等保三级认证基本要求

一、什么是等保测评? 企业单位委托经公安部认证的具有资质的测评机构,按照管理规范和技术标准,对相应的测评对象(信息系统)的状况进行测评。 1、安全技术测评:包括物理安全、网络安全、主机系统安全、应用安…...

Python 基本数据类型(一)

1. 整型 整型即整数,用 int 表示,在 Python3 中整型没有长度限制。 1.1 内置函数 1. int(num, baseNone) int( ) 函数用于将字符串转换为整型,默认转换为十进制。 >>> int(123) 123 >>> int(123, …...

win10 环境变量及其作用大全

------------------------------------------------------系统变量------------------------------------------------------ ComSpec: C:\WINDOWS\system32\cmd.exe command specification 解释: ComSpec是Windows操作系统中的一个环境变量,它表示Windo…...

@Valid与@Validated的区别

1.介绍 说明: 其实Valid 与 Validated都是做数据校验的,只不过注解位置与用法有点不同。 不同点: (1) Valid是使用Hibernate validation的时候使用。Validated是只用Spring Validator校验机制使用。 (2&…...

【LeetCode】剑指 Offer 09. 用两个栈实现队列 p68 -- Java Version

题目链接:https://leetcode.cn/problems/yong-liang-ge-zhan-shi-xian-dui-lie-lcof/ 1. 题目介绍(09. 用两个栈实现队列) 用两个栈实现一个队列。队列的声明如下,请实现它的两个函数 appendTail 和 deleteHead ,分别…...

Java并发编程面试题——JUC专题

文章目录一、AQS高频问题1.1 AQS是什么?1.2 唤醒线程时,AQS为什么从后往前遍历?1.3 AQS为什么用双向链表,(为啥不用单向链表)?1.4 AQS为什么要有一个虚拟的head节点1.5 ReentrantLock的底层实现…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目:3442. 奇偶频次间的最大差值 I 思路 :哈希,时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况,哈希表这里用数组即可实现。 C版本: class Solution { public:int maxDifference(string s) {int a[26]…...

【磁盘】每天掌握一个Linux命令 - iostat

目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...

ffmpeg(四):滤镜命令

FFmpeg 的滤镜命令是用于音视频处理中的强大工具,可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下: ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜: ffmpeg…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

浪潮交换机配置track检测实现高速公路收费网络主备切换NQA

浪潮交换机track配置 项目背景高速网络拓扑网络情况分析通信线路收费网络路由 收费汇聚交换机相应配置收费汇聚track配置 项目背景 在实施省内一条高速公路时遇到的需求,本次涉及的主要是收费汇聚交换机的配置,浪潮网络设备在高速项目很少,通…...

基于PHP的连锁酒店管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的连锁酒店管理系统 一 介绍 连锁酒店管理系统基于原生PHP开发,数据库mysql,前端bootstrap。系统角色分为用户和管理员。 技术栈 phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注册/登录/注销 2 个人中…...

为什么要创建 Vue 实例

核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...

6个月Python学习计划 Day 16 - 面向对象编程(OOP)基础

第三周 Day 3 🎯 今日目标 理解类(class)和对象(object)的关系学会定义类的属性、方法和构造函数(init)掌握对象的创建与使用初识封装、继承和多态的基本概念(预告) &a…...

C++实现分布式网络通信框架RPC(2)——rpc发布端

有了上篇文章的项目的基本知识的了解,现在我们就开始构建项目。 目录 一、构建工程目录 二、本地服务发布成RPC服务 2.1理解RPC发布 2.2实现 三、Mprpc框架的基础类设计 3.1框架的初始化类 MprpcApplication 代码实现 3.2读取配置文件类 MprpcConfig 代码实现…...

使用SSE解决获取状态不一致问题

使用SSE解决获取状态不一致问题 1. 问题描述2. SSE介绍2.1 SSE 的工作原理2.2 SSE 的事件格式规范2.3 SSE与其他技术对比2.4 SSE 的优缺点 3. 实战代码 1. 问题描述 目前做的一个功能是上传多个文件,这个上传文件是整体功能的一部分,文件在上传的过程中…...