Bert中文文本分类
这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main
我使用的训练环境是
pip install torch==2.0.0;
pip install transformers==4.30.2;
pip install gensim==4.3.3;
pip install huggingface-hub==0.15.1;
pip install modelscope==1.20.1;
一、准备训练数据
1.1 准备中文文本分类任务的训练数据
这里Demo数据如下:
各银行信用卡挂失费迥异 北京银行收费最高 0 莫泰酒店流拍 大摩叫价或降至6亿美元 4 乌兹别克斯坦议会立法院主席获连任 6 德媒披露鲁能引援关键人物 是他力荐德甲亚洲强人 7 辉立证券给予广汽集团持有评级 2 图文-业余希望赛海南站第二轮 球场的菠萝蜜 7 陆毅鲍蕾:近乎完美的爱情(组图)(2) 9 7000亿美元救市方案将成期市毒药 0 保诚启动210亿美元配股交易以融资收购AIG部门 2
分类class类别文件:
finance realty stocks education science society politics sports game entertainment
1.2 数据读取和截断,使满足BERT模型输入
读取训练数据,对文本进行处理,如截取过长的文本、补齐较短的文本,加上起始标示、对文本进行编码、添加掩码、转为tensor等操作。
import os
from config import parsers
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torchfrom transformers import AutoTokenizer, AutoModelForMaskedLMdef read_data(file):# 读取文件all_data = open(file, "r", encoding="utf-8").read().split("\n")# 得到所有文本、所有标签、句子的最大长度texts, labels, max_length = [], [], []for data in all_data:if data:text, label = data.split("\t")max_length.append(len(text))texts.append(text)labels.append(label)# 根据不同的数据集返回不同的内容if os.path.split(file)[1] == "train.txt":max_len = max(max_length)return texts, labels, max_lenreturn texts, labels,class MyDataset(Dataset):def __init__(self, texts, labels, max_length):self.all_text = textsself.all_label = labelsself.max_len = max_lengthself.tokenizer = BertTokenizer.from_pretrained(parsers().bert_pred)
# self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")def __getitem__(self, index):# 取出一条数据并截断长度text = self.all_text[index][:self.max_len]label = self.all_label[index]# 分词text_id = self.tokenizer.tokenize(text)# 加上起始标志text_id = ["[CLS]"] + text_id# 编码token_id = self.tokenizer.convert_tokens_to_ids(text_id)# 掩码 -》mask = [1] * len(token_id) + [0] * (self.max_len + 2 - len(token_id))# 编码后 -》长度一致token_ids = token_id + [0] * (self.max_len + 2 - len(token_id))# str -》 intlabel = int(label)# 转化成tensortoken_ids = torch.tensor(token_ids)mask = torch.tensor(mask)label = torch.tensor(label)return (token_ids, mask), labeldef __len__(self):# 得到文本的长度return len(self.all_text)
将文本处理后,就可以使用torch.utils.data中自带的DataLoader模块来加载训练数据了。
二、微调BERT模型
我们是微调BERT模型,需要获取BERT最后一个隐藏层的输出作为输入到下一个全连接层。
至于选择BERT模型的哪个输出作为linear层的输入,可以通过实验尝试,或者遵循常理。
pooler_output:这是通过将最后一层的隐藏状态的第一个token(通常是[CLS] token)通过一个线性层和激活函数得到的输出,常用于分类任务。
last_hidden_state:这是模型所有层的最后一个隐藏状态的输出,包含了整个序列的上下文信息,适用于序列级别的任务。
简单调用下BERT模型,打印出来最后一层看下:
import torch
import time
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLMdef process_text(text, bert_pred):tokenizer = BertTokenizer.from_pretrained(bert_pred)token_id = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokenizer.tokenize(text))mask = [1] * len(token_id) + [0] * (38 + 2 - len(token_id))token_ids = token_id + [0] * (38 + 2 - len(token_id))token_ids = torch.tensor(token_ids).unsqueeze(0)mask = torch.tensor(mask).unsqueeze(0)x = torch.stack([token_ids, mask])return xdevice = "cpu"
bert = BertModel.from_pretrained('./bert-base-chinese/')
texts = ["沈腾和马丽的电影《独行月球》挺好看"]
for text in texts:x = process_text(text, './bert-base-chinese/')input_ids, attention_mask = x[0].to(device), x[1].to(device)hidden_out = bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) print(hidden_out)
输出结果:


2.1 文本分类任务,选择使用pooler_output作为线性层的输入。
import torch.nn as nn
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from config import parsers
import torchclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.args = parsers()self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.bert = BertModel.from_pretrained(self.args.bert_pred) # bert 模型进行微调for param in self.bert.parameters():param.requires_grad = True# 一个全连接层self.linear = nn.Linear(self.args.num_filters, self.args.class_num)def forward(self, x):input_ids, attention_mask = x[0].to(self.device), x[1].to(self.device)hidden_out = self.bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) # 是否输出所有encoder层的结果# shape (batch_size, hidden_size) pooler_output --> hidden_out[0]pred = self.linear(hidden_out.pooler_output)# 返回预测结果return pred
2.2 优化器使用Adam、损失函数使用交叉熵损失函数
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MyModel().to(device)
opt = AdamW(model.parameters(), lr=args.learn_rate)
loss_fn = nn.CrossEntropyLoss()
三、训练模型
3.1 参数配置
def parsers():parser = argparse.ArgumentParser(description="Bert model of argparse")parser.add_argument("tx_date",nargs='?') #可选输入参数,计算日期parser.add_argument("--train_file", type=str, default=os.path.join("./data_all", "train.txt"))parser.add_argument("--dev_file", type=str, default=os.path.join("./data_all", "dev.txt"))parser.add_argument("--test_file", type=str, default=os.path.join("./data_all", "test.txt"))parser.add_argument("--classification", type=str, default=os.path.join("./data_all", "class.txt"))parser.add_argument("--bert_pred", type=str, default="./bert-base-chinese")parser.add_argument("--class_num", type=int, default=12)parser.add_argument("--max_len", type=int, default=38)parser.add_argument("--batch_size", type=int, default=32)parser.add_argument("--epochs", type=int, default=10)parser.add_argument("--learn_rate", type=float, default=1e-5)parser.add_argument("--num_filters", type=int, default=768)parser.add_argument("--save_model_best", type=str, default=os.path.join("model", "all_best_model.pth"))parser.add_argument("--save_model_last", type=str, default=os.path.join("model", "all_last_model.pth"))args = parser.parse_args()return args
3.2 模型训练
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from sklearn.metrics import accuracy_score
import timeif __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"print("device:", device)train_text, train_label, max_len = read_data(args.train_file)dev_text, dev_label = read_data(args.dev_file)args.max_len = max_lentrain_dataset = MyDataset(train_text, train_label, args.max_len)train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)dev_dataset = MyDataset(dev_text, dev_label, args.max_len)dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False)model = MyModel().to(device)opt = AdamW(model.parameters(), lr=args.learn_rate)loss_fn = nn.CrossEntropyLoss()acc_max = float("-inf")for epoch in range(args.epochs):loss_sum, count = 0, 0model.train()for batch_index, (batch_text, batch_label) in enumerate(train_dataloader):batch_label = batch_label.to(device)pred = model(batch_text)loss = loss_fn(pred, batch_label)opt.zero_grad()loss.backward()opt.step()loss_sum += losscount += 1# 打印内容if len(train_dataloader) - batch_index <= len(train_dataloader) % 1000 and count == len(train_dataloader) % 1000:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0if batch_index % 1000 == 999:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0model.eval()all_pred, all_true = [], []with torch.no_grad():for batch_text, batch_label in dev_dataloader:batch_label = batch_label.to(device)pred = model(batch_text)pred = torch.argmax(pred, dim=1).cpu().numpy().tolist()label = batch_label.cpu().numpy().tolist()all_pred.extend(pred)all_true.extend(label)acc = accuracy_score(all_pred, all_true)print(f"dev acc:{acc:.4f}")if acc > acc_max:print(acc, acc_max)acc_max = acctorch.save(model.state_dict(), args.save_model_best)print(f"以保存最佳模型")torch.save(model.state_dict(), args.save_model_last)end = time.time()print(f"运行时间:{(end-start)/60%60:.4f} min")
模型保存为:
-rw-rw-r-- 1 gaoToby gaoToby 391M Dec 24 14:02 all_best_model.pth
-rw-rw-r-- 1 gaoToby gaoToby 391M Dec 24 14:02 all_last_model.pth
四、模型推理预测
准备预测文本文件,加载模型,进行文本的类别预测。
def text_class_name(pred):result = torch.argmax(pred, dim=1)print(torch.argmax(pred, dim=1).cpu().numpy().tolist())result = result.cpu().numpy().tolist()classification = open(args.classification, "r", encoding="utf-8").read().split("\n")classification_dict = dict(zip(range(len(classification)), classification))print(f"文本:{text}\t预测的类别为:{classification_dict[result[0]]}")if __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"model = load_model(device, args.save_model_best)texts = ["沈腾和马丽的新电影《独行月球》好看", "最近金融环境不太好,投资需谨慎"]print("模型预测结果:")for text in texts:x = process_text(text, args.bert_pred)with torch.no_grad():pred = model(x)text_class_name(pred)end = time.time()print(f"耗时为:{end - start} s")
以上,基本流程完成。当然模型还需要调优来改进预测效果的。
代码是实际跑通的,我训练和预测均使用的是GPU。如果是使用GPU做模型训练,再使用CPU做推理预测的情况,推理预测加载模型的时候注意修改下:
myModel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
Done
相关文章:
Bert中文文本分类
这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main 我使用的训练环境是 pip install torch2.0.0; pi…...
【深度学习】Java DL4J基于 CNN 构建车辆识别与跟踪模型
🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…...
【C#】C#打印当前时间以及TimeSpan()介绍
1. C#打印当前时间 string currentDate DateTime.Now.ToString("yyyy-MM-dd HH:mm:ss.fff");Console.WriteLine(currentDate);2. TimeSpan()介绍 TimeSpan(long ticks)的单位是100ns //500ms new TimeSpan(10*1000*500);参考: C#-TimeSpan-计算时间差...
【Linux 网络 (五)】Tcp/Udp协议
Linux 网络 一前言二、Udp协议1)、Udp协议特点2)、Udp协议格式3)、Udp报文封装和解包过程4)、UDP的缓冲区 三、TCP协议1)、TCP协议特点2)、TCP协议格式1、4位首部长度、源端口、目的端口2、16位窗口大小3、…...
多旋翼无人机理论 | 四旋翼动力学数学模型与Matlab仿真
多旋翼无人机理论 | 四旋翼动力学数学模型与Matlab仿真 力的来源数学模型数学模型总结Matlab 仿真 力的来源 无人机的动力系统:电调-电机-螺旋桨 。 给人最直观的感受就是 电机带动螺旋桨转,产生升力。 螺旋桨旋转产生升力的原因,在很多年…...
Vue3项目中引入TailwindCSS(图文详情)
Vue3项目中引入TailwindCSS(图文详细) Tailwind CSS 是一个实用工具优先的 CSS 框架,提供丰富的低级类(如 text-center、bg-blue-500),允许开发者通过组合这些类快速构建自定义设计,而无需编写…...
【开源项目】数字孪生化工厂—开源工程及源码
飞渡科技数字孪生化工厂管理平台,基于自研孪生引擎,将物联网IOT、人工智能、大数据、云计算等技术应用于化工厂,为化工厂提供实时数据分析、工艺优化、设备运维等功能,助力提高生产效率以及提供安全保障。 通过可视化点位标注各厂…...
咨询团队如何通过轻量型工具优化项目管理和提高团队协作效率?
引言 在咨询行业,项目的复杂性和多样性往往意味着团队成员需要协同工作、迅速适应客户需求的变化并且在较短的时间内交付高质量的成果。对于咨询团队来说,选择一个适合的项目管理工具,不仅能够提高工作效率,还能促进团队的协作、…...
javaWeb开发
Java Web开发作为软件开发领域的一个重要分支,已经历经数十年的发展,并凭借其强大的跨平台能力、丰富的生态系统以及高度的安全性,成为构建企业级应用的首选技术之一。以下是对Java Web开发的详细解析: 一、Java Web开发的基本概…...
如何在 Vue 中处理 API 请求?
在 Vue.js 中处理 API 请求是构建动态、交互式 Web 应用程序的核心部分。为了有效地与后端服务器通信,Vue 生态系统提供了多种方式来发起和管理 API 请求。以下是几种常见的方法和最佳实践: 1. 使用 Axios Axios 是一个基于 Promise 的 HTTP 客户端&am…...
基于Debian的Linux发行版的包管理工具
基于Debian的Linux发行版中除了apt和apt-get之外,还有以下几种包管理工具: dpkg:这是Debian系发行版中最基础的包管理工具,专门用于安装、卸载和查询.deb包。与高级包管理器不同,dpkg不自动解决包的依赖关系࿰…...
2022年国家公考《申论》题(行政执法)
2022年国家公考《申论》题(行政执法) 材料一 新型冠状病毒肺炎疫情发生后,党中央、国务院出台了一系列支持企业发展的惠企政策。N市积极落实各项惠企政策,不断优化营商环境,推动区域经济高质量跨越式发展。 “当时…...
贪心算法(常见贪心模型)
常见贪心模型 简单排序模型 最小化战斗力差距 题目分析: #include <bits/stdc.h> using namespace std;const int N 1e5 10;int n; int a[N];int main() {// 请在此输入您的代码cin >> n;for (int i 1;i < n;i) cin >> a[i];sort(a1,a1n);…...
git自动压缩提交的脚本
可以将当前未提交的代码自动执行 git addgit commitgit squash Git 命令安装指南 1. 创建脚本目录 如果目录不存在,创建它: mkdir -p ~/.local/bin2. 创建脚本文件 vim ~/.local/bin/git-squash将完整的脚本代码复制到此文件中。 3. 设置脚本权限…...
Kinova在开源家庭服务机器人TidyBot++研究里大展身手
在科技日新月异的今天,机器人技术在家庭场景中的应用逐渐成为现实,改变着我们的生活方式。今天,我们将深入探讨一篇关于家用机器人研究的论文,剖析其中的创新成果, 论文引用链接:http://tidybot2.github.i…...
使用 Spring Boot 实现文件上传:从配置文件中动态读取上传路径
使用 Spring Boot 实现文件上传:从配置文件中动态读取上传路径 一、前言二、文件上传的基本概念三、环境准备1. 引入依赖2. 配置文件设置application.yml 配置示例:application.properties 配置示例: 四、编写文件上传功能代码1. 控制器类2. …...
《鸿蒙HarmonyOS应用开发从入门到精通(第2版)》学习笔记——HarmonyOS技术理念
1.2 技术理念 在万物智联时代重要机遇期,HarmonyOS结合移动生态发展的趋势,提出了三大技术理念(如下图3-1所示):一次开发,多端部署;可分可合,自由流转;统一生态…...
将多个 k8s yaml 配置文件合并为一个文件
如下bash脚本实现功能 “将多个k8s的yaml 配置文件” 合并为一个 yaml,使用 --- 分割文件配置。 创建文件 merge_yaml.sh ,内容如下: #!/bin/bash# 默认参数 input_patterns() # 匹配的文件模式数组 output_file"combined.yaml"…...
Linux 文件的特殊权限—Sticky Bit(SBIT)权限
本文为Ubuntu Linux操作系统- 第十九期~~ 其他特殊权限: 【SUID 权限】和【SGID 权限】 更多Linux 相关内容请点击👉【Linux专栏】~ 主页:【练小杰的CSDN】 文章目录 Sticky(SBIT)权限基本概念Sticky Bit 的表示方式举例 设置和取…...
MIPI D-PHY/C-PHY/M-PHY 高速串行接口标准
MIPI D-PHY、C-PHY和M-PHY都是MIPI联盟制定的高速串行接口标准。它们都具有低功耗、高速传输速率等特点,但各有侧重: ➢MIPI D-PHY:适用于手机与其他设备之间的数据传输。 ➢MIPI C-PHY:专为手机摄像头而设计。 ➢MIPI M-PHY&am…...
练习(含atoi的模拟实现,自定义类型等练习)
一、结构体大小的计算及位段 (结构体大小计算及位段 详解请看:自定义类型:结构体进阶-CSDN博客) 1.在32位系统环境,编译选项为4字节对齐,那么sizeof(A)和sizeof(B)是多少? #pragma pack(4)st…...
《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)
CSI-2 协议详细解析 (一) 1. CSI-2层定义(CSI-2 Layer Definitions) 分层结构 :CSI-2协议分为6层: 物理层(PHY Layer) : 定义电气特性、时钟机制和传输介质(导线&#…...
理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端
🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...
华为OD机试-食堂供餐-二分法
import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...
关于 WASM:1. WASM 基础原理
一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...
【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分
一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计,提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合:各模块职责清晰,便于独立开发…...
AspectJ 在 Android 中的完整使用指南
一、环境配置(Gradle 7.0 适配) 1. 项目级 build.gradle // 注意:沪江插件已停更,推荐官方兼容方案 buildscript {dependencies {classpath org.aspectj:aspectjtools:1.9.9.1 // AspectJ 工具} } 2. 模块级 build.gradle plu…...
Java线上CPU飙高问题排查全指南
一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...
Unsafe Fileupload篇补充-木马的详细教程与木马分享(中国蚁剑方式)
在之前的皮卡丘靶场第九期Unsafe Fileupload篇中我们学习了木马的原理并且学了一个简单的木马文件 本期内容是为了更好的为大家解释木马(服务器方面的)的原理,连接,以及各种木马及连接工具的分享 文件木马:https://w…...
HashMap中的put方法执行流程(流程图)
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...
