CENet及多模态情感计算实战(论文复现)
CENet及多模态情感计算实战(论文复现)
本文所涉及所有资源均在传知代码平台可获取
文章目录
- CENet及多模态情感计算实战(论文复现)
- 概述
- 研究背景
- 主要贡献
- 论文思路
- 主要内容和网络架构
- 数据集介绍
- 性能对比
- 复现过程(重要)
- 演示结果
概述
本文对 “Cross-Modal Enhancement Network for Multimodal Sentiment Analysis” 论文进行讲解和手把手复现教学,解决当下热门的多模态情感计算问题,并展示在MOSI和MOSEI两个数据集上的效果
研究背景
情感分析在人工智能向情感智能发展中起着重要作用。早期的情感分析研究主要集中在分析单模态数据上,包括文本情感分析、图像情感分析、音频情感分析等。然而,人类的情感是通过人体的多种感官来传达的。因此,单模态情感分析忽略了人类情感的多维性。相比之下,多模态情感分析通过结合文本、视觉和音频等多模态数据来推断一个人的情感状态。与单模态情感分析相比,多模态数据包含多样化的情感信息,具有更高的预测精度。目前,多模态情感分析已被广泛应用于视频理解、人机交互、政治活动等领域。近年来,随着互联网和各种多媒体平台的快速发展,通过互联网表达情感的载体和方式也变得越来越多样化。这导致了多媒体数据的快速增长,为多模态情感分析提供了大量的数据源。下图展示了多模态在情感计算任务中的优势。

主要贡献
- 提出了一种跨模态增强网络,通过融入长范围非文本情感语境来增强预训练语言模型中的文本表示;
- 提出一种特征转换策略,通过减小文本模态和非文本模态的初始表示之间的分布差异,促进了不同模态的融合;
- 融合了新的预训练语言模型SentiLARE来提高模型对情感词的提取效率,从而提升对情感计算的准确性。
论文思路
作者提出的跨模态增强网络(CENet)模型通过将视觉和声学信息集成到语言模型中来增强文本表示。在基于transformer的预训练语言模型中嵌入跨模态增强(CE)模块,根据非对齐非文本数据中隐含的长程情感线索增强每个单词的表示。此外,针对声学模态和视觉模态,提出了一种特征转换策略,以减少语言模态和非语言模态的初始表示之间的分布差异,从而促进不同模态的融合。
主要内容和网络架构

通过该图,我们可以看出该模型主要有以下几部分组成:
- 非文本模态特征转化
- 跨模态增强
- 预训练语言模型输出
- 非文本模态转换
针对预训练语言模型,初始文本表示是基于词汇表的单词索引序列,而视觉和听觉的表示则是实值向量序列。为了缩小这些异质模态之间的初始分布差异,进而减少在融合过程中非文本特征和文本特征之间的分布差距,本文提出了一种将非文本向量转换为索引的特征转换策略。这种策略有助于促进文本表征与非文本情感语境的有机融合。
具体而言,特征转换策略利用无监督聚类算法分别构建了“声学词汇表”和“视觉词汇表”。通过查询这些非语言词汇表,可以将原始的非语言特征序列转换为索引序列。下图展示了特征转换过程的具体步骤。考虑到k-means方法具有计算复杂度低和实现简单等优点,作者选择使用k-means算法来学习非语言模态的词汇。

- 跨模态增强模块
本文提出的CE模块旨在将长程视觉和声学信息集成到预训练语言模型中,以增强文本的表示能力。CE模块的核心组件是跨模态嵌入单元,其结构如下图所示。该单元利用跨模态注意力机制捕捉长程非文本情感信息,并生成基于文本的非语言嵌入。嵌入层的参数可学习,用于将经过特征转换策略处理后得到的非文本索引向量映射到高维空间,然后生成文本模态对非文本模态的注意力权重矩阵。

