当前位置: 首页 > news >正文

Pytorch 文本情感分类案例

一共六个脚本,分别是:

        ①generateDictionary.py用于生成词典

        ②datasets.py定义了数据集加载的方法

        ③models.py定义了网络模型

        ④configs.py配置一些参数

        ⑤run_train.py训练模型

        ⑥run_test.py测试模型

数据集icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88486959?spm=1001.2014.3001.5501停用词表icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88486973?spm=1001.2014.3001.5501

generateDictionary.py如下

import jiebadata_path = "./weibo_senti_100k.csv"
data_stop_path = "./hit_stopwords.txt"
data_list = open(data_path,encoding='utf-8').readlines()[1:]
stops_word = open(data_stop_path,encoding='utf-8').readlines()
stops_word = [line.strip() for line in stops_word]
stops_word.append(" ")
stops_word.append("\n")voc_dict = {}
min_seq = 1
top_n = 1000
UNK = "UNK"
PAD = "PAD"
for item in data_list:label = item[0]content = item[2:].strip()seg_list = jieba.cut(content,cut_all=False)seg_res = []for seg_item in seg_list:if seg_item in stops_word:continueseg_res.append(seg_item)if seg_item in voc_dict.keys():voc_dict[seg_item] += 1else:voc_dict[seg_item] = 1# print(content)# print(seg_res)voc_list = sorted([_ for _ in voc_dict.items() if _[1] > min_seq],key=lambda x:x[1],reverse=True)[:top_n]voc_dict = {word_count[0]:idx for idx,word_count in enumerate(voc_list)}voc_dict.update({UNK:len(voc_dict),PAD:len(voc_dict)+1})ff = open("./dict","w")
for item in voc_dict.keys():ff.writelines("{},{}\n".format(item,voc_dict[item]))
ff.close()

datasets.py如下

from torch.utils.data import Dataset, DataLoader
import jieba
import numpy as npdef read_dict(voc_dict_path):voc_dict = {}with open(voc_dict_path, 'r') as f:for line in f:line = line.strip()if line == '':continueword, index = line.split(",")voc_dict[word] = int(index)return voc_dictdef load_data(data_path, data_stop_path,isTest):data_list = open(data_path, encoding='utf-8').readlines()[1:]stops_word = open(data_stop_path, encoding='utf-8').readlines()stops_word = [line.strip() for line in stops_word]stops_word.append(" ")stops_word.append("\n")voc_dict = {}data = []max_len_seq = 0for item in data_list:label = item[0]content = item[2:].strip()seg_list = jieba.cut(content, cut_all=False)seg_res = []for seg_item in seg_list:if seg_item in stops_word:continueseg_res.append(seg_item)if seg_item in voc_dict.keys():voc_dict[seg_item] += 1else:voc_dict[seg_item] = 1if len(seg_res) > max_len_seq:max_len_seq = len(seg_res)if isTest:data.append([label, seg_res,content])else:data.append([label, seg_res])return data, max_len_seqclass text_ClS(Dataset):def __init__(self, data_path, data_stop_path,voc_dict_path,isTest=False):self.isTest = isTestself.data_path = data_pathself.data_stop_path = data_stop_pathself.voc_dict = read_dict(voc_dict_path)self.data, self.max_len_seq = load_data(self.data_path, self.data_stop_path,isTest)np.random.shuffle(self.data)def __len__(self):return len(self.data)def __getitem__(self, item):data = self.data[item]label = int(data[0])word_list = data[1]if self.isTest:content = data[2]input_idx = []for word in word_list:if word in self.voc_dict.keys():input_idx.append(self.voc_dict[word])else:input_idx.append(self.voc_dict["UNK"])if len(input_idx) < self.max_len_seq:input_idx += [self.voc_dict["PAD"] for _ in range(self.max_len_seq - len(input_idx))]data = np.array(input_idx)if self.isTest:return label,data,contentelse:return label, datadef data_loader(dataset,config):return DataLoader(dataset,batch_size=config.batch_size,shuffle=config.is_shuffle,num_workers=4,pin_memory=True)

