当前位置: 首页 > article >正文

# 基于BERT的文本分类

基于BERT的文本分类项目的实现

一、项目背景

该文本分类项目主要是情感分析,二分类问题,以下是大致流程及部分代码示例:


二、数据集介绍

2.1 数据集基本信息

数据集自定义
类型二分类(正面/负面)
样本量训练集 + 验证集 + 测试集
文本长度平均x字(最大x字)
领域商品评论、影视评论
# 加载数据集
dataset = pd.read_csv('data/train.txt', sep='\t')
print(dataset['train'][0])
# 输出:{'text': '这个手机性价比超高,拍照效果惊艳!', 'label': 1}

2.2 数据分析

2.2.1 句子长度分布
import matplotlib.pyplot as pltdef analyze_length(texts):lengths = [len(t) for t in texts]plt.figure(figsize=(12,5))plt.hist(lengths, bins=30, range=(0,256), color='blue', alpha=0.7)plt.title("文本长度分布", fontsize=14)plt.xlabel("字符数")plt.ylabel("样本量")plt.show()analyze_length(dataset['train']['text'])
2.2.2 标签分布
import pandas as pdpd.Series(dataset['train']['label']).value_counts().plot(kind='pie',autopct='%1.1f%%',title='类别分布(0-负面 1-正面)'
)
plt.show()
2.2.3 类别平衡处理
from torch.utils.data import WeightedRandomSampler# 计算类别权重
labels = dataset['train']['label']
class_weights = 1 / torch.Tensor([len(labels)-sum(labels), sum(labels)])
sampler = WeightedRandomSampler(weights=[class_weights[label] for label in labels],num_samples=len(labels),replacement=True
)

三、数据处理

3.1 BERT分词器

from transformers import BertTokenizertokenizer = BertTokenizer.from_pretrained('bert-base-chinese')def collate_fn(batch):texts = [item['text'] for item in batch]labels = [item['label'] for item in batch]# BERT编码inputs = tokenizer(texts,padding=True,truncation=True,max_length=256,return_tensors='pt')return {'input_ids': inputs['input_ids'],'attention_mask': inputs['attention_mask'],'labels': torch.LongTensor(labels)}

3.2 数据加载器

from torch.utils.data import DataLoadertrain_loader = DataLoader(dataset['train'],batch_size=32,collate_fn=collate_fn,sampler=sampler
)val_loader = DataLoader(dataset['validation'],batch_size=32,collate_fn=collate_fn
)

四、模型构建

4.1 BERT分类模型

import torch.nn as nn
from transformers import BertModelclass BertClassifier(nn.Module):def __init__(self):super().__init__()self.bert = BertModel.from_pretrained('bert-base-chinese')self.dropout = nn.Dropout(0.1)self.fc = nn.Linear(768, 2)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids, attention_mask)pooled = self.dropout(outputs.pooler_output)return self.fc(pooled)

4.2 模型配置

import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

五、模型训练与验证

5.1 训练流程

from tqdm import tqdmdef train_epoch(model, loader):model.train()total_loss = 0for batch in tqdm(loader):optimizer.zero_grad()input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)outputs = model(input_ids, attention_mask)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(loader)

5.2 验证流程

def evaluate(model, loader):model.eval()correct = 0total = 0with torch.no_grad():for batch in loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)outputs = model(input_ids, attention_mask)preds = torch.argmax(outputs, dim=1)correct += (preds == labels).sum().item()total += len(labels)return correct / total

六、实验结果

6.1 评估指标

Epoch训练Loss验证准确率测试准确率
# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix
import seaborn as snsdef plot_confusion_matrix(loader):y_true = []y_pred = []model.eval()with torch.no_grad():for batch in loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)outputs = model(input_ids, attention_mask)preds = torch.argmax(outputs, dim=1)y_true.extend(labels.cpu().numpy())y_pred.extend(preds.cpu().numpy())cm = confusion_matrix(y_true, y_pred)sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')plt.title('混淆矩阵')plt.xlabel('预测标签')plt.ylabel('真实标签')plt.show()plot_confusion_matrix(test_loader)

