当前位置: 首页 > article >正文

【transformers.Trainer填坑】在自定义compute_metrics时logits和labels数据维度不一致问题

问题描述

我在使用 transformers.Trainer 训练我的模型时,我自定义了 compute_loss 函数和compute_metrics函数,我的模型是一个简单的二分类模型。

在自定义 compute_loss 时这样写的:

def compute_loss(self, model, inputs, return_outputs=False):"""重写 Trainer.compute_loss:1) 提取字典中的 images, bboxes, locs, labels 等2) 用 vision_encoder 先处理图像,得到特征3) 用下游 model 做预测4) 计算并返回 loss"""# 前向传播outputs, labels = model(**inputs)  # (bz, num_classes), or (bz*num_frames, num_classes)batch_size = inputs['labels'].shape[0]outputs = outputs.squeeze()  # (bz*num_frames)if batch_size == 1:outputs = outputs.unsqueeze(0)# 计算 lossloss = self.loss_func(outputs, labels.float())if self.state.global_step % 10 == 0 and self.state.global_step > 0:# 以50个step为间隔打印pred_probs = torch.sigmoid(outputs)preds = (pred_probs > 0.5).int()logger.info(f"[global_step={self.state.global_step}] preds={preds.tolist()} / labels={labels.tolist()} / loss={loss.item():.4f}")# compute metricaccuracy = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())precision = precision_score(labels.cpu().numpy(), preds.cpu().numpy())recall = recall_score(labels.cpu().numpy(), preds.cpu().numpy())logger.info(f"[global_step={self.state.global_step}] accuracy={accuracy:.4f} / precision={precision:.4f} / recall={recall:.4f}")# 返回 (loss, outputs) 或者只返回 lossreturn (loss, outputs) if return_outputs else loss

于是就出现了报错,像这样的:

File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3754, in predictoutput = eval_loop(File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3966, in evaluation_loopmetrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))File "/workspace/train/object_query/train.py", line 281, in compute_metricscorrect_num = preds == labels
ValueError: operands could not be broadcast together with shapes (11720,) (12104,)output = eval_loop(File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3966, in evaluation_loopmetrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))File "/workspace/train/object_query/train.py", line 281, in compute_metricscorrect_num = preds == labels
ValueError: operands could not be broadcast together with shapes (11720,) (12104,)

原因

该问题是 transformers.Trainer 内部有一段对outputs的操作造成的:

if isinstance(outputs, dict):logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:logits = outputs[1:]

这里当 outputs 不是字典时,会把第一个位置的元素offset掉。

解决

Refer to here
所以,我们应该在返回那里这样写:

return (loss, {"label": outputs}) if return_outputs else loss

相关文章:

【transformers.Trainer填坑】在自定义compute_metrics时logits和labels数据维度不一致问题

