当前位置: 首页 > news >正文

【NLP 16、实践 ③ 找出特定字符在字符串中的位置】

看着父亲苍老的白发和渐渐老态的面容

希望时间再慢一些

                                                —— 24.12.19

一、定义模型

1.初始化模型

① 初始化父类

super(TorchModel, self).__init__(): 调用父类 nn.Module 初始化方法,确保模型能够正确初始化。

② 创建嵌入层

self.embedding = nn.Embedding(len(vocab), vector_dim): 创建一个嵌入层,将词汇表中的每个词映射到一个 vector_dim 维度的向量。

③ 创建RNN层

self.rnn = nn.RNN(vector_dim, vector_dim, batch_first=True): 创建一个 RNN 层输入和输出的特征维度均为 vector_dim,并且输入数据的第一维是批量大小。

④ 创建线性分类层

self.classify = nn.Linear(vector_dim, sentence_length + 1): 创建一个线性层,将 RNN 输出的特征向量映射到 sentence_length + 1 个分类标签。+1 是因为可能有某个词不存在的情况,此时的真实标签被设为 sentence_length。

⑤ 定义损失函数

self.loss = nn.functional.cross_entropy: 定义交叉熵损失函数,用于计算模型预测值与真实标签之间的差异。

class TorchModel(nn.Module):def __init__(self, vector_dim, sentence_length, vocab):super(TorchModel, self).__init__()self.embedding = nn.Embedding(len(vocab), vector_dim)  #embedding层# self.pool = nn.AvgPool1d(sentence_length)   #池化层#可以自行尝试切换使用rnnself.rnn = nn.RNN(vector_dim, vector_dim, batch_first=True)# +1的原因是可能出现a不存在的情况,那时的真实label在构造数据时设为了sentence_lengthself.classify = nn.Linear(vector_dim, sentence_length + 1)self.loss = nn.functional.cross_entropy

2、前向传播定义

① 输入嵌入

x = self.embedding(x):将输入 x 通过嵌入层转换为向量表示

② RNN处理

rnn_out, hidden = self.rnn(x):将嵌入后的向量输入到RNN层,得到每个时间步的输出 rnn_out 和最后一个时间步的隐藏状态 hidden。

③ 提取特征

x = rnn_out[:, -1, :]:从RNN的输出中提取最后一个时间步(最后一维)的特征向量。

④ 分类

y_pred = self.classify(x):将提取的特征向量通过线性层进行分类,得到预测值 y_pred。

⑤ 损失计算

如果提供了真实标签 y,则计算并返回损失值;否则,返回预测值。

    #当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):x = self.embedding(x)#使用pooling的情况,先使用pooling池化层会丢失模型语句的时序信息# x = x.transpose(1, 2)# x = self.pool(x)# x = x.squeeze()#使用rnn的情况# rnn_out:每个字对应的向量  hidden:最后一个输出的隐含层对应的向量rnn_out, hidden = self.rnn(x)# 中间维度改变,变成(batch_size数据样本数量, sentence_length文本长度, vector_dim向量维度)x = rnn_out[:, -1, :]  #或者写hidden.squeeze()也是可以的,因为rnn的hidden就是最后一个位置的输出#接线性层做分类y_pred = self.classify(x)if y is not None:return self.loss(y_pred, y)   #预测值和真实值计算损失else:return y_pred     

二、数据

1.建立词表

① 定义字符集

定义一个字符集 chars,包含字母 'a' 到 'k'。

② 定义字典

初始化一个字典 vocab,其中键为 'pad',值为 0。

③ 遍历字符集

使用 enumerate 遍历字符集 chars,为每个字符分配一个唯一的序号,从 1 开始。

④ 定义unk键

添加一个特殊的键 'unk',其值为当前字典的长度(即 26)。

⑤ 返回词汇表

将生成的词汇表返回

