当前位置: 首页 > news >正文

Bert中文文本分类

这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main

我使用的训练环境是

pip install torch==2.0.0;
pip install transformers==4.30.2;
pip install gensim==4.3.3;
pip install huggingface-hub==0.15.1;
pip install modelscope==1.20.1;

一、准备训练数据

1.1 准备中文文本分类任务的训练数据

这里Demo数据如下:

各银行信用卡挂失费迥异 北京银行收费最高    0
莫泰酒店流拍 大摩叫价或降至6亿美元 4
乌兹别克斯坦议会立法院主席获连任   6
德媒披露鲁能引援关键人物 是他力荐德甲亚洲强人    7
辉立证券给予广汽集团持有评级 2
图文-业余希望赛海南站第二轮 球场的菠萝蜜  7
陆毅鲍蕾:近乎完美的爱情(组图)(2)    9
7000亿美元救市方案将成期市毒药  0
保诚启动210亿美元配股交易以融资收购AIG部门   2

分类class类别文件:

finance
realty
stocks
education
science
society
politics
sports
game
entertainment

1.2 数据读取和截断,使满足BERT模型输入

读取训练数据,对文本进行处理,如截取过长的文本、补齐较短的文本,加上起始标示、对文本进行编码、添加掩码、转为tensor等操作。

import os
from config import parsers
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torchfrom transformers import AutoTokenizer, AutoModelForMaskedLMdef read_data(file):# 读取文件all_data = open(file, "r", encoding="utf-8").read().split("\n")# 得到所有文本、所有标签、句子的最大长度texts, labels, max_length = [], [], []for data in all_data:if data:text, label = data.split("\t")max_length.append(len(text))texts.append(text)labels.append(label)# 根据不同的数据集返回不同的内容if os.path.split(file)[1] == "train.txt":max_len = max(max_length)return texts, labels, max_lenreturn texts, labels,class MyDataset(Dataset):def __init__(self, texts, labels, max_length):self.all_text = textsself.all_label = labelsself.max_len = max_lengthself.tokenizer = BertTokenizer.from_pretrained(parsers().bert_pred)
#         self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")def __getitem__(self, index):# 取出一条数据并截断长度text = self.all_text[index][:self.max_len]label = self.all_label[index]# 分词text_id = self.tokenizer.tokenize(text)# 加上起始标志text_id = ["[CLS]"] + text_id# 编码token_id = self.tokenizer.convert_tokens_to_ids(text_id)# 掩码  -》mask = [1] * len(token_id) + [0] * (self.max_len + 2 - len(token_id))# 编码后  -》长度一致token_ids = token_id + [0] * (self.max_len + 2 - len(token_id))# str -》 intlabel = int(label)# 转化成tensortoken_ids = torch.tensor(token_ids)mask = torch.tensor(mask)label = torch.tensor(label)return (token_ids, mask), labeldef __len__(self):# 得到文本的长度return len(self.all_text)

将文本处理后,就可以使用torch.utils.data中自带的DataLoader模块来加载训练数据了。

二、微调BERT模型

我们是微调BERT模型,需要获取BERT最后一个隐藏层的输出作为输入到下一个全连接层。

至于选择BERT模型的哪个输出作为linear层的输入,可以通过实验尝试,或者遵循常理。

pooler_output:这是通过将最后一层的隐藏状态的第一个token(通常是[CLS] token)通过一个线性层和激活函数得到的输出,常用于分类任务。
last_hidden_state:这是模型所有层的最后一个隐藏状态的输出,包含了整个序列的上下文信息,适用于序列级别的任务。

简单调用下BERT模型,打印出来最后一层看下:

import torch
import time
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLMdef process_text(text, bert_pred):tokenizer = BertTokenizer.from_pretrained(bert_pred)token_id = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokenizer.tokenize(text))mask = [1] * len(token_id) + [0] * (38 + 2 - len(token_id))token_ids = token_id + [0] * (38 + 2 - len(token_id))token_ids = torch.tensor(token_ids).unsqueeze(0)mask = torch.tensor(mask).unsqueeze(0)x = torch.stack([token_ids, mask])return xdevice = "cpu"
bert = BertModel.from_pretrained('./bert-base-chinese/')
texts = ["沈腾和马丽的电影《独行月球》挺好看"]
for text in texts:x = process_text(text, './bert-base-chinese/')input_ids, attention_mask = x[0].to(device), x[1].to(device)hidden_out = bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) print(hidden_out)

 输出结果:

2.1 文本分类任务,选择使用pooler_output作为线性层的输入。

import torch.nn as nn
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from config import parsers
import torchclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.args = parsers()self.device = "cuda:0" if torch.cuda.is_available() else "cpu"  self.bert = BertModel.from_pretrained(self.args.bert_pred) # bert 模型进行微调for param in self.bert.parameters():param.requires_grad = True# 一个全连接层self.linear = nn.Linear(self.args.num_filters, self.args.class_num)def forward(self, x):input_ids, attention_mask = x[0].to(self.device), x[1].to(self.device)hidden_out = self.bert(input_ids, attention_mask=attention_mask,output_hidden_states=False)  # 是否输出所有encoder层的结果# shape (batch_size, hidden_size)  pooler_output -->  hidden_out[0]pred = self.linear(hidden_out.pooler_output)# 返回预测结果return pred

2.2 优化器使用Adam、损失函数使用交叉熵损失函数

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MyModel().to(device)
opt = AdamW(model.parameters(), lr=args.learn_rate)
loss_fn = nn.CrossEntropyLoss()

三、训练模型

3.1 参数配置

def parsers():parser = argparse.ArgumentParser(description="Bert model of argparse")parser.add_argument("tx_date",nargs='?') #可选输入参数,计算日期parser.add_argument("--train_file", type=str, default=os.path.join("./data_all", "train.txt"))parser.add_argument("--dev_file", type=str, default=os.path.join("./data_all", "dev.txt"))parser.add_argument("--test_file", type=str, default=os.path.join("./data_all", "test.txt"))parser.add_argument("--classification", type=str, default=os.path.join("./data_all", "class.txt"))parser.add_argument("--bert_pred", type=str, default="./bert-base-chinese")parser.add_argument("--class_num", type=int, default=12)parser.add_argument("--max_len", type=int, default=38)parser.add_argument("--batch_size", type=int, default=32)parser.add_argument("--epochs", type=int, default=10)parser.add_argument("--learn_rate", type=float, default=1e-5)parser.add_argument("--num_filters", type=int, default=768)parser.add_argument("--save_model_best", type=str, default=os.path.join("model", "all_best_model.pth"))parser.add_argument("--save_model_last", type=str, default=os.path.join("model", "all_last_model.pth"))args = parser.parse_args()return args

3.2 模型训练

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from sklearn.metrics import accuracy_score
import timeif __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"print("device:", device)train_text, train_label, max_len = read_data(args.train_file)dev_text, dev_label = read_data(args.dev_file)args.max_len = max_lentrain_dataset = MyDataset(train_text, train_label, args.max_len)train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)dev_dataset = MyDataset(dev_text, dev_label, args.max_len)dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False)model = MyModel().to(device)opt = AdamW(model.parameters(), lr=args.learn_rate)loss_fn = nn.CrossEntropyLoss()acc_max = float("-inf")for epoch in range(args.epochs):loss_sum, count = 0, 0model.train()for batch_index, (batch_text, batch_label) in enumerate(train_dataloader):batch_label = batch_label.to(device)pred = model(batch_text)loss = loss_fn(pred, batch_label)opt.zero_grad()loss.backward()opt.step()loss_sum += losscount += 1# 打印内容if len(train_dataloader) - batch_index <= len(train_dataloader) % 1000 and count == len(train_dataloader) % 1000:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0if batch_index % 1000 == 999:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0model.eval()all_pred, all_true = [], []with torch.no_grad():for batch_text, batch_label in dev_dataloader:batch_label = batch_label.to(device)pred = model(batch_text)pred = torch.argmax(pred, dim=1).cpu().numpy().tolist()label = batch_label.cpu().numpy().tolist()all_pred.extend(pred)all_true.extend(label)acc = accuracy_score(all_pred, all_true)print(f"dev acc:{acc:.4f}")if acc > acc_max:print(acc, acc_max)acc_max = acctorch.save(model.state_dict(), args.save_model_best)print(f"以保存最佳模型")torch.save(model.state_dict(), args.save_model_last)end = time.time()print(f"运行时间:{(end-start)/60%60:.4f} min")