models.py如下

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Model(nn.Module):def __init__(self,config):super(Model,self).__init__()self.embeding = nn.Embedding(config.n_vocab,config.embed_size,padding_idx=config.n_vocab - 1)self.lstm = nn.LSTM(config.embed_size,config.hidden_size,config.num_layers,batch_first=True,bidirectional=True,dropout=config.dropout)self.maxpool = nn.MaxPool1d(config.pad_size)self.fc = nn.Linear(config.hidden_size * 2 + config.embed_size,config.num_classes)self.softmax = nn.Softmax(dim=1)def forward(self,x):embed = self.embeding(x)out, _ = self.lstm(embed)out = torch.cat((embed, out), 2)out = F.relu(out)out = out.permute(0, 2, 1)out = self.maxpool(out).reshape(out.size()[0],-1)out = self.fc(out)out = self.softmax(out)return out

configs.py如下

import torch.typesclass Config():def __init__(self):self.n_vocab = 1002self.embed_size = 256self.hidden_size = 256self.num_layers = 5self.dropout = 0.8self.num_classes = 2self.pad_size = 32self.batch_size = 32self.is_shuffle = Trueself.learning_rate = 0.001self.num_epochs = 100self.devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

run_train.py如下

import torch
import torch.nn as nn
from torch import optim
from models import Model
from datasets import data_loader,text_ClS
from configs import Config
import time
import torch.multiprocessing as mpif __name__ == '__main__':mp.freeze_support()cfg = Config()data_path = "./weibo_senti_100k.csv"data_stop_path = "./hit_stopwords.txt"dict_path = "./dict"dataset = text_ClS(data_path, data_stop_path, dict_path)train_dataloader = data_loader(dataset,cfg)cfg.pad_size = dataset.max_len_seqmodel_text_cls = Model(cfg)model_text_cls.to(cfg.devices)loss_func = nn.CrossEntropyLoss()optimizer = optim.Adam(model_text_cls.parameters(), lr=cfg.learning_rate)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)for epoch in range(cfg.num_epochs):running_loss = 0correct = 0total = 0epoch_start_time = time.time()for i,(labels,datas) in enumerate(train_dataloader):datas = datas.to(cfg.devices)labels = labels.to(cfg.devices)pred = model_text_cls.forward(datas)loss_val = loss_func(pred,labels)running_loss += loss_val.item()loss_val.backward()if ((i + 1) % 4 == 0) or (i + 1 == len(train_dataloader)):optimizer.step()optimizer.zero_grad()_, predicted = torch.max(pred.data, 1)correct += (predicted == labels).sum().item()total += labels.size(0)scheduler.step()accuracy_train = 100 * correct / totalepoch_end_time = time.time()epoch_time = epoch_end_time - epoch_start_timetain_loss = running_loss / len(train_dataloader)print("Epoch [{}/{}],Time: {:.4f}s,Loss: {:.4f},Acc: {:.2f}%".format(epoch + 1, cfg.num_epochs, epoch_time, tain_loss,accuracy_train))torch.save(model_text_cls.state_dict(),"./text_cls_model/text_cls_model{}.pth".format(epoch))

run_test.py如下

import torch
import torch.nn as nn
from torch import optim
from models import Model
from datasets import data_loader,text_ClS
from configs import Config
import time
import torch.multiprocessing as mpif __name__ == '__main__':mp.freeze_support()cfg = Config()data_path = "./test.csv"data_stop_path = "./hit_stopwords.txt"dict_path = "./dict"cfg.batch_size = 1dataset = text_ClS(data_path, data_stop_path, dict_path,isTest=True)dataloader = data_loader(dataset,cfg)cfg.pad_size = dataset.max_len_seqmodel_text_cls = Model(cfg)model_text_cls.load_state_dict(torch.load('./text_cls_model/text_cls_model0.pth'))model_text_cls.to(cfg.devices)classes_name = ['负面的','正面的']for i,(label,input,content) in enumerate(dataloader):label = label.to(cfg.devices)input = input.to(cfg.devices)pred = model_text_cls.forward(input)_, predicted = torch.max(pred.data, 1)print("内容:{}, 实际结果:{}, 预测结果:{}".format(content,classes_name[label],classes_name[predicted[0]]))

