当前位置: 首页 > news >正文

pytorch使用SVM实现文本分类

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics# 1. 数据准备(中文文本)
texts = ["今天的足球比赛非常激烈,球队表现出色,最终赢得了比赛。","NBA比赛今天开打,球员们的表现非常精彩,球迷们热情高涨。","张艺谋的新电影上映了,票房成绩非常好,观众反响热烈。","娱乐圈最近又出了一些新闻,明星们的私生活成了大家讨论的焦点。","昨晚的篮球赛真是太精彩了,球员们的进攻和防守都非常强硬。","李宇春在最新的音乐会上演出了她的新歌,现场观众反应热烈。","今年的世界杯比赛激烈异常,球队之间的竞争越来越激烈。","最近的综艺节目非常火,明星嘉宾的表现让观众们大笑不已。"
]# 标签:0表示体育,1表示娱乐
labels = [0, 0, 1, 1, 0, 1, 0, 1]# 2. 数据预处理:中文分词和 TF-IDF 特征提取
def jieba_cut(text):return " ".join(jieba.cut(text))texts_cut = [jieba_cut(text) for text in texts]vectorizer = TfidfVectorizer(max_features=10000)
X_tfidf = vectorizer.fit_transform(texts_cut).toarray()
y = np.array(labels)# 3. 数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=42)# 4. PyTorch 数据加载
class NewsGroupDataset(torch.utils.data.Dataset):def __init__(self, features, labels):self.features = torch.tensor(features, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.features)def __getitem__(self, idx):return self.features[idx], self.labels[idx]train_dataset = NewsGroupDataset(X_train, y_train)
test_dataset = NewsGroupDataset(X_test, y_test)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)# 5. 定义 SVM 模型(使用线性层)
class SVM(nn.Module):def __init__(self, input_dim, output_dim):super(SVM, self).__init__()self.fc = nn.Linear(input_dim, output_dim)def forward(self, x):return self.fc(x)# 6. 获取特征数并初始化模型
input_dim = X_tfidf.shape[1]  # 自动获取特征数
model = SVM(input_dim=input_dim, output_dim=2)  # 使用特征数量设置输入维度
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 优化器,调整学习率# 7. 训练模型
num_epochs = 50  # 增加训练周期for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total}%")# 8. 测试模型
model.eval()
correct = 0
total = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {100 * correct / total}%")# 9. 输出性能指标
y_pred = []
y_true = []with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)y_pred.extend(predicted.numpy())y_true.extend(labels.numpy())print(metrics.classification_report(y_true, y_pred))# 10. 测试新样本
def predict(text, model, vectorizer):# 1. 进行分词text_cut = jieba_cut(text)# 2. 将文本转为 TF-IDF 特征向量text_tfidf = vectorizer.transform([text_cut]).toarray()# 3. 转换为 PyTorch 张量text_tensor = torch.tensor(text_tfidf, dtype=torch.float32)# 4. 模型预测model.eval()  # 设置模型为评估模式with torch.no_grad():output = model(text_tensor)_, predicted = torch.max(output.data, 1)# 5. 返回预测结果return predicted.item()# 测试一个新的中文文本
new_text = "今天的篮球比赛真是太精彩了,球员们的表现让大家都为之喝彩。"
predicted_label = predict(new_text, model, vectorizer)# 输出预测结果
if predicted_label == 0:print("预测类别: 体育")
else:print("预测类别: 娱乐")

1. 数据准备

  • 文本数据:我们定义了一个包含中文文本的列表,每条文本表示一个新闻或评论。
  • 标签:为每条文本分配了一个标签,0 代表“体育”,1 代表“娱乐”。

2. 数据预处理

  • 中文分词:使用 jieba 库对每条文本进行分词,并将分词后的结果连接成字符串。这是处理中文文本时的常见做法。
  • TF-IDF 特征提取:使用 TfidfVectorizer 将文本转化为数值特征。TF-IDF 是一种常见的文本表示方式,能够衡量单词在文档中的重要性。

3. 数据集分割

  • 使用 train_test_split 将数据分为训练集和测试集。80% 的数据用于训练,20% 用于测试。

4. PyTorch 数据加载

  • 定义 Dataset 类:创建了一个自定义的 NewsGroupDataset 类,继承自 torch.utils.data.Dataset,用于将文本特征和标签封装为 PyTorch 可用的数据集格式。
  • DataLoader:使用 DataLoader 将训练集和测试集数据进行批处理和加载。

