当前位置: 首页 > article >正文

用PyTorch和TorchText搞定AG_NEWS新闻分类:从数据加载到75%准确率的保姆级代码

用PyTorch和TorchText实现AG_NEWS新闻分类从零到75%准确率的完整指南当你第一次接触文本分类任务时可能会被数据处理和模型构建的复杂性吓到。本文将带你用PyTorch和TorchText从零开始构建一个新闻分类器无需任何先验知识只需跟着步骤操作就能达到75%的准确率。1. 环境准备与数据加载在开始之前确保你已经安装了最新版本的PyTorch和TorchText。可以通过以下命令安装pip install torch torchtextAG_NEWS是学术界常用的新闻分类数据集包含四个类别世界新闻、体育、商业和科技。让我们先加载这个数据集import torch import torch.nn as nn from torchtext.datasets import AG_NEWS from torchtext.data.utils import get_tokenizer from collections import Counter, OrderedDict # 创建数据目录并加载数据集 train_dataset, test_dataset AG_NEWS(root./data, split(train, test)) classes [World, Sports, Business, Sci/Tech]提示如果直接从GitHub下载数据集遇到问题可以尝试设置代理或更换网络环境。查看数据集前几个样本可以帮助我们理解数据结构for i, (label, text) in zip(range(3), train_dataset): print(f{classes[label-1]}: {text[:50]}...)2. 文本预处理与词表构建文本数据不能直接输入模型需要转换为数值形式。我们将使用以下步骤处理文本分词将句子拆分为单词或子词单元构建词表创建单词到索引的映射序列化将文本转换为数字序列# 使用基础英语分词器 tokenizer get_tokenizer(basic_english) # 构建词表 counter Counter() for (label, text) in train_dataset: counter.update(tokenizer(text)) # 按词频排序并创建词表 vocab torchtext.vocab.vocab( OrderedDict(sorted(counter.items(), keylambda x: x[1], reverseTrue)), min_freq1 ) vocab_size len(vocab) print(f词表大小: {vocab_size})3. 数据批处理与填充文本长度不一致是常见问题我们需要统一长度以便批量处理def padify(batch): # 将文本转换为索引序列 texts [vocab.lookup_indices(tokenizer(item[1])) for item in batch] # 获取当前批次最大长度 max_len max(map(len, texts)) # 对短文本进行填充 padded [torch.nn.functional.pad( torch.tensor(text), (0, max_len - len(text)), modeconstant, value0 ) for text in texts] labels torch.LongTensor([item[0]-1 for item in batch]) texts torch.stack(padded) return labels, texts # 创建数据加载器 train_loader torch.utils.data.DataLoader( list(train_dataset), # 转换为列表 batch_size32, collate_fnpadify, shuffleTrue )4. 模型构建与训练我们将使用简单的嵌入层全连接层的架构class NewsClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, num_classes): super().__init__() self.embedding nn.Embedding(vocab_size, embed_dim) self.fc nn.Linear(embed_dim, num_classes) def forward(self, x): # 获取词嵌入 embedded self.embedding(x) # [batch_size, seq_len, embed_dim] # 平均池化 pooled torch.mean(embedded, dim1) # [batch_size, embed_dim] # 分类 return self.fc(pooled)训练函数实现def train_model(model, dataloader, epochs3, lr0.001): optimizer torch.optim.Adam(model.parameters(), lrlr) criterion nn.CrossEntropyLoss() model.train() for epoch in range(epochs): total_loss 0 correct 0 count 0 for labels, texts in dataloader: optimizer.zero_grad() outputs model(texts) loss criterion(outputs, labels) loss.backward() optimizer.step() total_loss loss.item() correct (outputs.argmax(1) labels).sum().item() count len(labels) print(fEpoch {epoch1}: Loss{total_loss/count:.4f}, Acc{correct/count:.2%})初始化并训练模型# 初始化模型 model NewsClassifier(vocab_size, 64, len(classes)) # 训练模型 train_model(model, train_loader, epochs3, lr0.001)5. 模型评估与优化训练完成后我们需要评估模型在测试集上的表现def evaluate(model, dataloader): model.eval() correct 0 total 0 with torch.no_grad(): for labels, texts in dataloader: outputs model(texts) correct (outputs.argmax(1) labels).sum().item() total len(labels) return correct / total test_loader torch.utils.data.DataLoader( list(test_dataset), batch_size32, collate_fnpadify ) accuracy evaluate(model, test_loader) print(f测试集准确率: {accuracy:.2%})要进一步提升性能可以考虑以下优化策略调整嵌入维度尝试32、64、128等不同维度使用预训练词向量如GloVe或FastText添加更多层在嵌入层后加入LSTM或Transformer层调整学习率尝试不同的学习率和学习率调度策略# 使用预训练词向量的示例 pretrained_embeds torchtext.vocab.GloVe(name6B, dim100) # 替换模型中的嵌入层 model.embedding.weight.data.copy_(pretrained_embeds.get_vecs_by_tokens(vocab.get_itos()))6. 实际应用与部署训练好的模型可以保存并用于实际预测# 保存模型 torch.save(model.state_dict(), news_classifier.pth) # 加载模型 loaded_model NewsClassifier(vocab_size, 64, len(classes)) loaded_model.load_state_dict(torch.load(news_classifier.pth)) # 预测函数 def predict(text, model, vocab, tokenizer): tokens tokenizer(text) indices vocab.lookup_indices(tokens) tensor torch.LongTensor(indices).unsqueeze(0) model.eval() with torch.no_grad(): output model(tensor) prob torch.softmax(output, dim1) pred_class classes[output.argmax().item()] return pred_class, prob.max().item() # 示例预测 sample_text Apple releases new iPhone with advanced AI features pred_class, confidence predict(sample_text, loaded_model, vocab, tokenizer) print(f预测类别: {pred_class}, 置信度: {confidence:.2%})7. 常见问题与解决方案在实际应用中可能会遇到以下问题内存不足减小批量大小使用更小的嵌入维度考虑使用更高效的优化器如Adagrad过拟合添加Dropout层使用L2正则化增加训练数据性能瓶颈使用DataLoader的num_workers参数并行加载数据考虑使用混合精度训练在GPU上训练# 添加Dropout的改进模型 class ImprovedNewsClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, num_classes): super().__init__() self.embedding nn.Embedding(vocab_size, embed_dim) self.dropout nn.Dropout(0.5) self.fc nn.Linear(embed_dim, num_classes) def forward(self, x): embedded self.embedding(x) pooled torch.mean(embedded, dim1) pooled self.dropout(pooled) return self.fc(pooled)在项目实践中我发现以下几个技巧特别有用使用学习率调度器可以显著提升模型性能早停法(early stopping)能有效防止过拟合模型集成可以进一步提升准确率1-2个百分点

