pytorch使用SVM实现文本分类
人工智能例子汇总:AI常见的算法和例子-CSDN博客
完整代码:
import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics# 1. 数据准备(中文文本)
texts = ["今天的足球比赛非常激烈,球队表现出色,最终赢得了比赛。","NBA比赛今天开打,球员们的表现非常精彩,球迷们热情高涨。","张艺谋的新电影上映了,票房成绩非常好,观众反响热烈。","娱乐圈最近又出了一些新闻,明星们的私生活成了大家讨论的焦点。","昨晚的篮球赛真是太精彩了,球员们的进攻和防守都非常强硬。","李宇春在最新的音乐会上演出了她的新歌,现场观众反应热烈。","今年的世界杯比赛激烈异常,球队之间的竞争越来越激烈。","最近的综艺节目非常火,明星嘉宾的表现让观众们大笑不已。"
]# 标签:0表示体育,1表示娱乐
labels = [0, 0, 1, 1, 0, 1, 0, 1]# 2. 数据预处理:中文分词和 TF-IDF 特征提取
def jieba_cut(text):return " ".join(jieba.cut(text))texts_cut = [jieba_cut(text) for text in texts]vectorizer = TfidfVectorizer(max_features=10000)
X_tfidf = vectorizer.fit_transform(texts_cut).toarray()
y = np.array(labels)# 3. 数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=42)# 4. PyTorch 数据加载
class NewsGroupDataset(torch.utils.data.Dataset):def __init__(self, features, labels):self.features = torch.tensor(features, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.features)def __getitem__(self, idx):return self.features[idx], self.labels[idx]train_dataset = NewsGroupDataset(X_train, y_train)
test_dataset = NewsGroupDataset(X_test, y_test)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)# 5. 定义 SVM 模型(使用线性层)
class SVM(nn.Module):def __init__(self, input_dim, output_dim):super(SVM, self).__init__()self.fc = nn.Linear(input_dim, output_dim)def forward(self, x):return self.fc(x)# 6. 获取特征数并初始化模型
input_dim = X_tfidf.shape[1] # 自动获取特征数
model = SVM(input_dim=input_dim, output_dim=2) # 使用特征数量设置输入维度
criterion = nn.CrossEntropyLoss() # 损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01) # 优化器,调整学习率# 7. 训练模型
num_epochs = 50 # 增加训练周期for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total}%")# 8. 测试模型
model.eval()
correct = 0
total = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {100 * correct / total}%")# 9. 输出性能指标
y_pred = []
y_true = []with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)y_pred.extend(predicted.numpy())y_true.extend(labels.numpy())print(metrics.classification_report(y_true, y_pred))# 10. 测试新样本
def predict(text, model, vectorizer):# 1. 进行分词text_cut = jieba_cut(text)# 2. 将文本转为 TF-IDF 特征向量text_tfidf = vectorizer.transform([text_cut]).toarray()# 3. 转换为 PyTorch 张量text_tensor = torch.tensor(text_tfidf, dtype=torch.float32)# 4. 模型预测model.eval() # 设置模型为评估模式with torch.no_grad():output = model(text_tensor)_, predicted = torch.max(output.data, 1)# 5. 返回预测结果return predicted.item()# 测试一个新的中文文本
new_text = "今天的篮球比赛真是太精彩了,球员们的表现让大家都为之喝彩。"
predicted_label = predict(new_text, model, vectorizer)# 输出预测结果
if predicted_label == 0:print("预测类别: 体育")
else:print("预测类别: 娱乐")
1. 数据准备
- 文本数据:我们定义了一个包含中文文本的列表,每条文本表示一个新闻或评论。
- 标签:为每条文本分配了一个标签,0 代表“体育”,1 代表“娱乐”。
2. 数据预处理
- 中文分词:使用
jieba
库对每条文本进行分词,并将分词后的结果连接成字符串。这是处理中文文本时的常见做法。 - TF-IDF 特征提取:使用
TfidfVectorizer
将文本转化为数值特征。TF-IDF 是一种常见的文本表示方式,能够衡量单词在文档中的重要性。
3. 数据集分割
- 使用
train_test_split
将数据分为训练集和测试集。80% 的数据用于训练,20% 用于测试。
4. PyTorch 数据加载
- 定义 Dataset 类:创建了一个自定义的
NewsGroupDataset
类,继承自torch.utils.data.Dataset
,用于将文本特征和标签封装为 PyTorch 可用的数据集格式。 - DataLoader:使用
DataLoader
将训练集和测试集数据进行批处理和加载。
5. 模型定义
- 定义了一个简单的线性 SVM 模型。实际上,使用了一个线性层 (
nn.Linear
) 来进行分类,输入是文本的 TF-IDF 特征,输出是两个类别(体育或娱乐)。 - 使用了
CrossEntropyLoss
作为损失函数,因为这是分类任务中常用的损失函数。 - 优化器使用了随机梯度下降(SGD),并设置了学习率为 0.01。
6. 训练过程
- 训练模型的过程包括:前向传播(计算输出),计算损失,反向传播(更新参数),并在每个 epoch 后输出损失和准确率。
- 每个 batch 的训练过程中,模型会通过计算损失并进行优化,逐步提升准确率。
7. 测试和评估
- 在测试过程中,将模型设置为评估模式 (
model.eval()
),并计算测试集上的准确率。通过比较预测标签与真实标签,计算正确的预测数量并输出准确率。 - 使用
classification_report
来输出精确度、召回率和 F1 分数等更多的评估指标。
8. 预测新样本
- 定义了一个
predict()
函数,用于预测新的文本样本的分类。 - 预测过程包括:对文本进行分词,转化为 TF-IDF 特征,传入模型进行前向传播,最后返回模型预测的标签。
9. 输出
- 代码的最后,会输出模型对新文本的预测结果,标明是属于体育类别还是娱乐类别。
关键技术点:
- 中文分词:使用
jieba
对中文文本进行分词处理,这对于中文文本的处理至关重要。 - TF-IDF:将文本转换为数值特征,便于模型处理。TF-IDF 是基于单词在文档中的出现频率及其在整个语料中的稀有度进行加权的。
- 模型训练与评估:通过多轮训练提升模型准确度,使用测试集来评估模型的泛化能力。
- PyTorch DataLoader:通过
DataLoader
高效地处理训练集和测试集,进行批处理和自动化管理。
相关文章:
pytorch使用SVM实现文本分类
人工智能例子汇总:AI常见的算法和例子-CSDN博客 完整代码: import torch import torch.nn as nn import torch.optim as optim import jieba import numpy as np from sklearn.model_selection import train_test_split from sklearn.feature_extract…...
安卓(android)读取手机通讯录【Android移动开发基础案例教程(第2版)黑马程序员】
一、实验目的(如果代码有错漏,可在代码地址查看) 1.熟悉内容提供者(Content Provider)的概念和作用。 2.掌握内容提供者的创建和使用方法。 4.掌握内容URI的结构和用途。 二、实验条件 1.熟悉内容提供者的工作原理。 2.掌握内容提供者访问其…...
【Qt】常用的容器
Qt提供了多个基于模板的容器类,这些容器类可用于存储指定类型的数据项。例如常用的字符串列表类 QStringList 可用来操作一个 QList<QString>列表。 Qt的容器类比标准模板库(standard template library,STL)中的容器类更轻巧、使用更安全且更易于使…...

