当前位置: 首页 > news >正文

第N2周:中文文本分类-Pytorch实现

目录

  • 一、前言
  • 二、准备工作
  • 三、数据预处理
    • 1.加载数据
    • 2.构建词典
    • 3.生成数据批次和迭代器
  • 三、模型构建
    • 1. 搭建模型
    • 2. 初始化模型
    • 3. 定义训练与评估函数
  • 四、训练模型
    • 1. 拆分数据集并运行模型

一、前言

🍨 本文为🔗365天深度学习训练营 中的学习记录博客
🍖 原作者:K同学啊|接辅导、项目定制

● 难度:夯实基础⭐⭐
● 语言:Python3、Pytorch3
● 时间:4月23日-4月28日
🍺要求:
1、熟悉NLP的基础知识

二、准备工作

环境搭建
Python 3.8
pytorch == 1.8.1
torchtext == 0.9.1

三、数据预处理

1.加载数据

在这里插入图片描述

import torch
import torch.nn as nn
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")             #忽略警告信息# win10系统
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
import pandas as pd# 加载自定义中文数据
train_data = pd.read_csv('./data/train.csv', sep='\t', header=None)
train_data.head()
# 构造数据集迭代器
def coustom_data_iter(texts, labels):for x, y in zip(texts, labels):yield x, ytrain_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])

2.构建词典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# conda install jieba -y
import jieba# 中文分词方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])
label_name = list(set(train_data[1].values[:]))
print(label_name)
text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

3.生成数据批次和迭代器

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_text,_label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list  = torch.cat(text_list)offsets    = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回维度dim中输入元素的累计和return text_list.to(device),label_list.to(device), offsets.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle   =False,collate_fn=collate_batch)

三、模型构建

1. 搭建模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,   # 词典大小embed_dim,    # 嵌入的维度sparse=False) # self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange) # 初始化权重self.fc.weight.data.uniform_(-initrange, initrange)        self.fc.bias.data.zero_()                                  # 偏置值归零def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

2. 初始化模型

num_class  = len(label_name)
vocab_size = len(vocab)
em_size    = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)

3. 定义训练与评估函数