6.2 学习曲线

# 记录训练过程
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter()
for epoch in range(3):train_loss = train_epoch(model, train_loader)val_acc = evaluate(model, val_loader)writer.add_scalar('Loss/Train', train_loss, epoch)writer.add_scalar('Accuracy/Validation', val_acc, epoch)

七、流程架构图

原始文本
分词编码
BERT特征提取
全连接分类
损失计算
反向传播
模型评估

相关文章:

# 基于BERT的文本分类

基于BERT的文本分类项目的实现 一、项目背景 该文本分类项目主要是情感分析,二分类问题,以下是大致流程及部分代码示例: 二、数据集介绍 2.1 数据集基本信息 数据集自定义类型二分类(正面/负面)样本量训练集 验证…...

C++学习之服务器EPOLL模型、处理客户端请求、向客户端回复数、向客户端发送文件

目录 1.启动epoll模型 2.和客户端建立新连接 3.接受客户端Http请求数据 4.代码回顾从接受的数据中读出请求行 5.请求行解析 6.正则表达式以及匹配 7.解析请求行以及后续处理 8.对path处理说明 9.如何回复响应数据 10.对文件对应content-type如何查询 11.服务器处理流…...

BUUCTF-web刷题篇(17)

26.BabyUpload 源码:https://github.com/imaginiso/GXY_CTF/tree/master/Web/babyupload 查看题目源码: 写着:SetHandler application/x-httpd-php 通过源码可以看出这道文件上传题目主要还是考察.htaccess配置文件的特性,倘若…...

国网B接口协议调阅实时视频接口流程详解以及检索失败原因(电网B接口)

文章目录 一、B接口协议调阅实时视频接口介绍B.6.1 接口描述B.6.2 接口流程B.6.3 接口参数B.6.3.1 SIP头字段B.6.3.2 SIP响应码B.6.3.3 SDP参数定义B.6.3.4 RTP动态Payload定义 B.6.4 消息示例B.6.4.1 调阅实时视频请求B.6.4.2 调阅实时视频请求响应 二、B接口调阅实时视频失败…...

windows11下pytorch(cpu)安装

