当前位置: 首页 > news >正文

深度学习实战26-(Pytorch)搭建TextCNN实现多标签文本分类的任务

大家好,我是微学AI,今天给大家介绍一下深度学习实战26-(Pytorch)搭建TextCNN实现多标签文本分类的任务,TextCNN是一种用于文本分类的深度学习模型,它基于卷积神经网络(Convolutional Neural Networks, CNN)实现。TextCNN的主要思想是使用卷积操作从文本中提取有用的特征,并使用这些特征来预测文本的类别。

TextCNN将文本看作是一个一维的时序数据,将每个单词嵌入到一个向量空间中,形成一个词向量序列。然后,TextCNN通过堆叠一些卷积层和池化层来提取关键特征,并将其转换成一个固定大小的向量。最后,该向量将被送到一个全连接层进行分类。TextCNN的优点在于它可以非常有效地捕捉文本中的局部和全局特征,从而提高分类精度。此外,TextCNN的训练速度相对较快,具有较好的可扩展性.

TextCNN做多标签分类

1.库包导入

import os
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
from collections import Counter

 2.定义参数

max_length = 20
batch_size = 32
embedding_dim = 100
num_filters = 100
filter_sizes = [2, 3, 4]
num_classes = 4
learning_rate = 0.001
num_epochs = 2000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3. 数据集处理函数


def load_data(file_path):df = pd.read_csv(file_path,encoding='gbk')texts = df['text'].tolist()labels = df['label'].apply(lambda x: x.split("-")).tolist()return texts, labelsdef preprocess_text(text):text = re.sub(r'[^\w\s]', '', text)return text.strip().lower().split()def build_vocab(texts, max_size=10000):word_counts = Counter()for text in texts:word_counts.update(preprocess_text(text))vocab = {"<PAD>": 0, "<UNK>": 1}for i, (word, count) in enumerate(word_counts.most_common(max_size - 2)):vocab[word] = i + 2return vocabdef encode_text(text, vocab):tokens = preprocess_text(text)return [vocab.get(token, vocab["<UNK>"]) for token in tokens]def pad_text(encoded_text, max_length):return encoded_text[:max_length] + [0] * max(0, max_length - len(encoded_text))def encode_label(labels, label_set):encoded_labels = []for label in labels:encoded_label = [0] * len(label_set)for l in label:if l in label_set:encoded_label[label_set.index(l)] = 1encoded_labels.append(encoded_label)return encoded_labelsclass TextDataset(Dataset):def __init__(self, texts, labels):self.texts = textsself.labels = labelsdef __len__(self):return len(self.texts)def __getitem__(self, index):return torch.tensor(self.texts[index], dtype=torch.long), torch.tensor(self.labels[index], dtype=torch.float32)texts, labels = load_data("data_qa.csv")
vocab = build_vocab(texts)
label_set = ["人工智能", "卷积神经网络", "大数据",'ChatGPT']encoded_texts = [pad_text(encode_text(text, vocab), max_length) for text in texts]
encoded_labels = encode_label(labels, label_set)X_train, X_test, y_train, y_test = train_test_split(encoded_texts, encoded_labels, test_size=0.2, random_state=42)
#print(X_train,y_train)train_dataset = TextDataset(X_train, y_train)
test_dataset = TextDataset(X_test, y_test)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

数据集样例:

textlabel
人工智能如何影响进出口贸易——基于国家层面数据的实证检验人工智能
生成式人工智能——ChatGPT的变革影响、风险挑战及应对策略人工智能-ChatGPT
人工智能与人的自由全面发展关系探究——基于马克思劳动解放思想人工智能
中学生人工智能技术使用持续性行为意向影响因素研究 人工智能
人工智能技术在航天装备领域应用探讨 人工智能
人工智能赋能教育的伦理省思 人工智能
人工智能的神话:ChatGPT与超越的数字劳动“主体”之辨 人工智能-ChatGPT
人工智能(ChatGPT)对社科类研究生教育的挑战与机遇 人工智能-ChatGPT
人工智能助推教育变革的现实图景——教师对ChatGPT的应对策略分析 人工智能-ChatGPT
智能入场与民主之殇:人工智能时代民主政治的风险与挑战 人工智能
国内人工智能写作的研究现状分析及启示 人工智能
人工智能监管:理论、模式与趋势 人工智能
“新一代人工智能技术ChatGPT的应用与规制”笔谈 人工智能-ChatGPT
ChatGPT新一代人工智能技术发展的经济和社会影响 人工智能-ChatGPT
ChatGPT赋能劳动教育的图景展现及其实践策略 人工智能-ChatGPT
人工智能聊天机器人—基于ChatGPT、Microsoft Bing视角分析 人工智能-ChatGPT
拜登政府对华人工智能产业的打压与中国因应 人工智能
人工智能技术在现代农业机械中的应用研究人工智能
人工智能对中国制造业创新的影响研究—来自工业机器人应用的证据 人工智能
人工智能技术在电子产品设计中的应用人工智能
ChatGPT等智能内容生成与新闻出版业面临的智能变革人工智能-ChatGPT
基于卷积神经网络的农作物智能图像识别分类研究人工智能-卷积神经网络
基于卷积神经网络的图像分类改进方法研究人工智能-卷积神经网络

 

