当前位置: 首页 > news >正文

【HuggingFace文档学习】Bert的token分类与句分类

BERT特性:

  • BERT的嵌入是位置绝对(position absolute)的。
  • BERT擅长于预测掩码tokenNLU,但是不擅长下一文本生成。

1.BertForTokenClassification

一个用于token级分类的模型,可用于命名实体识别(NER)、部分语音标记(POS)等。对于给定的输入序列,模型将为每个标记/词产生一个标签。
输出的维度是 [batch_size, sequence_length, num_labels],其中 num_labels 是可能的标签数量。

class transformers.BertForTokenClassification(config)

继承父类:BertPreTrainedModel、torch.nn.Module

参数:config (BertConfig)——包含模型所有参数的模型配置类

包含一个token分类的任务头(线性层,可用于NER)。

forward方法:

参数

  • input_ids (torch.LongTensor of shape (batch_size, sequence_length)) — 输入序列对应的分词索引列表(indices list)。索引根据AutoTokenizer得到。
  • attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional) — 对输入序列的部分token加上掩码,使得注意力机制不会计算到。如填充token的索引(padding token indices)。取值为 [0, 1]二者之一。取0则表明掩码,取1则表明不掩码。
  • token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — 在分句任务中,表明token是属于第一句还是第二句。取值为 [0, 1]二者之一。
  • position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — 输入序列对应的位置索引列表(positional indices list)。 取值范围为 [0, config.max_position_embeddings - 1],从而加入位置信息。
  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) — 掩码(多头)自注意力模块的头。取值为 [0, 1]二者之一:取0则表示对应的头要掩码,取1则表示对应的头不掩码。
  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — 如果想要直接将嵌入向量传入给模型,由自己控制 input_ids的关联向量,那么就传这个参数。这样就不需要由本模型内部的嵌入层矩阵运算 input_ids
  • output_attentions (bool, optional) — 是否希望模型返回所有的注意力分数
  • output_hidden_states (bool, optional) — 是否希望模型返回所有层的隐藏状态
  • return_dict (bool, optional) — 是否希望输出的是ModelOutput,而不是直接的元组tuple。
  • labels (torch.LongTensor of shape (batch_size, sequence_length), optional) — 提供标签,用于计算loss。取值范围为 [0, config.max_position_embeddings - 1]

返回值

transformers.modeling_outputs.TokenClassifierOutput 或 tuple(torch.FloatTensor)

  1. 如果 return_dictFalse(或 return_dict 为空但配置文件中 self.config.use_return_dictFalse):
    • 如果提供了 labels 参数,输出是一个元组,包含:
      • loss: 计算的损失值。
      • logits: 分类头的输出,形状为 (batch_size, sequence_length, num_labels)
      • 其他 BERT 的输出(例如隐藏状态和注意力权重),但这取决于 BERT 的配置和输入参数。
    • 如果没有提供 labels 参数,输出只包含 logits 和其他 BERT 的输出。
  2. 如果 return_dictTrue(或 return_dict 为空但配置文件中 self.config.use_return_dictFalse):
    • 输出是一个 TokenClassifierOutput 对象,包含以下属性:
      • loss: 如果提供了 labels 参数,这是计算的损失值。
      • logits: 分类头的输出,形状为 (batch_size, sequence_length, num_labels)
      • hidden_states: BERT 的隐藏状态输出。
      • attentions: BERT 的注意力权重输出。

代码实现

