当前位置: 首页 > news >正文

240713_昇思学习打卡-Day25-LSTM+CRF序列标注(4)

240713_昇思学习打卡-Day25-LSTM+CRF序列标注(4)

最后一天咯,做第四部分。

BiLSTM+CRF模型

在实现CRF后,我们设计一个双向LSTM+CRF的模型来进行命名实体识别任务的训练。模型结构如下:

nn.Embedding -> nn.LSTM -> nn.Dense -> CRF

其中LSTM提取序列特征,经过Dense层变换获得发射概率矩阵,最后送入CRF层。具体实现如下:

# 定义双向LSTM结合CRF的序列标注模型
class BiLSTM_CRF(nn.Cell):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):"""初始化BiLSTM_CRF模型。参数:vocab_size: 词汇表大小。embedding_dim: 词嵌入维度。hidden_dim: LSTM隐藏层维度。num_tags: 标签种类数量。padding_idx: 填充索引,默认为0。"""super().__init__()# 初始化词嵌入层self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)# 初始化双向LSTM层self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)# 初始化从LSTM输出到标签的全连接层self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')# 初始化条件随机场层self.crf = CRF(num_tags, batch_first=True)def construct(self, inputs, seq_length, tags=None):"""模型的前向传播方法。参数:inputs: 输入序列,形状为(batch_size, seq_length)。seq_length: 序列长度,形状为(batch_size,)。tags: 真实标签,形状为(batch_size, seq_length),可选。返回:crf_outs: CRF层的输出,如果输入了真实标签则为损失值,否则为解码后的标签序列。"""# 通过词嵌入层获取词向量表示embeds = self.embedding(inputs)# 通过双向LSTM层获取序列特征outputs, _ = self.lstm(embeds, seq_length=seq_length)# 通过全连接层转换LSTM输出到标签空间feats = self.hidden2tag(outputs)# 通过CRF层计算损失或解码crf_outs = self.crf(feats, tags, seq_length)return crf_outs

完成模型设计后,我们生成两句例子和对应的标签,并构造词表和标签表。

# 设置词嵌入维度和隐藏层维度
embedding_dim = 16
hidden_dim = 32# 定义训练数据集,每条数据包含一个分词后的句子和相应的实体标签
training_data = [("清 华 大 学 坐 落 于 首 都 北 京".split(),  # 分词后的句子"B I I I O O O O O B I".split()  # 相应的实体标签),("重 庆 是 一 个 魔 幻 城 市".split(),  # 分词后的句子"B I O O O O O O O".split()  # 相应的实体标签)
]# 初始化词典,用于映射词到索引
word_to_idx = {}
# 添加特殊填充词到词典
word_to_idx['<pad>'] = 0
# 遍历训练数据,构建词到索引的映射
for sentence, tags in training_data:for word in sentence:# 如果词不在词典中,则添加到词典if word not in word_to_idx:word_to_idx[word] = len(word_to_idx)# 初始化标签到索引的映射
tag_to_idx = {"B": 0, "I": 1, "O": 2}
len(word_to_idx)

接下来实例化模型,选择优化器并将模型和优化器送入Wrapper。

由于CRF层已经进行了NLLLoss的计算,因此不需要再设置Loss。

# 实例化BiLSTM_CRF模型,传入词汇表大小、词嵌入维度、隐藏层维度以及标签种类数量
model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))# 初始化随机梯度下降优化器,设置学习率为0.01,权重衰减为1e-4
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)
# 使用MindSpore的value_and_grad函数创建一个函数,它会同时计算模型的损失值和梯度
# 第二个参数设置为None表示不保留反向图,第三个参数是优化器的参数列表
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)def train_step(data, seq_length, label):"""训练步骤函数,执行一次前向传播和反向传播更新模型参数。参数:data: 输入数据,形状为(batch_size, seq_length)。seq_length: 序列长度,形状为(batch_size,)。label: 真实标签,形状为(batch_size, seq_length)。返回:loss: 当前批次的损失值。"""# 使用grad_fn计算损失值和梯度loss, grads = grad_fn(data, seq_length, label)# 使用优化器更新模型参数optimizer(grads)# 返回损失值return loss

将生成的数据打包成Batch,按照序列最大长度,对长度不足的序列进行填充,分别返回输入序列、输出标签和序列长度构成的Tensor。