先装anaconda 见最下方 Pytorch 官网:PyTorch 找到下图(不要求版本一样)(我的电脑是集显(有navdia的装gpu),装cpu) 查看已有环境列表 创建环境 conda create –n 虚拟环境名字(…...

NVR接入录像回放平台用EasyCVR打造地下车库安防:大型商居安全优选方案

一、背景分析 随着居民生活品质的提升,大型商业建筑和住宅小区纷纷配套建设地下停车库。但是地下车库盗窃、失火、恶意毁坏车辆、外部人员随意进出等事件频发,部署视频监控系统成为保障地下车库的安全关键举措。 目前,很多商业和住宅都会在…...

玻璃期货数据下载与分析:Python金融实战分享

期货数据下载与分析:Python实战分享 引言 在金融市场中,期货分析是一项重要的工作,而获取准确且及时的数据是进行有效分析的基础。今天,我们将深入探讨一段使用Python编写的代码,该代码用于从郑州商品交易所&#xf…...

excel常见错误包括(#N/A、#VALUE!、#REF!、#DIV/0!、#NUM!、#NAME?、#NULL! )

目录 1. #N/A2. #VALUE!3. #REF!4. #DIV/0!5. #NUM!6. #NAME?7. #NULL!8.图表总结 在 Excel 中,可能会遇到以下常见的错误值,每个都有特定的含义和成因: 1. #N/A 含义: 表示“Not Available”(不可用)。…...

乾元通渠道商中标川藏铁路西藏救援队应急救援装备项目

乾元通渠道商中标川藏铁路西藏救援队应急救援装备项目,项目内通信指挥车基于最新一代应急指挥车解决方案打造,配合乾元通自研的车载多链路聚合路由及系统,主要用途为保障应急通讯,满足任务执行时指挥协调、通信联络及数据传输的要…...

数学知识——矩阵乘法

使用矩阵快速幂优化递推问题 对于一个递推问题,如递推式的每一项系数都为常数,我们可以使用矩阵快速幂来对算法进行优化。 一般形式为: F n F 1 A n − 1 F_nF_1A^{n-1} Fn​F1​An−1 由于递推式的每一项系数都为常数,因此对…...

左右开弓策略思路

一、策略概述 本策略是一种基于多种技术指标的复杂交易策略,包括自定义指标计算、过滤平滑处理以及交易信号生成。 该策略通过不同的交易平台代码段实现,旨在通过分析历史价格数据来预测未来价格走势,并据此生成交易信号。 二、主要技术指标…...

将jar包制作成deb一键安装包

文章目录 准备环境准备deb包结构构建Deb包测试安装常用操作命令 本文介绍如何将java运行环境、jar程序一起打包成一个deb格式的安装包,创建桌面图标,通过点击图标可使用系统自带浏览器快捷访问web服务的URL,同时注册服务并配置好开机自启。 准…...

Java 常用安全框架的 授权模型 对比分析,涵盖 RBAC、ABAC、ACL、基于权限/角色 等模型,结合框架实现方式、适用场景和优缺点进行详细说明

以下是 Java 常用安全框架的 授权模型 对比分析,涵盖 RBAC、ABAC、ACL、基于权限/角色 等模型,结合框架实现方式、适用场景和优缺点进行详细说明: 1. 授权模型类型与定义 模型名称定义特点RBAC(基于角色的访问控制)通…...

aws平台练习

注册 AWS 账户 访问 AWS 官方网站,点击“免费注册”按钮,按照提示完成账户注册: 提供电子邮件地址、密码和电话号码。 验证身份(可能需要手机验证码)。 设置 billing 信息。 2. 登录 AWS 管理控制台 使用注册的邮箱和…...

力扣DAY40-45 | 热100 | 二叉树:直径、层次遍历、有序数组->二叉搜索树、验证二叉搜索树、二叉搜索树中第K小的元素、右视图

前言 简单、中等 √ 好久没更了,感觉二叉树来回就那些。有点变懒要警醒,不能止步于笨方法!! 二叉树的直径 我的题解 遍历每个节点,左节点最大深度右节点最大深度当前节点当前节点为中心的直径。如果左节点深度更大…...

【MYSQL从入门到精通】数据类型及建表

一些基础操作语句 1.使用客户端工具连接数据库服务器:mysql -uroot -p 2.查看所有数据库:show databases; 3.创建属于自己的数据库: create database 数据库名;create database if not exists 数据库名; 强烈建议大家在建立数据库时指定编…...

【动态规划】 深入动态规划—两个数组的dp问题

文章目录 前言例题一、最长公共子序列二、不相交的线三、不同的子序列四、通配符匹配五、交错字符串六、两个字符串的最小ASCII删除和七、最长重复子数组 结语 前言 问题本质 它主要围绕着给定的两个数组展开,旨在通过对这两个数组元素间关系的分析,找出…...

结合大语言模型整理叙述并生成思维导图的思路

楔子 我比较喜欢长篇大论。这在代理律师界被视为一种禁忌。 我高中一年级的时候因为入学成绩好(所在县榜眼名次),直接被所在班的班主任任命为班长。我其实不喜欢这个岗位。因为老师一来就要提前注意到,要及时喊“起立”、英语课…...

Kotlin学习

kotlin android 开源,Kotlin开源项目集合_晚安 呼-华为开发者空间 干货来袭,推荐几款开源的Kotlin的Android项目 https://zhuanlan.zhihu.com/p/536789267 【已解决】 ubuntu apt-get update连不上dl.google.com_为什么不能ping谷歌-CSDN博客...

若依前后端分离版本从mysql切换到postgresql数据库

一、修改依赖&#xff1a; 修改admin模块pom.xml中的依赖,屏蔽或删除mysql依赖&#xff0c;增加postgresql依赖。 <!-- Mysql驱动包 --> <!--<dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId> &l…...

【力扣hot100题】(073)数组中的第K个最大元素

花了两天时间搞明白答案的快速排序和堆排序。 两种都写了一遍&#xff0c;感觉堆排序更简单很多。 两种都记录一下&#xff0c;包括具体方法和易错点。 快速排序 class Solution { public:vector<int> nums;int quicksort(int left,int right,int k){if(leftright) r…...

【AAOS】【源码分析】CarAudioService(二)-- 功能介绍

汽车音频是 Android 汽车操作系统 (AAOS) 的一项功能,允许车辆播放信息娱乐声音,例如媒体、导航和通信。AAOS 不负责具有严格可用性和时间要求的铃声和警告,因为这些声音通常由车辆的硬件处理。将汽车音频服务集成在汽车中,彻底改变了驾驶体验,为驾驶员和乘客提供了音乐、…...

mapbox基础,加载F4Map二维地图

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性二、🍀F4Map 简介2.1 ☘️技术特点2.2 ☘️核…...

Android:Android Studio右侧Gradle没有assembleRelease等选项

旧版as是“Do not build Gradle task list during Gradle sync” 操作这个选项。 参考这篇文章&#xff1a;Android Studio Gradle中没有Task任务&#xff0c;没有Assemble任务&#xff0c;不能方便导出aar包_gradle 没有task-CSDN博客 在as2024版本中&#xff0c;打开Setting…...

DRAM CRC:让DDR5内存数据更靠谱

DRAM(动态随机存取存储器)是电脑内存的核心部件,负责存储和传输数据。如果数据在传输中出错,后果可能很严重,比如程序崩溃或者数据损坏。为了解决这个问题,DDR5内存引入了一个新功能,叫DRAM CRC(循环冗余校验)。简单来说,它是用来检查读写数据有没有问题的工具。 下面…...

RAI Toolbox详解

RAI Toolbox详解 摘要 RAI Toolbox是一个综合性的工具集&#xff0c;旨在帮助开发者和AI系统利益相关者更负责任地开发和监控AI系统&#xff0c;并做出更好的数据驱动决策。本文将详细介绍RAI Toolbox的功能、使用场景以及与类似AI项目的对比&#xff0c;帮助读者全面了解RAI…...

心率测量-arduino+matlab

参考&#xff1a;【教程】教你玩转Stduino之手指心跳检测模块 - 知乎 (zhihu.com) 1 原理 心跳检测模块&#xff0c;由一个红外线发射LED和红外接收器构成。手指心跳监测模块能够测量脉搏&#xff0c;是这样工作的&#xff1a;当手指放在发射器与接收器之间&#xff0c;红外发射…...

H3C的MSTP+VRRP高可靠性组网技术(MSTP单域)

以下内容纯为博主分享自己的想法和理解&#xff0c;如有错误轻喷 MSTP多生成树协议可以基于不同实例实现不同VLAN之间的负载分担 VRRP虚拟路由器冗余协议可以提高网关的可靠性防止单点故障的可能 在以前这两种协议通常一起搭配组网&#xff0c;来提高网络的可靠性和稳定性&a…...

字符串替换 (模拟)神奇数 (数学)DNA序列 (固定长度的滑动窗口)

⭐️个人主页&#xff1a;小羊 ⭐️所属专栏&#xff1a;每日两三题 很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~ 目录 字符串替换 &#xff08;模拟&#xff09;神奇数 &#xff08;数学&#xff09;DNA序列 &#xff08;固定长度的滑动窗口&am…...

Centos7下安装hive详细步骤

在Centos 7系统上安装Hive的步骤如下&#xff1a; 下载Hive&#xff1a;首先&#xff0c;在Apache Hive的官方网站上下载最新版本的Hive压缩包&#xff0c;地址为&#xff1a;https://hive.apache.org/downloads.html。选择合适的版本并下载。 解压Hive压缩包&#xff1a;将下…...