#字符集随便挑了一些字,实际上还可以扩充
#为每个字生成一个标号
#{"a":1, "b":2, "c":3...}
#abc -> [1,2,3]
def build_vocab():chars = "abcdefghijk"  #字符集vocab = {"pad":0}for index, char in enumerate(chars):vocab[char] = index+1   #每个字对应一个序号vocab['unk'] = len(vocab) #26return vocab

2.随机生成样本

① 采样

random.sample(list(vocab.keys()), sentence_length):从词汇表 vocab 的键中随机选择 sentence_length 个不同的字符,生成列表 x

② 标签生成

index('a'):检查列表 x 中是否包含字符 "a",如果包含,记录 "a" 在列表中的索引位置为 y,否则,设置 y 为 sentence_length。

③ 转换

将列表 x 中的每个字符转换为其在词汇表中的序号,如果字符不在词汇表中,则使用 unk 的序号

④ 返回结果

返回转换后的列表 x 和标签 y

#随机生成一个样本
def build_sample(vocab, sentence_length):#注意这里用sample,是不放回的采样,每个字母不会重复出现,但是要求字符串长度要小于词表长度x = random.sample(list(vocab.keys()), sentence_length)#指定哪些字出现时为正样本if "a" in x:y = x.index("a")else:y = sentence_lengthx = [vocab.get(word, vocab['unk']) for word in x]   #将字转换成序号,为了做embeddingreturn x, y

3.建立数据集

① 初始化数据集

创建两个空列表 dataset_x 和 dataset_y,用于存储生成的样本和对应的标签

② 生成样本

使用 for 循环,循环次数为 sample_length,即需要生成的样本数量。在每次循环中,调用 build_sample 函数生成一个样本 (x, y),其中 x 是输入数据,y 是标签

③ 存储样本

将生成的样本 x 添加到 dataset_x 列表中。将生成的标签 y 添加到 dataset_y 列表

④ 返回数据集

将 dataset_x 和 dataset_y 转换为 torch.LongTensor 类型,以便在 PyTorch 中使用。返回转换后的数据集。

#建立数据集
#输入需要的样本数量。需要多少生成多少
def build_dataset(sample_length, vocab, sentence_length):dataset_x = []dataset_y = []for i in range(sample_length):x, y = build_sample(vocab, sentence_length)dataset_x.append(x)dataset_y.append(y)return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)

三、模型测试、训练、评估

1.建立模型

① 参数:

vocab:词汇表,通常是一个包含所有字符或单词的列表或字典

char_dim:字符的维度,即每个字符在嵌入层中的向量长度

sentence_length:句子的最大长度

② 过程:

使用传入的参数 char_dim、sentence_length 和 vocab 实例化一个 TorchModel 对象并返回

#建立模型
def build_model(vocab, char_dim, sentence_length):model = TorchModel(char_dim, sentence_length, vocab)return model

2.测试模型

① 设置模型为评估模式

model.eval():将模型设置为评估模式禁用 dropout 训练时的行为

② 生成测试数据集

调用 build_dataset 函数生成 200 个用于测试的样本

③ 打印样本数量

输出当前测试集中样本的数量

④ 模型预测

使用 torch.no_grad() 禁用梯度计算,提高推理速度并减少内存消耗,然后对生成的测试数据进行预测

⑤ 计算准确率

遍历预测结果和真实标签,统计正确和错误的预测数量,并计算准确率

⑥ 输出结果

打印正确预测的数量和准确率,并返回准确率

#测试代码
#用来测试每轮模型的准确率
def evaluate(model, vocab, sample_length):model.eval()x, y = build_dataset(200, vocab, sample_length)   #建立200个用于测试的样本print("本次预测集中共有%d个样本"%(len(y)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)      #模型预测for y_p, y_t in zip(y_pred, y):  #与真实标签进行对比if int(torch.argmax(y_p)) == int(y_t):correct += 1else:wrong += 1print("正确预测个数:%d, 正确率:%f"%(correct, correct/(correct+wrong)))return correct/(correct+wrong)

