bilstm -crf 的torch官方的教程 不支持 batchsize 。。。。有更效率更高的方案吗
根据您的需求,您正在寻找一个支持批量处理(batch size)的BiLSTM-CRF模型的PyTorch教程。以下是一些资源,它们提供了更高效的解决方案,包括批量处理的支持:
1. **a-PyTorch-Tutorial-to-Sequence-Labeling**:这个GitHub仓库提供了一个关于序列标注的PyTorch教程,其中包括了BiLSTM-CRF模型的实现。教程中详细解释了如何准备数据、构建模型以及训练模型,支持批量处理。[GitHub - sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling)
GitHub - sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling: Empower Sequence Labeling with Task-Aware Neural Language Model | a PyTorch Tutorial to Sequence Labeling
2. **cn-PyTorch-Sequence-Labeling-Tutorial-with-LM-LSTM-CRF**:这个资源是上述教程的中文翻译版本,同样提供了关于如何使用BiLSTM-CRF进行序列标注的详细指导,包括批量处理的支持。[GitHub - ofooo/cn-PyTorch-Sequence-Labeling-Tutorial-with-LM-LSTM-CRF](https://github.com/ofooo/cn-PyTorch-Sequence-Labeling-Tutorial-with-LM-LSTM-CRF)
https://github.com/ofooo/cn-PyTorch-Sequence-Labeling-Tutorial-with-LM-LSTM-CRF
3. **NLP with PyTorch: A Comprehensive Guide**:这个DataCamp教程提供了一个全面的指南,介绍了如何使用PyTorch进行自然语言处理,包括数据准备、模型定义、训练和预测。教程中提到了如何使用`DataLoader`来创建数据集,这对于批量处理是非常有用的。[DataCamp - NLP with PyTorch: A Comprehensive Guide](https://www.datacamp.com/tutorial/nlp-with-pytorch-a-comprehensive-guide)
https://www.datacamp.com/tutorial/nlp-with-pytorch-a-comprehensive-guide
改成批处理关键代码 previous_score = score[t - 1].view(batch_size, -1, 1)
def viterbi_decode(self, h: FloatTensor, mask: BoolTensor) -> List[List[int]]:"""decode labels using viterbi algorithm:param h: hidden matrix (batch_size, seq_len, num_labels):param mask: mask tensor of each sequencein mini batch (batch_size, batch_size):return: labels of each sequence in mini batch"""batch_size, seq_len, _ = h.size()# prepare the sequence lengths in each sequenceseq_lens = mask.sum(dim=1)# In mini batch, prepare the score# from the start sequence to the first labelscore = [self.start_trans.data + h[:, 0]]path = []for t in range(1, seq_len):# extract the score of previous sequence# (batch_size, num_labels, 1)previous_score = score[t - 1].view(batch_size, -1, 1)# extract the score of hidden matrix of sequence# (batch_size, 1, num_labels)h_t = h[:, t].view(batch_size, 1, -1)# extract the score in transition# from label of t-1 sequence to label of sequence of t# self.trans_matrix has the score of the transition# from sequence A to sequence B# (batch_size, num_labels, num_labels)score_t = previous_score + self.trans_matrix + h_t# keep the maximum value# and point where maximum value of each sequence# (batch_size, num_labels)best_score, best_path = score_t.max(1)score.append(best_score)path.append(best_path)
torchcrf 使用 支持批处理,torchcrf的简单使用-CSDN博客文章浏览阅读9.7k次,点赞5次,收藏33次。本文介绍了如何在PyTorch中安装和使用TorchCRF库,重点讲解了CRF模型参数设置、自定义掩码及损失函数的计算。作者探讨了如何将CRF的NLL损失与交叉熵结合,并通过自适应权重优化训练过程。虽然在单任务中效果不显著,但对于多任务学习提供了有价值的方法。
https://blog.csdn.net/csdndogo/article/details/125541213
torchcrf的简单使用-CSDN博客
为了防止文章丢失 ,吧内容转发在这里
https://blog.csdn.net/csdndogo/article/details/125541213
. 安装torchcrf,模型使用
安装:pip install TorchCRF
CRF的使用:在官网里有简单的使用说明
注意输入的格式。在其他地方下载的torchcrf有多个版本,有些版本有batch_first参数,有些没有,要看清楚有没有这个参数,默认batch_size是第一维度。
这个代码是我用来熟悉使用crf模型和损失函数用的,模拟多分类任务输入为随机数据和随机标签,所以最后的结果预测不能很好的跟标签对应。
import torch
import torch.nn as nn
import numpy as np
import random
from TorchCRF import CRF
from torch.optim import Adam
seed = 100
def seed_everything(seed=seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
num_tags = 5
model = CRF(num_tags, batch_first=True) # 这里根据情况而定
seq_len = 3
batch_size = 50
seed_everything()
trainset = torch.randn(batch_size, seq_len, num_tags) # features
traintags = (torch.rand([batch_size, seq_len])*4).floor().long() # (batch_size, seq_len)
testset = torch.randn(5, seq_len, num_tags) # features
testtags = (torch.rand([5, seq_len])*4).floor().long() # (batch_size, seq_len)
# 训练阶段
for e in range(50):
optimizer = Adam(model.parameters(), lr=0.05)
model.train()
optimizer.zero_grad()
loss = -model(trainset, traintags)
print('epoch{}: loss score is {}'.format(e, loss))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(),5)
optimizer.step()
#测试阶段
model.eval()
loss = model(testset, testtags)
model.decode(testset)
1.1模型参数,自定义掩码mask注意事项
def forward(self, emissions, labels: LongTensor, mask: BoolTensor)
1
分别为发射矩阵(各标签的预测值),标签,掩码(注意这里的mask类型为BoolTensor)
注意:此处自定义mask掩码时,使用LongTensor类型的[1,1,1,1,0,0]会报错,需要转换成ByteTensor,下面是一个简单的获取mask的函数,输入为标签数据:
def get_crfmask(self, labels):
crfmask = []
for batch in labels:
res = [0 if d == -1 else 1 for d in batch]
crfmask.append(res)
return torch.ByteTensor(crfmask)
运行运行
2. CRF的损失函数是什么?
损失函数由真实转移路径值和所有可能情况路径转移值两部分组成,损失函数的公式为
分子为真实转移路径值,分母为所有路径总分数,上图公式在crf原始代码中为:
def forward(
self, h: FloatTensor, labels: LongTensor, mask: BoolTensor) -> FloatTensor:
log_numerator = self._compute_numerator_log_likelihood(h, labels, mask)
log_denominator = self._compute_denominator_log_likelihood(h, mask)
return log_numerator - log_denominator
CRF损失函数值为负对数似然函数(NLL),所以如果原来的模型损失函数使用的是交叉熵损失函数,两个损失函数相加时要对CRF返回的损失取负。
loss = -model(trainset, traintags)
1
3. 如何联合CRF的损失函数和自己的网络模型的交叉熵损失函数进行训练?
我想在自己的模型上添加CRF,就需要联合原本的交叉熵损失函数和CRF的损失函数,因为CRF输出的时NLL,所以在模型在我仅对该损失函数取负之后和原先函数相加。
loss2 = -crf_layer(log_prob, label, mask=crfmask)
loss1 = loss_function(log_prob.permute(0, 2, 1), label)
loss = loss1 + loss2
loss.backward()
缺陷: 效果不佳,可以尝试对loss2添加权重。此处贴一段包含两个损失函数的自适应权重训练的函数。
3.1.自适应损失函数权重
由于CRF返回的损失与原来的损失数值不在一个量级,所以产生了自适应权重调整两个权重的大小来达到优化的目的。自适应权重原本属于多任务学习部分,未深入了解,代码源自某篇复现论文的博客。
class AutomaticWeightedLoss(nn.Module):
def __init__(self, num=2):
super(AutomaticWeightedLoss, self).__init__()
params = torch.ones(num, requires_grad=True)
self.params = torch.nn.Parameter(params)
def forward(self, *x):
loss_sum = 0
for i, loss in enumerate(x):
loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
return loss_sum
相关文章:
bilstm -crf 的torch官方的教程 不支持 batchsize 。。。。有更效率更高的方案吗
根据您的需求,您正在寻找一个支持批量处理(batch size)的BiLSTM-CRF模型的PyTorch教程。以下是一些资源,它们提供了更高效的解决方案,包括批量处理的支持: 1. **a-PyTorch-Tutorial-to-Sequence-Labeling*…...
Python面试常见问题及答案6
一、基础部分 问题1: 在Python中,如何将字符串转换为整数?如果字符串不是合法的数字字符串会怎样? 答案: 在Python中,可以使用int()函数将字符串转换为整数。如果字符串是合法的数字字符串,转换…...
代码随想录算法训练营第三天 | 链表理论基础 | 203.移除链表元素
感觉上是可以轻松完成的,因为对链接的结构,元素的删除过程心里明镜似的 实际上四处跑气 结构体的初始化好像完全忘掉了,用malloc折腾半天,忘记了用new,真想扇自己嘴巴子到飞起删除后写一个函数,把链表打印…...
1. 机器学习基本知识(5)——练习题(1)
1.7 🐦🔥练习题(本章重点回顾与总结) 0.回答格式约定: 对于书本内容的回答,将优先寻找书本内容作为答案进行回答。 书本内容回答完毕后,将对问题进行补充回答,上面分割线作为两个…...

vue 自定义组件image 和 input
本章主要是介绍自定义的组件:WInput:这是一个验证码输入框,自动校验,输入完成回调等;WImage:这是一个图片展示组件,集成了缩放,移动等操作。 目录 一、安装 二、引入组件 三、使用…...
系列3:基于Centos-8.6 Kubernetes使用nfs挂载pod的应用日志文件
每日禅语 古代,一位官员被革职遣返,心中苦闷无处排解,便来到一位禅师的法堂。禅师静静地听完了此人的倾诉,将他带入自己的禅房之中。禅师指着桌上的一瓶水,微笑着对官员说:“你看这瓶水,它已经…...
Jfinal项目整合Redis
1、引入相关依赖 <!-- https://mvnrepository.com/artifact/redis.clients/jedis --> <dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId><version>2.9.0</version> </dependency><depen…...

在Ubuntu服务器上备份文件到自己的百度网盘
文章目录 概述安装bypy同步文件定时任务脚本 概述 之前自购了一台阿里云服务器,系统镜像为Ubuntu 22.04, 并且搭建了LNMP开发环境(可以参考:《Ubuntu搭建PHP开发环境操作步骤(保姆级教程)》)。由于项目运行中会产生附…...

Unity 模板测试透视效果(URP)
可以实现笼中窥梦和PicoVR中通过VST局部透视效果。 使用到的Shader: Shader "Unlit/StencilShader" {Properties{[IntRange]_Index("Stencil Index",Range(0,255))0}SubShader{Tags{"RenderType""Opaque""Queue""Geo…...

《计算机视觉证书:开启职业发展新航道》
一、引言 在当今科技飞速发展的时代,计算机视觉技术正以惊人的速度改变着我们的生活和工作方式。从智能手机的人脸识别解锁到自动驾驶汽车的环境感知,计算机视觉技术的应用无处不在。而计算机视觉证书作为这一领域的专业认证,其作用愈发凸显…...

.NET6 WebApi第1讲:VSCode开发.NET项目、区别.NET5框架【两个框架启动流程详解】
一、使用VSCode开发.NET项目 1、创建文件夹,使用VSCode打开 2、安装扩展工具 1>C# 2>安装NuGet包管理工具,外部dll包依靠它来加载 法1》:NuGet Gallery,注意要启动科学的工具 法2》NuGet Package Manager GUl,…...

Git-分布式版本控制工具
目录 1. 概述 1. 1集中式版本控制工具 1.2分布式版本控制工具 2.Git 2.1 git 工作流程 1. 概述 在开发活动中,我们经常会遇到以下几个场景:备份、代码回滚、协同开发、追溯问题代码编写人和编写时间(追责)等。备份的话是为了…...
C++ 第10章 对文件的输入输出
https://www.bilibili.com/video/BV1cx4y1d7Ut/?p147&spm_id_from333.1007.top_right_bar_window_history.content.click&vd_sourcee8984989cddeb3ef7b7e9fd89098dbe8 🍁🍁🍁本篇为贺宏宏老师C语言视频教程文件输入输出部分笔记整理…...

【机器学习】手写数字识别的最优解:CNN+Softmax、Sigmoid与SVM的对比实战
一、基于CNNSoftmax函数进行分类 1数据集准备 2模型设计 3模型训练 4模型评估 5结果分析 二、 基于CNNsigmoid函数进行分类 1数据集准备 2模型设计 3模型训练 4模型评估 5结果分析 三、 基于CNNSVM进行分类 1数据集准备 2模型设计 3模型训练 4模型评估 5结果分…...
android 聊天界面键盘、表情切换丝滑
1、我们在聊天页面时候,往往会遇到,键盘、表情、其他选择切换时候页面会出现掉下来再弹起问题,这是因为,我们切换时候,键盘异步导致内容View高度变化,页面掉下来后,又被其他内容顶起这种很差视觉…...

Web项目图片视频加载缓慢/首屏加载白屏
Web项目图片视频加载缓慢/首屏加载白屏 文章目录 Web项目图片视频加载缓慢/首屏加载白屏一、原因二、 解决方案2.1、 图片和视频的优化2.1.1、压缩图片或视频2.1.2、 选择合适的图片或视频格式2.1.3、 使用图片或视频 CDN 加速2.1.4、Nginx中开启gzip 三、压缩工具推荐 一、原因…...

关于Git分支合并,跨仓库合并方式
关于Git合并代码的方式说明 文章目录 关于Git合并代码的方式说明前情提要开始合并方式一:git merge方式二:git cherry-pick方式三:git checkout Git跨仓库合并的准备事项前提拉取源仓库代码 前情提要 同仓库不同分支代码的合并可直接往下看文…...

[网络] UDP协议16位校验和
16位校验和是udp报头中的一个字段,绝大多数的教材和网课都会忽略这个字段,不去细究,我闲的蛋疼问了问ai,得到了一个答案,故作此文,以证明我爱学习之心惊天地泣鬼神(狗头 ai的回答 仅从作用来说,它会根据整个应用层报文进行运算,生成一个准确的数字,这个数字不能保证唯一性,但根…...
Vue 3 中的 `update:modelValue` 事件详解
在 Vue 3 中,update:modelValue 事件通常与 v-model 指令一起使用,以实现自定义组件的双向数据绑定。以下是对该事件的详细分析: 事件定义 首先,我们需要在组件中定义 update:modelValue 事件。可以使用 defineEmits 函…...

vue3+vite+ts 使用webrtc-streamer播放海康rtsp监控视频
了解webrtc-streamer webrtc-streamer 是一个使用简单机制通过 WebRTC 流式传输视频捕获设备和 RTSP 源的项目,它内置了一个小型的 HTTP server 来对 WebRTC需要的相关接口提供支持。相对于ffmpegflv.js的方案,延迟降低到了0.4秒左右,画面的…...

【HarmonyOS 5.0】DevEco Testing:鸿蒙应用质量保障的终极武器
——全方位测试解决方案与代码实战 一、工具定位与核心能力 DevEco Testing是HarmonyOS官方推出的一体化测试平台,覆盖应用全生命周期测试需求,主要提供五大核心能力: 测试类型检测目标关键指标功能体验基…...

CentOS下的分布式内存计算Spark环境部署
一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架,相比 MapReduce 具有以下核心优势: 内存计算:数据可常驻内存,迭代计算性能提升 10-100 倍(文档段落:3-79…...
第25节 Node.js 断言测试
Node.js的assert模块主要用于编写程序的单元测试时使用,通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试,通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...

如何在最短时间内提升打ctf(web)的水平?
刚刚刷完2遍 bugku 的 web 题,前来答题。 每个人对刷题理解是不同,有的人是看了writeup就等于刷了,有的人是收藏了writeup就等于刷了,有的人是跟着writeup做了一遍就等于刷了,还有的人是独立思考做了一遍就等于刷了。…...

AI,如何重构理解、匹配与决策?
AI 时代,我们如何理解消费? 作者|王彬 封面|Unplash 人们通过信息理解世界。 曾几何时,PC 与移动互联网重塑了人们的购物路径:信息变得唾手可得,商品决策变得高度依赖内容。 但 AI 时代的来…...
Python+ZeroMQ实战:智能车辆状态监控与模拟模式自动切换
目录 关键点 技术实现1 技术实现2 摘要: 本文将介绍如何利用Python和ZeroMQ消息队列构建一个智能车辆状态监控系统。系统能够根据时间策略自动切换驾驶模式(自动驾驶、人工驾驶、远程驾驶、主动安全),并通过实时消息推送更新车…...

Golang——6、指针和结构体
指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...

【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看
文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...
省略号和可变参数模板
本文主要介绍如何展开可变参数的参数包 1.C语言的va_list展开可变参数 #include <iostream> #include <cstdarg>void printNumbers(int count, ...) {// 声明va_list类型的变量va_list args;// 使用va_start将可变参数写入变量argsva_start(args, count);for (in…...