【HuggingFace文档学习】Bert的token分类与句分类
BERT特性:
- BERT的嵌入是位置绝对(position absolute)的。
- BERT擅长于预测掩码token和NLU,但是不擅长下一文本生成。
1.BertForTokenClassification
一个用于token级分类的模型,可用于命名实体识别(NER)、部分语音标记(POS)等。对于给定的输入序列,模型将为每个标记/词产生一个标签。
输出的维度是 [batch_size, sequence_length, num_labels],其中 num_labels 是可能的标签数量。
class transformers.BertForTokenClassification(config)
继承父类:BertPreTrainedModel、torch.nn.Module
参数:config (BertConfig)——包含模型所有参数的模型配置类。
包含一个token分类的任务头(线性层,可用于NER)。
forward方法:
参数
- input_ids (
torch.LongTensorof shape(batch_size, sequence_length)) — 输入序列对应的分词索引列表(indices list)。索引根据AutoTokenizer得到。 - attention_mask (
torch.FloatTensorof shape(batch_size, sequence_length), optional) — 对输入序列的部分token加上掩码,使得注意力机制不会计算到。如填充token的索引(padding token indices)。取值为[0, 1]二者之一。取0则表明掩码,取1则表明不掩码。 - token_type_ids (
torch.LongTensorof shape(batch_size, sequence_length), optional) — 在分句任务中,表明token是属于第一句还是第二句。取值为[0, 1]二者之一。 - position_ids (
torch.LongTensorof shape(batch_size, sequence_length), optional) — 输入序列对应的位置索引列表(positional indices list)。 取值范围为[0, config.max_position_embeddings - 1],从而加入位置信息。 - head_mask (
torch.FloatTensorof shape(num_heads,)or(num_layers, num_heads), optional) — 掩码(多头)自注意力模块的头。取值为[0, 1]二者之一:取0则表示对应的头要掩码,取1则表示对应的头不掩码。 - inputs_embeds (
torch.FloatTensorof shape(batch_size, sequence_length, hidden_size), optional) — 如果想要直接将嵌入向量传入给模型,由自己控制input_ids的关联向量,那么就传这个参数。这样就不需要由本模型内部的嵌入层矩阵运算input_ids。 - output_attentions (
bool, optional) — 是否希望模型返回所有的注意力分数。 - output_hidden_states (
bool, optional) — 是否希望模型返回所有层的隐藏状态。 - return_dict (
bool, optional) — 是否希望输出的是ModelOutput,而不是直接的元组tuple。 - labels (
torch.LongTensorof shape(batch_size, sequence_length), optional) — 提供标签,用于计算loss。取值范围为[0, config.max_position_embeddings - 1]。
返回值
transformers.modeling_outputs.TokenClassifierOutput 或 tuple(torch.FloatTensor)
- 如果
return_dict为False(或return_dict为空但配置文件中self.config.use_return_dict为False):- 如果提供了
labels参数,输出是一个元组,包含:loss: 计算的损失值。logits: 分类头的输出,形状为(batch_size, sequence_length, num_labels)。- 其他 BERT 的输出(例如隐藏状态和注意力权重),但这取决于 BERT 的配置和输入参数。
- 如果没有提供
labels参数,输出只包含logits和其他 BERT 的输出。
- 如果提供了
- 如果
return_dict为True(或return_dict为空但配置文件中self.config.use_return_dict为False):- 输出是一个
TokenClassifierOutput对象,包含以下属性:loss: 如果提供了labels参数,这是计算的损失值。logits: 分类头的输出,形状为(batch_size, sequence_length, num_labels)。hidden_states: BERT 的隐藏状态输出。attentions: BERT 的注意力权重输出。
- 输出是一个
代码实现
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. forNamed-Entity-Recognition (NER) tasks.""",BERT_START_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel):def __init__(self, config):super().__init__(config)self.num_labels = config.num_labels # 标签的数量self.bert = BertModel(config, add_pooling_layer=False) # 预训练BERTclassifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)self.dropout = nn.Dropout(classifier_dropout)self.classifier = nn.Linear(config.hidden_size, config.num_labels) # classification任务头,加在预训练BERT之上# Initialize weights and apply final processingself.post_init()@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))@add_code_sample_docstrings(checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,output_type=TokenClassifierOutput,config_class=_CONFIG_FOR_DOC,expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,)def forward(self,input_ids: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,token_type_ids: Optional[torch.Tensor] = None,position_ids: Optional[torch.Tensor] = None,head_mask: Optional[torch.Tensor] = None,inputs_embeds: Optional[torch.Tensor] = None,labels: Optional[torch.Tensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:r"""labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`."""return_dict = return_dict if return_dict is not None else self.config.use_return_dictoutputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,) # 预训练BERT的计算,得到输入序列经BERT计算的向量序列sequence_output = outputs[0]sequence_output = self.dropout(sequence_output)logits = self.classifier(sequence_output) # 再经过最后的任务头classificationloss = Noneif labels is not None:loss_fct = CrossEntropyLoss()loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))if not return_dict:output = (logits,) + outputs[2:]return ((loss,) + output) if loss is not None else outputreturn TokenClassifierOutput(loss=loss,logits=logits,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)
使用示例:
from transformers import AutoTokenizer, BertForTokenClassification
import torchtokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")inputs = tokenizer("HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
)with torch.no_grad():logits = model(**inputs).logits # 想要得到分类后的权重,获取的是输出的logits对象。predicted_token_class_ids = logits.argmax(-1)# Note that tokens are classified rather then input words which means that
# there might be more predicted token classes than words.
# Multiple token classes might account for the same word
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
# predicted_tokens_classes = ['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']
2.BertForSequenceClassification
一个用于整个句子或段落级别的分类的模型,可用于情感分析、文本分类等。对于给定的输入,模型将为整个序列产生一个分类标签。
输出的维度是 [batch_size, num_labels],其中 num_labels 是可能的分类数量。
class transformers.BertForSequenceClassification(config)
继承父类:BertPreTrainedModel、torch.nn.Module
参数:config (BertConfig)——包含模型所有参数的模型配置类。
forward方法:与BertForTokenClassification相同。
与BertForTokenClassification的差异:
- BertForSequenceClassification 在 BERT 的编码器输出上增加了一个**全连接层(通常连接到 [CLS] 标记的输出)**来进行分类。
- BertForTokenClassification 不需要额外的全连接层,而是直接使用 BERT输出的每个标记的表示,并可能有一个线性层来将其映射到标签空间。
使用示例:
import torch
from transformers import AutoTokenizer, BertForSequenceClassificationtokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity")
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity")inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")with torch.no_grad():logits = model(**inputs).logitspredicted_class_id = logits.argmax().item()
predicted_class_label = model.config.id2label[predicted_class_id]
# predicted_class_label = LABEL_1
相关文章:
【HuggingFace文档学习】Bert的token分类与句分类
BERT特性: BERT的嵌入是位置绝对(position absolute)的。BERT擅长于预测掩码token和NLU,但是不擅长下一文本生成。 1.BertForTokenClassification 一个用于token级分类的模型,可用于命名实体识别(NER)、部分语音标记…...
354 俄罗斯套娃信封问题(贪心+二分)
题目 链接 给你一个二维整数数组 envelopes ,其中 envelopes[i] [wi, hi] ,表示第 i 个信封的宽度和高度。 当另一个信封的宽度和高度都比这个信封大的时候,这个信封就可以放进另一个信封里,如同俄罗斯套娃一样。 请计算 最多…...
Vue页面结构
Vue页面结构 App.vue <!--html标签--> <template><div><h1>饿了么?</h1></div><HelloWorld msg"Vite Vue" /> </template> <!--js代码 vue3的语法--> <script setup> import HelloWorld f…...
【广州华锐互动】利用VR开展高压电缆运维实训,提供更加真实、安全的学习环境
VR高压电缆维护实训系统由广州华锐互动开发,应用于多家供电企业的员工培训中,该系统突破了传统培训的限制,为学员提供了更加真实、安全的学习环境,提高了培训效率和效果。 传统电缆井下运维培训通常是在实际井下环境中进行&#x…...
git的介绍和安装、常用命令、忽略文件、分支
git介绍和安装 首页功能写完了 ⇢ \dashrightarrow ⇢ 正常应该提交到版本仓库 ⇢ \dashrightarrow ⇢ 大家都能看到这个 ⇢ \dashrightarrow ⇢ 运维应该把现在这个项目部署到测试环境中 ⇢ \dashrightarrow ⇢ 测试开始测试 ⇢ \dashrightarrow ⇢ 客户可以看到目前做的…...
DNS(二)
实现 Internet DNS 架构 架构图 实验环境 关闭SELinux、Firewalld。时间保持一致 主机名IP角色client192.168.28.146DNS客户端,DNS地址为192.168.28.145localdns192.168.28.145本地DNS服务器(只缓存)forward192.168.28.144转发目标DNS服务…...
win 10怎么录屏?教你轻松捕捉屏幕活动
在当今科技快速发展的时代,录屏已成为信息分享、教学、游戏直播等方面的重要工具。无论是为了制作教程、分享游戏过程还是保存重要信息,录屏功能都发挥着举足轻重的作用。可是很多人不知道win 10怎么录屏,本文将详细介绍win10的三种常用录屏方…...
IP 协议的相关特性(部分)
IP 协议的报文格式 4位版本号: 用来表示IP协议的版本,现有的IP协议只有两个版本,IPv4,IPv6。 4位首部长度: 设定和TCP的首部长度一样 8位服务类型: (真正只有4位才有效果)…...
Java设计模式之代表模式
代表模式(Mediator Pattern)是一种行为型设计模式,它通过封装一组对象之间的交互方式,使得这些对象之间的通信变得松散耦合,从而降低了对象之间的直接依赖关系。代表模式通过引入一个中介者(Mediator&#…...
MySQL 查询 唯一约束 对应的字段,列名称合并
MySQL 查询 唯一约束 对应的字段,列名称合并 SELECT F.DbName,F.TableName,F.ConstraintName,GROUP_CONCAT(ColumnName) ColumnName FROM ( SELECT t1.TABLE_SCHEMA DbName, t1.TABLE_NAME TableName,t1.CONSTRAINT_NAME ConstraintName,t2.COLUMN_NAME ColumnNam…...
JDBC-day05(DAO及相关实现类)
七:DAO及相关实现类 1. DAO介绍 DAO:全称Data Access Object,是数据访问对象.在java服务器开发的三层架构中分成控制层(Controller),表示层(Service),数据访问层(Dao),数据访问层专门负责跟数据库进行数据交互.,包括了对数据的CRUDÿ…...
华为汪涛:5.5G时代UBB目标网,跃升数字生产力
[阿联酋,迪拜,2023年10月12日] 在2023全球超宽带高峰论坛上,华为常务董事、ICT基础设施业务管理委员会主任汪涛发表了“5.5G时代UBB目标网,跃升数字生产力”的主题发言,分享了超宽带产业的最新思考与实践,探…...
docker部署多个node-red操作过程
docker部署多个node-red操作过程 一、docker安装教程二、docker安装node-red2.1 在线安装node-red镜像2.1.1 拉取镜像2.1.2 创建目录并分配权限 2.2 离线安装node-red镜像 三、 docker操作node-red3.1 部署node-red3.2 查看\关闭\删除容器 四、Docker删除Redis镜像五、离线安装…...
王兴投资5G小基站
边缘计算社区获悉,近期深圳佳贤通信正式完成数亿元股权融资,本轮融资由美团龙珠领投。本轮融资资金主要用于技术研发、市场拓展等,将进一步巩固和扩大佳贤通信在5G小基站领域的技术及市场领先地位。 01 佳贤通信是什么样的公司? 深…...
【SA8295P 源码分析 (一)】54 - /ifs/bin/startupmgr 程序工作流程分析 及 script.c 介绍
【SA8295P 源码分析】54 - /ifs/bin/startupmgr 程序工作流程分析 及 script.c 介绍 一、startupmgr 可执行程序工作解析1. startupmgr\src\script.c 入口 main 函数:调用 init_loader_and_launcher 解析 scripts 数组二、ifsloader镜像加载流程分析:init_loader_and_launche…...
git 使用
参考 https://git-scm.com/book/zh/v2/Git-%E5%9F%BA%E7%A1%80-%E8%8E%B7%E5%8F%96-Git-%E4%BB%93%E5%BA%93 文件的状态变化周期 文章目录 git 基础检查当前文件状态、查看已暂存和未暂存的修改暂存前后的变化跟踪新文件提交更新移除文件移动文件、重命名操作查看提交历史撤消…...
MFC扩展库BCGControlBar Pro v33.6新版亮点 - 图形管理器改造升级
BCGControlBar库拥有500多个经过全面设计、测试和充分记录的MFC扩展类。 我们的组件可以轻松地集成到您的应用程序中,并为您节省数百个开发和调试时间。 BCGControlBar专业版 v33.6已正式发布了,此版本包含了对图表组件的改进、带隐藏标签的单类功能区栏…...
云上攻防-云原生篇KubernetesK8s安全APIKubelet未授权访问容器执行
文章目录 K8S集群架构解释K8S集群攻击点-重点API Server未授权访问&kubelet未授权访问复现k8s集群环境搭建1、攻击8080端口:API Server未授权访问2、攻击6443端口:API Server未授权访问3、攻击10250端口:kubelet未授权访问 K8S集群架构解…...
Django 访问静态文件的APP staticfiles
Django 框架默认带的 APP: django.contrib.staticfiles Django文档中也写明了:如何管理静态文件(如图片、JavaScript、CSS) |姜戈 文档 |姜戈 (djangoproject.com)https://docs.djangoproject.com/zh-hans/4.2/howto/static-file…...
Airbnb 迁移 SwiftUI 实践
从 2022 年开始,Airbnb 的 iOS 团队就认为 SwiftUI 已经足够成熟,可以在他们的官方应用中使用它。但 Airbnb 的工程师 Bryn Bodayle 表示,这需要一个谨慎的转换过程。 Airbnb 的工程师认为,SwiftUI 的主要优势是它的灵活性和可组合性、声明性、简洁性和惯用性,他们希望这…...
JavaSec-RCE
简介 RCE(Remote Code Execution),可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景:Groovy代码注入 Groovy是一种基于JVM的动态语言,语法简洁,支持闭包、动态类型和Java互操作性,…...
label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...
1688商品列表API与其他数据源的对接思路
将1688商品列表API与其他数据源对接时,需结合业务场景设计数据流转链路,重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点: 一、核心对接场景与目标 商品数据同步 场景:将1688商品信息…...
dedecms 织梦自定义表单留言增加ajax验证码功能
增加ajax功能模块,用户不点击提交按钮,只要输入框失去焦点,就会提前提示验证码是否正确。 一,模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…...
Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!
一、引言 在数据驱动的背景下,知识图谱凭借其高效的信息组织能力,正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合,探讨知识图谱开发的实现细节,帮助读者掌握该技术栈在实际项目中的落地方法。 …...
浅谈不同二分算法的查找情况
二分算法原理比较简单,但是实际的算法模板却有很多,这一切都源于二分查找问题中的复杂情况和二分算法的边界处理,以下是博主对一些二分算法查找的情况分析。 需要说明的是,以下二分算法都是基于有序序列为升序有序的情况…...
初学 pytest 记录
安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...
代码随想录刷题day30
1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币,另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额,返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...
莫兰迪高级灰总结计划简约商务通用PPT模版
莫兰迪高级灰总结计划简约商务通用PPT模版,莫兰迪调色板清新简约工作汇报PPT模版,莫兰迪时尚风极简设计PPT模版,大学生毕业论文答辩PPT模版,莫兰迪配色总结计划简约商务通用PPT模版,莫兰迪商务汇报PPT模版,…...
在 Spring Boot 项目里,MYSQL中json类型字段使用
前言: 因为程序特殊需求导致,需要mysql数据库存储json类型数据,因此记录一下使用流程 1.java实体中新增字段 private List<User> users 2.增加mybatis-plus注解 TableField(typeHandler FastjsonTypeHandler.class) private Lis…...