5. 模型定义

  • 定义了一个简单的线性 SVM 模型。实际上,使用了一个线性层 (nn.Linear) 来进行分类,输入是文本的 TF-IDF 特征,输出是两个类别(体育或娱乐)。
  • 使用了 CrossEntropyLoss 作为损失函数,因为这是分类任务中常用的损失函数。
  • 优化器使用了随机梯度下降(SGD),并设置了学习率为 0.01。

6. 训练过程

  • 训练模型的过程包括:前向传播(计算输出),计算损失,反向传播(更新参数),并在每个 epoch 后输出损失和准确率。
  • 每个 batch 的训练过程中,模型会通过计算损失并进行优化,逐步提升准确率。

7. 测试和评估

  • 在测试过程中,将模型设置为评估模式 (model.eval()),并计算测试集上的准确率。通过比较预测标签与真实标签,计算正确的预测数量并输出准确率。
  • 使用 classification_report 来输出精确度、召回率和 F1 分数等更多的评估指标。

8. 预测新样本

  • 定义了一个 predict() 函数,用于预测新的文本样本的分类。
  • 预测过程包括:对文本进行分词,转化为 TF-IDF 特征,传入模型进行前向传播,最后返回模型预测的标签。

9. 输出

  • 代码的最后,会输出模型对新文本的预测结果,标明是属于体育类别还是娱乐类别。

关键技术点:

  • 中文分词:使用 jieba 对中文文本进行分词处理,这对于中文文本的处理至关重要。
  • TF-IDF:将文本转换为数值特征,便于模型处理。TF-IDF 是基于单词在文档中的出现频率及其在整个语料中的稀有度进行加权的。
  • 模型训练与评估:通过多轮训练提升模型准确度,使用测试集来评估模型的泛化能力。
  • PyTorch DataLoader:通过 DataLoader 高效地处理训练集和测试集,进行批处理和自动化管理。

相关文章:

pytorch使用SVM实现文本分类

人工智能例子汇总:AI常见的算法和例子-CSDN博客 完整代码: import torch import torch.nn as nn import torch.optim as optim import jieba import numpy as np from sklearn.model_selection import train_test_split from sklearn.feature_extract…...

安卓(android)读取手机通讯录【Android移动开发基础案例教程(第2版)黑马程序员】

一、实验目的(如果代码有错漏,可在代码地址查看) 1.熟悉内容提供者(Content Provider)的概念和作用。 2.掌握内容提供者的创建和使用方法。 4.掌握内容URI的结构和用途。 二、实验条件 1.熟悉内容提供者的工作原理。 2.掌握内容提供者访问其…...

【Qt】常用的容器

Qt提供了多个基于模板的容器类&#xff0c;这些容器类可用于存储指定类型的数据项。例如常用的字符串列表类 QStringList 可用来操作一个 QList<QString>列表。 Qt的容器类比标准模板库(standard template library&#xff0c;STL)中的容器类更轻巧、使用更安全且更易于使…...

基于UKF-IMM无迹卡尔曼滤波与交互式多模型的轨迹跟踪算法matlab仿真,对比EKF-IMM和UKF

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于UKF-IMM无迹卡尔曼滤波与交互式多模型的轨迹跟踪算法matlab仿真,对比EKF-IMM和UKF。 2.测试软件版本以及运行结果展示 MATLAB2022A版本运行 3.核心程序 .…...

分布式事务组件Seata简介与使用,搭配Nacos统一管理服务端和客户端配置

文章目录 一. Seata简介二. 官方文档三. Seata分布式事务代码实现0. 环境简介1. 添加undo_log表2. 添加依赖3. 添加配置4. 开启Seata事务管理5. 启动演示 四. Seata Server配置Nacos1. 修改配置类型2. 创建Nacos配置 五. Seata Client配置Nacos1. 增加Seata关联Nacos的配置2. 在…...

JavaScript常用的内置构造函数

JavaScript作为一种广泛应用的编程语言&#xff0c;提供了丰富的内置构造函数&#xff0c;帮助开发者处理不同类型的数据和操作。这些内置构造函数在创建和操作对象时非常有用。本文将详细介绍JavaScript中常用的内置构造函数及其用途。 常用内置构造函数概述 1. Object Obj…...

25寒假算法刷题 | Day1 | LeetCode 240. 搜索二维矩阵 II,148. 排序链表

