自然语言处理入门6——RNN生成文本
一、文本生成
我们在前面的文章中介绍了LSTM,根据输入时序数据可以输出下一个可能性最高的数据,如果应用在文字上,就是根据输入的文字,可以预测下一个可能性最高的文字。利用这个特点,我们可以用LSTM来生成文本。输入一个单词,做embedding处理后,再输入到LSTM,会输出备选单词的得分,经过softmax得到概率,我们选择概率最高的单词作为输出。这种输出叫做确定性输出,如下图所示:

如果我们根据概率随机选择一个单词输出,这叫做概率性输出,类似于大语言模型中的temperature,当temperature高时,生成的变化越多,确定性越低。下面的代码是根据RNN模型生成文字的代码,其中start_id代表生成开始的第一个单词编号,skip_ids代表需要过滤的单词,比如空或者数字等等,sample_size代表采样大小,这里就是生成的单词的数量。
class RnnlmGen(Rnnlm):def generate(self, start_id, skip_ids=None, sample_size=100):word_ids = [start_id] # 最终的单词编号列表x = start_id # 第一个单词编号# 如果生成的单词列表长度还小于sample_size就继续生成while len(word_ids) < sample_size:# 转变shape便于输入模型x = np.array(x).reshape(1,1)# 根据模型预测输出单词得分score = self.predict(x)# 得到概率p = softmax(score.flatten())# 根据概率选择输出的单词编号sampled = np.random.choice(len(p), size=1, p=p)if (skip_ids is None) or (sampled not in skip_ids):x = sampledword_ids.append(int(x))return word_idscorpus, word_to_id, id_to_word = load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)
model = RnnlmGen()
# 设定start单词和skip单词
start_word = 'you'
start_id = word_to_id[start_word]
skip_words = ['N','<unk>','$']
skip_ids = [word_to_id[w] for w in skip_words]
# 生成文本
word_ids = model.generate(start_id, skip_ids)
txt = ' '.join(id_to_word[i] for i in word_ids)
txt = txt.replace(' <eos>','.\n')
print(txt)
# 输出:
you freshman retreat teeth instantly enhanced university brands exceptionally affiliates
various unfair leslie our assumes studies begin monitored bart leap reasonably gary poorer
industry southeast cemetery tables epo supportive nervous sooner inc soybeans scientific
expertise applying lufthansa introduction leventhal casting lights carries feared revamping
solar sachs widen training reins moves industrials technologies extent diagnostic narcotics
regularly literally hanover primarily reinsurance pro-life serve specifications fm jumbo
penalty actions l.p. ann keenan princeton despite stuart arise instrumentation classic exposed
violation dishonesty warner-lambert nicaragua infringed fantasy marcus portrait imported jordan
spurring component perestroika undo remic sacrifice veterans arms-control postal relying
homelessness quack voters
可以看到输出的文字几乎没有什么含义,这是因为我们使用的模型是原始的LSTM,没有经过训练,如果我们加载了上篇文字训练过后的模型BetterRnnlm后,效果会有明显提升。
... ...
model = RnnlmGen()
model.load_params('Rnnlm.pkl')
... ...
# 输出:
you place a short part of their nation 's relatively rumored air.far everyone agreed and yet as out next is a market to insist in out of wohlstetteris violates that on why it took the max for one. a caution on a chiefs of substantial chips three-year firstsouth in the ratio for candidates more investors the tax in medical lawsuits will be something about three to government simultaneous.a new market will n't be surprising these specify must very be exactly like economic houses to manufacturers competitors where he adds some error in not new
二、序列到序列模型
用RNN生成文本,还有一种更通用的用法,称为序列到序列的模型,也就是sequence to sequence。最典型的seq2seq应用就是机器翻译,输入一串用某种语言表示的文字,输出用另一种文字表示的文字。另外典型的应用包括:
自动摘要,它是输入一个长文本,输出一个表达核心含义的短文本;
问答系统:输入一个问题文本,输出一个答案文本;
聊天机器人:输入人类的文本,输出机器的文本;
算法学习:输入一串算法描述,输出计算答案;
图像文字生成:输入图像,这里图像也可以通过CNN等网络表示成一串向量,输出描述图像的文字;
可以看出,seq2seq可以有两个模块构成,一个模块处理输入文本,一个模块生成输出文本,处理输入文本的模块我们称为Encoder,生成文本的模块我们称为Decoder。一般来说,一个seq2seq就是由一个Encoder和一个Decoder合并在一起得到的。以书中的例子,日语翻译成英语为例的话,流程图如下:

下面我们以模拟一个加法的学习来实现这个seq2seq模型。这个模型的输入是一个三位数字以内的加法表达式,如“32+100”,输出是运算的结果,如“132”,编码器对“32+100”这个表达式拆分成“3”,“2”,“+”,“1”,“0”,“0”等几个字符,作为文本输入到编码器,得到隐藏信息,解码器输入隐藏信息以及“1”,“3”,“2”作为标签值,得到输出值,将输出值与“1”,“3”,“2”比较,得到损失,进行反向传播,实现整个训练过程。
不过这里有几个需要注意的地方:因为输入到编码器中的加法表达式长度可能不同,所以需要解决这个问题,最方便的方法是padding,也就是在表达式的前后插入填充字符。如:

总体流程如下图所示。训练的时候采用编码器和解码器训练,实际使用中,把编码器得到的隐藏信息输出到生成器,生成结果。

代码实现是基于之前的LSTM代码基础之上的,其实和LSTM的代码构建有很多类似的地方,编码器基本就是一个普通的LSTM,编码器代码:
class Encoder:def __init__(self, vocab_size, wordvec_size, hidden_size):V,D,H = vocab_size, wordvec_size, hidden_sizern = np.random.randnembed_W = (rn(V,D)/100).astype('f')lstm_Wx = (rn(D,4*H)/np.sqrt(D)).astype('f')lstm_Wh = (rn(H,4*H)/np.sqrt(H)).astype('f')lstm_b = np.zeros(4*H).astype('f')self.embed = TimeEmbedding(embed_W)self.lstm = TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful = False)self.params = self.embed.params + self.lstm.paramsself.grads = self.embed.grads + self.lstm.gradsself.hs = Nonedef forward(self,xs):xs = self.embed.forward(xs)hs = self.lstm.forward(xs)self.hs = hsreturn self.hs[:,-1,:]def backward(self,dh):dhs = np.zeros_like(self.hs)dhs[:,-1,:] = dhdout = self.lstm.backward(dhs)dout = self.embed.backward(dout)return dout
解码器和编码器的区别就在于,解码器还要多输入一个隐藏信息,并且正向传播输出多一个打分步骤。generate是实际生成文本结果的函数,和前面所述生成文本的区别在于,这里是生成加法结果的,所以不用概率性输出,而采用确定性输出,就是用argmax选择得分最高的输出,解码器代码:
class Decoder:def __init__(self, vocab_size, wordvec_size, hidden_size):V,D,H = vocab_size, wordvec_size, hidden_sizern = np.random.randnembed_W = (rn(V,D)/100).astype('f')lstm_Wx = (rn(D,4*H)/np.sqrt(D)).astype('f')lstm_Wh = (rn(H,4*H)/np.sqrt(H)).astype('f')lstm_b = np.zeros(4*H).astype('f')affine_W = (rn(H,V)/np.sqrt(H)).astype('f')affine_b = np.zeros(V).astype('f')self.embed = TimeEmbedding(embed_W)self.lstm = TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful = True)self.affine = TimeAffine(affine_W, affine_b)self.params, self.grads = [], []for layer in (self.embed, self.lstm, self.affine):self.params += layer.paramsself.grads += layer.gradsdef forward(self, xs, h):self.lstm.set_state(h)out = self.embed.forward(xs)out = self.lstm.forward(out)score = self.affine.forward(out)return scoredef backward(self, dscore):dout = self.affine.backward(dscore)dout = self.lstm.backward(dout)dout = self.embed.backward(dout)dh = self.lstm.dhreturn dhdef generate(self, h, start_id, sample_size):sampled = []sample_id = start_idself.lstm.set_state(h)for _ in range(sample_size):x = np.array(sample_id).reshape((1,1))out = self.embed.forward(x)out = self.lstm.forward(out)score = self.affine.forward(out)sample_id = np.argmax(score.flatten())sampled.append(int(sample_id))return sampled
基于上述编码器和解码器,构建seq2seq模型:
class Seq2seq(BaseModel):def __init__(self, vocab_size, wordvec_size, hidden_size):V,D,H = vocab_size, wordvec_size, hidden_sizeself.encoder = Encoder(V,D,H)self.decoder = Decoder(V,D,H)self.softmax = TimeSoftmaxWithLoss()self.params = self.encoder.params + self.decoder.paramsself.grads = self.encoder.grads + self.decoder.gradsdef forward(self, xs, ts):# 样本从开始到倒数第二个,标签从第1个开始到最后一个decoder_xs, decoder_ts = ts[:,:-1], ts[:,1:]h = self.encoder.forward(xs)score = self.decoder.forward(decoder_xs, h)loss = self.softmax.forward(score, decoder_ts)return lossdef backward(self, dout=1):dout = self.softmax.backward(dout)dh = self.decoder.backward(dout)dout = self.encoder.backward(dh)return doutdef generate(self, xs, start_id, sample_size):h = self.encoder.forward(xs)sampled = self.decoder.generate(h, start_id, sample_size)return sampled
采用该模型训练25个epoch后,预测精确度约在11%左右。一个改进办法是,将输入的表达式反转:

另一个改进方法是Peeky,它的特点就是把编码器传过来的隐藏信息,都输入到解码器的每个节点中,而之前只有解码器的第一个节点接收编码器传过来的隐藏信息。

训练代码如下(完整代码可以参考书的附带源代码):
# 读入数据集
(x_train,t_train),(x_test,t_test) = load_data('addition.txt', seed=1984)
char_to_id, id_to_char = get_vocab()
x_train, x_test = x_train[:,::-1], x_test[:,::-1] # 反转# 设定超参数
vocab_size = len(char_to_id)
wordvec_size = 16
hidden_size = 128
batch_size = 128
max_epoch = 25
max_grad = 5.0# 生成模型/优化器/训练器
model = PeekySeq2seq(vocab_size, wordvec_size, hidden_size)
optimizer = Adam()
trainer = Trainer(model, optimizer)acc_list = []
for epoch in range(max_epoch):trainer.fit(x_train, t_train, max_epoch=1, batch_size=batch_size, max_grad=max_grad)correct_num = 0for i in range(len(x_test)):question, correct = x_test[[i]], t_test[[i]]verbose = i < 10correct_num += eval_seq2seq(model, question, correct, id_to_char, verbose)acc = float(correct_num)/len(x_test)acc_list.append(acc)print('val acc %.3f%%' % (acc*100))acc_list3 = acc_list
plt.plot([i for i in range(max_epoch)], [acc*100 for acc in acc_list3], label='peeky+reverse', c='r',linestyle='--',marker='o')
plt.plot([i for i in range(max_epoch)], [acc*100 for acc in acc_list2], label='reverse', c='y',linestyle='-.',marker='>')
plt.plot([i for i in range(max_epoch)], [acc*100 for acc in acc_list1], label='original', c='g',linestyle=':',marker='*')
plt.xlabel('iterations')
plt.ylabel('accuracy')
plt.legend()
plt.show()
这里,我稍微修改了一下书中代码,我把三个精度图放在一起比较了。acc_list1代表原始seq2seq模型的训练精度,acc_list2代表输入反转后的模型训练精度,acc_list3代码输入反转并且加入Peeky后的模型训练精度。