@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. forNamed-Entity-Recognition (NER) tasks.""",BERT_START_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel):def __init__(self, config):super().__init__(config)self.num_labels = config.num_labels  # 标签的数量self.bert = BertModel(config, add_pooling_layer=False)  # 预训练BERTclassifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)self.dropout = nn.Dropout(classifier_dropout)self.classifier = nn.Linear(config.hidden_size, config.num_labels)  # classification任务头,加在预训练BERT之上# Initialize weights and apply final processingself.post_init()@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))@add_code_sample_docstrings(checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,output_type=TokenClassifierOutput,config_class=_CONFIG_FOR_DOC,expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,)def forward(self,input_ids: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,token_type_ids: Optional[torch.Tensor] = None,position_ids: Optional[torch.Tensor] = None,head_mask: Optional[torch.Tensor] = None,inputs_embeds: Optional[torch.Tensor] = None,labels: Optional[torch.Tensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:r"""labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`."""return_dict = return_dict if return_dict is not None else self.config.use_return_dictoutputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)  # 预训练BERT的计算,得到输入序列经BERT计算的向量序列sequence_output = outputs[0]sequence_output = self.dropout(sequence_output)logits = self.classifier(sequence_output)  # 再经过最后的任务头classificationloss = Noneif labels is not None:loss_fct = CrossEntropyLoss()loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))if not return_dict:output = (logits,) + outputs[2:]return ((loss,) + output) if loss is not None else outputreturn TokenClassifierOutput(loss=loss,logits=logits,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)

使用示例:

from transformers import AutoTokenizer, BertForTokenClassification
import torchtokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")inputs = tokenizer("HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
)with torch.no_grad():logits = model(**inputs).logits  # 想要得到分类后的权重,获取的是输出的logits对象。predicted_token_class_ids = logits.argmax(-1)# Note that tokens are classified rather then input words which means that
# there might be more predicted token classes than words.
# Multiple token classes might account for the same word
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
# predicted_tokens_classes = ['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] 

2.BertForSequenceClassification

一个用于整个句子或段落级别的分类的模型,可用于情感分析、文本分类等。对于给定的输入,模型将为整个序列产生一个分类标签。
输出的维度是 [batch_size, num_labels],其中 num_labels 是可能的分类数量。

class transformers.BertForSequenceClassification(config)

继承父类:BertPreTrainedModel、torch.nn.Module

参数:config (BertConfig)——包含模型所有参数的模型配置类

forward方法:BertForTokenClassification相同。

BertForTokenClassification的差异:

  • BertForSequenceClassification 在 BERT 的编码器输出上增加了一个**全连接层(通常连接到 [CLS] 标记的输出)**来进行分类。
  • BertForTokenClassification 不需要额外的全连接层,而是直接使用 BERT输出的每个标记的表示,并可能有一个线性层来将其映射到标签空间。

使用示例:

import torch
from transformers import AutoTokenizer, BertForSequenceClassificationtokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity")
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity")inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")with torch.no_grad():logits = model(**inputs).logitspredicted_class_id = logits.argmax().item()
predicted_class_label = model.config.id2label[predicted_class_id]
# predicted_class_label = LABEL_1

相关文章:

【HuggingFace文档学习】Bert的token分类与句分类

BERT特性: BERT的嵌入是位置绝对(position absolute)的。BERT擅长于预测掩码token和NLU,但是不擅长下一文本生成。 1.BertForTokenClassification 一个用于token级分类的模型,可用于命名实体识别(NER)、部分语音标记…...

354 俄罗斯套娃信封问题(贪心+二分)

题目 链接 给你一个二维整数数组 envelopes ,其中 envelopes[i] [wi, hi] ,表示第 i 个信封的宽度和高度。 当另一个信封的宽度和高度都比这个信封大的时候,这个信封就可以放进另一个信封里,如同俄罗斯套娃一样。 请计算 最多…...

Vue页面结构

Vue页面结构 App.vue <!--html标签--> <template><div><h1>饿了么&#xff1f;</h1></div><HelloWorld msg"Vite Vue" /> </template> <!--js代码 vue3的语法--> <script setup> import HelloWorld f…...

【广州华锐互动】利用VR开展高压电缆运维实训,提供更加真实、安全的学习环境

VR高压电缆维护实训系统由广州华锐互动开发&#xff0c;应用于多家供电企业的员工培训中&#xff0c;该系统突破了传统培训的限制&#xff0c;为学员提供了更加真实、安全的学习环境&#xff0c;提高了培训效率和效果。 传统电缆井下运维培训通常是在实际井下环境中进行&#x…...

git的介绍和安装、常用命令、忽略文件、分支

git介绍和安装 首页功能写完了 ⇢ \dashrightarrow ⇢ 正常应该提交到版本仓库 ⇢ \dashrightarrow ⇢ 大家都能看到这个 ⇢ \dashrightarrow ⇢ 运维应该把现在这个项目部署到测试环境中 ⇢ \dashrightarrow ⇢ 测试开始测试 ⇢ \dashrightarrow ⇢ 客户可以看到目前做的…...

DNS(二)

实现 Internet DNS 架构 架构图 实验环境 关闭SELinux、Firewalld。时间保持一致 主机名IP角色client192.168.28.146DNS客户端&#xff0c;DNS地址为192.168.28.145localdns192.168.28.145本地DNS服务器&#xff08;只缓存&#xff09;forward192.168.28.144转发目标DNS服务…...

win 10怎么录屏?教你轻松捕捉屏幕活动

在当今科技快速发展的时代&#xff0c;录屏已成为信息分享、教学、游戏直播等方面的重要工具。无论是为了制作教程、分享游戏过程还是保存重要信息&#xff0c;录屏功能都发挥着举足轻重的作用。可是很多人不知道win 10怎么录屏&#xff0c;本文将详细介绍win10的三种常用录屏方…...

IP 协议的相关特性(部分)

IP 协议的报文格式 4位版本号&#xff1a; 用来表示IP协议的版本&#xff0c;现有的IP协议只有两个版本&#xff0c;IPv4&#xff0c;IPv6。 4位首部长度&#xff1a; 设定和TCP的首部长度一样 8位服务类型&#xff1a; &#xff08;真正只有4位才有效果&#xff09;&#xf…...

Java设计模式之代表模式

代表模式&#xff08;Mediator Pattern&#xff09;是一种行为型设计模式&#xff0c;它通过封装一组对象之间的交互方式&#xff0c;使得这些对象之间的通信变得松散耦合&#xff0c;从而降低了对象之间的直接依赖关系。代表模式通过引入一个中介者&#xff08;Mediator&#…...

MySQL 查询 唯一约束 对应的字段,列名称合并

MySQL 查询 唯一约束 对应的字段&#xff0c;列名称合并 SELECT F.DbName,F.TableName,F.ConstraintName,GROUP_CONCAT(ColumnName) ColumnName FROM ( SELECT t1.TABLE_SCHEMA DbName, t1.TABLE_NAME TableName,t1.CONSTRAINT_NAME ConstraintName,t2.COLUMN_NAME ColumnNam…...

JDBC-day05(DAO及相关实现类)

七&#xff1a;DAO及相关实现类 1. DAO介绍 DAO&#xff1a;全称Data Access Object,是数据访问对象.在java服务器开发的三层架构中分成控制层(Controller),表示层(Service),数据访问层(Dao),数据访问层专门负责跟数据库进行数据交互.&#xff0c;包括了对数据的CRUD&#xff…...

华为汪涛:5.5G时代UBB目标网,跃升数字生产力

[阿联酋&#xff0c;迪拜&#xff0c;2023年10月12日] 在2023全球超宽带高峰论坛上&#xff0c;华为常务董事、ICT基础设施业务管理委员会主任汪涛发表了“5.5G时代UBB目标网&#xff0c;跃升数字生产力”的主题发言&#xff0c;分享了超宽带产业的最新思考与实践&#xff0c;探…...

docker部署多个node-red操作过程

docker部署多个node-red操作过程 一、docker安装教程二、docker安装node-red2.1 在线安装node-red镜像2.1.1 拉取镜像2.1.2 创建目录并分配权限 2.2 离线安装node-red镜像 三、 docker操作node-red3.1 部署node-red3.2 查看\关闭\删除容器 四、Docker删除Redis镜像五、离线安装…...

王兴投资5G小基站

边缘计算社区获悉&#xff0c;近期深圳佳贤通信正式完成数亿元股权融资&#xff0c;本轮融资由美团龙珠领投。本轮融资资金主要用于技术研发、市场拓展等&#xff0c;将进一步巩固和扩大佳贤通信在5G小基站领域的技术及市场领先地位。 01 佳贤通信是什么样的公司&#xff1f; 深…...

【SA8295P 源码分析 (一)】54 - /ifs/bin/startupmgr 程序工作流程分析 及 script.c 介绍

【SA8295P 源码分析】54 - /ifs/bin/startupmgr 程序工作流程分析 及 script.c 介绍 一、startupmgr 可执行程序工作解析1. startupmgr\src\script.c 入口 main 函数:调用 init_loader_and_launcher 解析 scripts 数组二、ifsloader镜像加载流程分析:init_loader_and_launche…...

git 使用

参考 https://git-scm.com/book/zh/v2/Git-%E5%9F%BA%E7%A1%80-%E8%8E%B7%E5%8F%96-Git-%E4%BB%93%E5%BA%93 文件的状态变化周期 文章目录 git 基础检查当前文件状态、查看已暂存和未暂存的修改暂存前后的变化跟踪新文件提交更新移除文件移动文件、重命名操作查看提交历史撤消…...

MFC扩展库BCGControlBar Pro v33.6新版亮点 - 图形管理器改造升级

BCGControlBar库拥有500多个经过全面设计、测试和充分记录的MFC扩展类。 我们的组件可以轻松地集成到您的应用程序中&#xff0c;并为您节省数百个开发和调试时间。 BCGControlBar专业版 v33.6已正式发布了&#xff0c;此版本包含了对图表组件的改进、带隐藏标签的单类功能区栏…...

云上攻防-云原生篇KubernetesK8s安全APIKubelet未授权访问容器执行

文章目录 K8S集群架构解释K8S集群攻击点-重点API Server未授权访问&kubelet未授权访问复现k8s集群环境搭建1、攻击8080端口&#xff1a;API Server未授权访问2、攻击6443端口&#xff1a;API Server未授权访问3、攻击10250端口&#xff1a;kubelet未授权访问 K8S集群架构解…...

Django 访问静态文件的APP staticfiles

Django 框架默认带的 APP&#xff1a; django.contrib.staticfiles Django文档中也写明了&#xff1a;如何管理静态文件&#xff08;如图片、JavaScript、CSS&#xff09; |姜戈 文档 |姜戈 (djangoproject.com)https://docs.djangoproject.com/zh-hans/4.2/howto/static-file…...

Airbnb 迁移 SwiftUI 实践

从 2022 年开始,Airbnb 的 iOS 团队就认为 SwiftUI 已经足够成熟,可以在他们的官方应用中使用它。但 Airbnb 的工程师 Bryn Bodayle 表示,这需要一个谨慎的转换过程。 Airbnb 的工程师认为,SwiftUI 的主要优势是它的灵活性和可组合性、声明性、简洁性和惯用性,他们希望这…...

业务系统对接大模型的基础方案:架构设计与关键步骤

业务系统对接大模型&#xff1a;架构设计与关键步骤 在当今数字化转型的浪潮中&#xff0c;大语言模型&#xff08;LLM&#xff09;已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中&#xff0c;不仅可以优化用户体验&#xff0c;还能为业务决策提供…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

pam_env.so模块配置解析

在PAM&#xff08;Pluggable Authentication Modules&#xff09;配置中&#xff0c; /etc/pam.d/su 文件相关配置含义如下&#xff1a; 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块&#xff0c;负责验证用户身份&am…...

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话&#xff1a; “利润不是赚出来的&#xff0c;是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业&#xff0c;很多企业看着销售不错&#xff0c;账上却没钱、利润也不见了&#xff0c;一翻库存才发现&#xff1a; 一堆卖不动的旧货…...

前端开发面试题总结-JavaScript篇(一)

文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包&#xff08;Closure&#xff09;&#xff1f;闭包有什么应用场景和潜在问题&#xff1f;2.解释 JavaScript 的作用域链&#xff08;Scope Chain&#xff09; 二、原型与继承3.原型链是什么&#xff1f;如何实现继承&a…...

视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)

前言&#xff1a; 最近在做行为检测相关的模型&#xff0c;用的是时空图卷积网络&#xff08;STGCN&#xff09;&#xff0c;但原有kinetic-400数据集数据质量较低&#xff0c;需要进行细粒度的标注&#xff0c;同时粗略搜了下已有开源工具基本都集中于图像分割这块&#xff0c…...

springboot整合VUE之在线教育管理系统简介

可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生&#xff0c;小白用户&#xff0c;想学习知识的 有点基础&#xff0c;想要通过项…...

【JVM面试篇】高频八股汇总——类加载和类加载器

目录 1. 讲一下类加载过程&#xff1f; 2. Java创建对象的过程&#xff1f; 3. 对象的生命周期&#xff1f; 4. 类加载器有哪些&#xff1f; 5. 双亲委派模型的作用&#xff08;好处&#xff09;&#xff1f; 6. 讲一下类的加载和双亲委派原则&#xff1f; 7. 双亲委派模…...

莫兰迪高级灰总结计划简约商务通用PPT模版

莫兰迪高级灰总结计划简约商务通用PPT模版&#xff0c;莫兰迪调色板清新简约工作汇报PPT模版&#xff0c;莫兰迪时尚风极简设计PPT模版&#xff0c;大学生毕业论文答辩PPT模版&#xff0c;莫兰迪配色总结计划简约商务通用PPT模版&#xff0c;莫兰迪商务汇报PPT模版&#xff0c;…...

Web中间件--tomcat学习

Web中间件–tomcat Java虚拟机详解 什么是JAVA虚拟机 Java虚拟机是一个抽象的计算机&#xff0c;它可以执行Java字节码。Java虚拟机是Java平台的一部分&#xff0c;Java平台由Java语言、Java API和Java虚拟机组成。Java虚拟机的主要作用是将Java字节码转换为机器代码&#x…...