当前位置: 首页 > news >正文

使用 BERT 和逻辑回归进行文本分类及示例验证

使用 BERT 和逻辑回归进行文本分类及示例验证

一、引言

在自然语言处理领域中,文本分类是一项至关重要的任务。本文将详细介绍如何结合 BERT 模型与逻辑回归算法来实现文本分类,并通过实际示例进行验证。

二、环境准备

为了运行本文中的代码,你需要安装以下库:

  • pandas:用于数据处理。
  • sklearn:包含机器学习算法。
  • torch:用于深度学习任务。
  • transformers:用于加载预训练语言模型。

三、代码实现

(一)读取数据集

首先,从 CSV 文件中读取数据集。假设该数据集包含两列,分别是content(文本内容)和labels(文本标签)。

import pandas as pd# 从 CSV 文件读取数据集
print("正在读取数据集...")
df = pd.read_csv('training_data.csv', encoding='utf-8-sig')
print("数据集读取完成,共包含 {} 条数据.".format(len(df)))

(二)分割数据集

接着,提取特征和目标,并将数据集分割为训练集和测试集。

# 提取特征和目标
X = df['content']
y = df['labels']# 分割数据集
print("正在分割数据集...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("训练集大小: {}, 测试集大小: {}".format(len(X_train), len(X_test)))

(三)加载 BERT 模型和分词器

然后,加载 BERT 模型和分词器,以便将文本转化为特征向量。

import torch
from transformers import BertTokenizer, BertModel# 加载 BERT 模型和分词器
print("加载 BERT 模型和分词器...")
tokenizer = BertTokenizer.from_pretrained('D:\\bert-base-chinese')
model = BertModel.from_pretrained('D:\\bert-base-chinese')

(四)文本转化为特征向量

定义一个函数get_embeddings,用于将文本转化为特征向量。该函数利用 BERT 模型对文本进行编码,然后获取[CLS]标记的输出作为文本的特征向量。

# 文本转化为特征向量
def get_embeddings(texts):print("正在生成文本特征向量...")inputs = tokenizer(texts.tolist(), padding=True, truncation=True, return_tensors='pt')with torch.no_grad():outputs = model(**inputs)# 获取[CLS]标记的输出作为文本的特征向量return outputs.last_hidden_state[:, 0, :].numpy()

(五)训练分类模型

使用逻辑回归算法作为分类模型。先将训练集转化为 BERT 特征,然后训练分类模型。

from sklearn.linear_model import LogisticRegression# 转换训练集和测试集为 BERT 特征
X_train_bert = get_embeddings(X_train)
X_test_bert = get_embeddings(X_test)# 训练分类模型
print("正在训练分类模型...")
classifier = LogisticRegression(max_iter=1000)  # 使用逻辑回归
classifier.fit(X_train_bert, y_train)
print("模型训练完成.")

(六)预测

使用训练好的分类模型对测试集进行预测,并打印预测结果。

# 预测
print("正在进行预测...")
predictions = classifier.predict(X_test_bert)# 打印预测结果
print("预测结果:", predictions)

(七)示例数据验证

最后,添加一些示例数据进行验证。将示例数据转化为 BERT 特征,然后使用分类模型进行预测,并打印预测结果。

# 添加示例数据进行验证
sample_texts = ["音乐有助力放松大脑,心情愉悦。","热爱生活,享受人生",
]# 将示例数据转换为 BERT 特征
print("正在对示例数据进行预测...")
sample_embeddings = get_embeddings(pd.Series(sample_texts))
sample_predictions = classifier.predict(sample_embeddings)# 打印示例数据预测结果
for text, prediction in zip(sample_texts, sample_predictions):print(f"文本: \"{text}\" 预测标签: {prediction}")

四、总结

本文介绍了如何运用 BERT 和逻辑回归进行文本分类,并通过示例数据进行了验证。借助 BERT 模型学习到的文本上下文信息,能够显著提高文本分类的准确性。同时,逻辑回归算法的快速性使得我们可以高效地对大量文本进行分类。

五、完整代码

text_categorize_and_tag.py

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import torch
from transformers import BertTokenizer, BertModel# 从CSV文件读取数据集
print("正在读取数据集...")
df = pd.read_csv('training_data.csv', encoding='utf-8-sig')
print("数据集读取完成,共包含 {} 条数据.".format(len(df)))# 提取特征和目标
X = df['content']
y = df['labels']# 分割数据集
print("正在分割数据集...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("训练集大小: {}, 测试集大小: {}".format(len(X_train), len(X_test)))# 加载BERT模型和分词器
print("加载BERT模型和分词器...")
tokenizer = BertTokenizer.from_pretrained('D:\\bert-base-chinese')
model = BertModel.from_pretrained('D:\\bert-base-chinese')# 文本转化为特征向量
def get_embeddings(texts):print("正在生成文本特征向量...")inputs = tokenizer(texts.tolist(), padding=True, truncation=True, return_tensors='pt')with torch.no_grad():outputs = model(**inputs)# 获取[CLS]标记的输出作为文本的特征向量return outputs.last_hidden_state[:, 0, :].numpy()# 转换训练集和测试集为BERT特征
X_train_bert = get_embeddings(X_train)
X_test_bert = get_embeddings(X_test)# 训练分类模型
print("正在训练分类模型...")
classifier = LogisticRegression(max_iter=1000)  # 使用逻辑回归
classifier.fit(X_train_bert, y_train)
print("模型训练完成.")# 预测
print("正在进行预测...")
predictions = classifier.predict(X_test_bert)# 打印预测结果
print("预测结果:", predictions)# 添加示例数据进行验证
sample_texts = ["音乐有助力放松大脑,心情愉悦。","热爱生活,享受人生",
]# 将示例数据转换为BERT特征
print("正在对示例数据进行预测...")
sample_embeddings = get_embeddings(pd.Series(sample_texts))
sample_predictions = classifier.predict(sample_embeddings)# 打印示例数据预测结果
for text, prediction in zip(sample_texts, sample_predictions):print(f"文本: \"{text}\" 预测标签: {prediction}")

training_data.csv

content,labels
"Python 是一种广泛使用的高级编程语言。","编程"
"自然语言处理是人工智能领域的重要研究方向。","NLP"
"机器学习是分析数据的重要工具。","机器学习"
"数据科学结合了统计学和计算机科学。","数据科学"
"人工智能正在改变我们的生活方式。","人工智能"
"深度学习能够处理复杂的数据集。","机器学习"
"很多企业开始应用人工智能技术以提高效率。","人工智能"
"数据分析是理解客户行为的重要工具。","数据科学"
"编程不仅是技术,更是一种思维方式。","编程"
"算法在大数据时代发挥着重要作用。","数据科学"
"音乐可以影响人的情绪和认知。","音乐"
"学习音乐可以提高学生的创造力。","教育"
"现场音乐会可以提供独特的视听体验。","娱乐"
"教育科技正在变革传统的学习方式。","教育"
"学习一门乐器有助于提升专注力。","音乐"
"电影和电视节目是现代娱乐的重要部分。","娱乐"
"音乐治疗被广泛应用于心理健康。","音乐"
"在线教育平台为学习者提供灵活的选择。","教育"
"综艺节目为观众提供了丰富的娱乐内容。","娱乐"
"这是一篇关于机器学习的文章。","科技"
"我喜欢户外活动和旅游。","生活"
"COVID-19疫情对全球经济产生了深远的影响。","财经"
"人工智能正在改变我们的生活方式。","科技"
"旅游是一种能让人开阔视野的活动。","生活"
"金融科技让我们的投资变得更加智能。","财经"
"环境保护对我们的未来至关重要。","环保"

相关文章:

使用 BERT 和逻辑回归进行文本分类及示例验证

使用 BERT 和逻辑回归进行文本分类及示例验证 一、引言 在自然语言处理领域中,文本分类是一项至关重要的任务。本文将详细介绍如何结合 BERT 模型与逻辑回归算法来实现文本分类,并通过实际示例进行验证。 二、环境准备 为了运行本文中的代码&#xf…...

【skywalking 】监控 Spring Cloud Gateway 数据

使用Spring Cloud 开发,用Skywalking 监控服务,但是Skywalking 默认是不支持 Spring Cloud Gateway 网关服务的,需要手动将 Gateway 的插件添加到 Skywalking 启动依赖 jar 中。 skywalking相关版本信息 jdk:17skywalking&#x…...

SpringWeb

SpringWeb SpringWeb 概述 SpringWeb 是 spring 框架中的一个模块,基于 Servlet API 构建的 web 框架. springWeb 是 Spring 为 web 层开发提供的一整套完备的解决方案。 在 web 层框架历经 Strust1,WebWork,Strust2 等诸多产品的历代更…...

嵌入式刷题(day21)

MySQL和sqlite的区别 MySQL和SQLite是两种常见的关系型数据库管理系统(RDBMS),但它们在特性、使用场景和架构方面有显著的区别: 1. 架构 MySQL:是一个基于服务器的数据库系统,遵循客户端-服务器架构。MySQL服务器运行在主机上,客户端通过网络连接并发送查询。它可以并…...

OpenAI 下一代旗舰模型现身?奥尔特曼亲自辟谣“猎户座“传闻

在人工智能领域最受瞩目的ChatGPT即将迎来两周岁之际,一场关于OpenAI新旗舰模型的传闻再次引发业界热议。然而,这场喧嚣很快就被OpenAI掌门人奥尔特曼亲自澄清。 事件源于科技媒体The Verge的一则报道。据多位知情人士透露,OpenAI可能会在11…...

【C++】STL初识

【C】STL初识 文章目录 【C】STL初识前言一、STL基本概念二、STL六大组件简介三、STL三大组件四、初识STL总结 前言 本篇文章将讲到STL基本概念,STL六大组件简介,STL三大组件,初识STL。 一、STL基本概念 STL(Standard Template Library,标准…...

框架篇补充(东西多 需要重新看网课)

什么是AOP 面向切面编程 降低耦合 提高代码的复用 Spring的bean的生命周期 实例化bean 赋值 初始化bean 使用bean 销毁bean SpringMVC的执行流程 Springboot自动装配原理 实际上就是为了从spring.factories文件中 获取到对应的需要 进行自动装配的类 并生成相应的Bean…...

合约门合同全生命周期管理系统:企业合同管理的数字化转型之道

合约门合同全生命周期管理系统:企业合同管理的数字化转型之道 1. 引言 在现代企业中,合同管理已经不再是简单的文件存储和审批流程,而是企业合规性、风险管理和业务流程的关键环节之一。随着企业规模的扩大和合同数量的增加,传统…...

等保测评与风险管理:识别、评估和缓解潜在的安全威胁

在信息化时代,数据已成为企业最宝贵的资产之一,而信息安全则成为守护这份资产免受侵害的重中之重。等保测评(信息安全等级保护测评)作为保障信息系统安全的重要手段,其核心在于通过科学、规范、专业的评估手段&#xf…...

Golang Agent 可观测性的全面升级与新特性介绍

作者:张海彬(古琦) 背景 自 2024 年 6 月 26 日,ARMS 发布了针对 Golang 应用的可观测性监控功能以来,阿里云 ARMS 团队与程序语言与编译器团队一直致力于不断优化和提升该系统的各项功能,旨在为开发者提…...

SpringBoot的开篇 特点 初始化 ioc 配置文件

文章目录 前言SpringBoot发展历程SpringBoot前置准备SpringBoot特点 SpringBoot项目初始化项目启动Springboot的核心概念IOC概念介绍Bean对象通过注解扫描包 例子配置文件 前言 SpringBoot发展历程 最初,Spring框架的使用需要大量的XML配置,这使得开发…...

docker 可用镜像服务地址(2024.10.25亲测可用)

1.错误 Error response from daemon: Get “https://registry-1.docker.io/v2/” 原因:镜像服务器地址不可用。 2.可用地址 编辑daemon.json: vi /etc/docker/daemon.json内容修改如下: {"registry-mirrors": ["https://…...

【SQL实验】表的更新和简单查询

完整代码在文章末尾 在上次实验创建的educ数据库基础上,用SQL语句为student表、course表和sc表中添加以下记录 【SQL实验】数据库、表、模式的SQL语句操作_创建一个名为educ数据库,要求如下: (下面三个表中属性的数据类型需要自己设计合适-CSDN博客在这篇博文中已经…...

【C++】 string的了解及使用

标准库中的string类 在使用string类时&#xff0c;必须包含#include头文件以及using namespace std; string类的常用接口说明 C中string为我们提供了丰富的接口来供我们使用 – string接口文档 这里我们只介绍一些常见的接口 string类对象的常见构造 #include <iostrea…...

【K8S】kubernetes-dashboard.yaml

https://raw.githubusercontent.com/kubernetes/dashboard/v3.0.0-alpha0/charts/kubernetes-dashboard.yaml 以下链接的内容&#xff1a; 由于国内访问不了&#xff0c;找到一些方法下载了这个文件内容&#xff0c; 部署是mages 对象的镜像 WEB docker.io/kubernetesui/dash…...

远程root用户访问服务器中的MySQL8

一、Ubuntu下的MySQL8安装 在Ubuntu系统中安装MySQL 8.0可以通过以下步骤进行1. 更新包管理工具的仓库列表&#xff1a; sudo apt update 2. 安装MySQL 8.0&#xff0c;root用户默认没有密码&#xff1a; sudo apt install mysql-server sudo apt install mysql-client 【…...

解释一下 Java 中的静态变量(Static Variable)和静态方法(Static Method)?

今天来和大家深入探讨一下 Java 中的静态变量和静态方法&#xff0c;并通过一些具体的例子来理解它们在实际开发中的应用。 静态变量&#xff08;Static Variable&#xff09; 静态变量&#xff0c;也称为类变量&#xff0c;是在类的层次上共享的变量。这意味着无论创建了多少…...

【Linux】————磁盘与文件系统

作者主页&#xff1a; 作者主页 本篇博客专栏&#xff1a;Linux 创作时间 &#xff1a;2024年10月17日 一、磁盘的物理结构 磁盘的物理结构如图所示&#xff1a; 其中具体的物理存储结构如下&#xff1a; 磁盘中存储的基本单位为扇区&#xff0c;一个扇区的大小一般为512字…...

平衡控制——直立环——速度环

目录 平衡控制原理 平衡控制模型 平衡控制中基于模型设计与自动代码生成技术 速度环应用原理 速度控制模型 平衡控制原理 下图是一个单摆模型&#xff0c;对其进行受力分析如图。 在重力作用下,单摆受到和角度成正比,运动方向相反的回复力。而且在空气中运动的单摆,由于受…...

面试简要介绍hashMap

jdk8之前&#xff0c;hashmap采用的数据结构是数组链表&#xff0c;jdk8之后采用的数据结构是数组链表/红黑树。hashmap的数据以键值对的形式存在&#xff0c;如果两个元素的hash值相同&#xff0c;就会发生hash冲突&#xff0c;被放到同一个链表上--->如何解决hash冲突---&…...

大话软工笔记—需求分析概述

需求分析&#xff0c;就是要对需求调研收集到的资料信息逐个地进行拆分、研究&#xff0c;从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要&#xff0c;后续设计的依据主要来自于需求分析的成果&#xff0c;包括: 项目的目的…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

毫米波雷达基础理论(3D+4D)

3D、4D毫米波雷达基础知识及厂商选型 PreView : https://mp.weixin.qq.com/s/bQkju4r6med7I3TBGJI_bQ 1. FMCW毫米波雷达基础知识 主要参考博文&#xff1a; 一文入门汽车毫米波雷达基本原理 &#xff1a;https://mp.weixin.qq.com/s/_EN7A5lKcz2Eh8dLnjE19w 毫米波雷达基础…...

安卓基础(Java 和 Gradle 版本)

1. 设置项目的 JDK 版本 方法1&#xff1a;通过 Project Structure File → Project Structure... (或按 CtrlAltShiftS) 左侧选择 SDK Location 在 Gradle Settings 部分&#xff0c;设置 Gradle JDK 方法2&#xff1a;通过 Settings File → Settings... (或 CtrlAltS)…...

用鸿蒙HarmonyOS5实现中国象棋小游戏的过程

下面是一个基于鸿蒙OS (HarmonyOS) 的中国象棋小游戏的实现代码。这个实现使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chinesechess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├──…...

渗透实战PortSwigger靶场:lab13存储型DOM XSS详解

进来是需要留言的&#xff0c;先用做简单的 html 标签测试 发现面的</h1>不见了 数据包中找到了一个loadCommentsWithVulnerableEscapeHtml.js 他是把用户输入的<>进行 html 编码&#xff0c;输入的<>当成字符串处理回显到页面中&#xff0c;看来只是把用户输…...

二维FDTD算法仿真

二维FDTD算法仿真&#xff0c;并带完全匹配层&#xff0c;输入波形为高斯波、平面波 FDTD_二维/FDTD.zip , 6075 FDTD_二维/FDTD_31.m , 1029 FDTD_二维/FDTD_32.m , 2806 FDTD_二维/FDTD_33.m , 3782 FDTD_二维/FDTD_34.m , 4182 FDTD_二维/FDTD_35.m , 4793...

java高级——高阶函数、如何定义一个函数式接口类似stream流的filter

java高级——高阶函数、stream流 前情提要文章介绍一、函数伊始1.1 合格的函数1.2 有形的函数2. 函数对象2.1 函数对象——行为参数化2.2 函数对象——延迟执行 二、 函数编程语法1. 函数对象表现形式1.1 Lambda表达式1.2 方法引用&#xff08;Math::max&#xff09; 2 函数接口…...

鸿蒙HarmonyOS 5军旗小游戏实现指南

1. 项目概述 本军旗小游戏基于鸿蒙HarmonyOS 5开发&#xff0c;采用DevEco Studio实现&#xff0c;包含完整的游戏逻辑和UI界面。 2. 项目结构 /src/main/java/com/example/militarychess/├── MainAbilitySlice.java // 主界面├── GameView.java // 游戏核…...