当前位置: 首页 > news >正文

使用 Bert 做文本分类,利用 Trainer 框架实现 二分类,事半功倍

简介

使用 AutoModelForSequenceClassification 导入Bert 模型。
很多教程都会自定义 损失函数,然后手动实现参数更新。
但本文不想手动微调,故使用 transformers 的 Trainer 自动微调。
人生苦短,我用框架,不仅可保证微调出的模型的效果,而且还省时间。

导包

import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (AutoTokenizer,AutoModelForSequenceClassification,
)import torch
from torch import nnimport os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'# AG_News 英文分类数据集
# ds = load_dataset("fancyzhx/ag_news")## 中文分类数据集
ds = load_dataset("lansinuote/ChnSentiCorp")

数据集的详情如下:

DatasetDict({train: Dataset({features: ['text', 'label'],num_rows: 9600})validation: Dataset({features: ['text', 'label'],num_rows: 1200})test: Dataset({features: ['text', 'label'],num_rows: 1200})
})
ds["train"][0]
{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般','label': 1}

加载 Bert 模型

model_name = "bert-base-chinese"tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True,
)bert = AutoModelForSequenceClassification.from_pretrained(model_name,trust_remote_code=True,num_labels=2,
)

如果你无法联网的话,使用本地huggingface模型:

bert = AutoModelForSequenceClassification.from_pretrained(model_name,trust_remote_code=True,revision="c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f",local_files_only=True,num_labels=2,
)

查看 bert 分类模型的网络结构:

bert

在这里插入图片描述

如上图所示,Bert 的分类模型:在原生的 Bert 模型后,加了一个Linear

下述是数据集转换函数:

def tokenize_func(item):global tokenizertokenized_inputs = tokenizer(item["text"],max_length=512,truncation=True,)return tokenized_inputs
tokenized_datasets = ds.map(tokenize_func,batched=True,
)

tokenized_datasets 的详情如下所示:

DatasetDict({train: Dataset({features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],num_rows: 9600})validation: Dataset({features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],num_rows: 1200})test: Dataset({features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],num_rows: 1200})
})

Train

from transformers import TrainingArgumentsargs = TrainingArguments("ChnSentiCorp_text_cls",eval_steps=8,evaluation_strategy="steps",save_strategy="epoch",save_total_limit=3,learning_rate=2e-5,num_train_epochs=3,weight_decay=0.01,per_device_train_batch_size=32,per_device_eval_batch_size=16,logging_steps=8,save_safetensors=True,overwrite_output_dir=True,# load_best_model_at_end=True,
)

TrainingArguments 的参数解释点击查看下述文章:
LLM大模型之Trainer以及训练参数

from transformers import DataCollatorWithPaddingdata_collator = DataCollatorWithPadding(tokenizer=tokenizer)
from transformers import Trainertrainer = Trainer(model=bert,args=args,train_dataset=tokenized_datasets["train"],eval_dataset=tokenized_datasets["validation"],data_collator=data_collator,# compute_metrics=compute_metrics,tokenizer=tokenizer,
)
trainer.train()

训练过程,在终端可以看见,训练和验证的损失值变化。
在这里插入图片描述

如果安装了 wandb,并且在系统环境变量中,进行了设置。

训练过程和评估过程的记录会自动上传到wandb中。

wandb

若你想使用 wandb,自行进行安装;个人强烈推荐,一劳永逸,这样就无需自己绘图展示模型的训练过程了。

在模型训练的过程,进入 wandb https://wandb.ai/home 看看模型的现在的训练的过程。
在这里插入图片描述

在这里插入图片描述

上图是在 wandb 网站看到的图,横轴是 epoch ,纵轴是 loss。
蓝色折线是在验证集上的损失,橙色折线是在训练集上的损失。

可以很直观的看到,在训练集上的loss 小于 在验证集上的 loss。

predict

训练完成的模型,使用 predict 方法,在测试集上预测。

predictions = trainer.predict(tokenized_datasets["test"])
preds = np.argmax(predictions.predictions, axis=-1)
preds

输出结果:

array([1, 0, 0, ..., 1, 1, 0])

预测结果评估

def eval_data(data):predictions = trainer.predict(data)preds = np.argmax(predictions.predictions, axis=-1)metric = evaluate.load("glue", "mrpc")return metric.compute(predictions=preds, references=predictions.label_ids)
eval_data(tokenized_datasets["test"])

输出结果:

{'accuracy': 0.9475, 'f1': 0.9478908188585607}

总结

总体上看,本文做了一下数据集的处理,大模型的微调过程、模型权重报错、日志记录,这些过程全部由 transformers 的 Trainer 自动进行。

用好 框架, 事半功倍。当然前提是已经掌握了基础的手动参数微调。

参考资料

  • huggingface 使用 Trainer API 微调模型