基于UKF-IMM无迹卡尔曼滤波与交互式多模型的轨迹跟踪算法matlab仿真,对比EKF-IMM和UKF
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于UKF-IMM无迹卡尔曼滤波与交互式多模型的轨迹跟踪算法matlab仿真,对比EKF-IMM和UKF。 2.测试软件版本以及运行结果展示 MATLAB2022A版本运行 3.核心程序 .…...

分布式事务组件Seata简介与使用,搭配Nacos统一管理服务端和客户端配置
文章目录 一. Seata简介二. 官方文档三. Seata分布式事务代码实现0. 环境简介1. 添加undo_log表2. 添加依赖3. 添加配置4. 开启Seata事务管理5. 启动演示 四. Seata Server配置Nacos1. 修改配置类型2. 创建Nacos配置 五. Seata Client配置Nacos1. 增加Seata关联Nacos的配置2. 在…...
JavaScript常用的内置构造函数
JavaScript作为一种广泛应用的编程语言,提供了丰富的内置构造函数,帮助开发者处理不同类型的数据和操作。这些内置构造函数在创建和操作对象时非常有用。本文将详细介绍JavaScript中常用的内置构造函数及其用途。 常用内置构造函数概述 1. Object Obj…...

25寒假算法刷题 | Day1 | LeetCode 240. 搜索二维矩阵 II,148. 排序链表
目录 240. 搜索二维矩阵 II题目描述题解 148. 排序链表题目描述题解 240. 搜索二维矩阵 II 点此跳转题目链接 题目描述 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性: 每行的元素从左到右升序排列。每列的元素从上到…...

