Bert中文文本分类
这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main
我使用的训练环境是
pip install torch==2.0.0;
pip install transformers==4.30.2;
pip install gensim==4.3.3;
pip install huggingface-hub==0.15.1;
pip install modelscope==1.20.1;
一、准备训练数据
1.1 准备中文文本分类任务的训练数据
这里Demo数据如下:
各银行信用卡挂失费迥异 北京银行收费最高 0 莫泰酒店流拍 大摩叫价或降至6亿美元 4 乌兹别克斯坦议会立法院主席获连任 6 德媒披露鲁能引援关键人物 是他力荐德甲亚洲强人 7 辉立证券给予广汽集团持有评级 2 图文-业余希望赛海南站第二轮 球场的菠萝蜜 7 陆毅鲍蕾:近乎完美的爱情(组图)(2) 9 7000亿美元救市方案将成期市毒药 0 保诚启动210亿美元配股交易以融资收购AIG部门 2
分类class类别文件:
finance realty stocks education science society politics sports game entertainment
1.2 数据读取和截断,使满足BERT模型输入
读取训练数据,对文本进行处理,如截取过长的文本、补齐较短的文本,加上起始标示、对文本进行编码、添加掩码、转为tensor等操作。
import os
from config import parsers
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torchfrom transformers import AutoTokenizer, AutoModelForMaskedLMdef read_data(file):# 读取文件all_data = open(file, "r", encoding="utf-8").read().split("\n")# 得到所有文本、所有标签、句子的最大长度texts, labels, max_length = [], [], []for data in all_data:if data:text, label = data.split("\t")max_length.append(len(text))texts.append(text)labels.append(label)# 根据不同的数据集返回不同的内容if os.path.split(file)[1] == "train.txt":max_len = max(max_length)return texts, labels, max_lenreturn texts, labels,class MyDataset(Dataset):def __init__(self, texts, labels, max_length):self.all_text = textsself.all_label = labelsself.max_len = max_lengthself.tokenizer = BertTokenizer.from_pretrained(parsers().bert_pred)
# self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")def __getitem__(self, index):# 取出一条数据并截断长度text = self.all_text[index][:self.max_len]label = self.all_label[index]# 分词text_id = self.tokenizer.tokenize(text)# 加上起始标志text_id = ["[CLS]"] + text_id# 编码token_id = self.tokenizer.convert_tokens_to_ids(text_id)# 掩码 -》mask = [1] * len(token_id) + [0] * (self.max_len + 2 - len(token_id))# 编码后 -》长度一致token_ids = token_id + [0] * (self.max_len + 2 - len(token_id))# str -》 intlabel = int(label)# 转化成tensortoken_ids = torch.tensor(token_ids)mask = torch.tensor(mask)label = torch.tensor(label)return (token_ids, mask), labeldef __len__(self):# 得到文本的长度return len(self.all_text)
将文本处理后,就可以使用torch.utils.data中自带的DataLoader模块来加载训练数据了。
二、微调BERT模型
我们是微调BERT模型,需要获取BERT最后一个隐藏层的输出作为输入到下一个全连接层。
至于选择BERT模型的哪个输出作为linear层的输入,可以通过实验尝试,或者遵循常理。
pooler_output:这是通过将最后一层的隐藏状态的第一个token(通常是[CLS] token)通过一个线性层和激活函数得到的输出,常用于分类任务。
last_hidden_state:这是模型所有层的最后一个隐藏状态的输出,包含了整个序列的上下文信息,适用于序列级别的任务。
简单调用下BERT模型,打印出来最后一层看下:
import torch
import time
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLMdef process_text(text, bert_pred):tokenizer = BertTokenizer.from_pretrained(bert_pred)token_id = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokenizer.tokenize(text))mask = [1] * len(token_id) + [0] * (38 + 2 - len(token_id))token_ids = token_id + [0] * (38 + 2 - len(token_id))token_ids = torch.tensor(token_ids).unsqueeze(0)mask = torch.tensor(mask).unsqueeze(0)x = torch.stack([token_ids, mask])return xdevice = "cpu"
bert = BertModel.from_pretrained('./bert-base-chinese/')
texts = ["沈腾和马丽的电影《独行月球》挺好看"]
for text in texts:x = process_text(text, './bert-base-chinese/')input_ids, attention_mask = x[0].to(device), x[1].to(device)hidden_out = bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) print(hidden_out)
输出结果:


2.1 文本分类任务,选择使用pooler_output作为线性层的输入。
import torch.nn as nn
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from config import parsers
import torchclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.args = parsers()self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.bert = BertModel.from_pretrained(self.args.bert_pred) # bert 模型进行微调for param in self.bert.parameters():param.requires_grad = True# 一个全连接层self.linear = nn.Linear(self.args.num_filters, self.args.class_num)def forward(self, x):input_ids, attention_mask = x[0].to(self.device), x[1].to(self.device)hidden_out = self.bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) # 是否输出所有encoder层的结果# shape (batch_size, hidden_size) pooler_output --> hidden_out[0]pred = self.linear(hidden_out.pooler_output)# 返回预测结果return pred
2.2 优化器使用Adam、损失函数使用交叉熵损失函数
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MyModel().to(device)
opt = AdamW(model.parameters(), lr=args.learn_rate)
loss_fn = nn.CrossEntropyLoss()
三、训练模型
3.1 参数配置
def parsers():parser = argparse.ArgumentParser(description="Bert model of argparse")parser.add_argument("tx_date",nargs='?') #可选输入参数,计算日期parser.add_argument("--train_file", type=str, default=os.path.join("./data_all", "train.txt"))parser.add_argument("--dev_file", type=str, default=os.path.join("./data_all", "dev.txt"))parser.add_argument("--test_file", type=str, default=os.path.join("./data_all", "test.txt"))parser.add_argument("--classification", type=str, default=os.path.join("./data_all", "class.txt"))parser.add_argument("--bert_pred", type=str, default="./bert-base-chinese")parser.add_argument("--class_num", type=int, default=12)parser.add_argument("--max_len", type=int, default=38)parser.add_argument("--batch_size", type=int, default=32)parser.add_argument("--epochs", type=int, default=10)parser.add_argument("--learn_rate", type=float, default=1e-5)parser.add_argument("--num_filters", type=int, default=768)parser.add_argument("--save_model_best", type=str, default=os.path.join("model", "all_best_model.pth"))parser.add_argument("--save_model_last", type=str, default=os.path.join("model", "all_last_model.pth"))args = parser.parse_args()return args
3.2 模型训练
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from sklearn.metrics import accuracy_score
import timeif __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"print("device:", device)train_text, train_label, max_len = read_data(args.train_file)dev_text, dev_label = read_data(args.dev_file)args.max_len = max_lentrain_dataset = MyDataset(train_text, train_label, args.max_len)train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)dev_dataset = MyDataset(dev_text, dev_label, args.max_len)dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False)model = MyModel().to(device)opt = AdamW(model.parameters(), lr=args.learn_rate)loss_fn = nn.CrossEntropyLoss()acc_max = float("-inf")for epoch in range(args.epochs):loss_sum, count = 0, 0model.train()for batch_index, (batch_text, batch_label) in enumerate(train_dataloader):batch_label = batch_label.to(device)pred = model(batch_text)loss = loss_fn(pred, batch_label)opt.zero_grad()loss.backward()opt.step()loss_sum += losscount += 1# 打印内容if len(train_dataloader) - batch_index <= len(train_dataloader) % 1000 and count == len(train_dataloader) % 1000:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0if batch_index % 1000 == 999:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0model.eval()all_pred, all_true = [], []with torch.no_grad():for batch_text, batch_label in dev_dataloader:batch_label = batch_label.to(device)pred = model(batch_text)pred = torch.argmax(pred, dim=1).cpu().numpy().tolist()label = batch_label.cpu().numpy().tolist()all_pred.extend(pred)all_true.extend(label)acc = accuracy_score(all_pred, all_true)print(f"dev acc:{acc:.4f}")if acc > acc_max:print(acc, acc_max)acc_max = acctorch.save(model.state_dict(), args.save_model_best)print(f"以保存最佳模型")torch.save(model.state_dict(), args.save_model_last)end = time.time()print(f"运行时间:{(end-start)/60%60:.4f} min")
模型保存为:
-rw-rw-r-- 1 gaoToby gaoToby 391M Dec 24 14:02 all_best_model.pth
-rw-rw-r-- 1 gaoToby gaoToby 391M Dec 24 14:02 all_last_model.pth
四、模型推理预测
准备预测文本文件,加载模型,进行文本的类别预测。
def text_class_name(pred):result = torch.argmax(pred, dim=1)print(torch.argmax(pred, dim=1).cpu().numpy().tolist())result = result.cpu().numpy().tolist()classification = open(args.classification, "r", encoding="utf-8").read().split("\n")classification_dict = dict(zip(range(len(classification)), classification))print(f"文本:{text}\t预测的类别为:{classification_dict[result[0]]}")if __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"model = load_model(device, args.save_model_best)texts = ["沈腾和马丽的新电影《独行月球》好看", "最近金融环境不太好,投资需谨慎"]print("模型预测结果:")for text in texts:x = process_text(text, args.bert_pred)with torch.no_grad():pred = model(x)text_class_name(pred)end = time.time()print(f"耗时为:{end - start} s")
以上,基本流程完成。当然模型还需要调优来改进预测效果的。
代码是实际跑通的,我训练和预测均使用的是GPU。如果是使用GPU做模型训练,再使用CPU做推理预测的情况,推理预测加载模型的时候注意修改下:
myModel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
Done
相关文章:
Bert中文文本分类
这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main 我使用的训练环境是 pip install torch2.0.0; pi…...
【深度学习】Java DL4J基于 CNN 构建车辆识别与跟踪模型
🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…...
【C#】C#打印当前时间以及TimeSpan()介绍
1. C#打印当前时间 string currentDate DateTime.Now.ToString("yyyy-MM-dd HH:mm:ss.fff");Console.WriteLine(currentDate);2. TimeSpan()介绍 TimeSpan(long ticks)的单位是100ns //500ms new TimeSpan(10*1000*500);参考: C#-TimeSpan-计算时间差...
【Linux 网络 (五)】Tcp/Udp协议
Linux 网络 一前言二、Udp协议1)、Udp协议特点2)、Udp协议格式3)、Udp报文封装和解包过程4)、UDP的缓冲区 三、TCP协议1)、TCP协议特点2)、TCP协议格式1、4位首部长度、源端口、目的端口2、16位窗口大小3、…...
多旋翼无人机理论 | 四旋翼动力学数学模型与Matlab仿真
多旋翼无人机理论 | 四旋翼动力学数学模型与Matlab仿真 力的来源数学模型数学模型总结Matlab 仿真 力的来源 无人机的动力系统:电调-电机-螺旋桨 。 给人最直观的感受就是 电机带动螺旋桨转,产生升力。 螺旋桨旋转产生升力的原因,在很多年…...
Vue3项目中引入TailwindCSS(图文详情)
Vue3项目中引入TailwindCSS(图文详细) Tailwind CSS 是一个实用工具优先的 CSS 框架,提供丰富的低级类(如 text-center、bg-blue-500),允许开发者通过组合这些类快速构建自定义设计,而无需编写…...
【开源项目】数字孪生化工厂—开源工程及源码
飞渡科技数字孪生化工厂管理平台,基于自研孪生引擎,将物联网IOT、人工智能、大数据、云计算等技术应用于化工厂,为化工厂提供实时数据分析、工艺优化、设备运维等功能,助力提高生产效率以及提供安全保障。 通过可视化点位标注各厂…...
咨询团队如何通过轻量型工具优化项目管理和提高团队协作效率?
引言 在咨询行业,项目的复杂性和多样性往往意味着团队成员需要协同工作、迅速适应客户需求的变化并且在较短的时间内交付高质量的成果。对于咨询团队来说,选择一个适合的项目管理工具,不仅能够提高工作效率,还能促进团队的协作、…...
javaWeb开发
Java Web开发作为软件开发领域的一个重要分支,已经历经数十年的发展,并凭借其强大的跨平台能力、丰富的生态系统以及高度的安全性,成为构建企业级应用的首选技术之一。以下是对Java Web开发的详细解析: 一、Java Web开发的基本概…...
如何在 Vue 中处理 API 请求?
在 Vue.js 中处理 API 请求是构建动态、交互式 Web 应用程序的核心部分。为了有效地与后端服务器通信,Vue 生态系统提供了多种方式来发起和管理 API 请求。以下是几种常见的方法和最佳实践: 1. 使用 Axios Axios 是一个基于 Promise 的 HTTP 客户端&am…...
基于Debian的Linux发行版的包管理工具
基于Debian的Linux发行版中除了apt和apt-get之外,还有以下几种包管理工具: dpkg:这是Debian系发行版中最基础的包管理工具,专门用于安装、卸载和查询.deb包。与高级包管理器不同,dpkg不自动解决包的依赖关系࿰…...
2022年国家公考《申论》题(行政执法)
2022年国家公考《申论》题(行政执法) 材料一 新型冠状病毒肺炎疫情发生后,党中央、国务院出台了一系列支持企业发展的惠企政策。N市积极落实各项惠企政策,不断优化营商环境,推动区域经济高质量跨越式发展。 “当时…...
贪心算法(常见贪心模型)
常见贪心模型 简单排序模型 最小化战斗力差距 题目分析: #include <bits/stdc.h> using namespace std;const int N 1e5 10;int n; int a[N];int main() {// 请在此输入您的代码cin >> n;for (int i 1;i < n;i) cin >> a[i];sort(a1,a1n);…...
git自动压缩提交的脚本
可以将当前未提交的代码自动执行 git addgit commitgit squash Git 命令安装指南 1. 创建脚本目录 如果目录不存在,创建它: mkdir -p ~/.local/bin2. 创建脚本文件 vim ~/.local/bin/git-squash将完整的脚本代码复制到此文件中。 3. 设置脚本权限…...
Kinova在开源家庭服务机器人TidyBot++研究里大展身手
在科技日新月异的今天,机器人技术在家庭场景中的应用逐渐成为现实,改变着我们的生活方式。今天,我们将深入探讨一篇关于家用机器人研究的论文,剖析其中的创新成果, 论文引用链接:http://tidybot2.github.i…...
使用 Spring Boot 实现文件上传:从配置文件中动态读取上传路径
使用 Spring Boot 实现文件上传:从配置文件中动态读取上传路径 一、前言二、文件上传的基本概念三、环境准备1. 引入依赖2. 配置文件设置application.yml 配置示例:application.properties 配置示例: 四、编写文件上传功能代码1. 控制器类2. …...
《鸿蒙HarmonyOS应用开发从入门到精通(第2版)》学习笔记——HarmonyOS技术理念
1.2 技术理念 在万物智联时代重要机遇期,HarmonyOS结合移动生态发展的趋势,提出了三大技术理念(如下图3-1所示):一次开发,多端部署;可分可合,自由流转;统一生态…...
将多个 k8s yaml 配置文件合并为一个文件
如下bash脚本实现功能 “将多个k8s的yaml 配置文件” 合并为一个 yaml,使用 --- 分割文件配置。 创建文件 merge_yaml.sh ,内容如下: #!/bin/bash# 默认参数 input_patterns() # 匹配的文件模式数组 output_file"combined.yaml"…...
Linux 文件的特殊权限—Sticky Bit(SBIT)权限
本文为Ubuntu Linux操作系统- 第十九期~~ 其他特殊权限: 【SUID 权限】和【SGID 权限】 更多Linux 相关内容请点击👉【Linux专栏】~ 主页:【练小杰的CSDN】 文章目录 Sticky(SBIT)权限基本概念Sticky Bit 的表示方式举例 设置和取…...
MIPI D-PHY/C-PHY/M-PHY 高速串行接口标准
MIPI D-PHY、C-PHY和M-PHY都是MIPI联盟制定的高速串行接口标准。它们都具有低功耗、高速传输速率等特点,但各有侧重: ➢MIPI D-PHY:适用于手机与其他设备之间的数据传输。 ➢MIPI C-PHY:专为手机摄像头而设计。 ➢MIPI M-PHY&am…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...
SCAU期末笔记 - 数据分析与数据挖掘题库解析
这门怎么题库答案不全啊日 来简单学一下子来 一、选择题(可多选) 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘:专注于发现数据中…...
pam_env.so模块配置解析
在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
在QWebEngineView上实现鼠标、触摸等事件捕获的解决方案
这个问题我看其他博主也写了,要么要会员、要么写的乱七八糟。这里我整理一下,把问题说清楚并且给出代码,拿去用就行,照着葫芦画瓢。 问题 在继承QWebEngineView后,重写mousePressEvent或event函数无法捕获鼠标按下事…...
JVM虚拟机:内存结构、垃圾回收、性能优化
1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...
GitFlow 工作模式(详解)
今天再学项目的过程中遇到使用gitflow模式管理代码,因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存,无论是github还是gittee,都是一种基于git去保存代码的形式,这样保存代码…...
CSS | transition 和 transform的用处和区别
省流总结: transform用于变换/变形,transition是动画控制器 transform 用来对元素进行变形,常见的操作如下,它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...
android13 app的触摸问题定位分析流程
一、知识点 一般来说,触摸问题都是app层面出问题,我们可以在ViewRootImpl.java添加log的方式定位;如果是touchableRegion的计算问题,就会相对比较麻烦了,需要通过adb shell dumpsys input > input.log指令,且通过打印堆栈的方式,逐步定位问题,并找到修改方案。 问题…...
DiscuzX3.5发帖json api
参考文章:PHP实现独立Discuz站外发帖(直连操作数据库)_discuz 发帖api-CSDN博客 简单改造了一下,适配我自己的需求 有一个站点存在多个采集站,我想通过主站拿标题,采集站拿内容 使用到的sql如下 CREATE TABLE pre_forum_post_…...