3.模型训练

① 配置参数

设置训练轮数epoch_num批量大小batch_size训练样本数train_sample字符维度char_dim句子长度sentence_length学习率learning_rate

② 建立字表

调用 build_vocab 函数生成字符到索引的映射。

③ 建立模型

调用 build_model 函数创建模型。

④ 选择优化器

torch.optim.Adam(model.parameters(), lr=learning_rate):使用 Adam 优化器

⑤ 训练过程

model.train():模型进入训练模式。每个 epoch 中,按批量生成训练数据,计算损失,反向传播并更新权重。记录每个 epoch 的平均损失

⑥ 评估模型

每个 epoch 结束后,调用 evaluate 函数评估模型性能。

⑦ 记录日志

记录每个 epoch 的准确率和平均损失

⑧ 绘制图表

绘制准确率和损失的变化曲线

⑨ 保存模型和词表

保存模型参数和词表

def main():#配置参数epoch_num = 20        #训练轮数batch_size = 40       #每次训练样本个数train_sample = 1000    #每轮训练总共训练的样本总数char_dim = 30         #每个字的维度sentence_length = 10   #样本文本长度learning_rate = 0.001 #学习率# 建立字表vocab = build_vocab()# 建立模型model = build_model(vocab, char_dim, sentence_length)# 选择优化器optim = torch.optim.Adam(model.parameters(), lr=learning_rate)log = []# 训练过程for epoch in range(epoch_num):model.train()watch_loss = []for batch in range(int(train_sample / batch_size)):x, y = build_dataset(batch_size, vocab, sentence_length) #构造一组训练样本optim.zero_grad()    #梯度归零loss = model(x, y)   #计算lossloss.backward()      #计算梯度optim.step()         #更新权重watch_loss.append(loss.item())print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model, vocab, sentence_length)   #测试本轮模型结果log.append([acc, np.mean(watch_loss)])#画图plt.plot(range(len(log)), [l[0] for l in log], label="acc")  #画acc曲线plt.plot(range(len(log)), [l[1] for l in log], label="loss")  #画loss曲线plt.legend()plt.show()#保存模型torch.save(model.state_dict(), "model.pth")# 保存词表writer = open("vocab.json", "w", encoding="utf8")writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))writer.close()return

四、模型预测

1.训练并保存模型

if __name__ == "__main__":main()

 


2.预测数据

用保存的训练好的模型进行预测

① 初始化参数

设置每个字的维度 char_dim 和样本文本长度 sentence_length

② 加载字符表

从指定路径加载字符表 vocab

③ 建立模型

调用 build_model 函数构建模型

④ 加载模型权重

从指定路径加载预训练的模型权重

⑤ 序列化输入

将输入字符串转换为模型所需的输入格式

⑥ 模型预测

将输入数据传递给模型进行预测

⑦ 输出结果

打印每个输入字符串的预测类别和概率值

#使用训练好的模型做预测
def predict(model_path, vocab_path, input_strings):char_dim = 30  # 每个字的维度sentence_length = 10  # 样本文本长度vocab = json.load(open(vocab_path, "r", encoding="utf8")) #加载字符表model = build_model(vocab, char_dim, sentence_length)     #建立模型model.load_state_dict(torch.load(model_path,weights_only=True))             #加载训练好的权重x = []for input_string in input_strings:x.append([vocab[char] for char in input_string])  #将输入序列化model.eval()   #测试模式with torch.no_grad():  #不计算梯度result = model.forward(torch.LongTensor(x))  #模型预测for i, input_string in enumerate(input_strings):print("输入:%s, 预测类别:%s, 概率值:%s" % (input_string, torch.argmax(result[i]), result[i])) #打印结果

3.调用函数进行预测

if __name__ == "__main__":# main()test_strings = ["kijabcdefh", "gijkbcdeaf", "gkijadfbec", "kijhdefacb"]predict("model.pth", "vocab.json", test_strings)

相关文章:

【NLP 16、实践 ③ 找出特定字符在字符串中的位置】

看着父亲苍老的白发和渐渐老态的面容 希望时间再慢一些 —— 24.12.19 一、定义模型 1.初始化模型 ① 初始化父类 super(TorchModel, self).__init__(): 调用父类 nn.Module 的初始化方法,确保模型能够正确初始化。 ② 创建嵌入层 self.embedding n…...

费解的开关(bfs + 哈希表 or 递推)

题目描述: 25盏灯排成一个5x5的方形。每一个灯都有一个开关,游戏者可以改变它的状态。每一步,游戏者可以改变某一个灯的状态。游戏者改变一个灯的状态会产生连锁反应:和这个灯上下左右相邻的灯也要相应地改变其状态。 我们用数字“1”表示一盏开着的灯,用数字“0”表示关…...

C语言——实现求出最大值

问题描述&#xff1a;利用C语言自定义函数求出一维数组里边最大的数字 //利用函数找最大数#include<stdio.h>int search(int s[9]) //查找函数 {int i , max s[0] , max_xia 0;for(i0;i<9;i){if(s[i] > max){max_xia i;max s[max_xia];}}return max; } in…...

基于微信小程序的短视频系统(SpringBoot)+文档

&#x1f497;博主介绍&#x1f497;&#xff1a;✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示&#xff1a;文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…...

Flutter 中 Sliver 的各种装饰器介绍与使用

在 Flutter 中&#xff0c;Sliver 是一种可以在滚动视图中实现自定义效果的组件。Sliver 组件可以根据滚动位置动态改变其外观和行为。本文将介绍几种常用的 Sliver 装饰器及其使用方法。 1. SliverAppBar SliverAppBar 是一个可以随着滚动而变化的应用栏。它可以在用户向下滚…...

电感的基本概念

电感的定义&#xff1a; 电感一般是由导线绕成空芯线圈或带铁芯的线圈而制成。 当线圈中有电流通过时&#xff0c;线圈周围就会产生磁场&#xff0c;当线圈中流过的是直流电流时&#xff0c;线圆周围就会产生固定的磁场&#xff0c;线圈产生的物理现象就是电磁铁&#xff0c;当…...

linux基于systemd自启守护进程 systemctl自定义服务傻瓜式教程

系统服务 书接上文: linux自启任务详解 演示系统:ubuntu 20.04 开发部署项目的时候常常有这样的场景: 业务功能以后台服务的形式提供,部署完成后可以随着系统的重启而自动启动;服务异常挂掉后可以再次拉起 这个功能在ubuntu系统中通常由systemd提供 如果仅仅需要达成上述的场…...

HTTP协议和接口测试详解

介绍接口测试前我们先来介绍一下HTTP协议&#xff0c;为什么先要介绍HTTP协议呢因为因为我们做接口测试其实就是用测试工具&#xff08;postman,fiddler,jmeter等等&#xff09;或代码来模拟用户使用软件的场景&#xff0c;在我们模拟的时候不像平时功能测试时我们有已经开发完…...

vue3【实战】定义全局方法(两种方案)

以全局方法 calculate 为例 src/utils/calculate.ts export default {sum: function (a: number, b: number) {return a b} }方案1&#xff1a; 依赖注入 provide inject main.ts import calculate from ./utils/calculateapp.provide(calculate, calculate)页面中 // esl…...

基于JavaScript的DBUtils增删改查操作实验

1、实验目的 学习和掌握数据库连接池的配置与管理。使用DBUtils进行增删改查操作。按照步骤&#xff0c;掌握并实现使用DBUtils实现增删改查的全过程。 2、实验所用方法 上机实践 3、实验步骤及截图 创建一个数据库表&#xff0c;使用下面sql语句创建数据库表并插入数据&#x…...

初学stm32 --- 系统时钟配置