模型保存为:

-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_best_model.pth
-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_last_model.pth

四、模型推理预测

准备预测文本文件,加载模型,进行文本的类别预测。


def text_class_name(pred):result = torch.argmax(pred, dim=1)print(torch.argmax(pred, dim=1).cpu().numpy().tolist())result = result.cpu().numpy().tolist()classification = open(args.classification, "r", encoding="utf-8").read().split("\n")classification_dict = dict(zip(range(len(classification)), classification))print(f"文本:{text}\t预测的类别为:{classification_dict[result[0]]}")if __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"model = load_model(device, args.save_model_best)texts = ["沈腾和马丽的新电影《独行月球》好看", "最近金融环境不太好,投资需谨慎"]print("模型预测结果:")for text in texts:x = process_text(text, args.bert_pred)with torch.no_grad():pred = model(x)text_class_name(pred)end = time.time()print(f"耗时为:{end - start} s")

以上,基本流程完成。当然模型还需要调优来改进预测效果的。

代码是实际跑通的,我训练和预测均使用的是GPU。如果是使用GPU做模型训练,再使用CPU做推理预测的情况,推理预测加载模型的时候注意修改下:

 myModel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

Done

相关文章:

Bert中文文本分类

这是一个经典的文本分类问题&#xff0c;使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main 我使用的训练环境是 pip install torch2.0.0; pi…...

【深度学习】Java DL4J基于 CNN 构建车辆识别与跟踪模型

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家&#xff0c;历代文学网&#xff08;PC端可以访问&#xff1a;https://literature.sinhy.com/#/?__c1000&#xff0c;移动端可微信小程序搜索“历代文学”&#xff09;总架构师&#xff0c;15年工作经验&#xff0c;精通Java编…...

【C#】C#打印当前时间以及TimeSpan()介绍

1. C#打印当前时间 string currentDate DateTime.Now.ToString("yyyy-MM-dd HH:mm:ss.fff");Console.WriteLine(currentDate);2. TimeSpan()介绍 TimeSpan(long ticks)的单位是100ns //500ms new TimeSpan(10*1000*500);参考&#xff1a; C#-TimeSpan-计算时间差...

【Linux 网络 (五)】Tcp/Udp协议

Linux 网络 一前言二、Udp协议1&#xff09;、Udp协议特点2&#xff09;、Udp协议格式3&#xff09;、Udp报文封装和解包过程4&#xff09;、UDP的缓冲区 三、TCP协议1&#xff09;、TCP协议特点2&#xff09;、TCP协议格式1、4位首部长度、源端口、目的端口2、16位窗口大小3、…...

多旋翼无人机理论 | 四旋翼动力学数学模型与Matlab仿真

多旋翼无人机理论 | 四旋翼动力学数学模型与Matlab仿真 力的来源数学模型数学模型总结Matlab 仿真 力的来源 无人机的动力系统&#xff1a;电调-电机-螺旋桨 。 给人最直观的感受就是 电机带动螺旋桨转&#xff0c;产生升力。 螺旋桨旋转产生升力的原因&#xff0c;在很多年…...

Vue3项目中引入TailwindCSS(图文详情)

Vue3项目中引入TailwindCSS&#xff08;图文详细&#xff09; Tailwind CSS 是一个实用工具优先的 CSS 框架&#xff0c;提供丰富的低级类&#xff08;如 text-center、bg-blue-500&#xff09;&#xff0c;允许开发者通过组合这些类快速构建自定义设计&#xff0c;而无需编写…...

【开源项目】数字孪生化工厂—开源工程及源码