MQTT知识
MQTT协议 MQTT 是一种基于发布/订阅模式的轻量级消息传输协议,专门针对低带宽和不稳定网络环境的物联网应用而设计,可以用极少的代码为联网设备提供实时可靠的消息服务。MQTT 协议广泛应用于物联网、移动互联网、智能硬件、车联网、智慧城市、远程医疗、…...

【机器学习与数据挖掘实战】案例11:基于灰色预测和SVR的企业所得税预测分析
【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈机器学习与数据挖掘实战 ⌋ ⌋ ⌋ 机器学习是人工智能的一个分支,专注于让计算机系统通过数据学习和改进。它利用统计和计算方法,使模型能够从数据中自动提取特征并做出预测或决策。数据挖掘则是从大型数据集中发现模式、关联…...

新一代搜索引擎,是 ES 的15倍?
Manticore Search介绍 Manticore Search 是一个使用 C 开发的高性能搜索引擎,创建于 2017 年,其前身是 Sphinx Search 。Manticore Search 充分利用了 Sphinx,显着改进了它的功能,修复了数百个错误,几乎完全重写了代码…...
使用 Context API 管理临时状态,避免 Redux/Zustand 的持久化陷阱
在开发 React Native 应用时,我们经常需要管理全局状态,比如用户信息、主题设置、网络状态等。而对于某些临时状态,例如 数据同步进行中的状态 (isSyncing),我们应该选择什么方式来管理它? 在项目开发过程中ÿ…...

PyTorch框架——基于深度学习YOLOv8神经网络学生课堂行为检测识别系统
基于YOLOv8深度学习的学生课堂行为检测识别系统,其能识别三种学生课堂行为:names: [举手, 读书, 写字] 具体图片见如下: 第一步:YOLOv8介绍 YOLOv8 是 ultralytics 公司在 2023 年 1月 10 号开源的 YOLOv5 的下一个重大更新版本…...

word2vec 实战应用介绍
Word2Vec 是一种由 Google 在 2013 年推出的重要词嵌入模型,通过将单词映射为低维向量,实现了对自然语言处理任务的高效支持。其核心思想是利用深度学习技术,通过训练大量文本数据,将单词表示为稠密的向量形式,从而捕捉单词之间的语义和语法关系。以下是关于 Word2Vec 实战…...
C# 操作符重载对象详解
.NET学习资料 .NET学习资料 .NET学习资料 一、操作符重载的概念 在 C# 中,操作符重载允许我们为自定义的类或结构体定义操作符的行为。通常,我们熟悉的操作符,如加法()、减法(-)、乘法&#…...