import timedef train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time   = time.time()for idx, (text,label,offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()                    # grad属性归零loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值loss.backward()                          # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪optimizer.step()  # 每一步自动更新# 记录acc与losstotal_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text,label,offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

四、训练模型

1. 拆分数据集并运行模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS     = 10 # epoch
LR         = 5  # 学习率
BATCH_SIZE = 64 # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss,lr))print('-' * 69)

相关文章:

第N2周:中文文本分类-Pytorch实现

目录 一、前言二、准备工作三、数据预处理1.加载数据2.构建词典3.生成数据批次和迭代器 三、模型构建1. 搭建模型2. 初始化模型3. 定义训练与评估函数 四、训练模型1. 拆分数据集并运行模型 一、前言 &#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客 …...

Salesforce许可证和版本有什么区别,购买帐号时应该如何选择?

Salesforce许可证分配给特定用户&#xff0c;授予他们访问Salesforce产品和功能的权限。Salesforce版本和许可证是不同的概念&#xff0c;但极易混淆。 Salesforce版本&#xff1a;这是对组织购买的Salesforce产品和功能的访问权限。大致可分为Essentials、Professional、Ente…...

接口测试怎么做?全网最详细从接口测试到接口自动化详解,看这篇就够了...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 抛出一个问题&…...

DataStore入门及在项目中的使用

首先给个官网的的地址&#xff1a;应用架构&#xff1a;数据层 - DataStore - Android 开发者 | Android Developers 小伙伴们可以直接看官网的资料&#xff0c;本篇文章是对官网的部分细节进行补充 一、为什么要使用DataStore 代替SharedPreferences SharedPreferences&a…...

用Python爬取中国各省GDP数据

介绍 在数据分析和经济研究中&#xff0c;了解中国各省份的GDP数据是非常重要的。然而&#xff0c;手动收集这些数据可能是一项繁琐且费时的任务。幸运的是&#xff0c;Python提供了一些强大的工具和库&#xff0c;使我们能够自动化地从互联网上爬取数据。本文将介绍如何使用P…...

深度学习-第T5周——运动鞋品牌识别

深度学习-第T5周——运动鞋品牌识别 深度学习-第T5周——运动鞋品牌识别一、前言二、我的环境三、前期工作1、导入数据集2、查看图片数目3、查看数据 四、数据预处理1、 加载数据1、设置图片格式2、划分训练集3、划分验证集4、查看标签 2、数据可视化3、检查数据4、配置数据集 …...

自媒体的孔雀效应:插根鸡毛还是专业才华?

自媒体时代&#xff0c;让许多原本默默无闻的人找到了表达自己的平台。有人声称&#xff0c;现在这个时代&#xff0c;“随便什么人身上插根鸡毛就可以当孔雀了”。可是&#xff0c;事实真的如此吗&#xff1f; 首先&#xff0c;我们不能否认的是&#xff0c;自媒体确实为大众提…...

Linux系统优化

一、系统启动流程 1.centos6 centos6开机启动流程&#xff0c;传送门 2.centos7启动流程 二、系统启动运行级别 2.1 什么是运行级别 运行级别&#xff1a;指操作系统当前正在运行的功能级别&#xff1b; [rootweb01 ~]# ll /usr/lib/systemd/system lrwxrwxrwx. 1 root root…...

Java笔记_22(反射和动态代理)

Java笔记_22 一、反射1.1、反射的概述1.2、获取class对象的三种方式1.3、反射获取构造方法1.4、反射获取成员变量1.5、反射获取成员方法1.6、综合练习1.6.1、保存信息1.6.2、跟配置文件结合动态创建 一、反射 1.1、反射的概述 什么是反射? 反射允许对成员变量&#xff0c;成…...

前端web入门-HTML-day01

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 HTML初体验 HTML 定义 标签语法 总结&#xff1a; HTML 基本骨架 基础知识&#xff1a; 总结&#…...

创建一个Go项目

创建一个Go项目 1.创建项目 package mainfunc main() {println("你好啊&#xff0c;简单点了&#xff01;") }如果是本地的话可以采用go run 项目名的方式。 可以采用go run --work 项目名的方式&#xff0c;此时可以展示日志信息。 如果是只编译的话 go build 项…...

从 Spring 的创建到 Bean 对象的存储、读取

目录 创建 Spring 项目&#xff1a; 1.创建一个 Maven 项目&#xff1a; 2.添加 Spring 框架支持&#xff1a; 3.配置资源文件&#xff1a; 4.添加启动类&#xff1a; Bean 对象的使用&#xff1a; 1.存储 Bean 对象&#xff1a; 1.1 创建 Bean&#xff1a; 1.2 存储 B…...

【一文吃透归并排序】基本归并·原地归并·自然归并 C++

目录 1 引入情境基本归并排序实现 C 2 原地归并排序2-1 死板的解法2-2 原地工作区2-3 链表归并排序 3 自底向上归并排序4 两路自然归并排序4-1 形式化描述4-2 代码实现 1 引入情境 归并思想&#xff1a;假设有两队小孩&#xff0c;都是从矮到高排序&#xff0c;现在通过一扇门后…...

读《Spring Boot 3核心技术与最佳实践》有感

我是谁&#xff1f; &#x1f468;‍&#x1f393;作者&#xff1a;bug菌 ✏️博客&#xff1a;CSDN、掘金、infoQ、51CTO等 &#x1f389;简介&#xff1a;CSDN/阿里云/华为云/51CTO博客专家&#xff0c;C站历届博客之星Top50&#xff0c;掘金/InfoQ/51CTO等社区优质创作者&am…...

板子短路了?

有段时间没更新了&#xff0c;主要是最近有点忙&#xff0c;当然也因为有点“懒”。 做这行业的都知道&#xff0c;下半年都是比较忙的&#xff0c;相信大家也是&#xff01; 相信做硬件的小伙伴们&#xff0c;遇到过短路的板子已经不计其数了。 短路带来的危害&#xff1a;…...

一行代码绘制高分SCI限制立方图

一、概述 Restricted cubic splines (RCS)是一种基于样条函数的非参数化模型&#xff0c;它可以可靠地拟合非线性关系&#xff0c;可以自适应地调整分割结点。在统计学和机器学习领域&#xff0c;RCS通常用来对连续型自变量进行建模&#xff0c;并在解释自变量与响应变量的关系…...

spring 容器结构/机制debug分析--Spring 学习的核心内容和几个重要概念--IOC 的开发模式--综合解图

目录 Spring Spring 学习的核心内容 解读上图: Spring 几个重要概念 ● 传统的开发模式 解读上图 ● IOC 的开发模式 解读上图 代码示例—入门 xml代码 注意事项和细节 1、说明 2、解释一下类加载路径 3、debug 看看 spring 容器结构/机制 综合解图 Spring Spr…...

excel实战小测第四

【项目背景】 本项目为某招聘网站部分招聘信息&#xff0c;要求对“数据分析师”岗位进行招聘需求分析&#xff0c;通过对城市、行业、学历要求、薪资待遇等不同方向进行相关性分析&#xff0c;加深对数据分析行业的了解。 结合企业真实招聘信息&#xff0c;可以帮助有意转向数…...

什么是SpringBoot自动配置

概述&#xff1a; 现在的Java面试基本都会问到你知道什么是Springboot的自动配置。为什么面试官要问这样的问题&#xff0c;主要是在于看你有没有对Springboot的原理有没有深入的了解&#xff0c;有没有看过Springboot的源码&#xff0c;这是区别普通程序员与高级程序员最好的…...

基于IC5000烧录器使用winIDEA烧写+调试程序(S32K324的软件烧写与调试)

目录 一、iSYSTEM简介二、如何使用iSYSTEM winIDEA烧写调试程序2.1 打开winIDEA&#xff1a;2.2 新建一个Workspace;2.3 硬件配置:2.4 选择CPU芯片型号&#xff1a;2.5 加载烧写文件&#xff1a;2.6 开始烧录程序&#xff1a;2.7 程序调试Debug&#xff1a;2.7.1 运行程序&…...

linux之kylin系统nginx的安装

一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源&#xff08;HTML/CSS/图片等&#xff09;&#xff0c;响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址&#xff0c;提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

利用ngx_stream_return_module构建简易 TCP/UDP 响应网关

一、模块概述 ngx_stream_return_module 提供了一个极简的指令&#xff1a; return <value>;在收到客户端连接后&#xff0c;立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量&#xff08;如 $time_iso8601、$remote_addr 等&#xff09;&a…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装&#xff08;Encapsulation&#xff09; 定义&#xff1a;将数据&#xff08;属性&#xff09;和操作数据的方法绑定在一起&#xff0c;通过访问控制符&#xff08;private、protected、public&#xff09;隐藏内部实现细节。示例&#xff1a; public …...

Python爬虫实战:研究feedparser库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...

页面渲染流程与性能优化

页面渲染流程与性能优化详解&#xff08;完整版&#xff09; 一、现代浏览器渲染流程&#xff08;详细说明&#xff09; 1. 构建DOM树 浏览器接收到HTML文档后&#xff0c;会逐步解析并构建DOM&#xff08;Document Object Model&#xff09;树。具体过程如下&#xff1a; (…...

Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理

引言 Bitmap&#xff08;位图&#xff09;是Android应用内存占用的“头号杀手”。一张1080P&#xff08;1920x1080&#xff09;的图片以ARGB_8888格式加载时&#xff0c;内存占用高达8MB&#xff08;192010804字节&#xff09;。据统计&#xff0c;超过60%的应用OOM崩溃与Bitm…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行

项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战&#xff0c;克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

三分算法与DeepSeek辅助证明是单峰函数

前置 单峰函数有唯一的最大值&#xff0c;最大值左侧的数值严格单调递增&#xff0c;最大值右侧的数值严格单调递减。 单谷函数有唯一的最小值&#xff0c;最小值左侧的数值严格单调递减&#xff0c;最小值右侧的数值严格单调递增。 三分的本质 三分和二分一样都是通过不断缩…...

日常一水C

多态 言简意赅&#xff1a;就是一个对象面对同一事件时做出的不同反应 而之前的继承中说过&#xff0c;当子类和父类的函数名相同时&#xff0c;会隐藏父类的同名函数转而调用子类的同名函数&#xff0c;如果要调用父类的同名函数&#xff0c;那么就需要对父类进行引用&#…...

ubuntu系统文件误删(/lib/x86_64-linux-gnu/libc.so.6)修复方案 [成功解决]

报错信息&#xff1a;libc.so.6: cannot open shared object file: No such file or directory&#xff1a; #ls, ln, sudo...命令都不能用 error while loading shared libraries: libc.so.6: cannot open shared object file: No such file or directory重启后报错信息&…...