这里设置多标签,用“-”符号隔开多个标签。

4.构建模型

class TextCNN(nn.Module):def __init__(self, vocab_size, embedding_dim, num_filters, filter_sizes, num_classes, dropout=0.5):super(TextCNN, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.convs = nn.ModuleList([nn.Conv2d(1, num_filters, (fs, embedding_dim)) for fs in filter_sizes])self.dropout = nn.Dropout(dropout)self.fc = nn.Linear(num_filters * len(filter_sizes), num_classes)def forward(self, x):x = self.embedding(x)x= x.unsqueeze(1)x = [torch.relu(conv(x)).squeeze(3) for conv in self.convs]x = [torch.max_pool1d(i, i.size(2)).squeeze(2) for i in x]x = torch.cat(x, 1)x = self.dropout(x)logits = self.fc(x)return torch.sigmoid(logits)

5.模型训练

def train_epoch(model, dataloader, criterion, optimizer, device):model.train()running_loss = 0.0correct_preds = 0  # 记录正确预测的数量total_preds = 0  # 记录总的预测数量for inputs, targets in dataloader:inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()running_loss += loss.item()# 计算正确预测的数量predicted_labels = torch.argmax(outputs, dim=1)targets = torch.argmax(targets, dim=1)correct_preds += (predicted_labels == targets).sum().item()total_preds += len(targets)accuracy = correct_preds / total_preds  # 计算准确率return running_loss / len(dataloader), accuracy  # 返回平均损失和准确率def evaluate(model, dataloader, device):model.eval()preds = []targets = []with torch.no_grad():for inputs, target in dataloader:inputs = inputs.to(device)outputs = model(inputs)preds.extend(outputs.cpu().numpy())targets.extend(target.numpy())return np.array(preds), np.array(targets)def calculate_metrics(preds, targets, threshold=0.5):preds = (preds > threshold).astype(int)f1 = f1_score(targets, preds, average="micro")precision = precision_score(targets, preds, average="micro")recall = recall_score(targets, preds, average="micro")return {"f1": f1, "precision": precision, "recall": recall}model = TextCNN(len(vocab), embedding_dim, num_filters, filter_sizes, num_classes).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)for epoch in range(num_epochs):if epoch % 20==0:train_loss,accuracy = train_epoch(model, train_loader, criterion, optimizer, device)print(f"Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Train Accuracy: {accuracy:.4f}")preds, targets = evaluate(model, test_loader, device)metrics = calculate_metrics(preds, targets)print(f"Epoch: {epoch + 1}, F1: {metrics['f1']:.4f}, Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}")
...
Epoch: 1821, Train Loss: 0.0055, Train Accuracy: 0.8837
Epoch: 1821, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1841, Train Loss: 0.0064, Train Accuracy: 0.9070
Epoch: 1841, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1861, Train Loss: 0.0047, Train Accuracy: 0.8837
Epoch: 1861, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1881, Train Loss: 0.0058, Train Accuracy: 0.8605
Epoch: 1881, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1901, Train Loss: 0.0064, Train Accuracy: 0.8488
Epoch: 1901, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1921, Train Loss: 0.0062, Train Accuracy: 0.8140
Epoch: 1921, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1941, Train Loss: 0.0059, Train Accuracy: 0.8953
Epoch: 1941, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1961, Train Loss: 0.0053, Train Accuracy: 0.8488
Epoch: 1961, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1981, Train Loss: 0.0055, Train Accuracy: 0.8488
Epoch: 1981, F1: 0.9429, Precision: 0.9429, Recall: 0.9429

大家可以利用自己的数据集进行训练,按照格式修改即可

相关文章:

深度学习实战26-(Pytorch)搭建TextCNN实现多标签文本分类的任务

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下深度学习实战26-(Pytorch)搭建TextCNN实现多标签文本分类的任务&#xff0c;TextCNN是一种用于文本分类的深度学习模型&#xff0c;它基于卷积神经网络(Convolutional Neural Networks, CNN)实现。TextCNN的主要思想…...

还在精神内耗?还在焦虑?可以看看这个

作为一个即将毕业的本科生&#xff0c;总是会不由自主的焦虑。因为不考研&#xff0c;所以显得和同学们格格不入&#xff0c;每天都在进行精神内耗&#xff0c;但是我不经意间看到了一个东西-《邓宁克鲁格效应》 上述的四个阶段刻画出了一条典型的“大师养成之路”。但大师毕竟…...

Event Camera (事件相机)

1.传统相机的缺点 1.随着计算机视觉领域的不断发展&#xff0c;目标检测的算法也越来越多样化&#xff0c;特别是近些年深度学习在计算机视觉领域的进步&#xff0c;已经产生了很多优秀的目标检测方法&#xff0c;这些基于帧的方法对于图片的质量有一定的要求&#xff0c;比如合…...

藏经阁(七)有源蜂鸣器和无源蜂鸣器 解析

文章目录 特征区别场景选型实战应用 特征 有源蜂鸣器特征&#xff1a; 又被称为直流蜂鸣器包含了一个多谐振荡器只要额定直流电压可以在两端发出声音具有驱动控制简单价格略高 无源蜂鸣器特征&#xff1a; 又被称为交流蜂鸣器内部没有振荡器需要在两端施加特定频率的方波电…...

配置FTP/TFTP协议的ASPF

在多通道协议和NAT的应用中&#xff0c;ASPF是重要的辅助功能。通过配置ASPF功能&#xff0c;实现内网正常对外提供FTP和TFTP服务&#xff0c;同时还可避免内网用户在访问外网Web服务器时下载危险控件。 组网需求 如图1所示&#xff0c;FW部署在某公司的出口&#xff0c;公司提…...

泛型基本说明

使用传统方法的问题分析 不能对加入到集合ArrayList中的数据类型进行约束&#xff08;不安全&#xff09;遍历的时候&#xff0c;需要进行类型转换&#xff0c;如果集合中的数据量较大&#xff0c;对效率有影响。泛型的好处 编译时&#xff0c;检查添加元素的类型&#xff0c;提…...

干洗店洗鞋下店预约小程序开发多少钱

干洗店小程序是一种便捷的移动应用程序&#xff0c;能够帮助用户快捷、轻松地处理干洗、洗衣和清洗等服务。随着智能手机普及和人们生活节奏的不断加快&#xff0c;越来越多人选择使用干洗店小程序来满足自己的日常衣物清洗需求。那干洗店小程序怎么弄&#xff0c;洗衣预约小程…...

用Python实现批量翻译文档文件

文件名批量翻译需要用到编程语言和相应的翻译 API&#xff0c;下面以 Python 和 Google 翻译 API 为例&#xff0c;介绍具体的实现步骤&#xff1a; 安装必要的 Python 库 使用 Python 代码进行文件名翻译需要先安装两个库&#xff1a;googletrans 和 os。 pip install goog…...

机器视觉公司,在玩一局玩不起的游戏

导语 有个著名咨询公司曾经预测过&#xff1a;未来只有两种公司&#xff0c;是人工智能的和不赚钱的。 它可能没想到&#xff0c;还有第三种——不赚钱的AI公司。 去年我们报道过“正在消失的机器视觉公司”&#xff0c;昔日的“AI 四小龙”&#xff08; 商汤、旷视、云从、依图…...

Zephyr 消息队列

文章目录 简介数据结构k_msgq 定义消息队列发送消息k_msgq_put 接收消息k_msgq_get wait_q 的双重身份清理消息队列k_msgq_cleanup 重置消息队列k_msgq_purge 读取数据k_msgq_peekk_msgq_peek_at 缓冲区容量k_msgq_num_free_getk_msgq_num_used_get 简介 message queue 用于中…...

Jenkins自动化部署实例讲解

文章目录 前言实例讲解基本环境全局工具配置创建任务任务配置源码管理构建步骤&#xff08;Build Steps&#xff09;第一步&#xff1a;调用Maven第二步&#xff1a;执行shell启动容器 后记 前言 你平常在做自己的项目时&#xff0c;是否有过部署项目太麻烦的想法&#xff1f;…...

RK356X 解除UVC摄像头预览分辨率1080P限制

平台 RK3566 Android 11 概述 UVC&#xff1a; USB video class&#xff08;又称为USB video device class or UVC&#xff09;就是USB device class视频产品在不需要安装任何的驱动程序下即插即用&#xff0c;包括摄像头、数字摄影机、模拟视频转换器、电视卡及静态视频相机…...

English Learning - L2-14 英音地道语音语调 重音技巧 2023.04.10 周一

English Learning - L2-14 英音地道语音语调 重音技巧 2023.04.10 周一 课前热身重音日常表达节奏单词全部重读的句子间隔时间非重读单词代词和缩约词助动词声临其境语调预习 课前热身 学习目标 重音 重弱突出&#xff0c;重音突出核心表达的意思 重音是落在重读单词上&…...

3.6 n维随机变量

学习目标&#xff1a; 学习n维随机变量需要掌握一定的数学知识&#xff0c;包括多元微积分、线性代数和概率论等。要学习n维随机变量&#xff0c;我会采取以下步骤&#xff1a; 复习相关的数学知识&#xff1a;首先&#xff0c;我会复习多元微积分、线性代数和概率论的基本知…...

JavaSE学习进阶day06_02 Set集合和Set接口

第二章 Set系列集合和Set接口 Set集合概述&#xff1a;前面学习了Collection集合下的List集合&#xff0c;现在继续学习它的另一个分支&#xff0c;Set集合。 set系列集合的特点&#xff1a; Set接口&#xff1a; java.util.Set接口和java.util.List接口一样&#xff0c;同样…...

基于matlab分析卫星星座对通信链路的干扰

一、前言 此示例说明如何分析从中地球轨道 &#xff08;MEO&#xff09; 中的卫星星座到位于太平洋的地面站的下行链路上的干扰。干扰星座由低地球轨道&#xff08;LEO&#xff09;的40颗卫星组成。此示例确定下行链路闭合的时间、载波噪声加干扰比以及链路裕量。 此示例需要卫…...

Python中的异常——概述和基本语法

Python中的异常——概述和基本语法 摘要&#xff1a;Python中的异常是指在程序运行时发生的错误情况&#xff0c;包括但不限于除数为0、访问未定义变量、数据类型错误等。异常处理机制是Python提供的一种解决这些错误的方法&#xff0c;我们可以使用try/except语句来捕获异常并…...

Tomcat 部署与优化

1. Tomcat概述 Tomcat是Java语言开发的&#xff0c;Tomcat服务器是一个免费的开放源代码的Web应用服务器&#xff0c;是Apache软件基金会的Jakarta项目中的一个核心项目&#xff0c;由Apache、Sun和其他一些公司及个人 共同开发而成。Tomcat属于轻量级应用服务器&#xff0c;在…...

多模态之论文笔记ViLT

文章目录 ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision一. 简介1.1 摘要1.2 文本编码器&#xff0c;图像编码器&#xff0c;特征交互复杂度分析1.2 特征交互方式分析1.3 图像特征提取分析 二. 方法 Vision-and-Language Transformer2.1.方…...

微服务架构下认证和鉴权理解

认证和鉴权 从单体应用到微服务架构&#xff0c;优势很多&#xff0c;但是并不是代表着就没有一点缺点了。 微服务架构&#xff0c;意味着每个服务都是松散耦合的。因此&#xff0c;作为软件工程师和架构师&#xff0c;我们在分布式架构中面临着安全挑战。微服务对外开放的端…...

Learn Claude Code Agent 开发 | 2、插拔式工具系统:扩展功能不修改核心循环

Learn Claude Code Agent 开发 | 2、插拔式工具系统&#xff1a;扩展功能不修改核心循环 整体概述 多工具分发核心实现是基础智能体循环的直接扩展&#xff0c;核心思想就是&#xff1a; “加一个工具, 只加一个 handler” – 循环不用动, 新工具注册进 dispatch map 就行。 …...

HoloPart:当3D模型学会自我解剖,深度学习的“X光眼“如何看透一切

HoloPart&#xff1a;当3D模型学会自我解剖&#xff0c;深度学习的"X光眼"如何看透一切 【免费下载链接】HoloPart Generative 3D Part Amodal Segmentation 项目地址: https://gitcode.com/gh_mirrors/ho/HoloPart 你是否曾对着一个复杂的3D模型感到困惑——…...

PowerPaint-V1 Gradio与VSCode集成开发:图像修复插件开发指南

PowerPaint-V1 Gradio与VSCode集成开发&#xff1a;图像修复插件开发指南 1. 开发环境准备 开始之前&#xff0c;我们需要准备好开发环境。VSCode作为代码编辑器&#xff0c;配合Python环境&#xff0c;可以让你更高效地开发PowerPaint-V1的图像修复插件。 首先确保你的系统…...

Captura视频质量优化终极指南:先降噪后锐化的完美工作流

Captura视频质量优化终极指南&#xff1a;先降噪后锐化的完美工作流 【免费下载链接】Captura Capture Screen, Audio, Cursor, Mouse Clicks and Keystrokes 项目地址: https://gitcode.com/gh_mirrors/ca/Captura Captura是一款功能强大的屏幕录制工具&#xff0c;支持…...

为什么你的Python 3.14 JIT始终未触发?揭开__pycache__/jit_profile.bin隐藏机制与企业级profile引导策略(仅3家头部云厂商公开的冷启动预热方案)

第一章&#xff1a;Python 3.14 JIT 编译器的演进逻辑与企业级定位Python 3.14 引入的原生 JIT&#xff08;Just-In-Time&#xff09;编译器并非对 CPython 的简单性能补丁&#xff0c;而是基于多年运行时分析与生产环境反馈重构的执行引擎。其核心演进逻辑聚焦于“渐进式优化”…...

Laf云平台终极灾备指南:如何实现多区域部署与智能故障转移

Laf云平台终极灾备指南&#xff1a;如何实现多区域部署与智能故障转移 【免费下载链接】laf labring/laf: 是一个用于 PHP 的轻量级 AJAX 库&#xff0c;可以方便地在 PHP 应用中实现 AJAX 通信。适合对 PHP、AJAX 库和想要实现 PHP AJAX 通信的开发者。 项目地址: https://g…...

Meta Manus vs OpenClaw:2026年AI Agent之战,谁才是你的最佳选择?

## 引言2026年AI Agent市场迎来爆发式增长&#xff0c;预计到2034年将达到1400亿美元规模。在这个赛道上&#xff0c;Meta的Manus和开源项目OpenClaw成为最受关注的两大竞争者。本文将深入分析两者的差异&#xff0c;帮助你做出最佳选择。## Meta Manus&#xff1a;巨头的入场#…...

Cloudflare邮件路由的隐藏玩法:一个域名无限别名,管理不同网站注册,再也不怕信息泄露

Cloudflare邮件路由的隐私管理艺术&#xff1a;用无限别名打造数字身份防火墙 在个人信息如同裸奔的数字时代&#xff0c;每次网站注册都是一次隐私赌博。你是否经历过这样的困扰&#xff1f;某个小众论坛注册三个月后&#xff0c;主邮箱突然涌入大量赌博邮件&#xff1b;双十一…...

Tesla Dashcam:3步搞定特斯拉行车记录视频合并的专业工具

Tesla Dashcam&#xff1a;3步搞定特斯拉行车记录视频合并的专业工具 【免费下载链接】tesla_dashcam Convert Tesla dash cam movie files into one movie 项目地址: https://gitcode.com/gh_mirrors/te/tesla_dashcam 还在为特斯拉行车记录仪生成的零散视频文件而烦恼…...

AI生成内容检测新思路:除了红绿词表,我们还能用哪些方法识别ChatGPT写的文章?

AI生成内容检测技术全景&#xff1a;超越红绿词表的七种实战方法 当ChatGPT生成的论文摘要通过学术评审、AI撰写的新闻稿被主流媒体刊发时&#xff0c;内容真实性的边界正在变得模糊。某高校教授最近向我展示了一份学生作业——文笔流畅的哲学论述&#xff0c;最终被证实完全由…...