Bert中文文本分类
这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main
我使用的训练环境是
pip install torch==2.0.0;
pip install transformers==4.30.2;
pip install gensim==4.3.3;
pip install huggingface-hub==0.15.1;
pip install modelscope==1.20.1;
一、准备训练数据
1.1 准备中文文本分类任务的训练数据
这里Demo数据如下:
各银行信用卡挂失费迥异 北京银行收费最高 0 莫泰酒店流拍 大摩叫价或降至6亿美元 4 乌兹别克斯坦议会立法院主席获连任 6 德媒披露鲁能引援关键人物 是他力荐德甲亚洲强人 7 辉立证券给予广汽集团持有评级 2 图文-业余希望赛海南站第二轮 球场的菠萝蜜 7 陆毅鲍蕾:近乎完美的爱情(组图)(2) 9 7000亿美元救市方案将成期市毒药 0 保诚启动210亿美元配股交易以融资收购AIG部门 2
分类class类别文件:
finance realty stocks education science society politics sports game entertainment
1.2 数据读取和截断,使满足BERT模型输入
读取训练数据,对文本进行处理,如截取过长的文本、补齐较短的文本,加上起始标示、对文本进行编码、添加掩码、转为tensor等操作。
import os
from config import parsers
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torchfrom transformers import AutoTokenizer, AutoModelForMaskedLMdef read_data(file):# 读取文件all_data = open(file, "r", encoding="utf-8").read().split("\n")# 得到所有文本、所有标签、句子的最大长度texts, labels, max_length = [], [], []for data in all_data:if data:text, label = data.split("\t")max_length.append(len(text))texts.append(text)labels.append(label)# 根据不同的数据集返回不同的内容if os.path.split(file)[1] == "train.txt":max_len = max(max_length)return texts, labels, max_lenreturn texts, labels,class MyDataset(Dataset):def __init__(self, texts, labels, max_length):self.all_text = textsself.all_label = labelsself.max_len = max_lengthself.tokenizer = BertTokenizer.from_pretrained(parsers().bert_pred)
# self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")def __getitem__(self, index):# 取出一条数据并截断长度text = self.all_text[index][:self.max_len]label = self.all_label[index]# 分词text_id = self.tokenizer.tokenize(text)# 加上起始标志text_id = ["[CLS]"] + text_id# 编码token_id = self.tokenizer.convert_tokens_to_ids(text_id)# 掩码 -》mask = [1] * len(token_id) + [0] * (self.max_len + 2 - len(token_id))# 编码后 -》长度一致token_ids = token_id + [0] * (self.max_len + 2 - len(token_id))# str -》 intlabel = int(label)# 转化成tensortoken_ids = torch.tensor(token_ids)mask = torch.tensor(mask)label = torch.tensor(label)return (token_ids, mask), labeldef __len__(self):# 得到文本的长度return len(self.all_text)
将文本处理后,就可以使用torch.utils.data中自带的DataLoader模块来加载训练数据了。
二、微调BERT模型
我们是微调BERT模型,需要获取BERT最后一个隐藏层的输出作为输入到下一个全连接层。
至于选择BERT模型的哪个输出作为linear层的输入,可以通过实验尝试,或者遵循常理。
pooler_output:这是通过将最后一层的隐藏状态的第一个token(通常是[CLS] token)通过一个线性层和激活函数得到的输出,常用于分类任务。
last_hidden_state:这是模型所有层的最后一个隐藏状态的输出,包含了整个序列的上下文信息,适用于序列级别的任务。
简单调用下BERT模型,打印出来最后一层看下:
import torch
import time
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLMdef process_text(text, bert_pred):tokenizer = BertTokenizer.from_pretrained(bert_pred)token_id = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokenizer.tokenize(text))mask = [1] * len(token_id) + [0] * (38 + 2 - len(token_id))token_ids = token_id + [0] * (38 + 2 - len(token_id))token_ids = torch.tensor(token_ids).unsqueeze(0)mask = torch.tensor(mask).unsqueeze(0)x = torch.stack([token_ids, mask])return xdevice = "cpu"
bert = BertModel.from_pretrained('./bert-base-chinese/')
texts = ["沈腾和马丽的电影《独行月球》挺好看"]
for text in texts:x = process_text(text, './bert-base-chinese/')input_ids, attention_mask = x[0].to(device), x[1].to(device)hidden_out = bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) print(hidden_out)
输出结果:


2.1 文本分类任务,选择使用pooler_output作为线性层的输入。
import torch.nn as nn
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from config import parsers
import torchclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.args = parsers()self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.bert = BertModel.from_pretrained(self.args.bert_pred) # bert 模型进行微调for param in self.bert.parameters():param.requires_grad = True# 一个全连接层self.linear = nn.Linear(self.args.num_filters, self.args.class_num)def forward(self, x):input_ids, attention_mask = x[0].to(self.device), x[1].to(self.device)hidden_out = self.bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) # 是否输出所有encoder层的结果# shape (batch_size, hidden_size) pooler_output --> hidden_out[0]pred = self.linear(hidden_out.pooler_output)# 返回预测结果return pred
2.2 优化器使用Adam、损失函数使用交叉熵损失函数
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MyModel().to(device)
opt = AdamW(model.parameters(), lr=args.learn_rate)
loss_fn = nn.CrossEntropyLoss()
三、训练模型
3.1 参数配置
def parsers():parser = argparse.ArgumentParser(description="Bert model of argparse")parser.add_argument("tx_date",nargs='?') #可选输入参数,计算日期parser.add_argument("--train_file", type=str, default=os.path.join("./data_all", "train.txt"))parser.add_argument("--dev_file", type=str, default=os.path.join("./data_all", "dev.txt"))parser.add_argument("--test_file", type=str, default=os.path.join("./data_all", "test.txt"))parser.add_argument("--classification", type=str, default=os.path.join("./data_all", "class.txt"))parser.add_argument("--bert_pred", type=str, default="./bert-base-chinese")parser.add_argument("--class_num", type=int, default=12)parser.add_argument("--max_len", type=int, default=38)parser.add_argument("--batch_size", type=int, default=32)parser.add_argument("--epochs", type=int, default=10)parser.add_argument("--learn_rate", type=float, default=1e-5)parser.add_argument("--num_filters", type=int, default=768)parser.add_argument("--save_model_best", type=str, default=os.path.join("model", "all_best_model.pth"))parser.add_argument("--save_model_last", type=str, default=os.path.join("model", "all_last_model.pth"))args = parser.parse_args()return args
3.2 模型训练
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from sklearn.metrics import accuracy_score
import timeif __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"print("device:", device)train_text, train_label, max_len = read_data(args.train_file)dev_text, dev_label = read_data(args.dev_file)args.max_len = max_lentrain_dataset = MyDataset(train_text, train_label, args.max_len)train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)dev_dataset = MyDataset(dev_text, dev_label, args.max_len)dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False)model = MyModel().to(device)opt = AdamW(model.parameters(), lr=args.learn_rate)loss_fn = nn.CrossEntropyLoss()acc_max = float("-inf")for epoch in range(args.epochs):loss_sum, count = 0, 0model.train()for batch_index, (batch_text, batch_label) in enumerate(train_dataloader):batch_label = batch_label.to(device)pred = model(batch_text)loss = loss_fn(pred, batch_label)opt.zero_grad()loss.backward()opt.step()loss_sum += losscount += 1# 打印内容if len(train_dataloader) - batch_index <= len(train_dataloader) % 1000 and count == len(train_dataloader) % 1000:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0if batch_index % 1000 == 999:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0model.eval()all_pred, all_true = [], []with torch.no_grad():for batch_text, batch_label in dev_dataloader:batch_label = batch_label.to(device)pred = model(batch_text)pred = torch.argmax(pred, dim=1).cpu().numpy().tolist()label = batch_label.cpu().numpy().tolist()all_pred.extend(pred)all_true.extend(label)acc = accuracy_score(all_pred, all_true)print(f"dev acc:{acc:.4f}")if acc > acc_max:print(acc, acc_max)acc_max = acctorch.save(model.state_dict(), args.save_model_best)print(f"以保存最佳模型")torch.save(model.state_dict(), args.save_model_last)end = time.time()print(f"运行时间:{(end-start)/60%60:.4f} min")
模型保存为:
-rw-rw-r-- 1 gaoToby gaoToby 391M Dec 24 14:02 all_best_model.pth
-rw-rw-r-- 1 gaoToby gaoToby 391M Dec 24 14:02 all_last_model.pth
四、模型推理预测
准备预测文本文件,加载模型,进行文本的类别预测。
def text_class_name(pred):result = torch.argmax(pred, dim=1)print(torch.argmax(pred, dim=1).cpu().numpy().tolist())result = result.cpu().numpy().tolist()classification = open(args.classification, "r", encoding="utf-8").read().split("\n")classification_dict = dict(zip(range(len(classification)), classification))print(f"文本:{text}\t预测的类别为:{classification_dict[result[0]]}")if __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"model = load_model(device, args.save_model_best)texts = ["沈腾和马丽的新电影《独行月球》好看", "最近金融环境不太好,投资需谨慎"]print("模型预测结果:")for text in texts:x = process_text(text, args.bert_pred)with torch.no_grad():pred = model(x)text_class_name(pred)end = time.time()print(f"耗时为:{end - start} s")
以上,基本流程完成。当然模型还需要调优来改进预测效果的。
代码是实际跑通的,我训练和预测均使用的是GPU。如果是使用GPU做模型训练,再使用CPU做推理预测的情况,推理预测加载模型的时候注意修改下:
myModel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
Done
相关文章:
Bert中文文本分类
这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main 我使用的训练环境是 pip install torch2.0.0; pi…...
【深度学习】Java DL4J基于 CNN 构建车辆识别与跟踪模型
🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…...
【C#】C#打印当前时间以及TimeSpan()介绍
1. C#打印当前时间 string currentDate DateTime.Now.ToString("yyyy-MM-dd HH:mm:ss.fff");Console.WriteLine(currentDate);2. TimeSpan()介绍 TimeSpan(long ticks)的单位是100ns //500ms new TimeSpan(10*1000*500);参考: C#-TimeSpan-计算时间差...
【Linux 网络 (五)】Tcp/Udp协议
Linux 网络 一前言二、Udp协议1)、Udp协议特点2)、Udp协议格式3)、Udp报文封装和解包过程4)、UDP的缓冲区 三、TCP协议1)、TCP协议特点2)、TCP协议格式1、4位首部长度、源端口、目的端口2、16位窗口大小3、…...
多旋翼无人机理论 | 四旋翼动力学数学模型与Matlab仿真
多旋翼无人机理论 | 四旋翼动力学数学模型与Matlab仿真 力的来源数学模型数学模型总结Matlab 仿真 力的来源 无人机的动力系统:电调-电机-螺旋桨 。 给人最直观的感受就是 电机带动螺旋桨转,产生升力。 螺旋桨旋转产生升力的原因,在很多年…...
Vue3项目中引入TailwindCSS(图文详情)
Vue3项目中引入TailwindCSS(图文详细) Tailwind CSS 是一个实用工具优先的 CSS 框架,提供丰富的低级类(如 text-center、bg-blue-500),允许开发者通过组合这些类快速构建自定义设计,而无需编写…...
【开源项目】数字孪生化工厂—开源工程及源码
飞渡科技数字孪生化工厂管理平台,基于自研孪生引擎,将物联网IOT、人工智能、大数据、云计算等技术应用于化工厂,为化工厂提供实时数据分析、工艺优化、设备运维等功能,助力提高生产效率以及提供安全保障。 通过可视化点位标注各厂…...
咨询团队如何通过轻量型工具优化项目管理和提高团队协作效率?
引言 在咨询行业,项目的复杂性和多样性往往意味着团队成员需要协同工作、迅速适应客户需求的变化并且在较短的时间内交付高质量的成果。对于咨询团队来说,选择一个适合的项目管理工具,不仅能够提高工作效率,还能促进团队的协作、…...
javaWeb开发
Java Web开发作为软件开发领域的一个重要分支,已经历经数十年的发展,并凭借其强大的跨平台能力、丰富的生态系统以及高度的安全性,成为构建企业级应用的首选技术之一。以下是对Java Web开发的详细解析: 一、Java Web开发的基本概…...
如何在 Vue 中处理 API 请求?
在 Vue.js 中处理 API 请求是构建动态、交互式 Web 应用程序的核心部分。为了有效地与后端服务器通信,Vue 生态系统提供了多种方式来发起和管理 API 请求。以下是几种常见的方法和最佳实践: 1. 使用 Axios Axios 是一个基于 Promise 的 HTTP 客户端&am…...
基于Debian的Linux发行版的包管理工具
基于Debian的Linux发行版中除了apt和apt-get之外,还有以下几种包管理工具: dpkg:这是Debian系发行版中最基础的包管理工具,专门用于安装、卸载和查询.deb包。与高级包管理器不同,dpkg不自动解决包的依赖关系࿰…...
2022年国家公考《申论》题(行政执法)
2022年国家公考《申论》题(行政执法) 材料一 新型冠状病毒肺炎疫情发生后,党中央、国务院出台了一系列支持企业发展的惠企政策。N市积极落实各项惠企政策,不断优化营商环境,推动区域经济高质量跨越式发展。 “当时…...
贪心算法(常见贪心模型)
常见贪心模型 简单排序模型 最小化战斗力差距 题目分析: #include <bits/stdc.h> using namespace std;const int N 1e5 10;int n; int a[N];int main() {// 请在此输入您的代码cin >> n;for (int i 1;i < n;i) cin >> a[i];sort(a1,a1n);…...
git自动压缩提交的脚本
可以将当前未提交的代码自动执行 git addgit commitgit squash Git 命令安装指南 1. 创建脚本目录 如果目录不存在,创建它: mkdir -p ~/.local/bin2. 创建脚本文件 vim ~/.local/bin/git-squash将完整的脚本代码复制到此文件中。 3. 设置脚本权限…...
Kinova在开源家庭服务机器人TidyBot++研究里大展身手
在科技日新月异的今天,机器人技术在家庭场景中的应用逐渐成为现实,改变着我们的生活方式。今天,我们将深入探讨一篇关于家用机器人研究的论文,剖析其中的创新成果, 论文引用链接:http://tidybot2.github.i…...
使用 Spring Boot 实现文件上传:从配置文件中动态读取上传路径
使用 Spring Boot 实现文件上传:从配置文件中动态读取上传路径 一、前言二、文件上传的基本概念三、环境准备1. 引入依赖2. 配置文件设置application.yml 配置示例:application.properties 配置示例: 四、编写文件上传功能代码1. 控制器类2. …...
《鸿蒙HarmonyOS应用开发从入门到精通(第2版)》学习笔记——HarmonyOS技术理念
1.2 技术理念 在万物智联时代重要机遇期,HarmonyOS结合移动生态发展的趋势,提出了三大技术理念(如下图3-1所示):一次开发,多端部署;可分可合,自由流转;统一生态…...
将多个 k8s yaml 配置文件合并为一个文件
如下bash脚本实现功能 “将多个k8s的yaml 配置文件” 合并为一个 yaml,使用 --- 分割文件配置。 创建文件 merge_yaml.sh ,内容如下: #!/bin/bash# 默认参数 input_patterns() # 匹配的文件模式数组 output_file"combined.yaml"…...
Linux 文件的特殊权限—Sticky Bit(SBIT)权限
本文为Ubuntu Linux操作系统- 第十九期~~ 其他特殊权限: 【SUID 权限】和【SGID 权限】 更多Linux 相关内容请点击👉【Linux专栏】~ 主页:【练小杰的CSDN】 文章目录 Sticky(SBIT)权限基本概念Sticky Bit 的表示方式举例 设置和取…...
MIPI D-PHY/C-PHY/M-PHY 高速串行接口标准
MIPI D-PHY、C-PHY和M-PHY都是MIPI联盟制定的高速串行接口标准。它们都具有低功耗、高速传输速率等特点,但各有侧重: ➢MIPI D-PHY:适用于手机与其他设备之间的数据传输。 ➢MIPI C-PHY:专为手机摄像头而设计。 ➢MIPI M-PHY&am…...
【飞机】倾转旋翼飞机齿轮箱建模与Matlab仿真(含非线性阻尼和立方摩擦效应)
✅作者简介:热爱科研的Matlab仿真开发者,擅长毕业设计辅导、数学建模、数据处理、建模仿真、程序设计、完整代码获取、论文复现及科研仿真。🍎 往期回顾关注个人主页:Matlab科研工作室👇 关注我领取海量matlab电子书和…...
GDBFrontend安全部署指南:保护调试会话的5个最佳实践
GDBFrontend安全部署指南:保护调试会话的5个最佳实践 【免费下载链接】gdb-frontend ☕ GDBFrontend is an easy, flexible and extensible gui debugger. Try it on https://debugme.dev 项目地址: https://gitcode.com/gh_mirrors/gd/gdb-frontend GDBFron…...
4大技术引擎破解魔兽争霸3现代适配难题
4大技术引擎破解魔兽争霸3现代适配难题 【免费下载链接】WarcraftHelper Warcraft III Helper , support 1.20e, 1.24e, 1.26a, 1.27a, 1.27b 项目地址: https://gitcode.com/gh_mirrors/wa/WarcraftHelper 当经典RTS游戏遇上现代硬件环境,总会面临兼容性的严…...
PyTorch 2.8 GPU算力优化部署教程:RTX 4090D显存利用率提升至92%
PyTorch 2.8 GPU算力优化部署教程:RTX 4090D显存利用率提升至92% 1. 环境准备与快速验证 在开始深度学习项目前,确保你的硬件配置符合以下要求: 显卡:NVIDIA RTX 4090D 24GB显存驱动版本:550.90.07或更高系统内存&a…...
SecGPT-14B提示工程:提升OpenClaw安全报告可读性的秘诀
SecGPT-14B提示工程:提升OpenClaw安全报告可读性的秘诀 1. 当安全报告遇上OpenClaw:我的真实痛点 上周五凌晨2点,我被OpenClaw的告警邮件惊醒——它发现我的个人服务器存在一个高危漏洞。但当我打开那份自动生成的安全报告时,眼…...
2026企业AI落地必看:避开3大坑,让你的智能体真正帮你赚钱!收藏这份实战指南
本文深入探讨了企业AI智能体落地的现实难题,包括数据基础薄弱、单体智能体处理复杂流程能力不足以及人机协同缺失三大痛点。作者通过分析30企业案例,提出了针对性的解决方案:建立RAG架构和OCR数据清洗以夯实数据基础;采用多智能体…...
OpenClaw技能开发入门:为Qwen3-14b_int4_awq扩展自定义功能
OpenClaw技能开发入门:为Qwen3-14b_int4_awq扩展自定义功能 1. 为什么需要自定义技能? 去年冬天,我花了整整两周时间手动整理公司项目的技术文档。每天重复着复制、粘贴、格式调整的机械操作,直到偶然发现OpenClaw这个开源自动化…...
Windows Cleaner实战指南:解决C盘空间不足和电脑卡顿的5个高效策略
Windows Cleaner实战指南:解决C盘空间不足和电脑卡顿的5个高效策略 【免费下载链接】WindowsCleaner Windows Cleaner——专治C盘爆红及各种不服! 项目地址: https://gitcode.com/gh_mirrors/wi/WindowsCleaner Windows Cleaner是一款专为Windows…...
别再只用箱线图了!用R语言vioplot绘制小提琴图的5个高级技巧与常见误区避坑
别再只用箱线图了!用R语言vioplot绘制小提琴图的5个高级技巧与常见误区避坑 当你已经能够熟练地用箱线图展示数据分布时,是否想过有一种更优雅、信息量更大的可视化方式?小提琴图(Violin Plot)正是这样一种工具&#x…...
基于RFM模型的电商用户价值分层画像分析
摘要本项目旨在通过Python对电商平台用户行为数据进行深度挖掘与分析,以构建用户画像为核心,实现对高价值用户、低价值用户及“白嫖党”的精准分层。项目基于RFM(Recency, Frequency, Monetary)模型理论,通过数据清洗、…...