测试结果如下

相关文章:

Pytorch 文本情感分类案例

一共六个脚本,分别是: ①generateDictionary.py用于生成词典 ②datasets.py定义了数据集加载的方法 ③models.py定义了网络模型 ④configs.py配置一些参数 ⑤run_train.py训练模型 ⑥run_test.py测试模型 数据集https://download.csdn.net/download/Victor_Li_/88486959?spm1…...

Flutter之GetX controller tag使用详解

本文主要介绍 GetX 依赖注入中 tag 的作用和使用详解。 作用 前面几篇文章介绍了 GetX 依赖注入的使用以及通过源码剖析了依赖注入的原理&#xff1a; •《Flutter应用框架搭建(一)GetX集成及使用详解》•《Flutter 通过源码一步一步剖析 Getx 依赖管理的实现》•《Flutter之…...

Kubernetes群集调度

调度约束 Kubernetes 是通过 List-Watch 的机制进行每个组件的协作&#xff0c;保持数据同步的&#xff0c;每个组件之间的设计实现了解耦。 用户是通过 kubectl 根据配置文件&#xff0c;向 APIServer 发送命令&#xff0c;在 Node 节点上面建立 Pod 和 Container。 APIServ…...

【总结】linux centos 7 开启网络白名单访问策略

目录 linux开启网络端口白名单访问策略开启白名单步骤补充说明 linux开启网络端口白名单访问策略 安全需要&#xff0c;被检测各种3306、9200、9300端口没有设置访问策略。需要整改。 对于linux来说&#xff0c;有两种方式可以开启防火墙 开启白名单步骤 场景一&#xff1a…...

2023-2024-1高级语言程序设计第1次月考

7-1-1 计算摄氏温度 给定一个华氏温度F&#xff0c;本题要求编写程序&#xff0c;计算对应的摄氏温度C。计算公式&#xff1a;C5(F−32)/9。题目保证输入与输出均在整型范围内。 输入格式: 输入在一行中给出一个华氏温度。 输出格式: 在一行中按照格式“Celsius C”输出对…...

目标检测:Proposal-Contrastive Pretraining for Object Detection from Fewer Data

论文作者&#xff1a;Quentin Bouniot,Romaric Audigier,Anglique Loesch,Amaury Habrard 作者单位&#xff1a;Universit Paris-Saclay; Universit Jean Monnet Saint-Etienne; Universitaire de France (IUF) 论文链接&#xff1a;http://arxiv.org/abs/2310.16835v1 内容…...

Cesium:CGCS2000坐标系的xyz坐标转换成WGS84坐标系的经纬高度,再转换到笛卡尔坐标系的xyz坐标

作者:CSDN @ _乐多_ 本文将介绍使用 Vue 、cesium、proj4 框架,实现将CGCS2000坐标系的xyz坐标转换成WGS84坐标系的经纬高度,再将WGS84坐标系的经纬高度转换到笛卡尔坐标系的xyz坐标的代码。并将输入和输出使用 Vue 前端框架展示了出来。代码即插即用。 网页效果如下图所示…...

【OpenCV实现图像:用Python生成图像特效,报错ValueError: too many values to unpack (expected 3)】

文章目录 概要读入图像改变单个通道黑白特效颜色反转将图像拆分成四个子部分 概要 Python是一种功能强大的编程语言&#xff0c;也是图像处理领域中常用的工具之一。通过使用Python的图像处理库&#xff08;例如Pillow、OpenCV等&#xff09;&#xff0c;开发者可以实现各种各…...