相关文章:

用PyTorch和TorchText搞定AG_NEWS新闻分类:从数据加载到75%准确率的保姆级代码

用PyTorch和TorchText实现AG_NEWS新闻分类:从零到75%准确率的完整指南 当你第一次接触文本分类任务时,可能会被数据处理和模型构建的复杂性吓到。本文将带你用PyTorch和TorchText从零开始构建一个新闻分类器,无需任何先验知识,只需…...

3步解锁百度网盘SVIP特权:macOS用户必备的高速下载解决方案

3步解锁百度网盘SVIP特权:macOS用户必备的高速下载解决方案 【免费下载链接】BaiduNetdiskPlugin-macOS For macOS.百度网盘 破解SVIP、下载速度限制~ 项目地址: https://gitcode.com/gh_mirrors/ba/BaiduNetdiskPlugin-macOS 还在为百度网盘Mac客户端的龟速…...

XUnity.AutoTranslator实战指南:Unity游戏实时翻译解决方案与开发者实践指南

XUnity.AutoTranslator实战指南:Unity游戏实时翻译解决方案与开发者实践指南 【免费下载链接】XUnity.AutoTranslator 项目地址: https://gitcode.com/gh_mirrors/xu/XUnity.AutoTranslator 1. 游戏翻译的核心痛点与技术破局 游戏玩家和开发者常常面临三大…...

ModTheSpire终极指南:杀戮尖塔模组加载器完整使用教程

ModTheSpire终极指南:杀戮尖塔模组加载器完整使用教程 【免费下载链接】ModTheSpire External mod loader for Slay The Spire 项目地址: https://gitcode.com/gh_mirrors/mo/ModTheSpire ModTheSpire是一款专为《杀戮尖塔》设计的强大模组加载器&#xff0c…...

终极指南:如何使用XGP-save-extractor解锁Xbox Game Pass存档迁移自由

终极指南:如何使用XGP-save-extractor解锁Xbox Game Pass存档迁移自由 【免费下载链接】XGP-save-extractor Python script to extract savefiles out of Xbox Game Pass for PC games 项目地址: https://gitcode.com/gh_mirrors/xg/XGP-save-extractor XGP-…...

专业Steam创意工坊下载解决方案:WorkshopDL跨平台多引擎架构指南

专业Steam创意工坊下载解决方案:WorkshopDL跨平台多引擎架构指南 【免费下载链接】WorkshopDL WorkshopDL - The Best Steam Workshop Downloader 项目地址: https://gitcode.com/gh_mirrors/wo/WorkshopDL WorkshopDL是一款专为技术爱好者和进阶用户设计的跨…...

