使用 BERT 和逻辑回归进行文本分类及示例验证
使用 BERT 和逻辑回归进行文本分类及示例验证
一、引言
在自然语言处理领域中,文本分类是一项至关重要的任务。本文将详细介绍如何结合 BERT 模型与逻辑回归算法来实现文本分类,并通过实际示例进行验证。
二、环境准备
为了运行本文中的代码,你需要安装以下库:
pandas:用于数据处理。sklearn:包含机器学习算法。torch:用于深度学习任务。transformers:用于加载预训练语言模型。
三、代码实现
(一)读取数据集
首先,从 CSV 文件中读取数据集。假设该数据集包含两列,分别是content(文本内容)和labels(文本标签)。
import pandas as pd# 从 CSV 文件读取数据集
print("正在读取数据集...")
df = pd.read_csv('training_data.csv', encoding='utf-8-sig')
print("数据集读取完成,共包含 {} 条数据.".format(len(df)))
(二)分割数据集
接着,提取特征和目标,并将数据集分割为训练集和测试集。
# 提取特征和目标
X = df['content']
y = df['labels']# 分割数据集
print("正在分割数据集...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("训练集大小: {}, 测试集大小: {}".format(len(X_train), len(X_test)))
(三)加载 BERT 模型和分词器
然后,加载 BERT 模型和分词器,以便将文本转化为特征向量。
import torch
from transformers import BertTokenizer, BertModel# 加载 BERT 模型和分词器
print("加载 BERT 模型和分词器...")
tokenizer = BertTokenizer.from_pretrained('D:\\bert-base-chinese')
model = BertModel.from_pretrained('D:\\bert-base-chinese')
(四)文本转化为特征向量
定义一个函数get_embeddings,用于将文本转化为特征向量。该函数利用 BERT 模型对文本进行编码,然后获取[CLS]标记的输出作为文本的特征向量。
# 文本转化为特征向量
def get_embeddings(texts):print("正在生成文本特征向量...")inputs = tokenizer(texts.tolist(), padding=True, truncation=True, return_tensors='pt')with torch.no_grad():outputs = model(**inputs)# 获取[CLS]标记的输出作为文本的特征向量return outputs.last_hidden_state[:, 0, :].numpy()
(五)训练分类模型
使用逻辑回归算法作为分类模型。先将训练集转化为 BERT 特征,然后训练分类模型。
from sklearn.linear_model import LogisticRegression# 转换训练集和测试集为 BERT 特征
X_train_bert = get_embeddings(X_train)
X_test_bert = get_embeddings(X_test)# 训练分类模型
print("正在训练分类模型...")
classifier = LogisticRegression(max_iter=1000) # 使用逻辑回归
classifier.fit(X_train_bert, y_train)
print("模型训练完成.")
(六)预测
使用训练好的分类模型对测试集进行预测,并打印预测结果。
# 预测
print("正在进行预测...")
predictions = classifier.predict(X_test_bert)# 打印预测结果
print("预测结果:", predictions)
(七)示例数据验证
最后,添加一些示例数据进行验证。将示例数据转化为 BERT 特征,然后使用分类模型进行预测,并打印预测结果。
# 添加示例数据进行验证
sample_texts = ["音乐有助力放松大脑,心情愉悦。","热爱生活,享受人生",
]# 将示例数据转换为 BERT 特征
print("正在对示例数据进行预测...")
sample_embeddings = get_embeddings(pd.Series(sample_texts))
sample_predictions = classifier.predict(sample_embeddings)# 打印示例数据预测结果
for text, prediction in zip(sample_texts, sample_predictions):print(f"文本: \"{text}\" 预测标签: {prediction}")
四、总结
本文介绍了如何运用 BERT 和逻辑回归进行文本分类,并通过示例数据进行了验证。借助 BERT 模型学习到的文本上下文信息,能够显著提高文本分类的准确性。同时,逻辑回归算法的快速性使得我们可以高效地对大量文本进行分类。
五、完整代码
text_categorize_and_tag.py
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import torch
from transformers import BertTokenizer, BertModel# 从CSV文件读取数据集
print("正在读取数据集...")
df = pd.read_csv('training_data.csv', encoding='utf-8-sig')
print("数据集读取完成,共包含 {} 条数据.".format(len(df)))# 提取特征和目标
X = df['content']
y = df['labels']# 分割数据集
print("正在分割数据集...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("训练集大小: {}, 测试集大小: {}".format(len(X_train), len(X_test)))# 加载BERT模型和分词器
print("加载BERT模型和分词器...")
tokenizer = BertTokenizer.from_pretrained('D:\\bert-base-chinese')
model = BertModel.from_pretrained('D:\\bert-base-chinese')# 文本转化为特征向量
def get_embeddings(texts):print("正在生成文本特征向量...")inputs = tokenizer(texts.tolist(), padding=True, truncation=True, return_tensors='pt')with torch.no_grad():outputs = model(**inputs)# 获取[CLS]标记的输出作为文本的特征向量return outputs.last_hidden_state[:, 0, :].numpy()# 转换训练集和测试集为BERT特征
X_train_bert = get_embeddings(X_train)
X_test_bert = get_embeddings(X_test)# 训练分类模型
print("正在训练分类模型...")
classifier = LogisticRegression(max_iter=1000) # 使用逻辑回归
classifier.fit(X_train_bert, y_train)
print("模型训练完成.")# 预测
print("正在进行预测...")
predictions = classifier.predict(X_test_bert)# 打印预测结果
print("预测结果:", predictions)# 添加示例数据进行验证
sample_texts = ["音乐有助力放松大脑,心情愉悦。","热爱生活,享受人生",
]# 将示例数据转换为BERT特征
print("正在对示例数据进行预测...")
sample_embeddings = get_embeddings(pd.Series(sample_texts))
sample_predictions = classifier.predict(sample_embeddings)# 打印示例数据预测结果
for text, prediction in zip(sample_texts, sample_predictions):print(f"文本: \"{text}\" 预测标签: {prediction}")
training_data.csv
content,labels
"Python 是一种广泛使用的高级编程语言。","编程"
"自然语言处理是人工智能领域的重要研究方向。","NLP"
"机器学习是分析数据的重要工具。","机器学习"
"数据科学结合了统计学和计算机科学。","数据科学"
"人工智能正在改变我们的生活方式。","人工智能"
"深度学习能够处理复杂的数据集。","机器学习"
"很多企业开始应用人工智能技术以提高效率。","人工智能"
"数据分析是理解客户行为的重要工具。","数据科学"
"编程不仅是技术,更是一种思维方式。","编程"
"算法在大数据时代发挥着重要作用。","数据科学"
"音乐可以影响人的情绪和认知。","音乐"
"学习音乐可以提高学生的创造力。","教育"
"现场音乐会可以提供独特的视听体验。","娱乐"
"教育科技正在变革传统的学习方式。","教育"
"学习一门乐器有助于提升专注力。","音乐"
"电影和电视节目是现代娱乐的重要部分。","娱乐"
"音乐治疗被广泛应用于心理健康。","音乐"
"在线教育平台为学习者提供灵活的选择。","教育"
"综艺节目为观众提供了丰富的娱乐内容。","娱乐"
"这是一篇关于机器学习的文章。","科技"
"我喜欢户外活动和旅游。","生活"
"COVID-19疫情对全球经济产生了深远的影响。","财经"
"人工智能正在改变我们的生活方式。","科技"
"旅游是一种能让人开阔视野的活动。","生活"
"金融科技让我们的投资变得更加智能。","财经"
"环境保护对我们的未来至关重要。","环保"
相关文章:
使用 BERT 和逻辑回归进行文本分类及示例验证
使用 BERT 和逻辑回归进行文本分类及示例验证 一、引言 在自然语言处理领域中,文本分类是一项至关重要的任务。本文将详细介绍如何结合 BERT 模型与逻辑回归算法来实现文本分类,并通过实际示例进行验证。 二、环境准备 为了运行本文中的代码…...
【skywalking 】监控 Spring Cloud Gateway 数据
使用Spring Cloud 开发,用Skywalking 监控服务,但是Skywalking 默认是不支持 Spring Cloud Gateway 网关服务的,需要手动将 Gateway 的插件添加到 Skywalking 启动依赖 jar 中。 skywalking相关版本信息 jdk:17skywalking&#x…...
SpringWeb
SpringWeb SpringWeb 概述 SpringWeb 是 spring 框架中的一个模块,基于 Servlet API 构建的 web 框架. springWeb 是 Spring 为 web 层开发提供的一整套完备的解决方案。 在 web 层框架历经 Strust1,WebWork,Strust2 等诸多产品的历代更…...
嵌入式刷题(day21)
MySQL和sqlite的区别 MySQL和SQLite是两种常见的关系型数据库管理系统(RDBMS),但它们在特性、使用场景和架构方面有显著的区别: 1. 架构 MySQL:是一个基于服务器的数据库系统,遵循客户端-服务器架构。MySQL服务器运行在主机上,客户端通过网络连接并发送查询。它可以并…...
OpenAI 下一代旗舰模型现身?奥尔特曼亲自辟谣“猎户座“传闻
在人工智能领域最受瞩目的ChatGPT即将迎来两周岁之际,一场关于OpenAI新旗舰模型的传闻再次引发业界热议。然而,这场喧嚣很快就被OpenAI掌门人奥尔特曼亲自澄清。 事件源于科技媒体The Verge的一则报道。据多位知情人士透露,OpenAI可能会在11…...
【C++】STL初识
【C】STL初识 文章目录 【C】STL初识前言一、STL基本概念二、STL六大组件简介三、STL三大组件四、初识STL总结 前言 本篇文章将讲到STL基本概念,STL六大组件简介,STL三大组件,初识STL。 一、STL基本概念 STL(Standard Template Library,标准…...
框架篇补充(东西多 需要重新看网课)
什么是AOP 面向切面编程 降低耦合 提高代码的复用 Spring的bean的生命周期 实例化bean 赋值 初始化bean 使用bean 销毁bean SpringMVC的执行流程 Springboot自动装配原理 实际上就是为了从spring.factories文件中 获取到对应的需要 进行自动装配的类 并生成相应的Bean…...
合约门合同全生命周期管理系统:企业合同管理的数字化转型之道
合约门合同全生命周期管理系统:企业合同管理的数字化转型之道 1. 引言 在现代企业中,合同管理已经不再是简单的文件存储和审批流程,而是企业合规性、风险管理和业务流程的关键环节之一。随着企业规模的扩大和合同数量的增加,传统…...
等保测评与风险管理:识别、评估和缓解潜在的安全威胁
在信息化时代,数据已成为企业最宝贵的资产之一,而信息安全则成为守护这份资产免受侵害的重中之重。等保测评(信息安全等级保护测评)作为保障信息系统安全的重要手段,其核心在于通过科学、规范、专业的评估手段…...
Golang Agent 可观测性的全面升级与新特性介绍
作者:张海彬(古琦) 背景 自 2024 年 6 月 26 日,ARMS 发布了针对 Golang 应用的可观测性监控功能以来,阿里云 ARMS 团队与程序语言与编译器团队一直致力于不断优化和提升该系统的各项功能,旨在为开发者提…...
SpringBoot的开篇 特点 初始化 ioc 配置文件
文章目录 前言SpringBoot发展历程SpringBoot前置准备SpringBoot特点 SpringBoot项目初始化项目启动Springboot的核心概念IOC概念介绍Bean对象通过注解扫描包 例子配置文件 前言 SpringBoot发展历程 最初,Spring框架的使用需要大量的XML配置,这使得开发…...
docker 可用镜像服务地址(2024.10.25亲测可用)
1.错误 Error response from daemon: Get “https://registry-1.docker.io/v2/” 原因:镜像服务器地址不可用。 2.可用地址 编辑daemon.json: vi /etc/docker/daemon.json内容修改如下: {"registry-mirrors": ["https://…...
【SQL实验】表的更新和简单查询
完整代码在文章末尾 在上次实验创建的educ数据库基础上,用SQL语句为student表、course表和sc表中添加以下记录 【SQL实验】数据库、表、模式的SQL语句操作_创建一个名为educ数据库,要求如下: (下面三个表中属性的数据类型需要自己设计合适-CSDN博客在这篇博文中已经…...
【C++】 string的了解及使用
标准库中的string类 在使用string类时,必须包含#include头文件以及using namespace std; string类的常用接口说明 C中string为我们提供了丰富的接口来供我们使用 – string接口文档 这里我们只介绍一些常见的接口 string类对象的常见构造 #include <iostrea…...
【K8S】kubernetes-dashboard.yaml
https://raw.githubusercontent.com/kubernetes/dashboard/v3.0.0-alpha0/charts/kubernetes-dashboard.yaml 以下链接的内容: 由于国内访问不了,找到一些方法下载了这个文件内容, 部署是mages 对象的镜像 WEB docker.io/kubernetesui/dash…...
远程root用户访问服务器中的MySQL8
一、Ubuntu下的MySQL8安装 在Ubuntu系统中安装MySQL 8.0可以通过以下步骤进行1. 更新包管理工具的仓库列表: sudo apt update 2. 安装MySQL 8.0,root用户默认没有密码: sudo apt install mysql-server sudo apt install mysql-client 【…...
解释一下 Java 中的静态变量(Static Variable)和静态方法(Static Method)?
今天来和大家深入探讨一下 Java 中的静态变量和静态方法,并通过一些具体的例子来理解它们在实际开发中的应用。 静态变量(Static Variable) 静态变量,也称为类变量,是在类的层次上共享的变量。这意味着无论创建了多少…...
【Linux】————磁盘与文件系统
作者主页: 作者主页 本篇博客专栏:Linux 创作时间 :2024年10月17日 一、磁盘的物理结构 磁盘的物理结构如图所示: 其中具体的物理存储结构如下: 磁盘中存储的基本单位为扇区,一个扇区的大小一般为512字…...
平衡控制——直立环——速度环
目录 平衡控制原理 平衡控制模型 平衡控制中基于模型设计与自动代码生成技术 速度环应用原理 速度控制模型 平衡控制原理 下图是一个单摆模型,对其进行受力分析如图。 在重力作用下,单摆受到和角度成正比,运动方向相反的回复力。而且在空气中运动的单摆,由于受…...
面试简要介绍hashMap
jdk8之前,hashmap采用的数据结构是数组链表,jdk8之后采用的数据结构是数组链表/红黑树。hashmap的数据以键值对的形式存在,如果两个元素的hash值相同,就会发生hash冲突,被放到同一个链表上--->如何解决hash冲突---&…...
3大远程管理痛点解决方案:MobaXterm中文版实现一站式终端效率革命
3大远程管理痛点解决方案:MobaXterm中文版实现一站式终端效率革命 【免费下载链接】Mobaxterm-Chinese Mobaxterm simplified Chinese version. Mobaxterm 的简体中文版. 项目地址: https://gitcode.com/gh_mirrors/mo/Mobaxterm-Chinese 远程服务器管理面临…...
Hurley:C#到裸机C的语义重铸编译器
1. 这不是代码转换器,而是一台“语义重铸机”你有没有试过把一段写得工整、泛型丰富、LINQ链式调用如行云流水的C#代码,硬生生塞进一个只认int main()和malloc的嵌入式环境?我去年在给某款国产工业PLC做边缘协议适配时就撞上了这堵墙…...
大麦网自动抢票神器:5分钟配置,告别抢票焦虑的终极指南
大麦网自动抢票神器:5分钟配置,告别抢票焦虑的终极指南 【免费下载链接】ticket-purchase 大麦自动抢票,支持人员、城市、日期场次、价格选择 项目地址: https://gitcode.com/GitHub_Trending/ti/ticket-purchase 还在为心仪演唱会门票…...
大麦抢票终极指南:告别手速焦虑,轻松锁定心仪演出门票
大麦抢票终极指南:告别手速焦虑,轻松锁定心仪演出门票 【免费下载链接】ticket-purchase 大麦自动抢票,支持人员、城市、日期场次、价格选择 项目地址: https://gitcode.com/GitHub_Trending/ti/ticket-purchase 面对热门演唱会门票&q…...
深入解析CPU L1/L2缓存:原理、性能影响与编程优化实战
1. 项目概述:从“快”字说起做性能调优或者写高性能代码的朋友,对“缓存”这个词一定不陌生。我们总在说,把数据放进缓存里,访问就快了。但缓存本身,尤其是离CPU核心最近的一级缓存(L1 Cache)和…...
windows下vs 2015 libtorrent库的配置,vs2015下-boost-openssl-libtorrent的配置
libtorrent依赖OpenSSL和boost库,首先要编译Openssl和boost库。 1、安装ActivePerl,下载地址:网上找。 安装完后配置环境变量(一般安装成功后,环境变量就已经配置好了,如果没有配置自己配置环境变量): …...
【2026年世界模型最全综述】:从开山之作到Sora与Genie 3
论文信息 标题:Understanding World or Predicting Future? A Comprehensive Surveyof World Models会议:ACM Computing Surveys 2026(计算机领域顶级综述期刊)单位:清华大学FIB-Lab代码:https://github.c…...
AssetStudio Unity资源提取终极指南:精准解析SerializedFile与AssetBundle
1. 为什么AssetStudio是Unity资源提取的“第一把刀”——不是因为它最强,而是因为它最准你有没有遇到过这样的场景:刚下载一个热门Unity手游的APK,兴致勃勃地解包,结果在assets/bin/Data/Managed/目录下看到一堆Assembly-CSharp.d…...
5分钟实现OBS多平台同步直播:obs-multi-rtmp插件完全指南
5分钟实现OBS多平台同步直播:obs-multi-rtmp插件完全指南 【免费下载链接】obs-multi-rtmp OBS複数サイト同時配信プラグイン 项目地址: https://gitcode.com/gh_mirrors/ob/obs-multi-rtmp 你是否厌倦了在不同直播平台间来回切换的繁琐操作?obs-…...
智能电表:解锁智能照明精细化能耗管控新密码
摘要随着双碳政策深度落地与智慧楼宇数字化升级,智能照明已成为商业园区、市政道路、综合体的标配设施。传统机械式电表仅具备基础电量统计功能,存在数据滞后、精度不足、无分区计量、无异常监测等短板,无法适配现代照明多回路、多场景、长时…...