def prepare_sequence(seqs, word_to_idx, tag_to_idx):"""准备序列数据,包括填充和转换成张量。参数:seqs: 一个包含句子和对应标签的元组列表。word_to_idx: 词到索引的映射字典。tag_to_idx: 标签到索引的映射字典。返回:seq_outputs: 填充后的序列数据张量。label_outputs: 填充后的标签数据张量。seq_length: 序列的真实长度列表。"""seq_outputs, label_outputs, seq_length = [], [], []# 获取最长序列长度max_len = max([len(i[0]) for i in seqs])for seq, tag in seqs:# 记录序列的真实长度seq_length.append(len(seq))# 将序列中的词转换为索引idxs = [word_to_idx[w] for w in seq]# 将标签转换为索引labels = [tag_to_idx[t] for t in tag]# 对序列进行填充idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])# 对标签进行填充,用'O'的索引填充labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])# 添加填充后的序列和标签到列表seq_outputs.append(idxs)label_outputs.append(labels)# 将列表转换为MindSpore张量return ms.Tensor(seq_outputs, ms.int64), \ms.Tensor(label_outputs, ms.int64), \ms.Tensor(seq_length, ms.int64)
# 调用prepare_sequence函数处理训练数据,并获取处理后的数据、标签和序列长度
data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)# 打印处理后的数据、标签和序列长度的形状,以确认数据转换是否正确
print(data.shape, label.shape, seq_length.shape)

对模型进行预编译后,训练500个step。

训练流程可视化依赖tqdm库,可使用pip install tqdm命令安装。

from tqdm import tqdm# 定义训练步骤的总数,用于进度条的设置
steps = 500# 使用tqdm创建一个进度条,总进度为steps
with tqdm(total=steps) as t:for i in range(steps):# 执行单步训练,这里假设train_step是一个已定义的训练函数# 参数data为训练数据,seq_length为序列长度,label为标签loss = train_step(data, seq_length, label)# 更新进度条的附带信息,显示当前的损失值t.set_postfix(loss=loss)# 更新进度条,表示完成了一步训练t.update(1)

最后我们来观察训练500个step后的模型效果,首先使用模型预测可能的路径得分以及候选序列。

# 调用模型进行预测或评估,返回得分和历史记录
score, history = model(data, seq_length)# 输出得分,用于查看模型的表现或决策
score

使用后处理函数进行预测得分的后处理。

predict = post_decode(score, history, seq_length)
predict

最后将预测的index序列转换为标签序列,打印输出结果,查看效果。

# 通过索引和标签的映射关系,构建标签到索引的反向映射
idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}def sequence_to_tag(sequences, idx_to_tag):"""将序列中的索引转换为对应的标签。参数:sequences: 一个包含标签索引的序列列表。idx_to_tag: 一个字典,用于将索引映射到对应的标签。返回:一个列表,其中每个元素是输入序列中索引转换为标签后的结果。"""# 初始化一个空列表,用于存储转换后的标签序列outputs = []# 遍历输入的序列列表for seq in sequences:# 对每个序列,将索引转换为标签,并添加到输出列表中outputs.append([idx_to_tag[i] for i in seq])# 返回转换后的标签序列列表return outputs
sequence_to_tag(predict, idx_to_tag)

打卡照片:
在这里插入图片描述

相关文章:

240713_昇思学习打卡-Day25-LSTM+CRF序列标注(4)

240713_昇思学习打卡-Day25-LSTMCRF序列标注&#xff08;4&#xff09; 最后一天咯&#xff0c;做第四部分。 BiLSTMCRF模型 在实现CRF后&#xff0c;我们设计一个双向LSTMCRF的模型来进行命名实体识别任务的训练。模型结构如下&#xff1a; nn.Embedding -> nn.LSTM -&…...

python requests关闭https校验

python requests关闭https校验 import requests# 关闭SSL验证 requests.get(https://***.com, verifyFalse)...

PG大会周五于杭州举办;Pika发布4.0;阿里云MySQL上线Zero-ETL集成能力

重要更新 1. PostgreSQL中国技术大会举行12日&#xff08;周五&#xff09;于杭州举办&#xff0c;是PostgreSQL社区年度的大会&#xff0c;举办地点&#xff1a;杭州君尚云郦酒店&#xff08;杭州市上城区临丁路1188号&#xff09;&#xff0c;感兴趣的可以考虑现场参加 ( [1]…...

虚拟机vmware网络设置