Pixel Language Portal效果展示:实时翻译+st.balloons()庆祝动画+HP状态变化的沉浸式交互录屏

Pixel Language Portal效果展示:实时翻译st.balloons()庆祝动画HP状态变化的沉浸式交互录屏 1. 像素冒险工坊的诞生 在传统翻译工具千篇一律的界面中,Pixel Language Portal(像素语言跨维传送门)带来了全新的视觉冲击。这款基于…...

NVIDIA Profile Inspector终极指南:解锁显卡隐藏性能的完整方案

NVIDIA Profile Inspector终极指南:解锁显卡隐藏性能的完整方案 【免费下载链接】nvidiaProfileInspector 项目地址: https://gitcode.com/gh_mirrors/nv/nvidiaProfileInspector NVIDIA Profile Inspector是一款能够深度访问和修改NVIDIA显卡驱动配置的专业…...

TCC-G15散热控制实战指南:释放Dell游戏本性能潜力

TCC-G15散热控制实战指南:释放Dell游戏本性能潜力 【免费下载链接】tcc-g15 Thermal Control Center for Dell G15 - open source alternative to AWCC 项目地址: https://gitcode.com/gh_mirrors/tc/tcc-g15 一、问题发现:游戏本散热困境的技术根…...

利用快马ai快速构建can协议数据监控工具原型

利用快马AI快速构建CAN协议数据监控工具原型 最近在做一个汽车电子相关的项目,需要监控CAN总线上的数据。作为初学者,我对CAN协议的理解还停留在理论层面,实际开发时发现从零搭建解析工具非常耗时。好在发现了InsCode(快马)平台,…...

图像恢复新基准:从复杂到简约,NAFNet如何重塑设计范式

1. 图像恢复的困境与NAFNet的破局之道 每次看到老照片上的划痕或是手机拍糊的夜景,总让人忍不住想:要是能一键修复该多好。这正是图像恢复技术要解决的问题——让模糊、噪点、压缩失真等受损图像重获新生。但你可能不知道,这个领域正面临着一…...

突破窗口尺寸限制:WindowResizer如何重新定义Windows界面控制

突破窗口尺寸限制:WindowResizer如何重新定义Windows界面控制 【免费下载链接】WindowResizer 一个可以强制调整应用程序窗口大小的工具 项目地址: https://gitcode.com/gh_mirrors/wi/WindowResizer WindowResizer是一款专注于解决Windows窗口尺寸调整难题的…...

GraspNet环境配置与编译问题实战指南

1. GraspNet环境配置避坑指南 第一次接触GraspNet这个3D抓取检测框架时,我花了整整三天时间才把环境配好。现在回想起来,大部分时间都浪费在了一些完全可以避免的问题上。今天我就把这些经验总结出来,帮你少走弯路。 GraspNet对CUDA和cuDNN的…...

基于STM32CubeMX HAL库的RS485半双工通信实战指南

1. RS485通信基础与STM32开发环境搭建 第一次接触RS485通信时,我被它独特的半双工特性深深吸引。想象一下双向单车道的马路,车辆只能单向交替通行,这就是半双工的精髓。相比全双工需要两根数据线的设计,RS485仅用一对双绞线就能实…...

Simulink AUTOSAR实战:从模型信号到RTE接口的完整映射流程解析

Simulink AUTOSAR实战:从模型信号到RTE接口的完整映射流程解析 在汽车电子软件开发领域,AUTOSAR标准已经成为行业通用架构,而Simulink作为模型化开发的主流工具,如何实现两者无缝衔接是每个汽车软件工程师必须掌握的技能。本文将带…...

告别‘白边’!用HBuilderX给你的UniApp应用做个全屏SPA:安卓透明导航栏+iOS安全区域配置详解

全屏SPA美学:UniApp应用透明导航栏与安全区域配置实战指南 当你在手机上打开一个视频应用,最影响沉浸感的往往不是内容本身,而是那些挥之不去的系统UI元素——安卓底部的虚拟导航栏、iOS标志性的"刘海"安全区域。这些设计本意是为…...

开源优化工具提升BT下载速度实战指南

开源优化工具提升BT下载速度实战指南 【免费下载链接】trackerslist Updated list of public BitTorrent trackers 项目地址: https://gitcode.com/GitHub_Trending/tr/trackerslist 在数字资源获取的过程中,许多用户都曾遭遇过BT下载速度缓慢、进度停滞不前…...

卡证检测矫正模型实操手册:解决‘检测不到’‘矫正失真’‘误检多框’三大问题

卡证检测矫正模型实操手册:解决‘检测不到’‘矫正失真’‘误检多框’三大问题 你是不是也遇到过这样的烦恼?拍了一张身份证照片,想用程序自动识别,结果模型告诉你“没找到”;好不容易检测到了,矫正出来的…...