飞渡科技数字孪生化工厂管理平台&#xff0c;基于自研孪生引擎&#xff0c;将物联网IOT、人工智能、大数据、云计算等技术应用于化工厂&#xff0c;为化工厂提供实时数据分析、工艺优化、设备运维等功能&#xff0c;助力提高生产效率以及提供安全保障。 通过可视化点位标注各厂…...

咨询团队如何通过轻量型工具优化项目管理和提高团队协作效率?

引言 在咨询行业&#xff0c;项目的复杂性和多样性往往意味着团队成员需要协同工作、迅速适应客户需求的变化并且在较短的时间内交付高质量的成果。对于咨询团队来说&#xff0c;选择一个适合的项目管理工具&#xff0c;不仅能够提高工作效率&#xff0c;还能促进团队的协作、…...

javaWeb开发

Java Web开发作为软件开发领域的一个重要分支&#xff0c;已经历经数十年的发展&#xff0c;并凭借其强大的跨平台能力、丰富的生态系统以及高度的安全性&#xff0c;成为构建企业级应用的首选技术之一。以下是对Java Web开发的详细解析&#xff1a; 一、Java Web开发的基本概…...

如何在 Vue 中处理 API 请求?

在 Vue.js 中处理 API 请求是构建动态、交互式 Web 应用程序的核心部分。为了有效地与后端服务器通信&#xff0c;Vue 生态系统提供了多种方式来发起和管理 API 请求。以下是几种常见的方法和最佳实践&#xff1a; 1. 使用 Axios Axios 是一个基于 Promise 的 HTTP 客户端&am…...

基于Debian的Linux发行版的包管理工具

基于Debian的Linux发行版中除了apt和apt-get之外&#xff0c;还有以下几种包管理工具&#xff1a; dpkg&#xff1a;这是Debian系发行版中最基础的包管理工具&#xff0c;专门用于安装、卸载和查询.deb包。与高级包管理器不同&#xff0c;dpkg不自动解决包的依赖关系&#xff0…...

2022年国家公考《申论》题(行政执法)

2022年国家公考《申论》题&#xff08;行政执法&#xff09; 材料一 新型冠状病毒肺炎疫情发生后&#xff0c;党中央、国务院出台了一系列支持企业发展的惠企政策。N市积极落实各项惠企政策&#xff0c;不断优化营商环境&#xff0c;推动区域经济高质量跨越式发展。   “当时…...

贪心算法(常见贪心模型)

常见贪心模型 简单排序模型 最小化战斗力差距 题目分析&#xff1a; #include <bits/stdc.h> using namespace std;const int N 1e5 10;int n; int a[N];int main() {// 请在此输入您的代码cin >> n;for (int i 1;i < n;i) cin >> a[i];sort(a1,a1n);…...

git自动压缩提交的脚本

可以将当前未提交的代码自动执行 git addgit commitgit squash Git 命令安装指南 1. 创建脚本目录 如果目录不存在&#xff0c;创建它&#xff1a; mkdir -p ~/.local/bin2. 创建脚本文件 vim ~/.local/bin/git-squash将完整的脚本代码复制到此文件中。 3. 设置脚本权限…...

Kinova在开源家庭服务机器人TidyBot++研究里大展身手

在科技日新月异的今天&#xff0c;机器人技术在家庭场景中的应用逐渐成为现实&#xff0c;改变着我们的生活方式。今天&#xff0c;我们将深入探讨一篇关于家用机器人研究的论文&#xff0c;剖析其中的创新成果&#xff0c; 论文引用链接&#xff1a;http://tidybot2.github.i…...

使用 Spring Boot 实现文件上传:从配置文件中动态读取上传路径

使用 Spring Boot 实现文件上传&#xff1a;从配置文件中动态读取上传路径 一、前言二、文件上传的基本概念三、环境准备1. 引入依赖2. 配置文件设置application.yml 配置示例&#xff1a;application.properties 配置示例&#xff1a; 四、编写文件上传功能代码1. 控制器类2. …...

《鸿蒙HarmonyOS应用开发从入门到精通(第2版)》学习笔记——HarmonyOS技术理念

