当前位置: 首页 > news >正文

【文本分类】bert二分类

import os
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm# 自定义数据集
class CustomDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_length=128):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_length = max_lengthdef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]encoding = self.tokenizer(text,max_length=self.max_length,padding="max_length",truncation=True,return_tensors="pt")return {"input_ids": encoding["input_ids"].squeeze(0),"attention_mask": encoding["attention_mask"].squeeze(0),"label": torch.tensor(label, dtype=torch.long)}# 训练函数
def train_model(model, train_loader, optimizer, device, num_epochs=3):model.train()for epoch in range(num_epochs):total_loss = 0for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)labels = batch["label"].to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losstotal_loss += loss.item()optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {epoch + 1} Loss: {total_loss / len(train_loader)}")# 评估函数
def evaluate_model(model, val_loader, device):model.eval()predictions, true_labels = [], []with torch.no_grad():for batch in val_loader:input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)labels = batch["label"].to(device)outputs = model(input_ids, attention_mask=attention_mask)logits = outputs.logitspreds = torch.argmax(logits, dim=1).cpu().numpy()predictions.extend(preds)true_labels.extend(labels.cpu().numpy())accuracy = accuracy_score(true_labels, predictions)report = classification_report(true_labels, predictions)print(f"Validation Accuracy: {accuracy}")print("Classification Report:")print(report)# 模型保存函数
def save_model(model, tokenizer, output_dir):os.makedirs(output_dir, exist_ok=True)model.save_pretrained(output_dir)tokenizer.save_pretrained(output_dir)print(f"Model saved to {output_dir}")# 模型加载函数
def load_model(output_dir, device):tokenizer = BertTokenizer.from_pretrained(output_dir)model = BertForSequenceClassification.from_pretrained(output_dir)model.to(device)print(f"Model loaded from {output_dir}")return model, tokenizer# 推理预测函数
def predict(texts, model, tokenizer, device, max_length=128):model.eval()encodings = tokenizer(texts,max_length=max_length,padding="max_length",truncation=True,return_tensors="pt")input_ids = encodings["input_ids"].to(device)attention_mask = encodings["attention_mask"].to(device)with torch.no_grad():outputs = model(input_ids, attention_mask=attention_mask)logits = outputs.logitsprobabilities = torch.softmax(logits, dim=1).cpu().numpy()predictions = torch.argmax(logits, dim=1).cpu().numpy()return predictions, probabilities# 主函数
def main():# 配置参数config = {"train_batch_size": 16,"val_batch_size": 16,"learning_rate": 5e-5,"num_epochs": 5,"max_length": 128,"device_id": 7,  # 指定 GPU ID"model_dir": "model","local_model_path": "roberta_tiny_model",  # 指定本地模型路径,如果为 None 则使用预训练模型"pretrained_model_name": "uer/chinese_roberta_L-12_H-128",  # 预训练模型名称}# 设置设备device = torch.device(f"cuda:{config['device_id']}" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 加载分词器和模型tokenizer = BertTokenizer.from_pretrained(config["local_model_path"])model = BertForSequenceClassification.from_pretrained(config["local_model_path"], num_labels=2)model.to(device)# 示例数据train_texts = ["This is a great product!", "I hate this service."]train_labels = [1, 0]val_texts = ["Awesome experience.", "Terrible product."]val_labels = [1, 0]# 创建数据集和数据加载器train_dataset = CustomDataset(train_texts, train_labels, tokenizer, config["max_length"])val_dataset = CustomDataset(val_texts, val_labels, tokenizer, config["max_length"])train_loader = DataLoader(train_dataset, batch_size=config["train_batch_size"], shuffle=True)val_loader = DataLoader(val_dataset, batch_size=config["val_batch_size"])# 定义优化器optimizer = AdamW(model.parameters(), lr=config["learning_rate"])# 训练模型train_model(model, train_loader, optimizer, device, num_epochs=config["num_epochs"])# 评估模型evaluate_model(model, val_loader, device)# 保存模型save_model(model, tokenizer, config["model_dir"])# 加载模型loaded_model, loaded_tokenizer = load_model(config["model_dir"], "cpu")# 推理预测new_texts = ["I love this!", "It's the worst."]predictions, probabilities = predict(new_texts, loaded_model, loaded_tokenizer,  "cpu")for text, pred, prob in zip(new_texts, predictions, probabilities):print(f"Text: {text}")print(f"Predicted Label: {pred} (Probability: {prob})")if __name__ == "__main__":main()

相关文章:

【文本分类】bert二分类

import os import torch from torch.utils.data import DataLoader, Dataset from transformers import BertTokenizer, BertForSequenceClassification, AdamW from sklearn.metrics import accuracy_score, classification_report from tqdm import tqdm# 自定义数据集 class…...

单例模式-如何保证全局唯一性?