问题描述 我在使用 transformers.Trainer 训练我的模型时,我自定义了 compute_loss 函数和compute_metrics函数,我的模型是一个简单的二分类模型。 在自定义 compute_loss 时这样写的: def compute_loss(self, model, inputs, return_outp…...

Django创建超管用户

在 Django 中创建超级用户(superuser)可以通过命令行工具 createsuperuser 完成。以下是具体步骤: 1. 确保已进行数据库迁移 在创建超级用户前,确保已执行数据库迁移: python manage.py migrate 2. 创建超级用户 …...

vue3实战-----集成sass

vue3实战-----集成sass 1.安装2.使用3.全局样式文件中不能使用变量 1.安装 在使用scss之前需要安装sass和sass-loader两个插件。 2.使用 安装好之后就可以在组件中使用scss了。需要加上lang“scss”。 注意:scss中变量用$,less中变量用。 3.全局样式文件中不能使用变量 …...

二分查找sql时间盲注,布尔盲注

目录 一:基础知识引导 数据库:information_schema里面记录着数据库的所有元信息 二,布尔盲注,时间盲注 (1)布尔盲注案例(以sqli-labs第八关为例): (2&am…...

计算机网络-MPLS转发原理

在上一篇关于 MPLS 基础的文章中,我们了解了 MPLS 的基本概念、术语以及它在网络中的重要性。今天,我们将深入探讨 MPLS 转发的原理与流程,帮助大家更好地理解 MPLS 是如何在实际网络中工作的。 一、MPLS 转发概述 MPLS 转发的本质是将数据…...

【设计模式】【行为型模式】职责链模式(Chain of Responsibility)

👋hi,我不是一名外包公司的员工,也不会偷吃茶水间的零食,我的梦想是能写高端CRUD 🔥 2025本人正在沉淀中… 博客更新速度 👍 欢迎点赞、收藏、关注,跟上我的更新节奏 🎵 当你的天空突…...

【H5自适应】高端科技类pbootcms网站模板 – 三级栏目、下载与招聘功能支持

(H5自适应)高端大气的科技类pbootcms网站模板 带三级栏目、下载和招聘功能 后台地址:您的域名/admin.php 后台账号:admin 后台密码:123456 为了提升系统安全,请将后台文件admin.php的文件名修改一下。修改之后,后台…...

【Java 面试 八股文】框架篇

框架篇 1. Spring框架中的单例bean是线程安全的吗?2. 什么是AOP?3. 你们项目中有没有使用到AOP?4. Spring中的事务是如何实现的?5. Spring中事务失效的场景有哪些?6. Spring的bean的生命周期?7. Spring中的…...

原型模式详解(Java)

原型模式(Prototype Pattern),作为一种极具代表性的创建型设计模式,其核心思想在于通过复制,亦即克隆现有的对象,来达成创建新对象的目的,而非依赖传统的构造函数途径。这一模式巧妙地基于现有对…...

TCP拥塞控制机制

TCP拥塞控制机制是TCP协议中至关重要的一部分,用于防止网络出现拥塞,保证网络的高效、稳定运行 拥塞控制的基本概念 拥塞:在计算机网络中,拥塞是指当网络中存在过多的分组时,网络性能下降的现象,如延迟增…...

自动化UI测试 | 什么是测试驱动开发(TDD)和行为驱动开发(BDD)?有何区别?

TDD(测试驱动开发)和BDD(行为驱动开发)是两种独特的软件开发技术,它们在测试的内容和方式上有所不同。尽管名称相似,但服务于不同的目的。 什么是TDD? TDD代表测试驱动开发。它是一个过程&…...

DeepSeek 助力 Vue 开发:打造丝滑的进度条

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 Deep…...

一场始于 Selector Error 的拯救行动:企查查数据采集故障排查记

时间轴呈现事故进程 17:00:开发人员小李正在尝试利用 Python 爬虫从企查查(https://www.qcc.com)抓取公司工商信息。原本一切正常,但突然发现信息采集失败,程序抛出大量选择器错误。17:15:小李发现&#x…...

微信服务号推送消息

这里如果 没有 就需要点新的功能去申请一下 申请成功之后就可以设置模版消息 推送到用户接受的页面是 需要后端调用接口 传递token 发送给客户...

24电子信息类研究生复试面试问题汇总 电子信息类专业知识问题最全!电子信息复试全流程攻略 电子信息考研复试真题汇总

你是不是在为电子信息考研复试焦虑?害怕被老师问到刁钻问题、担心专业面答不上来?别慌!作为复试面试92分逆袭上岸的学姐,今天手把手教你拆解电子信息类复试通关密码!看完这篇,让你面试现场直接开大&#xf…...

嵌入式EasyRTC实时通话支持海思hi3516cv610,编译器arm-v01c02-linux-musleabi-gcc

EasyRTC已经完美支持海思hi3516cv610,编译器arm-v01c02-linux-musleabi-gcc,总体SDK大小控制在680K以内(预计还能压缩100K上下): EasyRTC在hi3516cv610芯片上能双向通话、发送文字以及二进制指令,总体运行…...

如何搭建Wi-Fi CVE漏洞测试环境:详细步骤与设备配置

引言 随着Wi-Fi技术的普及,Wi-Fi网络成为了现代通信的重要组成部分。然而,Wi-Fi网络的安全性始终是一个备受关注的话题。通过漏洞扫描和安全测试,网络管理员可以及早发现并修复Wi-Fi设备中存在的安全隐患。本篇文章将详细介绍如何搭建Wi-Fi …...

sqlalchemy 使用fetchmany 报错 KeyError 或 AttributeError

问题 我遇到的问题是 AttributeError: Could not locate column in row for column xxxx 解决 首先看你定义的模型类是否缺失了相关的字段 Column XXX not found._clould not locate column in row for column-CSDN博客 其次 rows result.fetchmany(1000) for (row,) i…...

计算机视觉中图像的基础认知

一、图像/视频的基本属性 在计算机视觉中,图像和视频的本质是多维数值矩阵。图像或视频数据的一些基本属性。 宽度(W) 和 高度(H) 定义了图像的像素分辨率,单位通常是像素。例如,一张 1920x10…...

Docker Desktop WebAPI《1》

方法1 》》生成 的文档不要动, 》》执行 Container(Dockerfile) 会生成镜像文件和容器 》》生成的镜像和容器 在 Docker Desktop 中可以查看 用VS 的 Container Dockerfile 调试 但把这个调试工工具 停止,WebAPi就不能访问了 …...

ELK安装部署同步mysql数据

ELK 安装部署指南 ELK 是 Elasticsearch、Logstash 和 Kibana 的简称,用于日志收集、存储、分析和可视化。 1. 安装 Elasticsearch Elasticsearch 是一个分布式搜索和分析引擎。 1.1 下载并安装 访问 Elasticsearch 官网 下载最新版本。 解压并安装: tar…...

《OpenCV》——特征提取与匹配方法

特征提取 特征提取是从原始数据中提取出能够代表数据本质特征和关键信息的过程,在很多领域都有广泛应用。原始数据往往包含大量的冗余信息,特征提取的目的是去除这些冗余,提取出最具代表性、最能区分不同类别或模式的特征,从而降…...

Oracle DBA 诊断及统计工具-2

Oracle 数据表空间和索引表空间的资源分配比例总结 在 Oracle 数据库中,数据表空间和索引表空间并没有固定的资源分配比例,其分配需要综合考虑多种因素,以下是详细分析不同场景下的分配建议以及具体的分配思路。 影响分配比例的因素 数据读…...

如何使用DHTMLX Scheduler的拖放功能,在 JS 日程安排日历中创建一组相同的事件

DHTMLX Scheduler 是一个全面的调度解决方案,涵盖了与规划事件相关的广泛需求。假设您在我们的 Scheduler 文档中找不到任何功能,并且希望在我们的 Scheduler 文档中看到您的项目。在这种情况下,很可能可以使用自定义解决方案来实现此类功能。…...

​矩阵元素的“鞍点”​

题意: 一个矩阵元素的“鞍点”是指该位置上的元素值在该行上最大、在该列上最小。 本题要求编写程序,求一个给定的n阶方阵的鞍点。 输入格式: 输入第一行给出一个正整数n(1≤n≤6)。随后n行,每行给出n个整数…...

Qt的isVisible ()函数介绍和判断窗口是否在当前界面显示

1、现象:当Qt的窗口最小化时,isVisible值一定是true,这是正常的。 解释:在Qt中,当你点击窗口的最小化按钮时,Qt内部不会自动调用 hide() 方或 setVisible(false) 来隐藏窗口。相反,它会改变窗口…...

Unity-Mirror网络框架-从入门到精通之LagCompensation示例

文章目录 前言什么是滞后补偿Lag Compensation示例延迟补偿原理ServerCubeClientCubeCapture2DSnapshot3D补充LagCompensation.cs 独立算法滞后补偿器组件注意:算法最小示例前言 在现代游戏开发中,网络功能日益成为提升游戏体验的关键组成部分。本系列文章将为读者提供对Mir…...

Jenkins 通过 Execute Shell 执行 shell 脚本 七

Jenkins 通过 Execute Shell 执行 shell 脚本 七 一、创建 .sh 文件 项目目录下新建 .sh 文件 jenkins-script\shell\ci_android_master.sh添加 Execute Shell 模块 在 Command 中添加 # 获取 .sh 路径 CI_ANDROID_MASTER_PATH"${WORKSPACE}/jenkins-script/shell/…...

PyCharm 批量替换

选择替换的内容 1. 打开全局替换窗口 有两种方式可以打开全局替换窗口: 快捷键方式: 在 Windows 或 Linux 系统下,按下 Ctrl Shift R。在 Mac 系统下,按下 Command Shift R。菜单操作方式:点击菜单栏中的 Edit&…...

Linux-文件基本操作

1.基本概念 文件: 一组相关数据的集合 文件名: 01.sh //文件名 2.linux下的文件类型 b block 块设备文件 eg: 硬盘 c character 字符设备文件 eg: 鼠标,键盘 d directory 目录文件 eg: 文件夹 - regular 常规文件…...