众所周知&#xff0c;时钟系统是 CPU 的脉搏&#xff0c;就像人的心跳一样。所以时钟系统的重要性就不言而喻了。 STM32 的时钟系统比较复杂&#xff0c;不像简单的 51 单片机一个系统时钟就可以解决一切。于是有人要问&#xff0c;采用一个系统时钟不是很简单吗&#xff1f;为…...

实现星星评分系统

使用HTML、CSS和JavaScript实现星星评分系统 本文将详细讲解如何使用 HTML、CSS 和 JavaScript 实现一个简单的星星评分系统。用户可以通过点击星星进行评分&#xff0c;并且还能够看到星星的悬浮效果和已选中状态。 1. HTML 结构 我们首先在 HTML 中定义了一个星星评分的结…...

数据库建模工具 PDManer

数据库建模工具 PDManer 1.PDManer简介2.PDManer使用 1.PDManer简介 PDManer&#xff08;元数建模&#xff09;是一款功能强大且易于使用的开源数据库建模工具。它不仅支持多种常见数据库&#xff0c;如MySQL、PostgreSQL、Oracle、SQL Server等&#xff0c;还特别支持国产数据…...

后台运维操作建议

文章目录 1.版本升级2.配置发布3.数据库/脚本操作4.发布依赖确认5.发布规范6.服务下线参考文献 1.版本升级 版本升级是软件维护和演进中的关键环节&#xff0c;但它可能带来一系列问题。这些问题涉及兼容性、功能、性能、安全性等方面。 【强制】版本管理&#xff1a;使用版本…...

NX二次开发调用内部函数设置对象穿透显示DSS_ATTR_set_show_through