一、网络分类 打开vmware workstation网络编辑器可以知道有三种网络类型&#xff0c;分别是&#xff1a;桥接模式、nat模式、仅主机模式。 1、桥接模式 桥接模式是将主机网卡与虚拟机虚拟的网卡利用虚拟网桥进行通信。在桥接的作用下, 类似于把物理主机虚拟为一个交换机, 所有设…...

数学建模国赛入门指南

文章目录 认识数学建模及国赛认识数学建模什么是数学建模&#xff1f;数学建模比赛 国赛参赛规则、评奖原则如何评省、国奖评奖规则如何才能获奖 国赛赛题分类及选题技巧国赛赛题特点赛题分类 国赛历年题型及优秀论文 数学建模分工技巧数模必备软件数模资料文献数据收集资料收集…...

Java基础之集合

集合和数组的类比 数组: 长度固定可以存基本数据类型和引用数据类型 集合: 长度可变只能存引用数据类型存储基本数据类型要把他转化为对应的包装类 ArrayList集合 ArrayList成员方法 添加元素 删除元素 索引删除 查询 遍历数组...

深度学习和NLP中的注意力和记忆

深度学习和NLP中的注意力和记忆 文章目录 一、说明二、注意力解决了什么问题&#xff1f;#三、关注的代价#四、机器翻译之外的关注#五、注意力&#xff08;模糊&#xff09;记忆&#xff1f;# 一、说明 深度学习的最新趋势是注意力机制。在一次采访中&#xff0c;现任 OpenAI 研…...

自用的C++20协程学习资料

C20的一个重要更新就是加入了协程。 在网上找了很多学习资料&#xff0c;看了之后还是不明白。 最后找到下面这些资料总算是讲得比较明白&#xff0c;大家可以按照顺序阅读&#xff1a; 渡劫 C 协程&#xff08;1&#xff09;&#xff1a;C 协程概览C20协程原理和应用...

【C++】优先级队列(底层代码解释)

一. 定义 优先级队列是一个容器适配器&#xff0c;他可以根据不同的需求采用不同的容器来实现这个数据结构&#xff0c;优先级队列采用了堆的数据结构&#xff0c;默认使用vector作为容器&#xff0c;且采用大堆的结构进行存储数据。 &#xff08;1&#xff09;在第一个构造函数…...

华为模拟器防火墙配置实验(二)

一.实验拓扑 二.实验要求 1&#xff0c;DMZ区内的服务器&#xff0c;办公区仅能在办公时间内&#xff08;9&#xff1a;00 - 18&#xff1a;00&#xff09;可以访问&#xff0c;生产区的设备全天可以访问. 2&#xff0c;生产区不允许访问互联网&#xff0c;办公区和游客区允许…...

group 与查询字段

需求 每周周一&#xff0c;统计菜单在过去一周&#xff0c;点击次数&#xff0c;和点击人数&#xff08;同一个人访问多次按一次计算&#xff09; 表及数据 日志表 CREATE TABLE t_data_log ( id varchar(50) NOT NULL COMMENT 主键id, operation_object varchar(500) DE…...

PlantUML 教程:绘制时序图

绘制时序图是 PlantUML 的一个强大功能&#xff0c;下面是详细的 PlantUML 时序图教程&#xff0c;帮助你理解如何使用它来创建清晰的时序图。 基本概念 时序图&#xff08;Sequence Diagram&#xff09;用于展示对象之间的交互以及它们之间的消息传递顺序。它主要由以下元素…...

自定义ViewGroup-流式布局FlowLayout(重点:测量和布局)

效果 child按行显示&#xff0c;显示不下就换行。 分析 继承ViewGrouponDraw()不重写&#xff0c;使用ViewGroup的测量-重点 &#xff08;测量child、测量自己&#xff09;布局-重点 &#xff08;布局child&#xff09; 知识点 执行顺序 构造函数 -> onMeasure() -> …...

C++的入门基础(二)

目录 引用的概念和定义引用的特性引用的使用const引用指针和引用的关系引用的实际作用inlinenullptr 引用的概念和定义 在语法上引用是给一个变量取别名&#xff0c;和这个变量共用同一块空间&#xff0c;并不会给引用开一块空间。 取别名就是一块空间有多个名字 类型& …...

显示产业如何突破芯片短板

尽管中国在显示IC领域面临一定的不足&#xff0c;但新技术的不断涌现为中国企业提供了重要的发展机遇。随着手机、平板电脑和液晶电视对显示屏性能要求的不断提高&#xff0c;显示驱动IC也必须相应地发展&#xff0c;向更高分辨率、更大尺寸和更低功耗的方向迈进。例如&#xf…...

