pytorch使用SVM实现文本分类
人工智能例子汇总:AI常见的算法和例子-CSDN博客
完整代码:
import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics# 1. 数据准备(中文文本)
texts = ["今天的足球比赛非常激烈,球队表现出色,最终赢得了比赛。","NBA比赛今天开打,球员们的表现非常精彩,球迷们热情高涨。","张艺谋的新电影上映了,票房成绩非常好,观众反响热烈。","娱乐圈最近又出了一些新闻,明星们的私生活成了大家讨论的焦点。","昨晚的篮球赛真是太精彩了,球员们的进攻和防守都非常强硬。","李宇春在最新的音乐会上演出了她的新歌,现场观众反应热烈。","今年的世界杯比赛激烈异常,球队之间的竞争越来越激烈。","最近的综艺节目非常火,明星嘉宾的表现让观众们大笑不已。"
]# 标签:0表示体育,1表示娱乐
labels = [0, 0, 1, 1, 0, 1, 0, 1]# 2. 数据预处理:中文分词和 TF-IDF 特征提取
def jieba_cut(text):return " ".join(jieba.cut(text))texts_cut = [jieba_cut(text) for text in texts]vectorizer = TfidfVectorizer(max_features=10000)
X_tfidf = vectorizer.fit_transform(texts_cut).toarray()
y = np.array(labels)# 3. 数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=42)# 4. PyTorch 数据加载
class NewsGroupDataset(torch.utils.data.Dataset):def __init__(self, features, labels):self.features = torch.tensor(features, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.features)def __getitem__(self, idx):return self.features[idx], self.labels[idx]train_dataset = NewsGroupDataset(X_train, y_train)
test_dataset = NewsGroupDataset(X_test, y_test)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)# 5. 定义 SVM 模型(使用线性层)
class SVM(nn.Module):def __init__(self, input_dim, output_dim):super(SVM, self).__init__()self.fc = nn.Linear(input_dim, output_dim)def forward(self, x):return self.fc(x)# 6. 获取特征数并初始化模型
input_dim = X_tfidf.shape[1] # 自动获取特征数
model = SVM(input_dim=input_dim, output_dim=2) # 使用特征数量设置输入维度
criterion = nn.CrossEntropyLoss() # 损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01) # 优化器,调整学习率# 7. 训练模型
num_epochs = 50 # 增加训练周期for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total}%")# 8. 测试模型
model.eval()
correct = 0
total = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {100 * correct / total}%")# 9. 输出性能指标
y_pred = []
y_true = []with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)y_pred.extend(predicted.numpy())y_true.extend(labels.numpy())print(metrics.classification_report(y_true, y_pred))# 10. 测试新样本
def predict(text, model, vectorizer):# 1. 进行分词text_cut = jieba_cut(text)# 2. 将文本转为 TF-IDF 特征向量text_tfidf = vectorizer.transform([text_cut]).toarray()# 3. 转换为 PyTorch 张量text_tensor = torch.tensor(text_tfidf, dtype=torch.float32)# 4. 模型预测model.eval() # 设置模型为评估模式with torch.no_grad():output = model(text_tensor)_, predicted = torch.max(output.data, 1)# 5. 返回预测结果return predicted.item()# 测试一个新的中文文本
new_text = "今天的篮球比赛真是太精彩了,球员们的表现让大家都为之喝彩。"
predicted_label = predict(new_text, model, vectorizer)# 输出预测结果
if predicted_label == 0:print("预测类别: 体育")
else:print("预测类别: 娱乐")
1. 数据准备
- 文本数据:我们定义了一个包含中文文本的列表,每条文本表示一个新闻或评论。
- 标签:为每条文本分配了一个标签,0 代表“体育”,1 代表“娱乐”。
2. 数据预处理
- 中文分词:使用
jieba库对每条文本进行分词,并将分词后的结果连接成字符串。这是处理中文文本时的常见做法。 - TF-IDF 特征提取:使用
TfidfVectorizer将文本转化为数值特征。TF-IDF 是一种常见的文本表示方式,能够衡量单词在文档中的重要性。
3. 数据集分割
- 使用
train_test_split将数据分为训练集和测试集。80% 的数据用于训练,20% 用于测试。
4. PyTorch 数据加载
- 定义 Dataset 类:创建了一个自定义的
NewsGroupDataset类,继承自torch.utils.data.Dataset,用于将文本特征和标签封装为 PyTorch 可用的数据集格式。 - DataLoader:使用
DataLoader将训练集和测试集数据进行批处理和加载。
5. 模型定义
- 定义了一个简单的线性 SVM 模型。实际上,使用了一个线性层 (
nn.Linear) 来进行分类,输入是文本的 TF-IDF 特征,输出是两个类别(体育或娱乐)。 - 使用了
CrossEntropyLoss作为损失函数,因为这是分类任务中常用的损失函数。 - 优化器使用了随机梯度下降(SGD),并设置了学习率为 0.01。
6. 训练过程
- 训练模型的过程包括:前向传播(计算输出),计算损失,反向传播(更新参数),并在每个 epoch 后输出损失和准确率。
- 每个 batch 的训练过程中,模型会通过计算损失并进行优化,逐步提升准确率。
7. 测试和评估
- 在测试过程中,将模型设置为评估模式 (
model.eval()),并计算测试集上的准确率。通过比较预测标签与真实标签,计算正确的预测数量并输出准确率。 - 使用
classification_report来输出精确度、召回率和 F1 分数等更多的评估指标。
8. 预测新样本
- 定义了一个
predict()函数,用于预测新的文本样本的分类。 - 预测过程包括:对文本进行分词,转化为 TF-IDF 特征,传入模型进行前向传播,最后返回模型预测的标签。
9. 输出
- 代码的最后,会输出模型对新文本的预测结果,标明是属于体育类别还是娱乐类别。
关键技术点:
- 中文分词:使用
jieba对中文文本进行分词处理,这对于中文文本的处理至关重要。 - TF-IDF:将文本转换为数值特征,便于模型处理。TF-IDF 是基于单词在文档中的出现频率及其在整个语料中的稀有度进行加权的。
- 模型训练与评估:通过多轮训练提升模型准确度,使用测试集来评估模型的泛化能力。
- PyTorch DataLoader:通过
DataLoader高效地处理训练集和测试集,进行批处理和自动化管理。
相关文章:
pytorch使用SVM实现文本分类
人工智能例子汇总:AI常见的算法和例子-CSDN博客 完整代码: import torch import torch.nn as nn import torch.optim as optim import jieba import numpy as np from sklearn.model_selection import train_test_split from sklearn.feature_extract…...
一文速览DeepSeek-R1的本地部署——可联网、可实现本地知识库问答:包括671B满血版和各个蒸馏版的部署
前言 自从deepseek R1发布之后「详见《一文速览DeepSeek R1:如何通过纯RL训练大模型的推理能力以比肩甚至超越OpenAI o1(含Kimi K1.5的解读)》」,deepseek便爆火 爆火以后便应了“人红是非多”那句话,不但遭受各种大规模攻击,即便…...
Kubernetes学习之包管理工具(Helm)
一、基础知识 1.如果我们需要开发微服务架构的应用,组成应用的服务可能很多,使用原始的组织和管理方式就会非常臃肿和繁琐以及较难管理,此时我们需要一个更高层次的工具将这些配置组织起来。 2.helm架构: chart:一个应用的信息集合…...
2024美团春招硬件开发笔试真题及答案解析
目录 一、选择题 1、在 Linux,有一个名为 file 的文件,内容如下所示: 2、在 Linux 中,关于虚拟内存相关的说法正确的是() 3、AT89S52单片机中,在外部中断响应的期间,中断请求标志位查询占用了()。 4、下列关于8051单片机的结构与功能,说法不正确的是()? 5、…...
MyBatis-Plus速成指南:通用枚举 多数据源
通用枚举: 概述: 表中有些字段值是固定的,例如性别(男或女),此时我们可以使用 MyBatis-Plus 的通用枚举来实现 数据库表添加字段: 创建通用枚举类型: Getter public enum SexEnum {MALE(1, "男"…...
Android项目中使用Eclipse导出jar文件
2014年3月24日 天气晴朗 关于打包Android组件肯定是有用到的,比如开发了一个模块,为了更好的复用,我们可能会将它打包成jar文件方便其他项目引用。这个很好理解,也很简单。网上有一堆关于用Eclipse将Android项目打包成jar文件的&…...
网络安全学习 day4
防火墙的安全策略 规则--策略 条件 --- 检查报文的依据,防火墙将报文中携带的信息与条件逐一进行对比, 以此来判断报文是否是 匹配的 。不同的匹配条件之间属于 “ 与 ” 关系;相同的匹配条件中不同的参数信息之间的关系为 “ 或 ” 关系。…...
【SSM】Spring + SpringMVC + Mybatis
SSM课程,以下为该课程的笔记 bean:IOC容器创建的对象 P12 bean的生命周期 在bean中定义init()和destroy()方法,然后在xml中配置方法名,让bean对象能找到对应的生命周期方法。 或通过实现接口的方式定义声明周期方法。 P13 sett…...
智慧园区综合管理系统如何实现多个维度的高效管理与安全风险控制
内容概要 在当前快速发展的城市环境中,智慧园区综合管理系统正在成为各类园区管理的重要工具,无论是工业园、产业园、物流园,还是写字楼与公寓,都在积极寻求如何提升管理效率和保障安全。通过快鲸智慧园区管理系统,用…...
【协议详解】卫星通信5G IoT NTN SIB33-NB 信令详解
一、SIB33信令概述 在5G非地面网络(NTN)中,卫星的高速移动性和广域覆盖特性使得地面设备(UE)需要频繁切换卫星以维持连接。SIB32提供了UE预测当前服务的卫星覆盖信息,SystemInformationBlockType33&#x…...
《LLM大语言模型深度探索与实践:构建智能应用的新范式,融合代理与数据库的高级整合》
文章目录 Langchain的定义Langchain的组成三个核心组件实现整个核心组成部分 为什么要使用LangchainLangchain的底层原理Langchain实战操作LangSmithLangChain调用LLM安装openAI库-国内镜像源代码运行结果小结 使用Langchain的提示模板部署Langchain程序安装langserve代码请求格…...
Debian 10 中 Linux 4.19 内核在 x86_64 架构上对中断嵌套的支持情况
一、中断嵌套的定义与原理 中断嵌套是指在一个中断处理程序(ISR)正在执行的过程中,另一个更高优先级的中断请求到来,系统暂停当前中断处理程序,转而处理新的高优先级中断。处理完高优先级中断后,系统返回到原来的中断处理程序继续执行。这种机制允许系统更高效地响应紧急…...
【Envi遥感图像处理】010:归一化植被指数NDVI计算方法
文章目录 一、NDVI简介二、NDVI计算方法1. NDVI工具2. 波段运算三、注意事项1. 计算结果为一片黑2. 计算结果超出范围一、NDVI简介 归一化植被指数,是反映农作物长势和营养信息的重要参数之一,应用于遥感影像。NDVI是通过植被在近红外波段(NIR)和红光波段(R)的反射率差异…...
优选算法合集————双指针(专题二)
好久都没给大家带来算法专题啦,今天给大家带来滑动窗口专题的训练 题目一:长度最小的子数组 题目描述: 给定一个含有 n 个正整数的数组和一个正整数 target 。 找出该数组中满足其和 ≥ target 的长度最小的 连续子数组 [numsl, numsl1, …...
基于微信小程序的私家车位共享系统设计与实现(LW+源码+讲解)
专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…...
糖化之前,为什么要进行麦芽粉碎?
糖化的目的是将麦芽中的淀粉转化为可发酵性的糖分,而糖化之前,进行麦芽粉碎是确保糖化效果的关键步骤。本文天泰将阐述麦芽粉碎的重要性及其对酿造过程的影响。 一、麦芽粉碎的目的 增加酶的作用面积:麦芽中的淀粉和蛋白质等物质需要通过酶…...
PAT甲级1052、Linked LIst Sorting
题目 A linked list consists of a series of structures, which are not necessarily adjacent in memory. We assume that each structure contains an integer key and a Next pointer to the next structure. Now given a linked list, you are supposed to sort the stru…...
半导体器件与物理篇6 MESFET
金属-半导体接触 MESFET与MOSFET的相同点:它们的电压电流特性相似。都有源漏栅三极,强反型,漏极加正向电压,也会经历线性区、夹断点、饱和区三个阶段。 MESFET与MOSFET的不同点:在器件的栅电极部分,MESFE…...
BES2700源码解析之系统初始化
一 概述 bes2700凭借着超高的性能,超低的功耗,在可穿戴领域有着广泛的应用。笔者使用该芯片做了一些产品解决方案,发现该芯片的性能十分强大。这里做个系列的源码解析。 二 源码解析 1.GPIO和led灯的初始化: tgt_hardware_setup(…...
deepseek 本地化部署和小模型微调
安装ollama 因为本人gpu卡的机器系统是centos 7, 直接使用ollama会报 所以ollama使用镜像方式进行部署, 拉取镜像ollama/ollama 启动命令 docker run -d --privileged -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama 查看ollama 是否启动…...
socket实现HTTP请求,参考HttpURLConnection源码解析
背景 有台服务器,网卡绑定有2个ip地址,分别为: A:192.168.111.201 B:192.168.111.202 在这台服务器请求目标地址 C:192.168.111.203 时必须使用B作为源地址才能访问目标地址C,在这台服务器默认…...
3、C#基于.net framework的应用开发实战编程 - 实现(三、三) - 编程手把手系列文章...
三、 实现; 三.三、编写应用程序; 此文主要是实现应用的主要编码工作。 1、 分层; 此例子主要分为UI、Helper、DAL等层。UI负责便签的界面显示;Helper主要是链接UI和数据库操作的中间层;DAL为对数据库的操…...
Ubuntu下Tkinter绑定数字小键盘上的回车键(PySide6类似)
设计了一个tkinter程序,在Win下绑定回车键,直接绑定"<Return>"就可以使用主键盘和小键盘的回车键直接“提交”,到了ubuntu下就不行了。经过搜索,发现ubuntu下主键盘和数字小键盘的回车键,名称不一样。…...
基础笔记|splice()的用法
一、三种用法 splice(index, 0, element) 插入 元素,不删除任何元素。splice(index, deleteCount) 删除 deleteCount 个元素。splice(index, deleteCount, element1, element2, ...) 替换 元素,即删除 deleteCount 个元素,同时插入新的元素。…...
Java BIO详解
一、简介 1.1 BIO概述 BIO(Blocking I/O),即同步阻塞IO(传统IO)。 BIO 全称是 Blocking IO,同步阻塞式IO,是JDK1.4之前的传统IO模型,就是传统的 java.io 包下面的代码实现。 服务…...
Haproxy+keepalived高可用集群,haproxy宕机的解决方案
Haproxykeepalived高可用集群,允许keepalived宕机,允许后端真实服务器宕机,但是不允许haproxy宕机, 所以下面就是解决方案 keepalived配置高可用检测脚本 ,master和backup都要添加 配置脚本 # vim /etc/keepalived…...
98,【6】 buuctf web [ISITDTU 2019]EasyPHP
进入靶场 代码 <?php // 高亮显示当前 PHP 文件的源代码,通常用于调试或展示代码,方便用户查看代码逻辑 highlight_file(__FILE__);// 从 GET 请求中获取名为 _ 的参数值,并赋值给变量 $_ // 符号用于抑制可能出现的错误信息ÿ…...
九. Redis 持久化-RDB(详细讲解说明,一个配置一个说明分析,步步讲解到位)
九. Redis 持久化-RDB(详细讲解说明,一个配置一个说明分析,步步讲解到位) 文章目录 九. Redis 持久化-RDB(详细讲解说明,一个配置一个说明分析,步步讲解到位)1. RDB 概述2. RDB 持久化执行流程3. RDB 的详细配置4. RDB 备份&恢…...
小程序越来越智能化,作为设计师要如何进行创新设计
一、用户体验至上 (一)简洁高效的界面设计 小程序的特点之一是轻便快捷,用户期望能够在最短的时间内找到所需功能并完成操作。因此,设计师应致力于打造简洁高效的界面。避免过多的装饰元素和复杂的布局,采用清晰的导航…...
(done) MIT6.S081 2023 学习笔记 (Day7: LAB6 Multithreading)
网页:https://pdos.csail.mit.edu/6.S081/2023/labs/thread.html (任务1教会了你如何用 C 语言调用汇编,编译后链接即可) 任务1:Uthread: switching between threads (完成) 在这个练习中,你将设计一个用户级线程系统中的上下文切…...