目录 240. 搜索二维矩阵 II题目描述题解 148. 排序链表题目描述题解 240. 搜索二维矩阵 II 点此跳转题目链接 题目描述 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性&#xff1a; 每行的元素从左到右升序排列。每列的元素从上到…...

MQTT知识

MQTT协议 MQTT 是一种基于发布/订阅模式的轻量级消息传输协议&#xff0c;专门针对低带宽和不稳定网络环境的物联网应用而设计&#xff0c;可以用极少的代码为联网设备提供实时可靠的消息服务。MQTT 协议广泛应用于物联网、移动互联网、智能硬件、车联网、智慧城市、远程医疗、…...

【机器学习与数据挖掘实战】案例11:基于灰色预测和SVR的企业所得税预测分析

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈机器学习与数据挖掘实战 ⌋ ⌋ ⌋ 机器学习是人工智能的一个分支,专注于让计算机系统通过数据学习和改进。它利用统计和计算方法,使模型能够从数据中自动提取特征并做出预测或决策。数据挖掘则是从大型数据集中发现模式、关联…...

新一代搜索引擎,是 ES 的15倍?

Manticore Search介绍 Manticore Search 是一个使用 C 开发的高性能搜索引擎&#xff0c;创建于 2017 年&#xff0c;其前身是 Sphinx Search 。Manticore Search 充分利用了 Sphinx&#xff0c;显着改进了它的功能&#xff0c;修复了数百个错误&#xff0c;几乎完全重写了代码…...

使用 Context API 管理临时状态,避免 Redux/Zustand 的持久化陷阱

在开发 React Native 应用时&#xff0c;我们经常需要管理全局状态&#xff0c;比如用户信息、主题设置、网络状态等。而对于某些临时状态&#xff0c;例如 数据同步进行中的状态 (isSyncing)&#xff0c;我们应该选择什么方式来管理它&#xff1f; 在项目开发过程中&#xff…...

PyTorch框架——基于深度学习YOLOv8神经网络学生课堂行为检测识别系统

基于YOLOv8深度学习的学生课堂行为检测识别系统&#xff0c;其能识别三种学生课堂行为&#xff1a;names: [举手, 读书, 写字] 具体图片见如下&#xff1a; 第一步&#xff1a;YOLOv8介绍 YOLOv8 是 ultralytics 公司在 2023 年 1月 10 号开源的 YOLOv5 的下一个重大更新版本…...

word2vec 实战应用介绍

Word2Vec 是一种由 Google 在 2013 年推出的重要词嵌入模型,通过将单词映射为低维向量,实现了对自然语言处理任务的高效支持。其核心思想是利用深度学习技术,通过训练大量文本数据,将单词表示为稠密的向量形式,从而捕捉单词之间的语义和语法关系。以下是关于 Word2Vec 实战…...

C# 操作符重载对象详解

.NET学习资料 .NET学习资料 .NET学习资料 一、操作符重载的概念 在 C# 中&#xff0c;操作符重载允许我们为自定义的类或结构体定义操作符的行为。通常&#xff0c;我们熟悉的操作符&#xff0c;如加法&#xff08;&#xff09;、减法&#xff08;-&#xff09;、乘法&#…...

python学opencv|读取图像(五十四)使用cv2.blur()函数实现图像像素均值处理

【1】引言 前序学习进程中&#xff0c;对图像的操作均基于各个像素点上的BGR值不同而展开。 对于彩色图像&#xff0c;每个像素点上的BGR值为三个整数&#xff0c;因为是三通道图像&#xff1b;对于灰度图像&#xff0c;各个像素上的BGR值是一个整数&#xff0c;因为这是单通…...

CNN的各种知识点(四): 非极大值抑制(Non-Maximum Suppression, NMS)

非极大值抑制&#xff08;Non-Maximum Suppression, NMS&#xff09; 1. 非极大值抑制&#xff08;Non-Maximum Suppression, NMS&#xff09;概念&#xff1a;算法步骤&#xff1a;具体例子&#xff1a;PyTorch实现&#xff1a; 总结&#xff1a; 1. 非极大值抑制&#xff08;…...

虚幻基础16:locomotion direction

locomotion locomotion&#xff1a;角色运动系统的总称&#xff1a;包含移动、奔跑、跳跃、转向等。 locomotion direction 玩家输入 玩家输入&#xff1a;通常代表玩家想要的移动方向。 direction 可以计算当前朝向与移动方向的Δ。从而实现朝向与移动(玩家输入)方向的分…...