在初始训练阶段,由于语言表征和非语言表征处于不同的特征空间,它们之间的相关性通常较低。因此,注意力权重矩阵中的元素可能较小。为了更有效地学习模型参数,研究者在应用softmax之前使用超参数来缩放这些注意力权重矩阵。
基于注意力权重矩阵,可以生成基于文本的非语言向量。将基于文本的声学嵌入和基于文本的视觉嵌入结合起来,形成非语言增强嵌入。最后,通过整合非语言增强嵌入来更新文本的表示。因此,CE模块的提出旨在为文本提供非语言上下文信息,通过增加非语言增强嵌入来调整文本表示,从而使其在语义上更加准确和丰富。
- 预训练语言模型输出
作者采用SentiLARE作为语言模型,其利用包括词性和单词情感极性在内的语言知识来学习情感感知的语言表示。CE模块被集成到预训练语言模型的第i层中。值得注意的是,任何基于Transformer的预训练语言模型都可以与CE模块集成。下面是作者根据SentiLARE的设置进行的步骤:
- 给定一个单词序列,首先通过Stanford Log-Linear词性(POS)标记器学习其词性序列,并通过SentiwordNet学习单词级情感极性序列。
- 然后,使用预训练语言模型的分词器获取词标索引序列。这个序列作为输入,产生一个初步的增强语言知识表示。
- 更新后的文本表示将作为第(i+1)层的输入,并通过SentiLARE中的剩余层进行处理。
- 每一层的输出将是具有视觉和听觉信息的文本主导的高级情感表示。
- 最后,将这些文本表示输入到分类头中,以获取情感强度。
因此,CE模块通过将非语言增强嵌入集成到预训练语言模型中,有助于生成更富有情感感知的语言表示。这种方法能够在文本表示中有效地整合视觉和听觉信息,从而提升情感分析等任务的性能。
数据集介绍
1. CMU-MOSI: 它是一个多模态数据集,包括文本、视觉和声学模态。它来自Youtube上的93个电影评论视频。这些视频被剪辑成2199个片段。每个片段都标注了[-3,3]范围内的情感强度。该数据集分为三个部分,训练集(1,284段)、验证集(229段)和测试集(686段)。
2. CMU-MOSEI: 它类似于CMU-MOSI,但规模更大。它包含了来自在线视频网站的23,453个注释视频片段,涵盖了250个不同的主题和1000个不同的演讲者。CMU-MOSEI中的样本被标记为[-3,3]范围内的情感强度和6种基本情绪。因此,CMU-MOSEI可用于情感分析和情感识别任务。
性能对比
有下图可以观察到,该论文提出的CENet与其他SOTA模型对比性能有明显提升:

复现过程(重要)
1. 数据集准备
下载MOSI和MOSEI数据集已提取好的特征文件(.pkl)。把它放在"./dataset”目录。2. 下载预训练语言模型
下载SentiLARE语言模型文件,然后将它们放入"/pretrained-model / sentilare_model”目录。3. 下载需要的包
pip install -r requirements.txt4. 搭建CENet模块
利用pytorch框架对CENet模块进行搭建:class CE(nn.Module):def __init__(self, beta_shift_a=0.5, beta_shift_v=0.5, dropout_prob=0.2):super(CE, self).__init__()self.visual_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size)self.acoustic_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size)self.hv = SelfAttention(TEXT_DIM)self.ha = SelfAttention(TEXT_DIM)self.cat_connect = nn.Linear(2 * TEXT_DIM, TEXT_DIM)def forward(self, text_embedding, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None):visual_ = self.visual_embedding(visual_ids)acoustic_ = self.acoustic_embedding(acoustic_ids)visual_ = self.hv(text_embedding, visual_)acoustic_ = self.ha(text_embedding, acoustic_) visual_acoustic = torch.cat((visual_, acoustic_), dim=-1)shift = self.cat_connect(visual_acoustic)embedding_shift = shift + text_embeddingreturn embedding_shift5. 将CE模块与预训练语言模型融合
class BertEncoder(nn.Module):def __init__(self, config):super(BertEncoder, self).__init__()self.output_attentions = config.output_attentionsself.output_hidden_states = config.output_hidden_statesself.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])self.CE = CE()def forward(self, hidden_states, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, attention_mask=None, head_mask=None):all_hidden_states = ()all_attentions = ()for i, layer_module in enumerate(self.layer):if self.output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)if i == ROBERTA_INJECTION_INDEX:hidden_states = self.CE(hidden_states, visual=visual, acoustic=acoustic, visual_ids=visual_ids, acoustic_ids=acoustic_ids)6. 训练代码编写
定义一个整体的训练过程 train()函数,它负责训练模型多个 epoch,并在每个 epoch 结束后评估模型在验证集和测试集上的性能,并记录相关的指标和损失;并在训练最后一轮输出所有测试集id,true label 和 predicted label。
def train(args,model,train_dataloader,validation_dataloader,test_data_loader,optimizer,scheduler, ):valid_losses = []test_accuracies = []for epoch_i in range(int(args.n_epochs)):train_loss, train_pre, train_label = train_epoch(args, model, train_dataloader, optimizer, scheduler)valid_loss, valid_pre, valid_label = evaluate_epoch(args, model, validation_dataloader)test_loss, test_pre, test_label = evaluate_epoch(args, model, test_data_loader)train_acc, train_mae, train_corr, train_f_score = score_model(train_pre, train_label)test_acc, test_mae, test_corr, test_f_score = score_model(test_pre, test_label)non0_test_acc, _, _, non0_test_f_score = score_model(test_pre, test_label, use_zero=True)valid_acc, valid_mae, valid_corr, valid_f_score = score_model(valid_pre, valid_label)print("epoch:{}, train_loss:{}, train_acc:{}, valid_loss:{}, valid_acc:{}, test_loss:{}, test_acc:{}".format(epoch_i, train_loss, train_acc, valid_loss, valid_acc, test_loss, test_acc))valid_losses.append(valid_loss)test_accuracies.append(test_acc)wandb.log(({"train_loss": train_loss,"valid_loss": valid_loss,"train_acc": train_acc,"train_corr": train_corr,"valid_acc": valid_acc,"valid_corr": valid_corr,"test_loss": test_loss,"test_acc": test_acc,"test_mae": test_mae,"test_corr": test_corr,"test_f_score": test_f_score,"non0_test_acc": non0_test_acc,"non0_test_f_score": non0_test_f_score,"best_valid_loss": min(valid_losses),"best_test_acc": max(test_accuracies),}))# 输出测试集的 id、真实标签和预测标签with torch.no_grad():for step, batch in enumerate(test_data_loader):batch = tuple(t.to(DEVICE) for t in batch)input_ids, visual_ids, acoustic_ids, pos_ids, senti_ids, polarity_ids, visual, acoustic, input_mask, segment_ids, label_ids = batchvisual = torch.squeeze(visual, 1)outputs = model(input_ids,visual,acoustic,visual_ids,acoustic_ids,pos_ids, senti_ids, polarity_ids,token_type_ids=segment_ids,attention_mask=input_mask,labels=None,)logits = outputs[0]logits = logits.detach().cpu().numpy()label_ids = label_ids.detach().cpu().numpy()logits = np.squeeze(logits).tolist()label_ids = np.squeeze(label_ids).tolist()# 假设您从 label_ids 中获取 idsids = [f"sample_{idx}" for idx in range(len(label_ids))] # 这里是示例,您可以根据实际情况生成合适的 ids# 输出所有测试样本的 id、真实标签和预测标签值for i in range(len(ids)):print(f"id: {ids[i]}, true label: {label_ids[i]}, predicted label: {logits[i]}")6. 开始训练+测试
python train.py7. 输出结果
此外,为了方便直观地查看模型性能,我在最后一层训练结束后将所有测试集视频的clip id、真实标签和预测标签依次进行输出;并且结合wandb库自动保存结果可视化;结果在后续章节展示。
演示结果
训练过程
模型性能结果
接下来是我们自己补充的每个测试集的真实标签和预测标签
可视化
文章代码资源点击附件获取
相关文章:
CENet及多模态情感计算实战(论文复现)
CENet及多模态情感计算实战(论文复现) 本文所涉及所有资源均在传知代码平台可获取 文章目录 CENet及多模态情感计算实战(论文复现)概述研究背景主要贡献论文思路主要内容和网络架构数据集介绍性能对比复现过程(重要&am…...
备战秋招60天算法挑战,Day34
题目链接: https://leetcode.cn/problems/coin-change/ 视频题解: https://www.bilibili.com/video/BV1qsvDeHEkg/ LeetCode 322.零钱兑换 题目描述 给你一个整数数组coins,表示不同面额的硬币;以及一个整数amount,表…...
vue实现评论滚动效果
vue插件实现滚动效果 一、安装组件 官网地址:https://chenxuan0000.github.io/vue-seamless-scroll/ 1、vue2安装 npm install vue-seamless-scroll --savevue3安装 npm install vue3-seamless-scroll --save二、组件引入 <template><div v-if"…...
iphone13 不升级IOS使用广电卡
iPhone13的信号📶,15系统刷高版本iPCC,本帖以后不再更新!!! 自从知道可以通过刷iPCC的方式改善信号(不更新iOS大版本的情况下),尝试了各种版本。 我自己用下来总结 - 移动联通48、49、50 &…...
网络地址转换
文章目录 1. NAT使用环境2. NAT的优缺点3. NAT的三种类型4. NAT工作原理5. 配置示例6. 常用排错命令 1. NAT使用环境 需要连接到互联网,但主机没有全局唯一的IP地址;更换的ISP的要求对网络进行重新编址;需要合并两个使用相同编址方案的内联网…...
【python】 @property属性详解 and mysql的sqlalchemy的原生sql
【python】 property属性详解 一文搞懂python中常用的装饰器(classmethod、property、staticmethod、abstractmethod…) sqlalchemy的原生sql from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker# 数据库连接字符串 DATAB…...
西门子WinCC开发笔记(一):winCC西门子组态软件介绍、安装
文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/142060535 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、Op…...
如何在5个步骤中编写更好的ChatGPT提示
ChatGPT是一个风靡全球的生成式人工智能 (AI) 工具。虽然它有可能编造一些东西,但是通过精心设计提示,可以确保获得最佳结果。在这篇文章中,我们将探讨如何做到这一点。 在本文中,我将向你展示如何编写提示,激励驱动C…...
最小堆最大堆
文章目录 最小堆、最大堆最小堆(Min-Heap)最大堆(Max-Heap)堆的主要操作及时间复杂度堆的常见应用堆的数组表示大根堆--堆排序 最小堆、最大堆 最小堆(Min-Heap)和最大堆(Max-Heap)…...
华为 HCIP-Datacom H12-821 题库 (10)
有需要题库的可以看主页置顶 V群进行学习交流 1.缺省情况下,BGP 对等体邻接关系的保持时间是多少秒? A、120 秒 B、60 秒 C、10 秒 D、180 秒 答案:D 解析: BGP 存活消息每隔 60 秒发一次,保持时间“180 秒” 2.缺省…...
如何利用命令模式实现一个手游后端架构?
命令模式的原理解读 命令模式的英文翻译是 Command Design Pattern。在 GoF 的《设计模式》一书中,它是这么定义的: The command pattern encapsulates a request as an object, thereby letting us parameterize other objects with different reques…...
ThreadLocal 释放的方式有哪些
ThreadLocal基础概念:IT-BLOG-CN ThreadLocal是Java中用于在同一个线程中存储和隔离变量的一种机制。通常情况下,我们使用ThreadLocal来存储线程独有的变量,并在任务完成后通过remove方法清理这些变量,以防止内存泄漏。然而&…...
监控-zabbix
1运维监控 是指对计算机系统、网络、服务器等关键IT基础设施进行实时监控,以确保系统的稳定运行和及时发现潜在问题 2老监控框架(不会用但需要知道) Cacti: Cacti是一款基于PHP、MySQL开发的网络流量监测图形分析工具。主要监…...
设计模式 解释器模式(Interpreter Pattern)
文章目录 解释器模式简绍解释器模式的结构优缺点UML图具体代码实现Context 数据实体类,可以包含一些方法Abstract Expression 创建接口方法Terminal Expression 对数据简单处理Non-Terminal Expression 同样实现抽象接口方法Client(客户端) 调…...
Linux echo命令讲解及与重定向符搭配使用方法,tail命令及日志监听方式详解
echo echo具有回声,回响的意思,在linux系统中echo一般可以输出指定字符或用于命令执行 echo命令的用法为 echo 输出字符串 或 echo 命令 若参数为字符串则进行字符串输出,注意若字符串中含空格最好将其用引号括起,防止echo命…...
Linux网络:总结协议拓展
1. TCP/IP四层模型总结 2. 网络协议拓展 DNS协议(地址解析协议) TCP/IP使用IP地址和端口号来确定网络中一台主机的一个程序。 但是这样标定不方便记忆,于是开始引出主机名(字符串),使用hosts文件来描述…...
去除恢复出厂设置中UI文字显示
文章目录 需求场景 一、代码跟踪与分析在线文字搜索RK平台本地源码搜索实际测试验证代码推理 二、实现方案三、延伸知识四、知识总结 需求 需求:去除恢复出厂设置中UI文字显示 场景 Android 相关产品各种方向旋转、强制横竖屏等需求,导致在恢复出厂设…...
《高校教育管理》
《高校教育管理》为中文社会科学引文索引(CSSCI)来源期刊、北大中文核心期刊、RCCSE中国核心学术期刊、人大“复印报刊资料”重要转载来源期刊,是江苏大学主办,中国高等教育管理研究会协办的全国性高等教育管理专业期刊。 ISSN 1…...
全国计算机二级考试C语言篇4——选择题
运算符与表达式 1.赋值的正确写法 赋值操作是一个很常见的操作,但是赋值操作也有一些需要注意的地方。赋值操作是将一个表达式的值赋给一个变量的过程。在C语言中,赋值操作符是""。结合性从右到左,不控制求值顺序。 下面是几种C语言…...
数据结构————哈希表
哈希表(Hash table),也被称为散列表,是一种根据关键值(Key value)而直接进行访问的数据结构。它通过把关键值映射到表中的一个位置来访问记录,从而加快查找的速度。这个映射函数被称为散列函数或…...
接口测试中缓存处理策略
在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...
【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...
dedecms 织梦自定义表单留言增加ajax验证码功能
增加ajax功能模块,用户不点击提交按钮,只要输入框失去焦点,就会提前提示验证码是否正确。 一,模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…...
MMaDA: Multimodal Large Diffusion Language Models
CODE : https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA,它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构…...
江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命
在华东塑料包装行业面临限塑令深度调整的背景下,江苏艾立泰以一场跨国资源接力的创新实践,重新定义了绿色供应链的边界。 跨国回收网络:废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点,将海外废弃包装箱通过标准…...
tree 树组件大数据卡顿问题优化
问题背景 项目中有用到树组件用来做文件目录,但是由于这个树组件的节点越来越多,导致页面在滚动这个树组件的时候浏览器就很容易卡死。这种问题基本上都是因为dom节点太多,导致的浏览器卡顿,这里很明显就需要用到虚拟列表的技术&…...
Spring数据访问模块设计
前面我们已经完成了IoC和web模块的设计,聪明的码友立马就知道了,该到数据访问模块了,要不就这俩玩个6啊,查库势在必行,至此,它来了。 一、核心设计理念 1、痛点在哪 应用离不开数据(数据库、No…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
技术栈RabbitMq的介绍和使用
目录 1. 什么是消息队列?2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...
android13 app的触摸问题定位分析流程
一、知识点 一般来说,触摸问题都是app层面出问题,我们可以在ViewRootImpl.java添加log的方式定位;如果是touchableRegion的计算问题,就会相对比较麻烦了,需要通过adb shell dumpsys input > input.log指令,且通过打印堆栈的方式,逐步定位问题,并找到修改方案。 问题…...



