当前位置: 首页 > news >正文

[oneAPI] 使用Bert进行中文文本分类

[oneAPI] 使用Bert进行中文文本分类

  • Intel® Optimization for PyTorch
  • 基于BERT的文本分类模型
    • 数据预处理
    • 数据集
      • 定义tokenize
      • 建立词表
      • 转换为Token序列
      • padding处理与mask
    • 模型
  • 结果
  • OneAPI
  • 参考资料

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

Intel® Optimization for PyTorch

在本次实验中,我们利用PyTorch和Intel® Optimization for PyTorch的强大功能,对PyTorch进行了精心的优化和扩展。这些优化举措极大地增强了PyTorch在各种任务中的性能,尤其是在英特尔硬件上的表现更加突出。通过这些优化策略,我们的模型在训练和推断过程中变得更加敏捷和高效,显著地减少了计算时间,提高了整体效能。我们通过深度融合硬件和软件的精巧设计,成功地释放了硬件潜力,使得模型的训练和应用变得更加快速和高效。这一系列优化举措为人工智能应用开辟了新的前景,带来了全新的可能性。
在这里插入图片描述

基于BERT的文本分类模型

基于BERT的文本分类模型就是在原始的BERT模型后再加上一个分类层即可,同时,对于分类层的输入(也就是原始BERT的输出),默认情况下取BERT输出结果中[CLS]位置对于的向量即可,当然也可以修改为其它方式,例如所有位置向量的均值等。因此,对于基于BERT的文本分类模型来说其输入就是BERT的输入,输出则是每个类别对应的logits值。

数据预处理

在构建数据集之前,我们首先需要知道的是模型到底应该接收什么样的输入,然后才能构建出正确的数据形式。在上面我们说到,基于BERT的文本分类模型的输入就等价于BERT模型的输入,同时BERT模型的输入如图1所示:
在这里插入图片描述

数据集