875. 爱吃香蕉的珂珂

题目描述 珂珂喜欢吃香蕉。这里有 n 堆香蕉&#xff0c;第 i 堆中有 piles[i] 根香蕉。警卫已经离开了&#xff0c;将在 h 小时后回来。 珂珂可以决定她吃香蕉的速度 k &#xff08;单位&#xff1a;根/小时&#xff09;。每个小时&#xff0c;她将会选择一堆香蕉&#xff0c…...

台灯太亮会导致近视吗?精选高品质的台灯

台灯相信很多家庭都会备上一台&#xff0c;用于办公、休闲或者给孩子学习使用&#xff0c;如果使用的台灯亮度过高的话&#xff0c;可能会对视力造成一定的影响&#xff0c;尤其是夜晚的时候。建议是选择带有亮度调节功能的台灯会比较好一点&#xff0c;可以自行根据周围环境的…...

Scala函数和闭包

1. 函数 1.1 函数与方法 Scala 中函数与方法的区别非常小&#xff0c;如果函数作为某个对象的成员&#xff0c;这样的函数被称为方法&#xff0c;否则就是一个正常的函数。 // 定义方法 def multi1(x:Int) {x * x} // 定义函数 val multi2 (x: Int) > {x * x}println(mult…...

LeetCode----1935. 可以输入的最大单词数

题目 键盘出现了一些故障,有些字母键无法正常工作。而键盘上所有其他键都能够正常工作。 给你一个由若干单词组成的字符串 text ,单词间由单个空格组成(不含前导和尾随空格);另有一个字符串 brokenLetters ,由所有已损坏的不同字母键组成,返回你可以使用此键盘完全输入…...

学习笔记三十:K8S配置管理中心Secret实现加密数据配置管理

K8S配置管理中心Secret实现加密数据配置管理 Secret概述secret三种可选参数:Secret类型 使用Secret通过环境变量引入Secret通过volume挂载Secret创建Secret创建yaml文件将Secret挂载到Volume中 Secret概述 Configmap一般是用来存放明文数据的&#xff0c;如配置文件&#xff0…...

关于uviewui修改主题及在uniapp中的应用

在uview使用过程中遇到很多不方便的地方&#xff0c;记录下来 修改主题颜色 给UI框架换个主题色基础方法是覆盖原有色&#xff08;但这个方法比较笨&#xff0c;处理起来也不干净利索&#xff09;&#xff0c;所以换个思路改变基础色值变量&#xff0c;步骤主要分为2部分&…...

使用QEMU模拟启动uboot

uboot的相关知识&#xff0c;可以参考&#xff1a;uboot基本概念。 一、环境配置 WSL: ubutu20.04 模拟开发板&#xff1a;vexpress-a9 uboot版本&#xff1a;u-boot-2023.10 二、安装QEMU 2.1、安装sudo apt install qemu2.2、查看支持哪些开发板qemu-system-arm -M help结…...

学习数据结构和算法之前,你需要知道什么?

最快的学习方法是什么&#xff1f;计算机基础支持有哪些&#xff1f;学习数据结构和算法应该如何思考&#xff1f;如何成长&#xff1f;为什么要学习数据结构和算法&#xff1f; 最快的学习方法是什么&#xff1f; 实践。 计算机基础支持有哪些&#xff1f; 数据结构和算法。…...

16. 机器学习 - 决策树

Hi&#xff0c;你好。我是茶桁。 在上一节课讲SVM之后&#xff0c;再给大家将一个新的分类模型「决策树」。我们直接开始正题。 决策树 我们从一个例子开始&#xff0c;来看下面这张图&#xff1a; 假设我们的x1 ~ x4是特征&#xff0c;y是最终的决定&#xff0c;打比方说是…...

将多余的内存,当作虚拟内存。修改edge缓存路径到虚拟内存中