JAVA红娘交友小程序实现原理及开源uniapp代码片段

JAVA红娘交友小程序实现原理后端架构设计基于Spring Boot框架搭建RESTful API服务,采用Maven进行依赖管理。核心模块包括用户认证模块、匹配算法模块、即时通讯模块和数据持久化模块。数据库设计使用MySQL关系型数据库,主要表结构包括:用户表…...

技术指南|USB接口全解析:从Type-A到Type-C的演变与应用

1. USB接口的前世今生:从Type-A到Type-C的进化之路 记得我第一次接触电脑时,那个蓝色的USB接口让我印象深刻。当时只知道它叫"USB",后来才知道那是Type-A接口。20多年过去,USB接口已经经历了翻天覆地的变化。从最初的T…...

数字记忆守护者:WeChatMsg让微信聊天记录成为永恒的时光胶囊

数字记忆守护者:WeChatMsg让微信聊天记录成为永恒的时光胶囊 【免费下载链接】WeChatMsg 提取微信聊天记录,将其导出成HTML、Word、CSV文档永久保存,对聊天记录进行分析生成年度聊天报告 项目地址: https://gitcode.com/GitHub_Trending/we…...

利用SoftEther实现跨平台虚拟私有网络部署指南

1. SoftEther简介与核心优势 如果你正在寻找一款能同时在Windows、Linux、Mac、Android和iOS上运行的虚拟私有网络解决方案,SoftEther绝对值得深入了解。这个源自日本筑波大学的开源项目,经过多年发展已经成为支持协议最全面的跨平台工具之一。我第一次…...

Qwen3-VL-8B在软件测试中的应用:自动生成测试用例与缺陷报告截图分析

Qwen3-VL-8B在软件测试中的应用:自动生成测试用例与缺陷报告截图分析 最近和几个做软件测试的朋友聊天,大家普遍都在吐槽一件事:写测试用例和缺陷报告太费时间了。尤其是现在敏捷开发节奏快,版本迭代频繁,测试人员不仅…...

突破硬件壁垒:开源驱动技术如何解锁跨系统硬件潜能

突破硬件壁垒:开源驱动技术如何解锁跨系统硬件潜能 【免费下载链接】DFRDisplayKm Windows infrastructure support for Apple DFR (Touch Bar) 项目地址: https://gitcode.com/gh_mirrors/df/DFRDisplayKm 副标题:从驱动开发到功能实现——让专属…...

老旧Mac终极重生指南:OpenCore Legacy Patcher完整教程

老旧Mac终极重生指南:OpenCore Legacy Patcher完整教程 【免费下载链接】OpenCore-Legacy-Patcher Experience macOS just like before 项目地址: https://gitcode.com/GitHub_Trending/op/OpenCore-Legacy-Patcher OpenCore Legacy Patcher是一款强大的开源…...

抖音直播回放下载工具全解析:技术原理与跨领域应用指南

抖音直播回放下载工具全解析:技术原理与跨领域应用指南 【免费下载链接】douyin-downloader A practical Douyin downloader for both single-item and profile batch downloads, with progress display, retries, SQLite deduplication, and browser fallback supp…...

解锁TranslucentTB:4种高效实现Windows任务栏透明化的方法

解锁TranslucentTB:4种高效实现Windows任务栏透明化的方法 【免费下载链接】TranslucentTB A lightweight utility that makes the Windows taskbar translucent/transparent. 项目地址: https://gitcode.com/gh_mirrors/tr/TranslucentTB 任务栏作为Windows…...

如何在VMware上运行macOS虚拟机:终极Unlocker完整指南

如何在VMware上运行macOS虚拟机:终极Unlocker完整指南 【免费下载链接】unlocker VMware Workstation macOS 项目地址: https://gitcode.com/gh_mirrors/unloc/unlocker 你是不是一直想在Windows或Linux电脑上体验macOS系统,却被VMware的限制挡在…...

下一代神经机器翻译质量评估框架:COMET的革命性架构与智能评估范式

下一代神经机器翻译质量评估框架:COMET的革命性架构与智能评估范式 【免费下载链接】COMET A Neural Framework for MT Evaluation 项目地址: https://gitcode.com/gh_mirrors/com/COMET COMET(A Neural Framework for MT Evaluation&#xff09…...

DS4Windows进阶指南:让PlayStation手柄在PC平台发挥极致性能

DS4Windows进阶指南:让PlayStation手柄在PC平台发挥极致性能 【免费下载链接】DS4Windows Like those other ds4tools, but sexier 项目地址: https://gitcode.com/gh_mirrors/ds/DS4Windows DS4Windows是一款开源工具,专为解决PlayStation手柄在…...