C++游戏开发实战:从引擎架构到物理碰撞

&#x1f4dd;个人主页&#x1f339;&#xff1a;一ge科研小菜鸡-CSDN博客 &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; 1. 引言 C 是游戏开发中最受欢迎的编程语言之一&#xff0c;因其高性能、低延迟和强大的底层控制能力&#xff0c;被广泛用于游戏…...

代理模式——C++实现

目录 1. 代理模式简介 2. 代码示例 1. 代理模式简介 代理模式是一种行为型模式。 代理模式的定义&#xff1a;由于某些原因需要给某对象提供一个代理以控制该对象的访问。这时&#xff0c;访问对象不适合或者不能直接访问引用目标对象&#xff0c;代理对象作为访问对象和目标…...

什么情况下,C#需要手动进行资源分配和释放?什么又是非托管资源?

扩展&#xff1a;如何使用C#的using语句释放资源&#xff1f;什么是IDisposable接口&#xff1f;与垃圾回收有什么关系&#xff1f;-CSDN博客 托管资源的回收有GC自动触发&#xff0c;而非托管资源需要手动释放。 在 C# 中&#xff0c;非托管资源是指那些不由 CLR&#xff08;…...

【Linux】C语言执行shell指令

在C语言中执行Shell指令 在C语言中&#xff0c;有几种方法可以执行Shell指令&#xff1a; 1. 使用system()函数 这是最简单的方法&#xff0c;包含在stdlib.h头文件中&#xff1a; #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...

Debian系统简介

目录 Debian系统介绍 Debian版本介绍 Debian软件源介绍 软件包管理工具dpkg dpkg核心指令详解 安装软件包 卸载软件包 查询软件包状态 验证软件包完整性 手动处理依赖关系 dpkg vs apt Debian系统介绍 Debian 和 Ubuntu 都是基于 Debian内核 的 Linux 发行版&#xff…...

条件运算符

C中的三目运算符&#xff08;也称条件运算符&#xff0c;英文&#xff1a;ternary operator&#xff09;是一种简洁的条件选择语句&#xff0c;语法如下&#xff1a; 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true&#xff0c;则整个表达式的结果为“表达式1”…...

家政维修平台实战20:权限设计

目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系&#xff0c;主要是分成几个表&#xff0c;用户表我们是记录用户的基础信息&#xff0c;包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题&#xff0c;不同的角色&#xf…...

数据链路层的主要功能是什么

数据链路层&#xff08;OSI模型第2层&#xff09;的核心功能是在相邻网络节点&#xff08;如交换机、主机&#xff09;间提供可靠的数据帧传输服务&#xff0c;主要职责包括&#xff1a; &#x1f511; 核心功能详解&#xff1a; 帧封装与解封装 封装&#xff1a; 将网络层下发…...

浅谈不同二分算法的查找情况

二分算法原理比较简单&#xff0c;但是实际的算法模板却有很多&#xff0c;这一切都源于二分查找问题中的复杂情况和二分算法的边界处理&#xff0c;以下是博主对一些二分算法查找的情况分析。 需要说明的是&#xff0c;以下二分算法都是基于有序序列为升序有序的情况&#xf…...

Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理

引言 Bitmap&#xff08;位图&#xff09;是Android应用内存占用的“头号杀手”。一张1080P&#xff08;1920x1080&#xff09;的图片以ARGB_8888格式加载时&#xff0c;内存占用高达8MB&#xff08;192010804字节&#xff09;。据统计&#xff0c;超过60%的应用OOM崩溃与Bitm…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制

在数字化浪潮席卷全球的今天&#xff0c;数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具&#xff0c;在大规模数据获取中发挥着关键作用。然而&#xff0c;传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时&#xff0c;常出现数据质…...

管理学院权限管理系统开发总结

文章目录 &#x1f393; 管理学院权限管理系统开发总结 - 现代化Web应用实践之路&#x1f4dd; 项目概述&#x1f3d7;️ 技术架构设计后端技术栈前端技术栈 &#x1f4a1; 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 &#x1f5c4;️ 数据库设…...

react菜单,动态绑定点击事件,菜单分离出去单独的js文件,Ant框架

1、菜单文件treeTop.js // 顶部菜单 import { AppstoreOutlined, SettingOutlined } from ant-design/icons; // 定义菜单项数据 const treeTop [{label: Docker管理,key: 1,icon: <AppstoreOutlined />,url:"/docker/index"},{label: 权限管理,key: 2,icon:…...