传知代码-图神经网络长对话理解(论文复现)
代码以及视频讲解
本文所涉及所有资源均在传知代码平台可获取
概述
情感识别是人类对话理解的关键任务。随着多模态数据的概念,如语言、声音和面部表情,任务变得更加具有挑战性。作为典型解决方案,利用全局和局部上下文信息来预测对话中每个单个句子(即话语)的情感标签。具体来说,全局表示可以通过对话级别的跨模态交互建模来捕获。局部表示通常是通过发言者的时间信息或情感转变来推断的,这忽略了话语级别的重要因素。此外,大多数现有方法在统一输入中使用多模态的融合特征,而不利用模态特定的表示。针对这些问题,我们提出了一种名为“关系时序图神经网络与辅助跨模态交互(CORECT)”的新型神经网络框架,它以模态特定的方式有效捕获了对话级别的跨模态交互和话语级别的时序依赖,用于对话理解。大量实验证明了CORECT的有效性,通过在IEMOCAP和CMUMOSEI数据集上取得了多模态ERC任务的最新成果。
模型整体架构
特征提取
文本采用transformerde方式进行编码
音频,视频都采用全连接的方式进行编码
通过添加相应的讲话者嵌入来增强技术增强
关系时序图卷积网络(RT-GCN)
解读:RT-GCN旨在通过利用话语之间以及话语与其模态之间的多模态图来捕获对话中每个话语的局部上下文信息,关系时序图在一个模块中同时实现了上下文信息,与模态之间的信息的传递。对话中情感识别需要跨模态学习到信息,同时也需要学习上下文的信息,整合成一个模块的作用将两部分并行处理,降低模型的复杂程度,降低训练成本,降低训练难度。
建图方式,模态与模态之间有边相连,对话之间有边相连:
建图之后,用图transformer融合不同模态,以及不同语句的信息,得到处理之后特征向量:
两两交叉模态特征交互
跨模态的异质性经常提高了分析人类语言的难度。利用跨模态交互可能有助于揭示跨模态之间的“不对齐”特性和长期依赖关系。受到这一思想的启发(Tsai等人,2019),我们将配对的跨模态特征交互(P-CM)方法设计到我们提出的用于对话理解的框架中。
线性分类器
最后就是根据提取出来的特征进行情感分类了:
代码修改
这是对话中多模态情感识别(视觉,音频,文本)在数据集IEMOCAP目前为止的SOTA。在离线系统已经取得了相当不错的表现。(离线系统的意思是,是一段已经录制好的视频,而不是事实录制如线上开会)
但是却存在一个问题,输入的数据是已经给定的一个视频,分析某一句话的情感状态的时候,论文的方法使用了过去的信息,也使用了未来的信息,这样会在工业界实时应用场景存在一定的问题。
比如在开线上会议,需要检测开会双方的情绪,不可能用未来将要说的话预测现在的情绪。因为未来的话都还没被说话者说出来,此时,就不能参考到未来的语句来预测现在语句的情感信息。但是原文的方法在数据结构图的构建的时候,连接上了未来语句和现在语句的边,用图神经网络学习了之间的关联。
因此,修改建图方式,不考虑未来的情感信息,重新训练网络,得到了还可以接受的效果,精度大概在82%左右,原文的精度在84%左右,2%精度的牺牲解决了是否能实时的问题其实是值得的。
演示效果
核心逻辑
在这里可以粘贴您的核心代码逻辑:
# start#模型核心部分import torch
import torch.nn as nn
import torch.nn.functional as Ffrom .Classifier import Classifier
from .UnimodalEncoder import UnimodalEncoder
from .CrossmodalNet import CrossmodalNet
from .GraphModel import GraphModel
from .functions import multi_concat, feature_packing
import corectlog = corect.utils.get_logger()class CORECT(nn.Module):def __init__(self, args):super(CORECT, self).__init__()self.args = argsself.wp = args.wpself.wf = args.wfself.modalities = args.modalitiesself.n_modals = len(self.modalities)self.use_speaker = args.use_speakerg_dim = args.hidden_sizeh_dim = args.hidden_sizeic_dim = 0if not args.no_gnn:ic_dim = h_dim * self.n_modalsif not args.use_graph_transformer and (args.gcn_conv == "gat_gcn" or args.gcn_conv == "gcn_gat"):ic_dim = ic_dim * 2if args.use_graph_transformer:ic_dim *= args.graph_transformer_nheadsif args.use_crossmodal and self.n_modals > 1:ic_dim += h_dim * self.n_modals * (self.n_modals - 1)if self.args.no_gnn and (not self.args.use_crossmodal or self.n_modals == 1):ic_dim = h_dim * self.n_modalsa_dim = args.dataset_embedding_dims[args.dataset]['a']t_dim = args.dataset_embedding_dims[args.dataset]['t']v_dim = args.dataset_embedding_dims[args.dataset]['v']dataset_label_dict = {"iemocap": {"hap": 0, "sad": 1, "neu": 2, "ang": 3, "exc": 4, "fru": 5},"iemocap_4": {"hap": 0, "sad": 1, "neu": 2, "ang": 3},"mosei": {"Negative": 0, "Positive": 1},}dataset_speaker_dict = {"iemocap": 2,"iemocap_4": 2,"mosei":1,}tag_size = len(dataset_label_dict[args.dataset])self.n_speakers = dataset_speaker_dict[args.dataset]self.wp = args.wpself.wf = args.wfself.device = args.deviceself.encoder = UnimodalEncoder(a_dim, t_dim, v_dim, g_dim, args)self.speaker_embedding = nn.Embedding(self.n_speakers, g_dim)print(f"{args.dataset} speakers: {self.n_speakers}")if not args.no_gnn:self.graph_model = GraphModel(g_dim, h_dim, h_dim, self.device, args)print('CORECT --> Use GNN')if args.use_crossmodal and self.n_modals > 1:self.crossmodal = CrossmodalNet(g_dim, args)print('CORECT --> Use Crossmodal')elif self.n_modals == 1:print('CORECT --> Crossmodal not available when number of modalitiy is 1')self.clf = Classifier(ic_dim, h_dim, tag_size, args)self.rlog = {}def represent(self, data):# Encoding multimodal featurea = data['audio_tensor'] if 'a' in self.modalities else Nonet = data['text_tensor'] if 't' in self.modalities else Nonev = data['visual_tensor'] if 'v' in self.modalities else Nonea, t, v = self.encoder(a, t, v, data['text_len_tensor'])# Speaker embeddingif self.use_speaker:emb = self.speaker_embedding(data['speaker_tensor'])a = a + emb if a != None else Nonet = t + emb if t != None else Nonev = v + emb if v != None else None# Graph constructmultimodal_features = []if a != None:multimodal_features.append(a)if t != None:multimodal_features.append(t)if v != None:multimodal_features.append(v)out_encode = feature_packing(multimodal_features, data['text_len_tensor'])out_encode = multi_concat(out_encode, data['text_len_tensor'], self.n_modals)out = []if not self.args.no_gnn:out_graph = self.graph_model(multimodal_features, data['text_len_tensor'])out.append(out_graph)if self.args.use_crossmodal and self.n_modals > 1:out_cr = self.crossmodal(multimodal_features)out_cr = out_cr.permute(1, 0, 2)lengths = data['text_len_tensor']batch_size = lengths.size(0)cr_feat = []for j in range(batch_size):cur_len = lengths[j].item()cr_feat.append(out_cr[j,:cur_len])cr_feat = torch.cat(cr_feat, dim=0).to(self.device)out.append(cr_feat)if self.args.no_gnn and (not self.args.use_crossmodal or self.n_modals == 1):out = out_encodeelse:out = torch.cat(out, dim=-1)return outdef forward(self, data):graph_out = self.represent(data)out = self.clf(graph_out, data["text_len_tensor"])return outdef get_loss(self, data):graph_out = self.represent(data)loss = self.clf.get_loss(graph_out, data["label_tensor"], data["text_len_tensor"])return lossdef get_log(self):return self.rlog#图神经网络
import torch
import torch.nn as nn
from torch_geometric.nn import RGCNConv, TransformerConvimport corectclass GNN(nn.Module):def __init__(self, g_dim, h1_dim, h2_dim, num_relations, num_modals, args):super(GNN, self).__init__()self.args = argsself.num_modals = num_modalsif args.gcn_conv == "rgcn":print("GNN --> Use RGCN")self.conv1 = RGCNConv(g_dim, h1_dim, num_relations)if args.use_graph_transformer:print("GNN --> Use Graph Transformer")in_dim = h1_dimself.conv2 = TransformerConv(in_dim, h2_dim, heads=args.graph_transformer_nheads, concat=True)self.bn = nn.BatchNorm1d(h2_dim * args.graph_transformer_nheads)def forward(self, node_features, node_type, edge_index, edge_type):if self.args.gcn_conv == "rgcn":x = self.conv1(node_features, edge_index, edge_type)if self.args.use_graph_transformer:x = nn.functional.leaky_relu(self.bn(self.conv2(x, edge_index)))return x
使用方式&部署方式
首先建议安装conda,因为想要复现深度学习的代码,github上不同项目的环境差别太大,同时处理多个项目的时候很麻烦,在这里就不做conda安装的教程了,请自行学习。
安装pytorch:
请到pytorch官网找安装命令,尽量不要直接pip install
https://pytorch.org/get-started/previous-versions/
给大家直接对着我安装版本来下载,因为图神经网络的包版本要求很苛刻,版本对应不上很容易报错:
只要环境配置好了,找到这个文件,里面的代码粘贴到终端运行即可
温馨提示
1.数据集和已训练好的模型都在.md文件中有百度网盘链接,直接下载放到指定文件夹即可
2.注意,训练出来的模型是有硬件要求的,我是用cpu进行训练的,模型只能在cpu跑,如果想在gpu上跑,请进行重新训练
3.如果有朋友希望用苹果的gpu进行训练,虽然现在pytorch框架已经支持mps(mac版本的cuda可以这么理解)训练,但是很遗憾,图神经网络的包还不支持,不过不用担心,这个模型的训练量很小,我全程都是苹果笔记本完成训练的。
源码下载
相关文章:

传知代码-图神经网络长对话理解(论文复现)
代码以及视频讲解 本文所涉及所有资源均在传知代码平台可获取 概述 情感识别是人类对话理解的关键任务。随着多模态数据的概念,如语言、声音和面部表情,任务变得更加具有挑战性。作为典型解决方案,利用全局和局部上下文信息来预测对话中每…...

部署前端项目
常见部署方式有:静态托管服务、服务器部署 1. 静态托管服务 使用平台部署代码,比如 GitHub。 | 创建一个仓库,仓库名一般是 yourGithubName.github.io。 | 将打包后的静态文件文件上传到仓库。 | 在“Settings”(选项࿰…...

使用POI实现Excel文件的读取(超详细)
目录 一 导入poi相关的maven坐标 二 实现创建并且写入文件 2.1实现步骤 2.2实现代码 2.3效果展示 编辑 2.4注意 三 实现从Excel文件中读取数据 3.1实现步骤 3.2实现代码 3.3结果展示 一 导入poi相关的maven坐标 <!-- Apache poi --><dependency><gro…...
Debezium系列之:记录一次数据库某张表部分数据未同步到hive表的原因
Debezium系列之:记录一次数据库某张表部分数据未同步到hive表的原因 一、背景二、查找数据丢失流程三、数据丢失原因四、解决方法一、背景 反馈mysql数据库中某张表的数据没有同步到hive中,现在需要排查定位下原因数据丢失一般常见需求排查的方向: 数据是否采集到hdfs上采集…...

爆破器材期刊
《爆破器材》简介 《爆破器材》自1958年创刊以来,深受广大读者喜爱,是中国兵工学会主办的中央级技术刊物,在国内外公开发行,近几年已发行到10个国家和地区。《爆破器材》杂志被美国著名检索机构《化学文摘》(CA&a…...
Nginx Websocket 协议配置支持
前后分离的 Web 架构应用,在开发环境启动是可以直接连接支持 websocket 协议,因为没有中间件做转发处理。 当我们对前端进行编译后,通过 nginx 反向代理访问时,需要在nginx 配置文件中增加一些特定的头信息,让服务端识…...
【生成式对抗网络】GANs在数据生成、艺术创作,以及在增强现实和虚拟现实中的应用
一、GANs在数据生成中的应用 生成对抗网络(Generative Adversarial Networks, GANs)在数据生成领域具有显著的应用价值。GANs通过生成器(Generator)和判别器(Discriminator)两个相互竞争的神经网络&#x…...
大模型面试(三)
这次是某家公司的一个电话面试,问的过程还比较简单直接。 问:我们在大模型开源项目的应用上遇到了什么困难? 这个。。有两个困难,一个是RAG的优化,一开始RAG是比较慢的,而且召回率不高; 后来…...
pycharm中快捷键汇总
Pycarm指令汇总 Ctrl鼠标 单击,能直接查看其用法 Ctrl/ 快速注释 CtrlC 在pycharm的terminal中可以停止运行, 其他的地方可以复制。 CtrlV 粘贴 CtrlA 全选 CtrlP 查看()中需要填写什么参数 Altenter 自动不补全所需要的库...
TCP/IP协议族结构和协议
TCP/IP协议族是互联网及许多其他网络的基础,它由一系列相互关联的协议组成,用于实现网络通信。TCP/IP协议族采用ARPANET参考模型,大致可以分为四个层次:链路层、网络层、传输层和应用层。每个层次都有特定的协议和功能,确保数据能够从一个网络设备传输到另一个网络设备。 …...
大模型一些概念的理解 - 线性层、前向传播、后向传播
文章目录 前言一、线性层1. 什么是线性层?2. 通俗解释3. 示例 二、前向传播1. 什么是前向传播?2. 通俗解释3. 示例 三、后向传播1. 什么是后向传播?2. 通俗解释3. 具体步骤 四、示例五、在 PyTorch 中的后向传播 前言 最近提问里有问到一些名…...

AWS 云安全性:检测 SSH 暴力攻击
由于开源、低成本、可靠性和灵活性等优势,云基础设施主要由基于linux的机器主导,然而,它们也不能幸免于黑客的攻击,从而影响云的安全性。攻击Linux机器最流行的方法之一是通过SSH通道。 什么是 SSH 安全外壳协议(Sec…...

7.9数据结构
思维导图 作业 doubleloop.h #ifndef __DOUBLELOOP_H__ #define __DOUBLELOOP_H__#include <stdio.h> #include <stdlib.h>typedef int datatype; typedef struct node {union{int len;datatype data;};struct node *pri;//前驱指针struct node *next;//后继指针…...
Python 文件操作:打开数据处理的大门
在 Python 的学习之旅中,文件操作是一个非常实用且必不可少的技能。不论是数据分析还是日常的数据处理,良好的文件操作技巧都能让你的编程之路更加顺畅。今天,我将带你走进 Python 文件操作的世界,不仅教你如何读写文件࿰…...

单对以太网连接器多场景应用
单对以太网连接器应用场景概述 单对以太网(Single Pair Ethernet,简称SPE)作为一种新兴的以太网技术,以其独特的优势在多个领域得到了广泛的应用。SPE通过单对电缆进行数据传输,支持高速数据传输,同时还能…...

Python pip的更新问题
你是否也出现了更新pip的情况 1、提示更新pip版本 pip install --upgrade pip2、更新操作,我操作了 pip install --upgrade pip更新了,等啊等。。。 然后就是连接超时,安装失败 3、我不信,我就要更新,我还要使用镜…...

[Linux][Shell][Shell基础] -- [Shebang][特殊符号][变量][父子Shell]详细讲解
目录 0.前置知识1.Shebang2.Linux特殊符号整理3.变量4.环境变量5.父子shell0.概念1.创建进程列表(创建子shell执行命令) 6.内置命令 vs 外置命令 0.前置知识 #用于注释shell脚本语⾔属于⼀种弱类型语⾔:⽆需声明变量类型,直接定义使⽤shell三剑客&#…...
DS200CVMAG1AEB处理器 控制器 模块
DS200CVMAG1AEB特征: 高性能:采用先进的控制算法和高功率IGBT器件,可提供高电流和精确的运动控制。 高精度:采用高分辨率编码器和位置环路技术,位置精度可达0.1μm,适用于各种精密机械应用,如数…...

阈值分割后配合Connection算子和箭头工具快速知道区域的ID并选择指定区域
代码 dev_close_window () read_image (Image, E:/机器视觉学习/海康视觉平台/二期VM视觉学习/二期VM视觉学习/机器视觉程序/标定相机找圆心和焊头修正相机找圆心之算法软件/标定相机找圆心和焊头修正相机找圆心之算法软件/03 标定相机找圆心/S2/1号机/1.bmp) get_image_size …...

【work】AI八股-神经网络相关
Deep-Learning-Interview-Book/docs/深度学习.md at master amusi/Deep-Learning-Interview-Book GitHub 网上相关总结: 小菜鸡写一写基础深度学习的问题(复制大佬的,自己复习用) - 知乎 (zhihu.com) CV面试问题准备持续更新贴 …...

深度学习在微纳光子学中的应用
深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向: 逆向设计 通过神经网络快速预测微纳结构的光学响应,替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】
微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来,Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...
《Playwright:微软的自动化测试工具详解》
Playwright 简介:声明内容来自网络,将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具,支持 Chrome、Firefox、Safari 等主流浏览器,提供多语言 API(Python、JavaScript、Java、.NET)。它的特点包括&a…...
FastAPI 教程:从入门到实践
FastAPI 是一个现代、快速(高性能)的 Web 框架,用于构建 API,支持 Python 3.6。它基于标准 Python 类型提示,易于学习且功能强大。以下是一个完整的 FastAPI 入门教程,涵盖从环境搭建到创建并运行一个简单的…...

UE5 学习系列(三)创建和移动物体
这篇博客是该系列的第三篇,是在之前两篇博客的基础上展开,主要介绍如何在操作界面中创建和拖动物体,这篇博客跟随的视频链接如下: B 站视频:s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...
Nginx server_name 配置说明
Nginx 是一个高性能的反向代理和负载均衡服务器,其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机(Virtual Host)。 1. 简介 Nginx 使用 server_name 指令来确定…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南
🚀 C extern 关键字深度解析:跨文件编程的终极指南 📅 更新时间:2025年6月5日 🏷️ 标签:C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言🔥一、extern 是什么?&…...
css3笔记 (1) 自用
outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size:0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格ÿ…...

蓝桥杯3498 01串的熵
问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798, 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...
Python ROS2【机器人中间件框架】 简介
销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...