当前位置: 首页 > news >正文

BERT文本分类(PyTorch和Transformers)畅用七个模型架构

(PyTorch)BERT文本分类:七种模型架构

🌟 1. 介绍

使用BERT完成文本分类任务(如情感分析,新闻文本分类等等)对于NLPer已经是很基础的工作了!虽说已迈入LLM时代,但是BERT这类较小的传统预训练模型的作用同样不可小觑,常常在某些场景下达到出乎意料的效果。本文介绍了一个利用BERT完成文本分类的开源项目,项目源码 Bert-Text-Classification。项目包含了7条基线,每条基线运行起来所需的显存都很小,用来作baseline非常的不错。

为方便下载:夸克网盘链接,大家也可直接进入GitHub地址访问项目(同时给作者点个🌟哈哈哈)
另外,当前很多的BERT文本分类项目都是很老的包了,可能在配置环境时会出现问题。但本项目使用的包在当前主流的服务器上肯定可以配置成功!按照本项目的依赖文件去安装肯定不会出问题!

🔍 2. 模型架构解析

模型类型模型描述适用场景
BertOrigin标准BERT实现基线对比实验
BertATT动态注意力机制增强长文本关键信息提取
BertLSTMRNN时序建模能力融合序列依赖性强的任务
BertCNN局部特征提取器设计短文本模式识别
BertDPCNN深度金字塔卷积架构层次化特征学习
BertRCNNRNN-CNN混合架构上下文敏感的分类任务
BertCNNPlus全局-局部特征融合机制需要综合语义理解的场景

🛠 3. 快速开始

话不多说,直接开始复现步骤(作者已按照以下步骤完成复现,若遇到问题,欢迎在评论区提出,或者在GitHub的issue中提出,issue中提出问题会有邮件通知作者)

3.1 环境要求

  • Python 3.8
  • 依赖安装:pip install -r requirements.txt
# 创建conda环境
conda create -n bert2text python=3.8
# 安装依赖
cd xx/xx/Bert-Text-Classification
pip install -r requirements.txt

3.2 数据准备

  • 文本分类数据集,包含训练集、验证集、测试集
  • 单条数据格式:label \t text,如:
  • 0	我爱北京天安门
    1	这个电影真好看
    
  • 数据集存放路径:data/,数据集文件名:train.tsvdev.tsvtest.tsv

源码中的数据集因版权,只保留了前100条数据,若需要情感二分类SST2的完整数据集,可通过链接 SST2 访问下载。若是自己的数据集,请自行按照以上格式修改(注意分类类别需自行转换为数字标签)!比如:0对应“消极”,1对应“积极"

3.3 模型训练和评估

0️⃣ 修改相应的运行文件的配置信息,如run_STT2.py

# 数据集路径
data_dir = "./data/SST2"
# 标签列表
label_list = ["0", "1"] # SST2数据集标签列表(SST2是情感二分类数据集)
# 所需运行的模型名称列表
model_name_list = ["BertOrigin", "BertATT", "BertCNN", "BertCNNPlus", "BertDPCNN", "BertRCNN", "BertLSTM"] # 运行七条基线模型
# 模型保存路径、缓存保存路径、日志保存路径
output_dir = "./sst2_output/"
cache_dir = "./sst2_cache/"
log_dir = "./sst2_log/"
# BERT预训练模型路径,中文数据集使用"bert-base-chinese",英文数据集使用"bert-base-uncased"
model_name_or_path = "XXXX/XXX/bert-base-uncased"

需要运行几条基线,就在model_name_list 中保留哪几个模型的名称。中文文本分类需要使用中文版的BERT。

1️⃣ 运行训练脚本:

# 注意修改执行的数据集脚本名称,如运行SST2数据集时,run.sh中应该是python3 run_SST2.py
CUDA_VISIBLE_DEVICES=0 bash run.sh# run.sh中的参数说明:
max_seq_length:句子截断长度
num_train_epochs:训练轮数
do_train:是否训练
gpu_ids:使用的GPU编号,注意单卡训练时,gpu_ids为0
gradient_accumulation_steps:梯度累积步数
print_step:打印训练信息的步数(验证频率)
early_stop:早停步数,即验证集准确率连续early_stop次不再提升时,停止训练。当设置很大时,相当于关闭了早停功能。
train_batch_size:训练批次大小

第一次跑可以直接使用默认配置,看是否能运行成功。若爆显存可修改max_seq_length和train_batch_size大小。

2️⃣ 测试模型:

# 训练完成后,模型保存在`output_dir`目录下,如`./sst2_output/BertOrigin/`,包含模型文件、词表文件、配置文件等。
# 移除`run.sh`中的`do_train`参数,运行测试脚本
CUDA_VISIBLE_DEVICES=0 bash run.sh

在源码的设置中,其实每训练一个epoch后,都在测试集上进行了预测,同时记录了对应的结果。所以做不做步骤2其实都可以。保存的结果在./sst2_output/BertXXX/BertXXX/metric_info_for_test.json中。记录的指标包含:

{"epoch": 1,"auc": 0.5,"accuracy": 66.69859514687101,"P_macro_avg": 44.97381885071993,"R_macro_avg": 45.84622325341583,"F1_macro_avg": 44.36622804917629,"P_weighted_avg": 63.133033572230936,"R_weighted_avg": 66.69859514687101,"F1_weighted_avg": 63.44510333347591}

注意,若分类类别大于2,则AUC指标设置为0.5不变。

3.4 训练监控

可通过TensorBoard查看训练日志,日志路径示例:./sst2_log/
tensorboard的使用在此处就不再赘述了,若有需要可自行搜索教程!

相关文章:

BERT文本分类(PyTorch和Transformers)畅用七个模型架构

(PyTorch)BERT文本分类:七种模型架构 🌟 1. 介绍 使用BERT完成文本分类任务(如情感分析,新闻文本分类等等)对于NLPer已经是很基础的工作了!虽说已迈入LLM时代,但是BERT…...

两步在 Vite 中配置 Tailwindcss

第一步:安装依赖 npm i -D tailwindcss tailwindcss/vite第二步:引入 tailwindcss 更改配置 // src/main.js import tailwindcss/index// vite.config.js import vue from vitejs/plugin-vue import tailwindcss from tailwindcss/viteexport default …...

【vmware虚拟机安装教程】