相关文章:

使用 Bert 做文本分类,利用 Trainer 框架实现 二分类,事半功倍

简介 使用 AutoModelForSequenceClassification 导入Bert 模型。 很多教程都会自定义 损失函数,然后手动实现参数更新。 但本文不想手动微调,故使用 transformers 的 Trainer 自动微调。 人生苦短,我用框架,不仅可保证微调出的模…...

Obsidian git sync error / Obsidian git 同步失敗

Issue: commit due to empty commit message Solution 添加commit資訊,確保不留空白 我的設置:auto-backup: {{hostname}}/{{date}}/...

谷歌英文SEO外链如何做?

做英文SEO外链涉及多种策略和技巧,目标是提升目标网站的排名和流量,Google的搜索算法在不断演变,但外链一直是搜索引擎优化中重要的一环。有效的外链建设能够显著提升网站的SEO数据效果。关键在于创建一个多元化且自然的外链结构。不能仅仅依…...

vue使用Export2Excel导出表格

安装插件 npm install xlsx xlsx-style file-saver npm install node-polyfill-webpack-plugin (如果不安装的话后面使用会报错) 添加相关配置 在vue.config.js文件 const NodePolyfillPlugin require("node-polyfill-webpack-plugin") module.exports defineCon…...

Linux环境变量 本地变量 命令行参数

并行和并发 并行 多个进程在多个 CPU 下分别,同时进行运行。 并发 多个进程在一个 CPU 采用进程切换的方式,在一段时间内,让多个进程都得以推进,称之为并发。 CPU 中的寄存器扮演什么角色? 寄存器:cpu 内的寄存器里面保存的是进程…...

向量数据库Faiss的搭建与使用

1. 什么是Faiss? Faiss是由Facebook AI Research团队开发的一个库,旨在高效地进行大规模向量相似性搜索。它不仅支持CPU,还能利用GPU进行加速,非常适合处理大量高维数据。Faiss提供了多种索引类型,以适应不同的需求&a…...

微信小程序接入客服功能

前言 用户可使用小程序客服消息功能,与小程序的客服人员进行沟通。客服功能主要用于在小程序内 用户与客服直接沟通用,本篇介绍客服功能的基础开发以及进阶功能的使用,另外介绍多种客服的对接方式。 更多介绍请查看客服消息使用指南 客服视…...

mysql开启远程访问

个人建议mysql可以用宝塔自动下载安装。 远程访问, 1.关闭防火墙,确保ip能ping通 2.ping端口确定数据库能ping通 3.本地先连上去命令行修改远程访问权限。 mysql -u root -p use mysql; select user,host from user; select host from user where u…...

【NLP自然语言处理】文本处理的基本方法

目录 🍔什么是分词 🍔中文分词工具jieba 2.1 jieba的基本特点 2.2 jieba的功能 2.3 jieba的安装及使用 🍔什么是命名实体识别 🍔什么是词性标注 🍔小结 学习目标 🍀 了解什么是分词, 词性标注, 命名…...

uniapp使用defineExpose暴露和onMounted访问

defineExpose作用 暴露方法和数据 允许从模板或其他组件访问当前组件内部的方法和数据。明确指定哪些方法和数据可以被外部访问,从而避免不必要的暴露。 增强安全性 通过显式声明哪些方法和数据可以被外部访问,防止意外修改内部状态。提高组件的安全性&a…...

怎么使用matplotlib绘制一个从-2π到2π的sin(x)的折线图-学习篇

首先:如果你的环境中没有安装matplotlib,使用以下命令可以直接安装 pip install matplotlib如何画一个这样的折线图呢?往下看 想要画一个简单的sin(x)在-2π到2π的折线图,我们要拆分成以下步骤: 先导入相关的库文…...

【Java毕业设计】基于SpringBoot+Vue+uniapp的农产品商城系统

文章目录 一、系统架构1、后端:SpringBoot、Mybatis2、前端:Vue、ElementUI4、小程序:uniapp3、数据库:MySQL 二、系统功能三、系统展示1、小程序2、后台管理系统 一、系统架构 1、后端:SpringBoot、Mybatis 2、前端…...

C++ | Leetcode C++题解之第390题消除游戏