以下是几种实现单例模式并保证全局唯一性的方法: 1. 饿汉式单例模式 class Singleton { private:// 私有构造函数,防止外部创建对象Singleton() {}// 静态成员变量,存储单例对象static Singleton instance; public:// 公有静态成员函数&…...

设计模式学习笔记——结构型模式

文章目录 适配器模式 Adapter适用场景UML 桥接模式 Bridge适用场景UML 组合模式 Composite装饰模式 Decorator外观模式 Facade享元模式 Flyweight代理模式 Proxy 适配器模式 Adapter 适用场景 希望使用某个类, 但是其接口与其他代码不兼容时, 可以使用…...

WEB攻防-通用漏洞_文件上传_黑白盒审计流程

目录 前置知识点 Finecms-CMS文件上传 ​编辑 Cuppa-Cms文件上传 Metinfo-CMS 文件上传 前置知识点 思路: 黑盒就是寻找一切存在文件上传的功能应用 1 、个人用户中心是否存在文件上传功能 2 、后台管理系统是否存在文件上传功能 3 、字典目录扫描探针文件上传构…...

RabbitMQ基本介绍及简单上手

(一)什么是MQ MQ(message queue)本质上是队列,满足先入先出,只不过队列中存放的内容是消息而已,那什么是消息呢? 消息可以是字符串,json也可以是一些复杂对象 我们应用场…...

服务器证书不受信任是什么问题?

用户在访问某些网站时,可能会遇到“服务器证书不受信任”的警告。这一问题不仅影响用户的浏览体验,更可能对网站的信誉和安全性产生深远影响。那么服务器证书不受信任是什么问题呢? 服务器证书的基本概念 服务器证书是由证书颁发机构(CA)签…...

spring mvc源码学习笔记之十

前面的文章介绍了用 WebApplicationInitializer 或者 AbstractAnnotationConfigDispatcherServletInitializer 来代替 web.xml 。 我们学 java web 的时候就知道,servlet 容器会自动加载 web.xml。 那么,疑问就来了,WebApplicationInitialize…...

Ubuntu 下载安装 elasticsearch7.17.9

参考 https://blog.csdn.net/qq_26039331/article/details/115024218 https://blog.csdn.net/mengo1234/article/details/104989382 过程 来到 Es 的版本发布列表页面:https://www.elastic.co/downloads/past-releases#elasticsearch 根据自己的系统以及要安装的…...

Qt笔记:网络编程Tcp

一、铺垫 1.以下只是告诉诸位怎样去构建服务器与客户端;客户端这样构建肯定没问题;但是服务端不可能这样写,因为他是布置在Linux上的,纯数据类处理服务器,根本不可能用Qt写;这在Qt的http类中就表明了&…...

C++单例模式跨DLL调用问题梳理

问题案例: 假设有这样一个单例模式的代码 //test.h header class Test { public:static Test &instance() {static Test ins;return ins;}void foo(); };void testFoo();//test.cpp source #include "test.h"void Test::foo() {printf("%p\n&q…...

oracle闪回版本查询

闪回版本查询(Flashback Versions Query)是Oracle数据库提供的一种功能,允许用户查看某个表在特定时间范围内的所有版本。这对于审计和调试数据修改问题非常有用。通过闪回版本查询,你可以了解表中的数据在某个时间段内的变化历史…...

C#用winform窗口程序操作服务+不显示Form窗体,只显示右下角托盘图标+开机时自启动程序【附带项目地址】

服务的文章在:https://blog.csdn.net/weixin_43768573/article/details/144957941 一、用winform窗口程序操作服务 1、点击“创建新项目”,选择“Windows 服务(.NET Framework)” 2、给项目命名 3、右击项目->添加->新建项,选择“应用程序清单文件(仅限Windo…...

UOS系统和windows系统wps文档显示差异问题解决

最近在使用UOS系统的过程中,发现了一个很有意思的现象。就是在UOS系统上编辑的文档,发到windows系统上,会出现两个文档显示差异很大的情况,文档都是使用一样的wps软件打开的。到底是什么原因导致这种现象的呢?该如何解…...

JS中函数基础知识之查漏补缺(写给小白的学习笔记)

函数 函数是ECMAScript中 最有意思的部分之一, 主要是因为函数实际上是对象.-- 每个函数 都是Function类型的实例,Function也有属性和方法. 因为函数是对象,所以函数名就是指向函数对象的指针. 常用的定义函数的语法: ①函数声明 ②函数表达式 ③箭头函数 function sum (n…...

蓝桥杯训练

1对于一个字母矩阵,我们称矩阵中的一个递增序列是指在矩阵中找到两个字母,它们在同一行,同一列,或者在同一 45 度的斜线上,这两个字母从左向右看、或者从上向下看是递增的。 例如,如下矩阵中 LANN QIAO有…...

前端学习DAY33(外边距的折叠)

垂直外边距的重叠 在网页中相邻的垂直方向的外边距,会发生外边距的重叠 兄弟元素 兄弟元素之间的相邻外边距会取(绝对值)最大值,而不是取和,谁大取谁 特殊情况:如果相邻的外边距一正一负,则取两…...

asp.net core mvc的 ViewBag , ViewData , Module ,TempData

在 ASP.NET MVC 和 ASP.NET Core MVC 中,ViewBag 和 ViewData 是两种用于将数据从控制器传递到视图(View)的常用方法。它们都允许控制器将动态数据传递给视图,但它们的实现方式有所不同。关于 Module,它通常指的是某种…...

Linux驱动学习之第二个驱动程序(LED点亮关闭驱动程序-分层设计思想,使其能适应不同的板子-驱动程序模块为多个源文件怎么写Makefile)

目录 看这篇博文前请先掌握下面这些博文中的知识需要的PDF资料完整源代码board_fire_imx6ull-pro.c中的代码leddrv.c中的代码ledtest.c中的代码 程序设计思想和文件结构实现分层思想的具体方法概述具体实现分析定义结构体led_operations用来集合各个单板硬件层面操作LED的函数定…...

手写@EnableTransactionalManagement

定义一个注解,用于标注于方法上,标志着此方法是一个事务方法。 Target({ElementType.METHOD,ElementType.TYPE}) Retention(RetentionPolicy.RUNTIME) public interface MyTransaction {}定义一个开启事务功能的注解 Component Import(TransActionBean…...

【Vue】:解决动态更新 <video> 标签 src 属性后视频未刷新的问题

问题描述 在 Vue.js 项目&#xff0c;当尝试动态更新 <video> 标签的 <source> 元素 src 属性来切换视频时&#xff0c;遇到了一个问题&#xff1a;即使 src 属性已更改&#xff0c;浏览器仍显示旧视频。具体表现为用户选择新视频后&#xff0c;视频区域继续显示之…...

装饰模式(Decorator Pattern)重构java邮件发奖系统实战

前言 现在我们有个如下的需求&#xff0c;设计一个邮件发奖的小系统&#xff0c; 需求 1.数据验证 → 2. 敏感信息加密 → 3. 日志记录 → 4. 实际发送邮件 装饰器模式&#xff08;Decorator Pattern&#xff09;允许向一个现有的对象添加新的功能&#xff0c;同时又不改变其…...

Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例

使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件&#xff0c;常用于在两个集合之间进行数据转移&#xff0c;如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model&#xff1a;绑定右侧列表的值&…...

(二)TensorRT-LLM | 模型导出(v0.20.0rc3)

0. 概述 上一节 对安装和使用有个基本介绍。根据这个 issue 的描述&#xff0c;后续 TensorRT-LLM 团队可能更专注于更新和维护 pytorch backend。但 tensorrt backend 作为先前一直开发的工作&#xff0c;其中包含了大量可以学习的地方。本文主要看看它导出模型的部分&#x…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

&#x1f31f; 什么是 MCP&#xff1f; 模型控制协议 (MCP) 是一种创新的协议&#xff0c;旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议&#xff0c;它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 &#xff08;一&#xff09;多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如&#xff0c;当用户上传一张“蓝色连衣裙”的图片时&#xff0c;接口可自动提取图像中的颜色&#xff08;RGB值&…...

GitHub 趋势日报 (2025年06月08日)

&#x1f4ca; 由 TrendForge 系统生成 | &#x1f310; https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日获星趋势图 今日获星趋势图 884 cognee 566 dify 414 HumanSystemOptimization 414 omni-tools 321 note-gen …...

leetcodeSQL解题:3564. 季节性销售分析

leetcodeSQL解题&#xff1a;3564. 季节性销售分析 题目&#xff1a; 表&#xff1a;sales ---------------------- | Column Name | Type | ---------------------- | sale_id | int | | product_id | int | | sale_date | date | | quantity | int | | price | decimal | -…...

Ascend NPU上适配Step-Audio模型

1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统&#xff0c;支持多语言对话&#xff08;如 中文&#xff0c;英文&#xff0c;日语&#xff09;&#xff0c;语音情感&#xff08;如 开心&#xff0c;悲伤&#xff09;&#x…...

什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南

文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/55aefaea8a9f477e86d065227851fe3d.pn…...

有限自动机到正规文法转换器v1.0

1 项目简介 这是一个功能强大的有限自动机&#xff08;Finite Automaton, FA&#xff09;到正规文法&#xff08;Regular Grammar&#xff09;转换器&#xff0c;它配备了一个直观且完整的图形用户界面&#xff0c;使用户能够轻松地进行操作和观察。该程序基于编译原理中的经典…...