【HuggingFace文档学习】Bert的token分类与句分类
BERT特性:
- BERT的嵌入是位置绝对(position absolute)的。
- BERT擅长于预测掩码token和NLU,但是不擅长下一文本生成。
1.BertForTokenClassification
一个用于token级分类的模型,可用于命名实体识别(NER)、部分语音标记(POS)等。对于给定的输入序列,模型将为每个标记/词产生一个标签。
输出的维度是 [batch_size, sequence_length, num_labels],其中 num_labels 是可能的标签数量。
class transformers.BertForTokenClassification(config)
继承父类:BertPreTrainedModel、torch.nn.Module
参数:config (BertConfig)——包含模型所有参数的模型配置类。
包含一个token分类的任务头(线性层,可用于NER)。
forward方法:
参数
- input_ids (
torch.LongTensorof shape(batch_size, sequence_length)) — 输入序列对应的分词索引列表(indices list)。索引根据AutoTokenizer得到。 - attention_mask (
torch.FloatTensorof shape(batch_size, sequence_length), optional) — 对输入序列的部分token加上掩码,使得注意力机制不会计算到。如填充token的索引(padding token indices)。取值为[0, 1]二者之一。取0则表明掩码,取1则表明不掩码。 - token_type_ids (
torch.LongTensorof shape(batch_size, sequence_length), optional) — 在分句任务中,表明token是属于第一句还是第二句。取值为[0, 1]二者之一。 - position_ids (
torch.LongTensorof shape(batch_size, sequence_length), optional) — 输入序列对应的位置索引列表(positional indices list)。 取值范围为[0, config.max_position_embeddings - 1],从而加入位置信息。 - head_mask (
torch.FloatTensorof shape(num_heads,)or(num_layers, num_heads), optional) — 掩码(多头)自注意力模块的头。取值为[0, 1]二者之一:取0则表示对应的头要掩码,取1则表示对应的头不掩码。 - inputs_embeds (
torch.FloatTensorof shape(batch_size, sequence_length, hidden_size), optional) — 如果想要直接将嵌入向量传入给模型,由自己控制input_ids的关联向量,那么就传这个参数。这样就不需要由本模型内部的嵌入层矩阵运算input_ids。 - output_attentions (
bool, optional) — 是否希望模型返回所有的注意力分数。 - output_hidden_states (
bool, optional) — 是否希望模型返回所有层的隐藏状态。 - return_dict (
bool, optional) — 是否希望输出的是ModelOutput,而不是直接的元组tuple。 - labels (
torch.LongTensorof shape(batch_size, sequence_length), optional) — 提供标签,用于计算loss。取值范围为[0, config.max_position_embeddings - 1]。
返回值
transformers.modeling_outputs.TokenClassifierOutput 或 tuple(torch.FloatTensor)
- 如果
return_dict为False(或return_dict为空但配置文件中self.config.use_return_dict为False):- 如果提供了
labels参数,输出是一个元组,包含:loss: 计算的损失值。logits: 分类头的输出,形状为(batch_size, sequence_length, num_labels)。- 其他 BERT 的输出(例如隐藏状态和注意力权重),但这取决于 BERT 的配置和输入参数。
- 如果没有提供
labels参数,输出只包含logits和其他 BERT 的输出。
- 如果提供了
- 如果
return_dict为True(或return_dict为空但配置文件中self.config.use_return_dict为False):- 输出是一个
TokenClassifierOutput对象,包含以下属性:loss: 如果提供了labels参数,这是计算的损失值。logits: 分类头的输出,形状为(batch_size, sequence_length, num_labels)。hidden_states: BERT 的隐藏状态输出。attentions: BERT 的注意力权重输出。
- 输出是一个
代码实现
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. forNamed-Entity-Recognition (NER) tasks.""",BERT_START_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel):def __init__(self, config):super().__init__(config)self.num_labels = config.num_labels # 标签的数量self.bert = BertModel(config, add_pooling_layer=False) # 预训练BERTclassifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)self.dropout = nn.Dropout(classifier_dropout)self.classifier = nn.Linear(config.hidden_size, config.num_labels) # classification任务头,加在预训练BERT之上# Initialize weights and apply final processingself.post_init()@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))@add_code_sample_docstrings(checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,output_type=TokenClassifierOutput,config_class=_CONFIG_FOR_DOC,expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,)def forward(self,input_ids: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,token_type_ids: Optional[torch.Tensor] = None,position_ids: Optional[torch.Tensor] = None,head_mask: Optional[torch.Tensor] = None,inputs_embeds: Optional[torch.Tensor] = None,labels: Optional[torch.Tensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:r"""labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`."""return_dict = return_dict if return_dict is not None else self.config.use_return_dictoutputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,) # 预训练BERT的计算,得到输入序列经BERT计算的向量序列sequence_output = outputs[0]sequence_output = self.dropout(sequence_output)logits = self.classifier(sequence_output) # 再经过最后的任务头classificationloss = Noneif labels is not None:loss_fct = CrossEntropyLoss()loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))if not return_dict:output = (logits,) + outputs[2:]return ((loss,) + output) if loss is not None else outputreturn TokenClassifierOutput(loss=loss,logits=logits,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)
使用示例:
from transformers import AutoTokenizer, BertForTokenClassification
import torchtokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")inputs = tokenizer("HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
)with torch.no_grad():logits = model(**inputs).logits # 想要得到分类后的权重,获取的是输出的logits对象。predicted_token_class_ids = logits.argmax(-1)# Note that tokens are classified rather then input words which means that
# there might be more predicted token classes than words.
# Multiple token classes might account for the same word
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
# predicted_tokens_classes = ['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']
2.BertForSequenceClassification
一个用于整个句子或段落级别的分类的模型,可用于情感分析、文本分类等。对于给定的输入,模型将为整个序列产生一个分类标签。
输出的维度是 [batch_size, num_labels],其中 num_labels 是可能的分类数量。
class transformers.BertForSequenceClassification(config)
继承父类:BertPreTrainedModel、torch.nn.Module
参数:config (BertConfig)——包含模型所有参数的模型配置类。
forward方法:与BertForTokenClassification相同。
与BertForTokenClassification的差异:
- BertForSequenceClassification 在 BERT 的编码器输出上增加了一个**全连接层(通常连接到 [CLS] 标记的输出)**来进行分类。
- BertForTokenClassification 不需要额外的全连接层,而是直接使用 BERT输出的每个标记的表示,并可能有一个线性层来将其映射到标签空间。
使用示例:
import torch
from transformers import AutoTokenizer, BertForSequenceClassificationtokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity")
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity")inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")with torch.no_grad():logits = model(**inputs).logitspredicted_class_id = logits.argmax().item()
predicted_class_label = model.config.id2label[predicted_class_id]
# predicted_class_label = LABEL_1
相关文章:
【HuggingFace文档学习】Bert的token分类与句分类
BERT特性: BERT的嵌入是位置绝对(position absolute)的。BERT擅长于预测掩码token和NLU,但是不擅长下一文本生成。 1.BertForTokenClassification 一个用于token级分类的模型,可用于命名实体识别(NER)、部分语音标记…...
354 俄罗斯套娃信封问题(贪心+二分)
题目 链接 给你一个二维整数数组 envelopes ,其中 envelopes[i] [wi, hi] ,表示第 i 个信封的宽度和高度。 当另一个信封的宽度和高度都比这个信封大的时候,这个信封就可以放进另一个信封里,如同俄罗斯套娃一样。 请计算 最多…...
Vue页面结构
Vue页面结构 App.vue <!--html标签--> <template><div><h1>饿了么?</h1></div><HelloWorld msg"Vite Vue" /> </template> <!--js代码 vue3的语法--> <script setup> import HelloWorld f…...
【广州华锐互动】利用VR开展高压电缆运维实训,提供更加真实、安全的学习环境
VR高压电缆维护实训系统由广州华锐互动开发,应用于多家供电企业的员工培训中,该系统突破了传统培训的限制,为学员提供了更加真实、安全的学习环境,提高了培训效率和效果。 传统电缆井下运维培训通常是在实际井下环境中进行&#x…...
git的介绍和安装、常用命令、忽略文件、分支
git介绍和安装 首页功能写完了 ⇢ \dashrightarrow ⇢ 正常应该提交到版本仓库 ⇢ \dashrightarrow ⇢ 大家都能看到这个 ⇢ \dashrightarrow ⇢ 运维应该把现在这个项目部署到测试环境中 ⇢ \dashrightarrow ⇢ 测试开始测试 ⇢ \dashrightarrow ⇢ 客户可以看到目前做的…...
DNS(二)
实现 Internet DNS 架构 架构图 实验环境 关闭SELinux、Firewalld。时间保持一致 主机名IP角色client192.168.28.146DNS客户端,DNS地址为192.168.28.145localdns192.168.28.145本地DNS服务器(只缓存)forward192.168.28.144转发目标DNS服务…...
win 10怎么录屏?教你轻松捕捉屏幕活动
在当今科技快速发展的时代,录屏已成为信息分享、教学、游戏直播等方面的重要工具。无论是为了制作教程、分享游戏过程还是保存重要信息,录屏功能都发挥着举足轻重的作用。可是很多人不知道win 10怎么录屏,本文将详细介绍win10的三种常用录屏方…...
IP 协议的相关特性(部分)
IP 协议的报文格式 4位版本号: 用来表示IP协议的版本,现有的IP协议只有两个版本,IPv4,IPv6。 4位首部长度: 设定和TCP的首部长度一样 8位服务类型: (真正只有4位才有效果)…...
Java设计模式之代表模式
代表模式(Mediator Pattern)是一种行为型设计模式,它通过封装一组对象之间的交互方式,使得这些对象之间的通信变得松散耦合,从而降低了对象之间的直接依赖关系。代表模式通过引入一个中介者(Mediator&#…...
MySQL 查询 唯一约束 对应的字段,列名称合并
MySQL 查询 唯一约束 对应的字段,列名称合并 SELECT F.DbName,F.TableName,F.ConstraintName,GROUP_CONCAT(ColumnName) ColumnName FROM ( SELECT t1.TABLE_SCHEMA DbName, t1.TABLE_NAME TableName,t1.CONSTRAINT_NAME ConstraintName,t2.COLUMN_NAME ColumnNam…...
JDBC-day05(DAO及相关实现类)
七:DAO及相关实现类 1. DAO介绍 DAO:全称Data Access Object,是数据访问对象.在java服务器开发的三层架构中分成控制层(Controller),表示层(Service),数据访问层(Dao),数据访问层专门负责跟数据库进行数据交互.,包括了对数据的CRUDÿ…...
华为汪涛:5.5G时代UBB目标网,跃升数字生产力
[阿联酋,迪拜,2023年10月12日] 在2023全球超宽带高峰论坛上,华为常务董事、ICT基础设施业务管理委员会主任汪涛发表了“5.5G时代UBB目标网,跃升数字生产力”的主题发言,分享了超宽带产业的最新思考与实践,探…...
docker部署多个node-red操作过程
docker部署多个node-red操作过程 一、docker安装教程二、docker安装node-red2.1 在线安装node-red镜像2.1.1 拉取镜像2.1.2 创建目录并分配权限 2.2 离线安装node-red镜像 三、 docker操作node-red3.1 部署node-red3.2 查看\关闭\删除容器 四、Docker删除Redis镜像五、离线安装…...
王兴投资5G小基站
边缘计算社区获悉,近期深圳佳贤通信正式完成数亿元股权融资,本轮融资由美团龙珠领投。本轮融资资金主要用于技术研发、市场拓展等,将进一步巩固和扩大佳贤通信在5G小基站领域的技术及市场领先地位。 01 佳贤通信是什么样的公司? 深…...
【SA8295P 源码分析 (一)】54 - /ifs/bin/startupmgr 程序工作流程分析 及 script.c 介绍
【SA8295P 源码分析】54 - /ifs/bin/startupmgr 程序工作流程分析 及 script.c 介绍 一、startupmgr 可执行程序工作解析1. startupmgr\src\script.c 入口 main 函数:调用 init_loader_and_launcher 解析 scripts 数组二、ifsloader镜像加载流程分析:init_loader_and_launche…...
git 使用
参考 https://git-scm.com/book/zh/v2/Git-%E5%9F%BA%E7%A1%80-%E8%8E%B7%E5%8F%96-Git-%E4%BB%93%E5%BA%93 文件的状态变化周期 文章目录 git 基础检查当前文件状态、查看已暂存和未暂存的修改暂存前后的变化跟踪新文件提交更新移除文件移动文件、重命名操作查看提交历史撤消…...
MFC扩展库BCGControlBar Pro v33.6新版亮点 - 图形管理器改造升级
BCGControlBar库拥有500多个经过全面设计、测试和充分记录的MFC扩展类。 我们的组件可以轻松地集成到您的应用程序中,并为您节省数百个开发和调试时间。 BCGControlBar专业版 v33.6已正式发布了,此版本包含了对图表组件的改进、带隐藏标签的单类功能区栏…...
云上攻防-云原生篇KubernetesK8s安全APIKubelet未授权访问容器执行
文章目录 K8S集群架构解释K8S集群攻击点-重点API Server未授权访问&kubelet未授权访问复现k8s集群环境搭建1、攻击8080端口:API Server未授权访问2、攻击6443端口:API Server未授权访问3、攻击10250端口:kubelet未授权访问 K8S集群架构解…...
Django 访问静态文件的APP staticfiles
Django 框架默认带的 APP: django.contrib.staticfiles Django文档中也写明了:如何管理静态文件(如图片、JavaScript、CSS) |姜戈 文档 |姜戈 (djangoproject.com)https://docs.djangoproject.com/zh-hans/4.2/howto/static-file…...
Airbnb 迁移 SwiftUI 实践
从 2022 年开始,Airbnb 的 iOS 团队就认为 SwiftUI 已经足够成熟,可以在他们的官方应用中使用它。但 Airbnb 的工程师 Bryn Bodayle 表示,这需要一个谨慎的转换过程。 Airbnb 的工程师认为,SwiftUI 的主要优势是它的灵活性和可组合性、声明性、简洁性和惯用性,他们希望这…...
【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15
缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下: struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...
iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘
美国西海岸的夏天,再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至,这不仅是开发者的盛宴,更是全球数亿苹果用户翘首以盼的科技春晚。今年,苹果依旧为我们带来了全家桶式的系统更新,包括 iOS 26、iPadOS 26…...
TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案
一、TRS收益互换的本质与业务逻辑 (一)概念解析 TRS(Total Return Swap)收益互换是一种金融衍生工具,指交易双方约定在未来一定期限内,基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...
C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...
【网络安全】开源系统getshell漏洞挖掘
审计过程: 在入口文件admin/index.php中: 用户可以通过m,c,a等参数控制加载的文件和方法,在app/system/entrance.php中存在重点代码: 当M_TYPE system并且M_MODULE include时,会设置常量PATH_OWN_FILE为PATH_APP.M_T…...
C# 表达式和运算符(求值顺序)
求值顺序 表达式可以由许多嵌套的子表达式构成。子表达式的求值顺序可以使表达式的最终值发生 变化。 例如,已知表达式3*52,依照子表达式的求值顺序,有两种可能的结果,如图9-3所示。 如果乘法先执行,结果是17。如果5…...
为什么要创建 Vue 实例
核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...
系统掌握PyTorch:图解张量、Autograd、DataLoader、nn.Module与实战模型
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文通过代码驱动的方式,系统讲解PyTorch核心概念和实战技巧,涵盖张量操作、自动微分、数据加载、模型构建和训练全流程&#…...
十九、【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建
【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建 前言准备工作第一部分:回顾 Django 内置的 `User` 模型第二部分:设计并创建 `Role` 和 `UserProfile` 模型第三部分:创建 Serializers第四部分:创建 ViewSets第五部分:注册 API 路由第六部分:后端初步测…...
mac:大模型系列测试
0 MAC 前几天经过学生优惠以及国补17K入手了mac studio,然后这两天亲自测试其模型行运用能力如何,是否支持微调、推理速度等能力。下面进入正文。 1 mac 与 unsloth 按照下面的进行安装以及测试,是可以跑通文章里面的代码。训练速度也是很快的。 注意…...