获取动态库libdisp.dll的路径 void TcharToChar(const TCHAR* tchar, char* _char) {int iLength; #if UNICODE//获取字节长度 iLength = WideCharToMultiByte(CP_ACP, 0, tchar, -1, NULL, 0, NULL, NULL);//将tchar值赋给_char WideCharToMultiByte(CP_ACP, 0, tchar, …...

ubuntu16.04ros-用海龟机器人仿真循线系统

下载安装sudo apt-get install ros-kinetic-turtlebot ros-kinetic-turtlebot-apps ros-kinetic-turtlebot-interactions ros-kinetic-turtlebot-simulator ros-kinetic-kobuki-ftdi sudo apt-get install ros-kinetic-rocon-*echo "source /opt/ros/kinetic/setup.bash…...

解决Ubuntu 20.04上编译OpenCV 3.2时遇到的stdlib.h缺失错误

解决Ubuntu 20.04上编译OpenCV 3.2时遇到的stdlib.h缺失错误 您在 Ubuntu 20.04 上编译 OpenCV 3.2 时遇到的错误与 C 标准库的头文件配置问题有关。错误消息指出系统无法找到 <stdlib.h>&#xff0c;这通常与预编译头文件的处理、GCC 版本或者头文件搜索路径有关。下面…...

HTML综合案例

为了前端考试。 效果图&#xff1a; HTML代码&#xff1a; <!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><…...

TanStack——为现代前端开发提供高性能和灵活的工具

TanStack 是一个由社区主导的开源项目集合&#xff0c;专注于为现代前端开发提供高性能和灵活的工具。它包括多个流行的 JavaScript 和 TypeScript 库&#xff0c;主要用于处理表格、查询、虚拟化、状态管理等功能。 文章目录 1、TanStack Query&#xff1a;1.1 useQuery&#…...

Java爬虫️ 使用Jsoup库进行API请求有什么优势?

在Java的世界里&#xff0c;Jsoup库以其强大的HTML解析能力而闻名。它不仅仅是一个简单的解析器&#xff0c;更是一个功能齐全的工具箱&#xff0c;为开发者提供了从网页抓取到数据处理的一站式解决方案。本文将深入探讨使用Jsoup库进行API请求的优势&#xff0c;并提供代码示例…...

React源码02 - 基础知识 React API 一览

1. JSX到JavaScript的转换 <div id"div" key"key"><span>1</span><span>2</span> </div>React.createElement("div", // 大写开头会当做原生dom标签的字符串&#xff0c;而组件使用大写开头时&#xff0c;这…...

COMSOL with Matlab

文章目录 基本介绍COMSOL with MatlabCOMSOL主Matlab辅Matlab为主Comsol为辅 操作步骤常用指令mphopenmphgeommghmeshmphmeshstatsmphnavigatormphplot常用指令mphsavemphlaunchModelUtil.clear 实例教学自动另存新档**把语法套用到边界条件**把语法套用到另存新档 函数及其微分…...

【报表查询】.NET开源ORM框架 SqlSugar 系列

文章目录 前言实践一、按月统计没有为0实践二、 统计某月每天的数量实践三、对象和表随意JOIN实践四、 List<int>和表随意JOIN实践五、大数据处理实践六、每10分钟统计Count实践七、 每个ID都要对应时间总结 前言 在我们实际开发场景中&#xff0c;报表是最常见的功能&a…...

PostgreSQL数据库访问限制详解

pg_hba.conf 文件是 PostgreSQL 数据库系统中非常重要的一个配置文件&#xff0c;它用于定义哪些用户&#xff08;或客户端&#xff09;可以连接到 PostgreSQL 数据库服务器&#xff0c;以及他们可以使用哪些认证方法进行连接。 pg_hba.conf 的名称来源于 "Host-Based Aut…...

【test linux】创建一个ext4类型的文件系统

创建一个ext4类型的文件系统 dd 是一个非常强大的命令行工具&#xff0c;用于在Unix/Linux系统中进行低级别的数据复制和转换。这条命令的具体参数含义如下&#xff1a; if/dev/zero&#xff1a;指定输入文件&#xff08;input file&#xff09;为 /dev/zero&#xff0c;这是一…...

如何在繁忙的生活中找到自己的节奏?

目录 一、理解生活节奏的重要性 二、分析当前生活节奏 1. 时间分配 2. 心理状态 3. 身体状况 4. 生活习惯 1. 快慢适中 2. 张弛结合 3. 与目标相符 三、掌握调整生活节奏的策略 1. 设定优先级 2. 合理规划时间 3. 学会拒绝与取舍 4. 保持健康的生活方式 5. 留出…...

AI-PR曲线

PR曲线 人工智能里面的一个小概念。 2.3 性能度量&#xff08;查全率&#xff0c;查准率&#xff0c;F1&#xff0c;PR曲线与ROC曲线&#xff09; 预测出来的是一个概率&#xff0c;不能根据概率来说它是正类还是负类&#xff0c;要有一个阈值。 查准率&#xff08;Precision&…...

Guava 提供了集合操作 `List`、`Set` 和 `Map` 三个工具类

入门示例 guava 最佳实践 学习指南 以下是使用Google Guava库中的工具方法来创建和操作List、Set、Map集合的一些示例&#xff1a; List相关操作 创建List 使用Lists.newArrayList()创建一个新的可变ArrayList实例。List<Integer> list Lists.newArrayList(1, 2, 3);/…...

深入解析 Elasticsearch 集群配置文件参数

在自建 Elasticsearch 集群时&#xff0c;我们需要通过 elasticsearch.yml 文件对节点角色、网络设置、集群发现和数据存储路径等进行灵活配置。配置项的合理设置对集群的稳定性、性能与扩展性影响深远。本文将以一个示例配置文件为蓝本&#xff0c;逐条解析各参数的含义与建议…...

WebMvcConfigurer和WebMvcConfigurationSupport(MVC配置)

一:基本介绍 WebMvcConfigurer是接口&#xff0c;用于配置全局的SpringMVC的相关属性&#xff0c;采用JAVABean的方式来代替传统的XML配置文件&#xff0c;提供了跨域设置、静态资源处理器、类型转化器、自定义拦截器、页面跳转等能力。 WebMvcConfigurationSupport是webmvc的…...