1.2 技术理念 在万物智联时代重要机遇期&#xff0c;HarmonyOS结合移动生态发展的趋势&#xff0c;提出了三大技术理念&#xff08;如下图3-1所示&#xff09;&#xff1a;一次开发&#xff0c;多端部署&#xff1b;可分可合&#xff0c;自由流转&#xff1b;统一生态&#xf…...

将多个 k8s yaml 配置文件合并为一个文件

如下bash脚本实现功能 “将多个k8s的yaml 配置文件” 合并为一个 yaml&#xff0c;使用 --- 分割文件配置。 创建文件 merge_yaml.sh &#xff0c;内容如下&#xff1a; #!/bin/bash# 默认参数 input_patterns() # 匹配的文件模式数组 output_file"combined.yaml"…...

Linux 文件的特殊权限—Sticky Bit(SBIT)权限

本文为Ubuntu Linux操作系统- 第十九期~~ 其他特殊权限: 【SUID 权限】和【SGID 权限】 更多Linux 相关内容请点击&#x1f449;【Linux专栏】~ 主页&#xff1a;【练小杰的CSDN】 文章目录 Sticky&#xff08;SBIT&#xff09;权限基本概念Sticky Bit 的表示方式举例 设置和取…...

MIPI D-PHY/C-PHY/M-PHY 高速串行接口标准

MIPI D-PHY、C-PHY和M-PHY都是MIPI联盟制定的高速串行接口标准。它们都具有低功耗、高速传输速率等特点&#xff0c;但各有侧重&#xff1a; ➢MIPI D-PHY&#xff1a;适用于手机与其他设备之间的数据传输。 ➢MIPI C-PHY&#xff1a;专为手机摄像头而设计。 ➢MIPI M-PHY&am…...

Leetcode 3577. Count the Number of Computer Unlocking Permutations

Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接&#xff1a;3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯&#xff0c;要想要能够将所有的电脑解锁&#x…...

跨链模式:多链互操作架构与性能扩展方案

跨链模式&#xff1a;多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈&#xff1a;模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展&#xff08;H2Cross架构&#xff09;&#xff1a; 适配层&#xf…...

汇编常见指令

汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX&#xff08;不访问内存&#xff09;XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南

文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/55aefaea8a9f477e86d065227851fe3d.pn…...

【开发技术】.Net使用FFmpeg视频特定帧上绘制内容

目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法&#xff0c;当前调用一个医疗行业的AI识别算法后返回…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件&#xff0c;用于在原生应用中加载 HTML 页面&#xff1a; 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...

基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解

JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用&#xff0c;结合SQLite数据库实现联系人管理功能&#xff0c;并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能&#xff0c;同时可以最小化到系统…...

PHP 8.5 即将发布:管道操作符、强力调试

前不久&#xff0c;PHP宣布了即将在 2025 年 11 月 20 日 正式发布的 PHP 8.5&#xff01;作为 PHP 语言的又一次重要迭代&#xff0c;PHP 8.5 承诺带来一系列旨在提升代码可读性、健壮性以及开发者效率的改进。而更令人兴奋的是&#xff0c;借助强大的本地开发环境 ServBay&am…...

comfyui 工作流中 图生视频 如何增加视频的长度到5秒

comfyUI 工作流怎么可以生成更长的视频。除了硬件显存要求之外还有别的方法吗&#xff1f; 在ComfyUI中实现图生视频并延长到5秒&#xff0c;需要结合多个扩展和技巧。以下是完整解决方案&#xff1a; 核心工作流配置&#xff08;24fps下5秒120帧&#xff09; #mermaid-svg-yP…...

五子棋测试用例

一.项目背景 1.1 项目简介 传统棋类文化的推广 五子棋是一种古老的棋类游戏&#xff0c;有着深厚的文化底蕴。通过将五子棋制作成网页游戏&#xff0c;可以让更多的人了解和接触到这一传统棋类文化。无论是国内还是国外的玩家&#xff0c;都可以通过网页五子棋感受到东方棋类…...