一、下载工具&#xff0c;把内存映射成硬盘 软媒内存盘 v1.1.3.0 软媒内存盘下载-软媒内存盘 v1.1.3.0 - 下载吧 (xiazaiba.com) 二、映射edge的缓存路径 到新建的虚拟硬盘中 mklink /D "C:\Users\Administrator\AppData\Local\Microsoft\Edge\User Data" "V:\…...

【从0到1设计一个网关】过滤器链的实现---实现负载均衡过滤器

文章目录 什么是过滤器?编写负载均衡过滤器负载均衡的定义与实现负载均衡算法设计实现效果演示链接 自研网关整合Nacos,实现服务注册和配置变更 源码链接 什么是过滤器? 再前面的几个章节中我们已经实现了将我们的网关服务注册到注册中心,并且成功的从配置中心拉取了配置…...

科技云报道:打造生成式AI应用,什么才是关键?

科技云报道原创。 生成式AI作为当前人工智能的前沿领域&#xff0c;全球多家科技企业都在加大生成式AI的研发投入力度。 随着技术、产品及应用等方面不断推出重要成果&#xff0c;如今有更多的行业用户在思考该如何将生成式AI应用落地。 但开发生成式AI应用是一个充满挑战的…...

网络编程(Modbus进阶)

思维导图 Modbus RTU&#xff08;先学一点理论&#xff09; 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议&#xff0c;由 Modicon 公司&#xff08;现施耐德电气&#xff09;于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...

RestClient

什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端&#xff0c;它允许HTTP与Elasticsearch 集群通信&#xff0c;而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级&#xff…...

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板&#xff0c;载入页面后&#xff0c;会显示引导弹窗&#xff0c;适用于引导用户使用页面&#xff0c;点击完成后&#xff0c;会显示下一个引导弹窗&#xff0c;直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...

Xshell远程连接Kali(默认 | 私钥)Note版

前言:xshell远程连接&#xff0c;私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用&#xff1a;实现组件通用属性的渐变过渡效果&#xff0c;提升用户体验。支持属性&#xff1a;width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项&#xff1a; 布局类属性&#xff08;如宽高&#xff09;变化时&#…...

Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)

文章目录 1.什么是Redis&#xff1f;2.为什么要使用redis作为mysql的缓存&#xff1f;3.什么是缓存雪崩、缓存穿透、缓存击穿&#xff1f;3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...

聊聊 Pulsar:Producer 源码解析

一、前言 Apache Pulsar 是一个企业级的开源分布式消息传递平台&#xff0c;以其高性能、可扩展性和存储计算分离架构在消息队列和流处理领域独树一帜。在 Pulsar 的核心架构中&#xff0c;Producer&#xff08;生产者&#xff09; 是连接客户端应用与消息队列的第一步。生产者…...

CentOS下的分布式内存计算Spark环境部署

一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架&#xff0c;相比 MapReduce 具有以下核心优势&#xff1a; 内存计算&#xff1a;数据可常驻内存&#xff0c;迭代计算性能提升 10-100 倍&#xff08;文档段落&#xff1a;3-79…...

【机器视觉】单目测距——运动结构恢复

ps&#xff1a;图是随便找的&#xff0c;为了凑个封面 前言 在前面对光流法进行进一步改进&#xff0c;希望将2D光流推广至3D场景流时&#xff0c;发现2D转3D过程中存在尺度歧义问题&#xff0c;需要补全摄像头拍摄图像中缺失的深度信息&#xff0c;否则解空间不收敛&#xf…...

【SQL学习笔记1】增删改查+多表连接全解析(内附SQL免费在线练习工具)

可以使用Sqliteviz这个网站免费编写sql语句&#xff0c;它能够让用户直接在浏览器内练习SQL的语法&#xff0c;不需要安装任何软件。 链接如下&#xff1a; sqliteviz 注意&#xff1a; 在转写SQL语法时&#xff0c;关键字之间有一个特定的顺序&#xff0c;这个顺序会影响到…...