STM32HAL库+ESP8266+cJSON+微信小程序_连接华为云物联网平台

STM32HAL库ESP8266cJSON微信小程序_连接华为云物联网平台 实验使用资源&#xff1a;正点原子F407 USART1&#xff1a;PA9P、A10&#xff08;串口打印调试&#xff09; USART3&#xff1a;PB10、PB11&#xff08;WiFi模块&#xff09; DHT11&#xff1a;PG9&#xff08;采集数据…...

debian或Ubuntu中开启ssh允许root远程ssh登录的方法

debian或Ubuntu中开启ssh允许root远程ssh登录的方法 前因&#xff1a; 因开发需要&#xff0c;需要设置开发板的ssh远程连接。 操作步骤如下&#xff1a; 安装openssh-server sudo apt install openssh-server设置root用户密码&#xff1a; sudo passwd root允许root用户…...

C++《日期》实现

C《日期》实现 头文件实现文件 头文件 在该文件中是为了声明函数和定义类成员 using namespace std; class Date {friend ostream& operator<<(ostream& out, const Date& d);//友元friend istream& operator>>(istream& cin, Date& d);//…...

【面试题】MySQL(第三篇)

目录 1. MySQL中如何处理死锁&#xff1f; 2. MySQL中的主从复制是如何实现的&#xff1f; 3. MySQL中的慢查询日志是什么&#xff1f;如何使用它来优化性能&#xff1f; 4.存储过程 一、定义与基本概念 二、特点与优势 三、类型与分类 四、创建与执行 五、示例 六、总…...

tensorflow之欠拟合与过拟合,正则化缓解

过拟合泛化性弱 欠拟合解决方法&#xff1a; 增加输入特征项 增加网络参数 减少正则化参数 过拟合的解决方法&#xff1a; 数据清洗 增大训练集 采用正则化 增大正则化参数 正则化缓解过拟合 正则化在损失函数中引入模型复杂度指标&#xff0c;利用给w增加权重&#xff0c;…...

Oracle查询表空间大小

1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

MODBUS TCP转CANopen 技术赋能高效协同作业

在现代工业自动化领域&#xff0c;MODBUS TCP和CANopen两种通讯协议因其稳定性和高效性被广泛应用于各种设备和系统中。而随着科技的不断进步&#xff0c;这两种通讯协议也正在被逐步融合&#xff0c;形成了一种新型的通讯方式——开疆智能MODBUS TCP转CANopen网关KJ-TCPC-CANP…...

零基础设计模式——行为型模式 - 责任链模式

第四部分&#xff1a;行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习&#xff01;行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想&#xff1a;使多个对象都有机会处…...

【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具

第2章 虚拟机性能监控&#xff0c;故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令&#xff1a;jps [options] [hostid] 功能&#xff1a;本地虚拟机进程显示进程ID&#xff08;与ps相同&#xff09;&#xff0c;可同时显示主类&#x…...

【开发技术】.Net使用FFmpeg视频特定帧上绘制内容

目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法&#xff0c;当前调用一个医疗行业的AI识别算法后返回…...

【Java学习笔记】BigInteger 和 BigDecimal 类

BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点&#xff1a;传参类型必须是类对象 一、BigInteger 1. 作用&#xff1a;适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...

LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》

这段 Python 代码是一个完整的 知识库数据库操作模块&#xff0c;用于对本地知识库系统中的知识库进行增删改查&#xff08;CRUD&#xff09;操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 &#x1f4d8; 一、整体功能概述 该模块…...

为什么要创建 Vue 实例

核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...

什么是VR全景技术

VR全景技术&#xff0c;全称为虚拟现实全景技术&#xff0c;是通过计算机图像模拟生成三维空间中的虚拟世界&#xff0c;使用户能够在该虚拟世界中进行全方位、无死角的观察和交互的技术。VR全景技术模拟人在真实空间中的视觉体验&#xff0c;结合图文、3D、音视频等多媒体元素…...

华为OD最新机试真题-数组组成的最小数字-OD统一考试(B卷)

题目描述 给定一个整型数组,请从该数组中选择3个元素 组成最小数字并输出 (如果数组长度小于3,则选择数组中所有元素来组成最小数字)。 输入描述 行用半角逗号分割的字符串记录的整型数组,0<数组长度<= 100,0<整数的取值范围<= 10000。 输出描述 由3个元素组成…...