题目: 题解: class Solution { public:int lastRemaining(int n) {int a1 1;int k 0, cnt n, step 1;while (cnt > 1) {if (k % 2 0) { // 正向a1 a1 step;} else { // 反向a1 (cnt % 2 0) ? a1 : a1 step;}k;cnt cnt >> 1;step …...

echarts进度

echarts图表集 const data[{ value: 10.09,name:制梁进度, color: #86C58C,state: }, { value: 66.00,name:架梁进, color: #C6A381 ,state:正常}, { value: 33.07,name:下部进度, color: #669BDA,state:正常 }, ];// const textStyle { "color": "#CED6C8&…...

PostgreSQL16.4搭建一主一从集群

PostgreSQL搭建一主一从集群的过程主要涉及到基础环境准备、PostgreSQL安装、主从节点配置以及同步验证等步骤。以下是一个详细的搭建过程: 一、基础环境准备 创建虚拟机: 准备两台虚拟机,分别作为主节点和从节点。为每台虚拟机分配独立的IP…...

Spring01——Spring简介、Spring Framework架构、Spring核心概念、IOC入门案例、DI入门案例

为什么要学 spring技术是JavaEE开发必备技能,企业开发技术选型命中率>90%专业角度 简化开发:降低企业开发的复杂度框架整合:高效整合其他技术,提高开发与运行效率 学什么 简化开发 IOCAOP 事务处理 框架整合 MyBatis 怎…...

深度学习|模型推理:端到端任务处理

引言 深度学习的崛起推动了人工智能领域的诸多技术突破,尤其是在处理复杂数据与任务的能力方面。模型推理作为深度学习的核心环节,决定了模型在真实应用场景中的表现。而端到端任务处理(End-to-End Task Processing)作为深度学习的一种重要范式,通过从输入到输出的直接映…...

【深度学习 Pytorch】2024年最新版本PyTorch学习指南

引言 2024年,深度学习技术在各个领域取得了显著的进展,而PyTorch作为深度学习领域的主流框架之一,凭借其易用性、灵活性和强大的社区支持,受到了广大研究者和开发者的喜爱。本文将为您带来一份2024年最新版本的PyTorch学习指南&a…...

第 1 章:原生 AJAX

原生AJAX 1. AJAX 简介 AJAX 全称为 Asynchronous JavaScript And XML,就是异步的 JS 和 XML。通过 AJAX 可以在浏览器中向服务器发送异步请求,最大的优势:无刷新获取数据。AJAX 不是新的编程语言,而是一种将现有的标准组合在一…...

【代码随想录|贪心part04以后——重叠区间】

代代码随想录|贪心part04以后——重叠区间 一、part041、452.用最少数量的箭引爆气球2、435. 无重叠区间2、763.划分字母区间3、56. 合并区间4、738.单调递增的数字总结python 一、part04 1、452.用最少数量的箭引爆气球 452. 用最少数量的箭引爆气球 class Solution:def f…...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例

文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...

【单片机期末】单片机系统设计

主要内容:系统状态机,系统时基,系统需求分析,系统构建,系统状态流图 一、题目要求 二、绘制系统状态流图 题目:根据上述描述绘制系统状态流图,注明状态转移条件及方向。 三、利用定时器产生时…...

IT供电系统绝缘监测及故障定位解决方案

随着新能源的快速发展,光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域,IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选,但在长期运行中,例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...

使用Matplotlib创建炫酷的3D散点图:数据可视化的新维度

文章目录 基础实现代码代码解析进阶技巧1. 自定义点的大小和颜色2. 添加图例和样式美化3. 真实数据应用示例实用技巧与注意事项完整示例(带样式)应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...

PHP 8.5 即将发布:管道操作符、强力调试

前不久,PHP宣布了即将在 2025 年 11 月 20 日 正式发布的 PHP 8.5!作为 PHP 语言的又一次重要迭代,PHP 8.5 承诺带来一系列旨在提升代码可读性、健壮性以及开发者效率的改进。而更令人兴奋的是,借助强大的本地开发环境 ServBay&am…...

Spring Boot + MyBatis 集成支付宝支付流程

Spring Boot MyBatis 集成支付宝支付流程 核心流程 商户系统生成订单调用支付宝创建预支付订单用户跳转支付宝完成支付支付宝异步通知支付结果商户处理支付结果更新订单状态支付宝同步跳转回商户页面 代码实现示例&#xff08;电脑网站支付&#xff09; 1. 添加依赖 <!…...

Vue 3 + WebSocket 实战:公司通知实时推送功能详解

&#x1f4e2; Vue 3 WebSocket 实战&#xff1a;公司通知实时推送功能详解 &#x1f4cc; 收藏 点赞 关注&#xff0c;项目中要用到推送功能时就不怕找不到了&#xff01; 实时通知是企业系统中常见的功能&#xff0c;比如&#xff1a;管理员发布通知后&#xff0c;所有用户…...

基于单片机的宠物屋智能系统设计与实现(论文+源码)

本设计基于单片机的宠物屋智能系统核心是实现对宠物生活环境及状态的智能管理。系统以单片机为中枢&#xff0c;连接红外测温传感器&#xff0c;可实时精准捕捉宠物体温变化&#xff0c;以便及时发现健康异常&#xff1b;水位检测传感器时刻监测饮用水余量&#xff0c;防止宠物…...