可以看到,原始的seq2seq模型在训练25个epoch后,精度大约11%,反转输入后,训练精度大概55%,再加入Peeky后,精度已经非常接近100%了,一般可以达到96%~98%之间。
相关文章:
自然语言处理入门6——RNN生成文本
一、文本生成 我们在前面的文章中介绍了LSTM,根据输入时序数据可以输出下一个可能性最高的数据,如果应用在文字上,就是根据输入的文字,可以预测下一个可能性最高的文字。利用这个特点,我们可以用LSTM来生成文本。输入…...
$R^n$超平面约束下的向量列
原向量: x → \overset{\rightarrow}{x} x→ 与 x → \overset{\rightarrow}{x} x→法向相同的法向量(与 x → \overset{\rightarrow}{x} x→同向) ( x → ⋅ n → ∣ n → ∣ 2 ) n → (\frac{\overset{\rightarrow}x\cdot\overset{\righta…...
FPGA_DDR错误总结
1otp 31-67 解决 端口没连接 必须赋值; 2.PLACE 30-58 TERM PLINITCALIBZ这里有问题 在顶层输出但是没有管脚约束报错 3.ERROR: [Place 30-675] 这是时钟不匹配IBUF不在同一个时钟域,时钟不在同一个时钟域里,推荐的不建议修改 问题 原本…...
k8s之Service类型详解
1.ClusterIP 类型 2.NodePort 类型 3.LoadBalancer 类型 4.ExternalName 类型 类型为 ExternalName 的 Service 将 Service 映射到 DNS 名称,而不是典型的选择算符, 例如 my-service 或者 cassandra。你可以使用 spec.externalName 参数指定这些服务…...
NOIP2011提高组.玛雅游戏
目录 题目算法标签: 模拟, 搜索, d f s dfs dfs, 剪枝优化思路*详细注释版代码精简注释版代码 题目 185. 玛雅游戏 算法标签: 模拟, 搜索, d f s dfs dfs, 剪枝优化 思路 可行性剪枝 如果某个颜色的格子数量少于 3 3 3一定无解因为要求字典序最小, 因此当一个格子左边有…...
网络安全应急响应之文件痕迹排查:从犯罪现场到数字狩猎的进化论
凌晨3点,某金融企业的服务器突然告警,核心数据库出现未知进程访问。安全团队紧急介入时,攻击者已抹去日志痕迹。在这场与黑客的时间赛跑中,文件痕迹排查成为破局关键。本文将带您深入数字取证的"案发现场",揭…...
移动端六大语言速记:第11部分 - 内存管理
移动端六大语言速记:第11部分 - 内存管理 本文将对比Java、Kotlin、Flutter(Dart)、Python、ArkTS和Swift这六种移动端开发语言在内存管理方面的特性,帮助开发者理解和掌握各语言的内存管理机制。 11. 内存管理 11.1 垃圾回收机制对比 各语言垃圾回收机制的主要特点对比:…...
基于ssm框架的校园代购服务订单管理系统【附源码】
1、系统框架 1.1、项目所用到技术: javaee项目 Spring,springMVC,mybatis,mvc,vue,maven项目。 1.2、项目用到的环境: 数据库 :mysql5.X、mysql8.X都可以jdk1.8tomcat8 及以上开发…...
lib-zo,C语言另一个协程库,函数列表
lib-zo,C语言另一个协程库,函数列表 支持开启/禁用指定fd是否开启协程切换 /* 主动设置fd支持协程切换 */ void zcoroutine_enable_fd(int fd);/* 主动设置fd不支持协程切换 */ void zcoroutine_disable_fd(int fd);函数列表 #pragma once#ifndef ___ZC_LIB_INCLUDE_COROUTI…...
【10】数据结构的矩阵与广义表篇章
目录标题 二维以上矩阵矩阵存储方式行序优先存储列序优先存储 特殊矩阵对称矩阵稀疏矩阵三元组方式存储稀疏矩阵的实现三元组初始化稀疏矩阵的初始化稀疏矩阵的创建展示当前稀疏矩阵稀疏矩阵的转置 三元组稀疏矩阵的调试与总代码十字链表方式存储稀疏矩阵的实现十字链表数据标签…...
本地项目HTTPS访问问题解决方案
本地项目无法通过 HTTPS 访问的原因通常是默认配置未启用 HTTPS 或缺少有效的 SSL 证书。以下是详细解释和解决方案: 原因分析 默认开发服务器仅支持 HTTP 大多数本地开发工具(如 Vite、Webpack、React 等)默认启动的是 HTTP 服务器ÿ…...
猜猜乐游戏(python)
import randomprint(**30) print(欢迎进入娱乐城) print(**30)username input(输入用户名:) cs 0answer input( 是否加入"猜猜乐"游戏(yes/no)? )if answer yes:while True:num int(input(%s! 当前你的金币数为%d! 请充值(100¥30币&…...
spring boot 2.7 集成 Swagger 3.0 API文档工具
背景 Swagger 3.0 是 OpenAPI 规范体系下的重要版本,其前身是 Swagger 2.0。在 Swagger 2.0 之后,该规范正式更名为 OpenAPI 规范,并基于新的版本体系进行迭代,因此 Swagger 3.0 实际对应 OpenAPI 3.0 版本。这一版本着重强化了对…...
Dinky 和 Flink CDC 在实时整库同步的探索之路
摘要:本文整理自 Dinky 社区负责人,Apache Flink CDC contributor 亓文凯老师在 Flink Forward Asia 2024 数据集成(二)专场中的分享。主要讲述 Dinky 的整库同步技术方案演变至 Flink CDC Yaml 作业的探索历程,并深入…...
视频融合平台EasyCVR搭建智慧粮仓系统:为粮仓管理赋能新优势
一、项目背景 当前粮仓管理大多仍处于原始人力监管或初步信息化监管阶段。部分地区虽采用了简单的传感监测设备,仍需大量人力的配合,这不仅难以全面监控粮仓复杂的环境,还容易出现管理 “盲区”,无法实现精细化的管理。而一套先进…...
3D Gaussian Splatting as MCMC 与gsplat中的应用实现
3D高斯泼溅(3D Gaussian splatting)自2023年提出以后,相关研究paper井喷式增长,尽管出现了许多改进版本,但依旧面临着诸多挑战,例如实现照片级真实感、应对高存储需求,而 “悬浮的高斯核” 问题就是其中之一。浮动高斯核通常由输入图像中的曝光或颜色不一致引发,也可能…...
C++初阶-C++的讲解1
目录 1.缺省(sheng)参数 2.函数重载 3.引用 3.1引用的概念和定义 3.2引用的特性 3.3引用的使用 3.4const引用 3.5.指针和引用的关系 4.nullptr 5.总结 1.缺省(sheng)参数 (1)缺省参数是声明或定义是为函数的参数指定一个缺省值。在调用该函数是…...
STM32_USB
概述 本文是使用HAL库的USB驱动 因为官方cubeMX生成的hal库做组合设备时过于繁琐 所以这里使用某大神的插件,可以集成在cubeMX里自动生成组合设备 有小bug会覆盖生成文件里自己写的内容,所以生成一次后注意保存 插件安装 下载地址 https://github.com/alambe94/I-CUBE-USBD-Com…...
STM32 的编程方式总结
🧱 按照“是否可独立工作”来分: 库/方式是否可独立使用是否依赖其他库说明寄存器裸写✅ 是❌ 无完全自主控制,无库依赖标准库(StdPeriph)✅ 是❌ 只依赖 CMSIS自成体系(F1专属),只…...
MFC工具栏CToolBar从专家到小白
CToolBar m_wndTool; //创建控件 m_wndTool.CreateEx(this, TBSTYLE_FLAT|TBSTYLE_NOPREFIX, WS_CHILD | WS_VISIBLE | CBRS_FLYBY | CBRS_TOP | CBRS_SIZE_DYNAMIC); //加载工具栏资源 m_wndTool.LoadToolBar(IDR_TOOL_LOAD) //在.rc中定义:IDR_TOOL_LOAD BITMAP …...
vllm作为服务启动,无需额外编写sh文件,一步到位【Ubuntu】
看到网上有的vllm写法,需要额外建立一个.sh文件,还是不够简捷。这里提供一种直接编写service文件一步到位的写法: vi /etc/systemd/system/vllm.service [Unit] DescriptionvLLM Service Afternetwork.target[Service] Typesimple Userroot…...
大厂机考——各算法与数据结构详解
目录及其索引 哈希双指针滑动窗口子串普通数组矩阵链表二叉树图论回溯二分查找栈堆贪心算法动态规划多维动态规划学科领域与联系总结 哈希 学科领域:计算机科学、密码学、数据结构 定义:通过哈希函数将任意长度的输入映射为固定长度…...
hive中的特殊字符
1、UTF-8 编码的非断空格(对应 Unicode 码点为 \u00A0) 这种空格在网页中常见(HTML 中的 ),用于阻止文本在换行时被分割。由于它不是标准空格(ASCII 20),使用TRIM 函数无法直接…...
10:00开始面试,10:08就出来了,问的问题有点变态。。。
从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到8月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%…...
基于ueditor编辑器的功能开发之给编辑器图片增加水印功能
用户需求,双击编辑器中的图片的时候,出现弹框,用户可以选择水印缩放倍数、距离以及水印所放置的方位(当然有很多水印插件,位置大小透明度用户都能够自定义,但是用户需求如此,就自己写了…...
fabric test-network启动
//按照这个来放,免得出错 mkdir -p $GOPATH/src/github.com/hyperledger cd $GOPATH/src/github.com/hyperledger # 获取fabric-samples源码 git clone https://github.com/hyperledger/fabric-samples.git export FABRIC_LOGGING_SPECdebug cd fabric-samples # …...
【CSS基础】- 02(emmet语法、复合选择器、显示模式、背景标签)
css第二天 一、emmet语法 1、简介 Emmet语法的前身是Zen coding,它使用缩写,来提高html/css的编写速度, Vscode内部已经集成该语法。 快速生成HTML结构语法 快速生成CSS样式语法 2、快速生成HTML结构语法 生成标签 直接输入标签名 按tab键即可 比如 div 然后tab…...
centos7系统搭建nagios监控
~监控节点安装 1. 系统准备 1.1 更新系统并安装依赖 sudo yum install -y httpd php php-cli gcc glibc glibc-common gd gd-devel make net-snmp openssl-devel wget unzip sudo yum install -y epel-release # 安装 EPEL 仓库 sudo yum install -y automake autoconf lib…...
docker的几种网络模式
Bridge(桥接)模式 默认网络模式:Docker的默认网络模式是Bridge模式。工作原理:Docker在宿主机上创建一个虚拟的桥接网络,每个容器在启动时会从这个桥接网络中分配一个IP地址。容器之间可以通过这个桥接网络进行通信。…...
Android 中Intent 相关问题
在回答 Intent 问题时,清晰区分其 定义、类型 和 应用场景。以下是的回答策略: 一、Intent 的核心定义 Intent 是 Android 系统中的 消息传递对象,主要用于三大场景: 2. 隐式 Intent(Implicit Intent) 三、…...