python学opencv|读取图像(五十四)使用cv2.blur()函数实现图像像素均值处理
【1】引言 前序学习进程中,对图像的操作均基于各个像素点上的BGR值不同而展开。 对于彩色图像,每个像素点上的BGR值为三个整数,因为是三通道图像;对于灰度图像,各个像素上的BGR值是一个整数,因为这是单通…...
CNN的各种知识点(四): 非极大值抑制(Non-Maximum Suppression, NMS)
非极大值抑制(Non-Maximum Suppression, NMS) 1. 非极大值抑制(Non-Maximum Suppression, NMS)概念:算法步骤:具体例子:PyTorch实现: 总结: 1. 非极大值抑制(…...
虚幻基础16:locomotion direction
locomotion locomotion:角色运动系统的总称:包含移动、奔跑、跳跃、转向等。 locomotion direction 玩家输入 玩家输入:通常代表玩家想要的移动方向。 direction 可以计算当前朝向与移动方向的Δ。从而实现朝向与移动(玩家输入)方向的分…...

C++游戏开发实战:从引擎架构到物理碰撞
📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 1. 引言 C 是游戏开发中最受欢迎的编程语言之一,因其高性能、低延迟和强大的底层控制能力,被广泛用于游戏…...

代理模式——C++实现
目录 1. 代理模式简介 2. 代码示例 1. 代理模式简介 代理模式是一种行为型模式。 代理模式的定义:由于某些原因需要给某对象提供一个代理以控制该对象的访问。这时,访问对象不适合或者不能直接访问引用目标对象,代理对象作为访问对象和目标…...
什么情况下,C#需要手动进行资源分配和释放?什么又是非托管资源?
扩展:如何使用C#的using语句释放资源?什么是IDisposable接口?与垃圾回收有什么关系?-CSDN博客 托管资源的回收有GC自动触发,而非托管资源需要手动释放。 在 C# 中,非托管资源是指那些不由 CLR(…...
ES6从入门到精通:前言
ES6简介 ES6(ECMAScript 2015)是JavaScript语言的重大更新,引入了许多新特性,包括语法糖、新数据类型、模块化支持等,显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var…...
从零实现富文本编辑器#5-编辑器选区模型的状态结构表达
先前我们总结了浏览器选区模型的交互策略,并且实现了基本的选区操作,还调研了自绘选区的实现。那么相对的,我们还需要设计编辑器的选区表达,也可以称为模型选区。编辑器中应用变更时的操作范围,就是以模型选区为基准来…...

python/java环境配置
环境变量放一起 python: 1.首先下载Python Python下载地址:Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个,然后自定义,全选 可以把前4个选上 3.环境配置 1)搜高级系统设置 2…...

(二)TensorRT-LLM | 模型导出(v0.20.0rc3)
0. 概述 上一节 对安装和使用有个基本介绍。根据这个 issue 的描述,后续 TensorRT-LLM 团队可能更专注于更新和维护 pytorch backend。但 tensorrt backend 作为先前一直开发的工作,其中包含了大量可以学习的地方。本文主要看看它导出模型的部分&#x…...
渲染学进阶内容——模型
最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...

Nuxt.js 中的路由配置详解
Nuxt.js 通过其内置的路由系统简化了应用的路由配置,使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...

论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)
笔记整理:刘治强,浙江大学硕士生,研究方向为知识图谱表示学习,大语言模型 论文链接:http://arxiv.org/abs/2407.16127 发表会议:ISWC 2024 1. 动机 传统的知识图谱补全(KGC)模型通过…...
【AI学习】三、AI算法中的向量
在人工智能(AI)算法中,向量(Vector)是一种将现实世界中的数据(如图像、文本、音频等)转化为计算机可处理的数值型特征表示的工具。它是连接人类认知(如语义、视觉特征)与…...
Python如何给视频添加音频和字幕
在Python中,给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加,包括必要的代码示例和详细解释。 环境准备 在开始之前,需要安装以下Python库:…...
是否存在路径(FIFOBB算法)
题目描述 一个具有 n 个顶点e条边的无向图,该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序,确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数,分别表示n 和 e 的值(1…...