在这里,我们使用到的数据集是今日头条开放的一个新闻分类数据集(https://github.com/aceimnorstuvwxz/toutiao-text-classfication-dataset),一共包含有382688条数据,15个类别,经过处理后数据集格式为:

千万不要乱申请网贷,否则后果很严重_!_4
10年前的今天,纪念5.12汶川大地震10周年_!_11
怎么看待杨毅在一NBA直播比赛中说詹姆斯的球场统治力已经超过乔丹、伯德和科比?_!_3
戴安娜王妃的车祸有什么谜团?_!_2

其中_!_左边为新闻标题,也就是后面需要用到的分类文本,右边为类别标签。

定义tokenize

将输入进来的文本序列tokenize到字符级别。对于中文语料来说就是将每个字和标点符号都给切分开。在这里,我们可以借用transformers包中的BertTokenizer方法来完成,如下所示:

1 if __name__ == '__main__':
2     model_config = ModelConfig()
3     tokenizer = BertTokenizer.from_pretrained(model_config.pretrained_model_dir).tokenize
4     print(tokenizer("青山不改,绿水长流,我们月来客栈见!"))
5     print(tokenizer("10年前的今天,纪念5.12汶川大地震10周年"))
6 
7 # ['青', '山', '不', '改', ',', '绿', '水', '长', '流', ',', '我', '们', '月', '来', '客', '栈', '见', '!']
8 # ['10', '年', '前', '的', '今', '天', ',', '纪', '念', '5', '.', '12', '汶', '川', '大', '地', '震', '10', '周', '年']

建立词表

将vocab.txt中的内容读取进来形成一个词表即可

1 class Vocab:2     UNK = '[UNK]'3     def __init__(self, vocab_path):4         self.stoi = {}5         self.itos = []6         with open(vocab_path, 'r', encoding='utf-8') as f:7             for i, word in enumerate(f):8                 w = word.strip('\n')9                 self.stoi[w] = i
10                 self.itos.append(w)
11 
12     def __getitem__(self, token):
13         return self.stoi.get(token, self.stoi.get(Vocab.UNK))
14 
15     def __len__(self):
16         return len(self.itos)

转换为Token序列

在得到构建的字典后,便可以通过如下函数来将训练集、验证集和测试集转换成Token序列:

 1 def data_process(self, filepath):2     raw_iter = open(filepath, encoding="utf8").readlines()3     data = []4     max_len = 05     for raw in tqdm(raw_iter, ncols=80):6         line = raw.rstrip("\n").split(self.split_sep)7         s, l = line[0], line[1]8         tmp = [self.CLS_IDX] + [self.vocab[token] for token in self.tokenizer(s)]9         if len(tmp) > self.max_position_embeddings - 1:
10             tmp = tmp[:self.max_position_embeddings - 1]  # BERT预训练模型只取前512个字符
11         tmp += [self.SEP_IDX]
12         tensor_ = torch.tensor(tmp, dtype=torch.long)
13         l = torch.tensor(int(l), dtype=torch.long)
14         max_len = max(max_len, tensor_.size(0))
15         data.append((tensor_, l))
16     return data, max_len

padding处理与mask

对原始文本序列tokenize转换为Token ID后还需要对其进行padding处理。对于这一处理过程可以通过如下代码来完成:

 1 def pad_sequence(sequences, batch_first=False, max_len=None, padding_value=0):2     if max_len is None:3         max_len = max([s.size(0) for s in sequences])4     out_tensors = []5     for tensor in sequences:6         if tensor.size(0) < max_len:7             tensor = torch.cat([tensor, torch.tensor(8               [padding_value] * (max_len - tensor.size(0)))], dim=0)9         else:
10             tensor = tensor[:max_len]
11         out_tensors.append(tensor)
12     out_tensors = torch.stack(out_tensors, dim=1)
13     if batch_first:
14         return out_tensors.transpose(0, 1)
15     return out_tensors

模型

class BertModel(nn.Module):""""""def __init__(self, config):super().__init__()self.bert_embeddings = BertEmbeddings(config)self.bert_encoder = BertEncoder(config)self.bert_pooler = BertPooler(config)self.config = configself._reset_parameters()def forward(self,input_ids=None,attention_mask=None,token_type_ids=None,position_ids=None):"""***** 一定要注意,attention_mask中,被mask的Token用1(True)表示,没有mask的用0(false)表示这一点一定一定要注意:param input_ids:  [src_len, batch_size]:param attention_mask: [batch_size, src_len] mask掉padding部分的内容:param token_type_ids: [src_len, batch_size]  # 如果输入模型的只有一个序列,那么这个参数也不用传值:param position_ids: [1,src_len] # 在实际建模时这个参数其实可以不用传值:return:"""embedding_output = self.bert_embeddings(input_ids=input_ids,position_ids=position_ids,token_type_ids=token_type_ids)# embedding_output: [src_len, batch_size, hidden_size]all_encoder_outputs = self.bert_encoder(embedding_output,attention_mask=attention_mask)# all_encoder_outputs 为一个包含有num_hidden_layers个层的输出sequence_output = all_encoder_outputs[-1]  # 取最后一层# sequence_output: [src_len, batch_size, hidden_size]pooled_output = self.bert_pooler(sequence_output)# 默认是最后一层的first token 即[cls]位置经dense + tanh 后的结果# pooled_output: [batch_size, hidden_size]return pooled_output, all_encoder_outputsdef _reset_parameters(self):r"""Initiate parameters in the transformer model.""""""初始化"""for p in self.parameters():if p.dim() > 1:normal_(p, mean=0.0, std=self.config.initializer_range)@classmethoddef from_pretrained(cls, config, pretrained_model_dir=None):model = cls(config)  # 初始化模型,cls为未实例化的对象,即一个未实例化的BertModel对象pretrained_model_path = os.path.join(pretrained_model_dir, "pytorch_model.bin")if not os.path.exists(pretrained_model_path):raise ValueError(f"<路径:{pretrained_model_path} 中的模型不存在,请仔细检查!>\n"f"中文模型下载地址:https://huggingface.co/bert-base-chinese/tree/main\n"f"英文模型下载地址:https://huggingface.co/bert-base-uncased/tree/main\n")loaded_paras = torch.load(pretrained_model_path)state_dict = deepcopy(model.state_dict())loaded_paras_names = list(loaded_paras.keys())[:-8]model_paras_names = list(state_dict.keys())[1:]if 'use_torch_multi_head' in config.__dict__ and config.use_torch_multi_head:torch_paras = format_paras_for_torch(loaded_paras_names, loaded_paras)for i in range(len(model_paras_names)):logging.debug(f"## 成功赋值参数:{model_paras_names[i]},形状为: {torch_paras[i].size()}")if "position_embeddings" in model_paras_names[i]:# 这部分代码用来消除预训练模型只能输入小于512个字符的限制if config.max_position_embeddings > 512:new_embedding = replace_512_position(state_dict[model_paras_names[i]],loaded_paras[loaded_paras_names[i]])state_dict[model_paras_names[i]] = new_embeddingcontinuestate_dict[model_paras_names[i]] = torch_paras[i]logging.info(f"## 注意,正在使用torch框架中的MultiHeadAttention实现")else:for i in range(len(loaded_paras_names)):logging.debug(f"## 成功将参数:{loaded_paras_names[i]}赋值给{model_paras_names[i]},"f"参数形状为:{state_dict[model_paras_names[i]].size()}")if "position_embeddings" in model_paras_names[i]:# 这部分代码用来消除预训练模型只能输入小于512个字符的限制if config.max_position_embeddings > 512:new_embedding = replace_512_position(state_dict[model_paras_names[i]],loaded_paras[loaded_paras_names[i]])state_dict[model_paras_names[i]] = new_embeddingcontinuestate_dict[model_paras_names[i]] = loaded_paras[loaded_paras_names[i]]logging.info(f"## 注意,正在使用本地MyTransformer中的MyMultiHeadAttention实现,"f"如需使用torch框架中的MultiHeadAttention模块可通过config.__dict__['use_torch_multi_head'] = True实现")model.load_state_dict(state_dict)return model

结果

在这里插入图片描述

OneAPI

import intel_extension_for_pytorch as ipexmodel = model.to(config.device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

参考资料

基于BERT预训练模型的中文文本分类任务: https://www.ylkz.life/deeplearning/p10979382/

相关文章:

[oneAPI] 使用Bert进行中文文本分类

[oneAPI] 使用Bert进行中文文本分类 Intel Optimization for PyTorch基于BERT的文本分类模型数据预处理数据集定义tokenize建立词表转换为Token序列padding处理与mask 模型 结果OneAPI参考资料 比赛&#xff1a;https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517…...

【数据治理】什么是数据库归档

文章目录 前言什么是数据归档 前言 如果您的日常工作中需要对数据库进行管理&#xff0c;那您肯定已经或即将遭遇这样的困惑&#xff1a;随着业务的蓬勃发展&#xff0c;数据库文件的大小逐渐增大&#xff0c;您需要为在线业务提供越来越大的高性能磁盘容量&#xff0c;但数据…...

AI代码补全 案例 - 阿里云智能编码插件Cosy

文章目录 Cosy简介Cosy安装Marketplace安装【推荐】离线安装安装效果Cosy功能体验代码智能补全代码示例搜索API搜索自然语言搜索控制台异常搜索优质文档搜索Cosy体验有感参考Cosy简介 阿里云智能编码插件(Alibaba Cloud AI Coding Assistant)是一款AI编程助手,提供代码智能…...

【Linux】进程信号篇Ⅰ:信号的产生(signal、kill、raise、abort、alarm)、信号的保存(core dump)

文章目录 一、 signal 函数&#xff1a;用户自定义捕捉信号二、信号的产生1. 通过中断按键产生信号2. 调用系统函数向进程发信号2.1 kill 函数&#xff1a;给任意进程发送任意信号2.2 raise 函数&#xff1a;给调用进程发送任意信号2.3 abort 函数&#xff1a;给调用进程发送 6…...

漏洞指北-VulFocus靶场专栏-中级03

漏洞指北-VulFocus靶场专栏-初级03 中级009 &#x1f338;gxlcms-cve_2018_14685&#x1f338;step1&#xff1a;安装系统 密码rootstep2 进入后台页面 账号密码&#xff1a;admin amdin888step3 查看详细 有phpinfo() 中级010 &#x1f338;dedecms-cnvd_2018_01221&#x1f3…...

【leetcode 力扣刷题】数组交集(数组、set、map都可实现哈希表)

数组交集 349. 两个数组的交集排序&#xff0b;双指针数组实现哈希表unordered_setunordered_map 350. 两个数组的交集Ⅱ排序 双指针数组实现哈希表unordered_map 349. 两个数组的交集 题目链接&#xff1a;349. 两个数组的交集 题目内容如下&#xff0c;理解题意&#xff1a…...

MySQL 8.0.31 登录提示caching_sha2_password问题解决方法

MySQL 8.0.31 登录提示caching_sha2_password问题解决方法 MySQL 8.0.31 使用了 caching_sha2_password 作为默认的身份验证插件&#xff0c;这可能导致一些旧的客户端和库无法连接到服务器。以下是一些解决此类问题的常见步骤和建议&#xff1a; 确保MySQL服务正在运行&#…...

[Google] DeepMind Gemini: 新一代LLM结合AlphaGo技术将力压 GPT-4|未来 AI 领域的新巨头

2016年&#xff0c;Google DeepMind 人工智能实验室孕育出的 AlphaGo 人工智能程序在围棋赛场上一举击败冠军选手&#xff0c;成为历史的见证者。如今&#xff0c;DeepMind 联合创始人兼首席执行官 Demis Hassabis 表示&#xff0c;他们的工程师正借鉴 AlphaGo 的技术研发一款名…...

Maven高级

目录 一、分模块开发与设计 1. 分模块开发的意义 2. 分模块开发&#xff08;模块拆分&#xff09; &#xff08;1&#xff09;创建Maven模块 &#xff08;2&#xff09;书写模块代码 &#xff08;3&#xff09;通过maven指令安装模块到本地仓库&#xff08;install指令&…...

【视觉SLAM入门】5.2. 2D-3D PNP 3D-3D ICP BA非线性优化方法 数学方法SVD DLT

"养气之学&#xff0c;戒之躁急" 1. 3D-2D PNP1.1 代数法1.1.1 DLT(直接线性变换法)1.1.2. P3P 1.2 优化法BA (Bundle Adjustment)法 2. 3D-3D ICP2.1 代数法2.1.1 SVD方法 2.2 优化(BA)法2.2.2 非线性优化方法 前置事项&#xff1a; 1. 3D-2D PNP 该问题描述为&am…...

人脸老化预测(Python)

本次项目的文件 main.py主程序如下 导入必要的库和模块&#xff1a; 导入 TensorFlow 库以及自定义的 FaceAging 模块。导入操作系统库和参数解析库。 定义 str2bool 函数&#xff1a; 自定义函数用于将字符串转换为布尔值。 创建命令行参数解析器&#xff1a; 使用 argparse.A…...

AWS SDK 3.x for .NET Framework 4.0 可行性测试

前言 为了应对日益增长的网络安全挑战, 越来越多的互联网厂商已经陆续开始或者已经彻底停止了对 SSL 3 / TLS 1.0 / TLS1.1 等上古加密算法的支持. 而对于一些同样拥有悠久历史的和 AWS 服务相关联的应用程序, 是否可以通过仅更新 SDK 版本的方式来适应新的环境. 本文将以 Win…...

两个list。如何使用流的写法将一个list中的对象中的某些属性根据另外一个list中的属性值赋值进去?

两个list。如何使用流的写法将一个list中的对象中的某些属性根据另外一个list中的属性值赋值进去? 你可以使用Java 8以上版本中的流(Stream)和Lambda表达式来实现这个需求。假设有两个List&#xff0c;一个是sourceList&#xff0c;包含要赋值属性的对象&#xff1b;另一个是…...

美国陆军希望大数据技术能够帮助保护其云安全

随着陆军采用更大型的云服务&#xff0c;一位高级官员警告说&#xff0c;一些在私营部门有效的快速软件开发技巧和简单解决方案&#xff08;例如开放代码库&#xff09;如果没有额外的安全性&#xff0c;将无法为军队工作。 我们知道现代软件开发确实依赖于第三方库&#xff…...

vue 文字跑马灯

<template><div class"marquee-container"><div class"marquee-content"><div>{{ marqueeText }}</div><div>{{ marqueeText }}</div> <!-- 复制一份文本&#xff0c;用于无缝衔接 --></div></d…...

开源ChatGPT系统源码 采用NUXT3+Laravel9后端开发 前后端分离版本

开源ChatGPT系统源码 采用NUXT3Laravel9后端开发 前后端分离版本 ChatGPT是一种基于AI的聊天机器人技术&#xff0c;它可以帮助用户与聊天机器人进行自然语言交流&#xff0c;以解决用户的问题或满足用户的需求。ChatGPT的核心技术是使用自然语言处理&#xff08;NLP&#xff…...

【LeetCode|数据结构】剑指 Offer 33. 二叉搜索树的后序遍历序列

题目链接 剑指 Offer 33. 二叉搜索树的后序遍历序列 标签 二叉搜索树、后序遍历 步骤 二叉搜索树的左子树的节点值 ≤ \le ≤根节点值 ≤ \le ≤右子树的节点值&#xff1b;对于后序遍历序列最后一个元素的值为根节点的值&#xff1b; 由上面的两个性质可以得出&#xff…...

自定义协程

难点 自己写了一遍协程&#xff0c;困难的地方在于unity中的执行顺序突然发现unity里面可以 yield return 的其实有很多 WaitForSeconds WaitForSecondsRealtime WaitForEndOfFrame WaitForFixedUpdate WaitUntil WaitWhile IEnumerator&#xff08;可以用于协程嵌套&#xf…...

【Atcoder】 [ABC240Ex] Sequence of Substrings

题目链接 Atcoder方向 Luogu方向 题目解法 先考虑一个性质&#xff0c;选出的子串长度不会超过 2 n \sqrt {2n} 2n ​ 考虑最劣的选法是选出长度为 1 , 2 , 3 , . . . 1,2,3,... 1,2,3,... 的子串&#xff08;如果后一个选出的串比前一个子串长度大超过1&#xff0c;那么后…...

真机二阶段之堆叠技术

堆叠技术 --- 可以将多台真实的物理设备逻辑上抽象成一台 思科 -- VPC 华为 -- iStack和CSS 华三 -- IRF 锐捷 -- VSU iStack和CSS的区别&#xff1a; CSS --- 集群 --- 它仅支持将两台支持集群的交换机逻辑上整合成一台设备。 iStack --- 堆叠 --- 可以将多台支持堆叠的交换…...

java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别

UnsatisfiedLinkError 在对接硬件设备中&#xff0c;我们会遇到使用 java 调用 dll文件 的情况&#xff0c;此时大概率出现UnsatisfiedLinkError链接错误&#xff0c;原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用&#xff0c;结果 dll 未实现 JNI 协…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下&#xff0c;江苏艾立泰以一场跨国资源接力的创新实践&#xff0c;重新定义了绿色供应链的边界。 跨国回收网络&#xff1a;废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点&#xff0c;将海外废弃包装箱通过标准…...

[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...

ardupilot 开发环境eclipse 中import 缺少C++

目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...

爬虫基础学习day2

# 爬虫设计领域 工商&#xff1a;企查查、天眼查短视频&#xff1a;抖音、快手、西瓜 ---> 飞瓜电商&#xff1a;京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空&#xff1a;抓取所有航空公司价格 ---> 去哪儿自媒体&#xff1a;采集自媒体数据进…...

AI,如何重构理解、匹配与决策?

AI 时代&#xff0c;我们如何理解消费&#xff1f; 作者&#xff5c;王彬 封面&#xff5c;Unplash 人们通过信息理解世界。 曾几何时&#xff0c;PC 与移动互联网重塑了人们的购物路径&#xff1a;信息变得唾手可得&#xff0c;商品决策变得高度依赖内容。 但 AI 时代的来…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合

在汽车智能化的汹涌浪潮中&#xff0c;车辆不再仅仅是传统的交通工具&#xff0c;而是逐步演变为高度智能的移动终端。这一转变的核心支撑&#xff0c;来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒&#xff08;T-Box&#xff09;方案&#xff1a;NXP S32K146 与…...

从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障

关键领域软件测试的"安全密码"&#xff1a;Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天&#xff0c;软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力&#xff0c;从金融交易到交通管控&#xff0c;这些关乎国计民生的关键领域…...

tauri项目,如何在rust端读取电脑环境变量

如果想在前端通过调用来获取环境变量的值&#xff0c;可以通过标准的依赖&#xff1a; std::env::var(name).ok() 想在前端通过调用来获取&#xff0c;可以写一个command函数&#xff1a; #[tauri::command] pub fn get_env_var(name: String) -> Result<String, Stri…...

Python 高效图像帧提取与视频编码:实战指南

Python 高效图像帧提取与视频编码:实战指南 在音视频处理领域,图像帧提取与视频编码是基础但极具挑战性的任务。Python 结合强大的第三方库(如 OpenCV、FFmpeg、PyAV),可以高效处理视频流,实现快速帧提取、压缩编码等关键功能。本文将深入介绍如何优化这些流程,提高处理…...