以下是在VMware Workstation Pro上安装虚拟机的详细教程: 准备工作 下载VMware Workstation Pro 访问VMware官网下载并安装VMware Workstation Pro(支持Windows和Linux系统)。安装完成后,确保已激活软件(试用版或正式…...

文字转语音(三)FreeTTS实现

项目中有相关的功能,就简单研究了一下。 说明 FreeTTS 是一个基于 Java 的开源文本转语音(TTS)引擎,旨在将文字内容转换为自然语音输出。 FreeTTS 适合对 英文语音质量要求低、预算有限且需要离线运行 的场景,但若需…...

string类详解(上)

文章目录 目录1. STL简介1.1 什么是STL1.2 STL的版本1.3 STL的六大组件 2. 为什么学习string类3. 标准库中的string类3.1 string类3.2 string类的常用接口说明 目录 STL简介为什么学习string类标准库中的string类string类的模拟实现现代版写法的String类写时拷贝 1. STL简介 …...

Visual Studio Code使用ai大模型编成

1、在Visual Studio Code搜索安装roo code 2、去https://openrouter.ai/settings/keys官网申请个免费的配置使用...

外贸跨境订货系统流程设计、功能列表及源码输出

在全球化的商业环境下,外贸跨境订货系统对于企业拓展国际市场、提升运营效率至关重要。该系统旨在为外贸企业提供一个便捷、高效、安全的订货平台,实现商品展示、订单管理、物流跟踪等功能,满足跨境业务的多样化需求。以下将详细阐述外贸订货…...

TraeAi上手体验

一、Trae介绍 由于MarsCode 在国内由于规定限制,无法使用 Claude 3.5 Sonnet 模型,字节跳动选择在海外推出 Trae,官网:https://www.trae.ai/。 二、安装 1.下载安装Trae-Setup-x64.exe 2.注册登录 安装完成后,点击登…...

深入解析 vLLM:高性能 LLM 服务框架的架构之美(一)原理与解析

修改内容时间2.4.1处理请求的流程,引用更好的流程图2025.02.11首发2025.02.08 深入解析 vLLM:高性能 LLM 服务框架的架构之美(一)原理与解析 深入解析 vLLM:高性能 LLM 服务框架的架构之美(二)…...

thingboard告警信息格式美化

原始报警json内容: { "severity": "CRITICAL","acknowledged": false,"cleared": false,"assigneeId": null,"startTs": 1739801102349,"endTs": 1739801102349,"ackTs": 0,&quo…...

redis解决高并发看门狗策略

当一个业务执行时间超过自己设定的锁释放时间,那么会导致有其他线程进入,从而抢到同一个票,所有需要使用看门狗策略,其实就是开一个守护线程,让守护线程去监控key,如果到时间了还未结束,就会将这个key重新s…...

Python函数的函数名250217

函数名其实就是一个变量,这个变量就是代指函数而已函数也可以被哈希,所以函数名也可以当作集合中的元素,也可作为字典的key值 # 将函数作为字典中的值,可以避免写大量的if...else语句 def fun1():return 123 def fun2():return 4…...

Unity 获取独立显卡数量

获取独立显卡数量 导入插件包打开Demo 运行看控制台日志 public class GetGraphicCountDemo : MonoBehaviour{public int count;// Start is called before the first frame updatevoid Start(){count this.GetIndependentGraphicsDeviceCount();}}...

JAVA生产环境(IDEA)排查死锁

使用 IntelliJ IDEA 排查死锁 IntelliJ IDEA 提供了强大的工具来帮助开发者排查死锁问题。以下是具体的排查步骤: 1. 编写并运行代码 首先,我们编写一个可能导致死锁的示例代码: public class DeadlockExample {private static final Obj…...

如何正确安装Stable Diffusion Web UI以及对应的xFormers

本文是我总结的步骤,验证了几次保证是对的。因为正确的安装 Stable Diffusion Web UI 以及对应的 xFormers 实在是太麻烦了,官方和网上的步骤都是残缺和分散的,加上国内网络速度不理想,所以需要一些额外步骤,之前研究出…...

机器学习_14 随机森林知识点总结

随机森林(Random Forest)是一种强大的集成学习算法,广泛应用于分类和回归任务。它通过构建多棵决策树并综合它们的预测结果,显著提高了模型的稳定性和准确性。今天,我们就来深入探讨随机森林的原理、实现和应用。 一、…...

机器学习基本篇

文章目录 1 基本概念2 基本流程2.0 数据获取2.1 预处理2.1.0 认识数据认识问题2.1.1 不平衡标签的处理a.随机过采样方法 ROS,random over-samplingb. SMOTE synthetic minority Over-Sampling Technique2.2 缺失值处理2.3 数据清洗2.3.0离散特征编码2.3.1 连续特征处理归一化标…...

vue2.x与vue3.x生命周期的比较

vue2.x 生命周期图示: new Vue() | v Init Events & Lifecycle | v beforeCreate | v created | v beforeMount | v mounted | v beforeUpdate (when data changes) | v updated | v beforeDestroy (when vm.…...

接口测试及常用接口测试工具(Postman/Jmeter)

🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 首先,什么是接口呢? 接口一般来说有两种,一种是程序内部的接口,一种是系统对外的接口。 系统对外的接口&#xf…...

[论文阅读] SeeSR: Towards Semantics-Aware Real-World Image Super-Resolution

文章目录 一、前言二、主要贡献三、Introduction四、Methodology4.1 Motivation :4.2Framework Overview.** 一、前言 通信作者是香港理工大学 & OPPO研究所的张磊教授,也是图像超分ISR的一个大牛了。 论文如下 SeeSR: Towards Semantics-Aware Rea…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天,再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至,这不仅是开发者的盛宴,更是全球数亿苹果用户翘首以盼的科技春晚。今年,苹果依旧为我们带来了全家桶式的系统更新,包括 iOS 26、iPadOS 26…...

8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂

蛋白质结合剂(如抗体、抑制肽)在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上,高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术,但这类方法普遍面临资源消耗巨大、研发周期冗长…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

CentOS下的分布式内存计算Spark环境部署

一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架,相比 MapReduce 具有以下核心优势: 内存计算:数据可常驻内存,迭代计算性能提升 10-100 倍(文档段落:3-79…...

大语言模型如何处理长文本?常用文本分割技术详解

为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...

Axios请求超时重发机制

Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

Unit 1 深度强化学习简介

Deep RL Course ——Unit 1 Introduction 从理论和实践层面深入学习深度强化学习。学会使用知名的深度强化学习库,例如 Stable Baselines3、RL Baselines3 Zoo、Sample Factory 和 CleanRL。在独特的环境中训练智能体,比如 SnowballFight、Huggy the Do…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作:ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等(ArcGIS出图图例8大技巧),那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

【Redis】笔记|第8节|大厂高并发缓存架构实战与优化

缓存架构 代码结构 代码详情 功能点: 多级缓存,先查本地缓存,再查Redis,最后才查数据库热点数据重建逻辑使用分布式锁,二次查询更新缓存采用读写锁提升性能采用Redis的发布订阅机制通知所有实例更新本地缓存适用读多…...

GitHub 趋势日报 (2025年06月06日)

📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 590 cognee 551 onlook 399 project-based-learning 348 build-your-own-x 320 ne…...