当前位置: 首页 > news >正文

生动理解深度学习精度提升利器——测试时增强(TTA)

测试时增强(Test-Time Augmentation,TTA)是一种在深度学习模型的测试阶段应用数据增强的技术手段。它是通过对测试样本进行多次随机变换或扰动,产生多个增强的样本,并使用这些样本进行预测的多数投票或平均来得出最终预测结果。

为了直观理解TTA执行的过程,这里我绘制了流程示意图如下所示:

TTA的过程如下:

  1. 数据增强:

    • 在测试时,对每个测试样本应用随机的变换或扰动操作,生成多个增强样本。
    • 常用的数据增强操作包括随机翻转、随机旋转、随机裁剪、随机缩放等。这些操作可以增加样本的多样性,模拟真实世界中的不确定性和变化。
  2. 多次预测:

    • 使用训练好的模型对生成的增强样本进行多次预测。
    • 对于每个增强样本,都会得到一个预测结果。
  3. 预测结果集成:

    • 对多次预测的结果进行集成,常用的集成方式有多数投票和平均。
    • 对于分类任务,多数投票即选择预测结果中出现次数最多的类别作为最终的预测类别。对于回归任务,平均即将多次预测结果进行平均。

接下来针对性地对比分析下使用TTA带来的优点和缺点:

优点:

  • 提高鲁棒性:通过应用数据增强,TTA可以增加样本的多样性和泛化能力,提高模型在面对未见过的输入分布和未知变化时的鲁棒性。
  • 提高准确性:通过多次预测和集成,TTA可以减少预测结果的随机性和偶然误差,提高最终预测结果的稳定性和准确性。
  • 模型评估和排名:TTA可以改变模型预测的不确定性,使得模型评估更可靠,能够更好地对不同模型进行性能排名。

缺点:

  • 计算开销:生成和预测多个增强样本会增加计算量。特别是在大型模型和复杂任务中,可能导致推理时间的显著增加,限制了TTA的实际应用。
  • 可能造成过拟合:对于已包含在训练数据中的变换或扰动,如果在测试时反复应用,可能会导致模型对这些特定样本的过拟合,从而影响模型的泛化能力。

TTA是一种常用的技术手段,通过应用数据增强和集成预测结果,可以提高深度学习模型在测试阶段的性能和鲁棒性。然而,TTA的应用需要平衡计算开销和预测准确性,并谨慎处理可能导致模型过拟合的问题。根据具体任务和需求,可以灵活选择合适的增强操作和集成策略来使用TTA。

下面是demo代码实现,如下所示:

import numpy as np
import torch
import torchvision.transforms as transformsdef test_time_augmentation(model, image, n_augmentations):# 定义数据增强的变换transform = transforms.Compose([transforms.ToTensor(),# 在此添加你需要的任何其他数据增强操作])# 存储多次预测结果的列表predictions = []# 对图像应用多次增强和预测for _ in range(n_augmentations):augmented_image = transform(image)augmented_image = augmented_image.unsqueeze(0)  # 增加一个维度作为批次with torch.no_grad():# 切换模型为评估模式,确保不执行梯度计算model.eval()# 使用增强的图像进行预测output = model(augmented_image)_, predicted = torch.max(output.data, 1)predictions.append(predicted.item())# 执行多数投票并返回最终预测结果final_prediction = np.bincount(predictions).argmax()return final_prediction

在前文鸟类细粒度识别项目实验中测试发现,应用TTA技术后,对应的评估指标上有明显的涨点,但是很明显地可以发现:在整个测试过程中资源消耗增加明显,且耗时显著增长,这也是TTA无法避免的劣势,在对精度要求较高的场景下可以有限考虑引入TTA,但是对于计算时耗要求较高的场景则不推荐使用TTA。

开源社区里面也有一些优秀的实现,这里推荐一个,地址在这里,如下所示:

目前有将近1k的star量,还是蛮不错的。

安装方法如下所示:

pip安装:
pip install ttach源码安装:
pip install git+https://github.com/qubvel/ttach
        Input|           # input batch of images / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)| | | | | | |     # pass augmented batches through model| | | | | | |     # reverse transformations for each batch of masks/labels\ \ \ / / /      # merge predictions (mean, max, gmean, etc.)|           # output batch of masks/labelsOutput

目前支持分割、分类、关键点检测三种任务,实例使用如下所示:

Segmentation model wrapping [docstring]:
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')Classification model wrapping [docstring]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())Keypoints model wrapping [docstring]:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
data transforms 实例实现如下所示:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose([tta.HorizontalFlip(),tta.Rotate90(angles=[0, 180]),tta.Scale(scales=[1, 2, 4]),tta.Multiply(factors=[0.9, 1, 1.1]),        ]
)tta_model = tta.SegmentationTTAWrapper(model, transforms)

Custom model (multi-input / multi-output)实现如下所示:

# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() # augment imageaugmented_image = transformer.augment_image(image)# pass to modelmodel_output = model(augmented_image, another_input_data)# reverse augmentation for mask and labeldeaug_mask = transformer.deaugment_mask(model_output['mask'])deaug_label = transformer.deaugment_label(model_output['label'])# save resultslabels.append(deaug_mask)masks.append(deaug_label)# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)

Transforms详情如下所示:

TransformParametersValues
HorizontalFlip--
VerticalFlip--
Rotate90anglesList[0, 90, 180, 270]
Scalescales
interpolation
List[float]
"nearest"/"linear"
Resizesizes
original_size
interpolation
List[Tuple[int, int]]
Tuple[int,int]
"nearest"/"linear"
AddvaluesList[float]
MultiplyfactorsList[float]
FiveCropscrop_height
crop_width
int
int

支持的结果融合方法如下:

mean
gmean (geometric mean)
sum
max
min
tsharpen (temperature sharpen with t=0.5)

相关文章:

生动理解深度学习精度提升利器——测试时增强(TTA)

测试时增强(Test-Time Augmentation,TTA)是一种在深度学习模型的测试阶段应用数据增强的技术手段。它是通过对测试样本进行多次随机变换或扰动,产生多个增强的样本,并使用这些样本进行预测的多数投票或平均来得出最终预…...

Redis基础知识(四):使用redis-cli命令测试状态

文章目录 测试Redis服务是否启动查看Redis数据库运行状态 Redis是一款开源的高性能键值数据库,具有快速、灵活、高效、稳定的特点,广泛应用于互联网领域。在开发过程中,我们需要通过测试Redis的状态来保证其正常运行,这就需要使用…...

【web开发】4、JavaScript与jQuery

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 一、JavaScript与jQuery二、JavaScript常用的基本功能1.插入位置2.注释3.变量4.数组5.滚动字符 三、jQuery常用的基本功能1.引入jQuery2.寻找标签3.val、text、appe…...

关于el-date-picker组件修改输入框以及下拉框的样式

因为业务需求,从element plus直接拿过来的组件样式和整体风格不搭,所以要修改样式,直接deep修改根本不生效,最后才发现el-date-picker组件有一个popper-class属性,通过这个属性我们就能够修改下拉框的样式,…...

JSCPC f ( 期望dp

#include <bits/stdc.h> using namespace std; using VI vector<int>; double dp[2000010]; int n; string s; //可能要特判 b 1的情况 //有 a 个 材料 ,每 b 个 合成一个&#xff0c;俩种方案&#xff0c; //1 . 双倍产出 p //2 . 返还材料 q int a,b; double …...

Django(10)-项目实战-对发布会管理系统进行测试并获取测试覆盖率

在发布会签到系统中使用django开发了发布会签到系统, 本文对该系统进行测试。 django.test django.test是Django框架中的一个模块,提供了用于编写和运行测试的工具和类。 django.test模块包含了一些用于测试的类和函数,如: TestCase:这是一个基类,用于编写Django测试用…...

ABB机器人10106故障报警(维修时间提醒)的处理方法

ABB机器人10106故障报警&#xff08;维修时间提醒&#xff09;的处理方法 故障原因&#xff1a; ABB机器人智能周期保养维护提醒&#xff0c;用于提示用户对机器人进行必要的保养和检修。 处理方法&#xff1a; 完成对应的保养和检修后&#xff0c;要进行一个操作&#xf…...

性能测试 —— 吞吐量和并发量的关系? 有什么区别?

吞吐量&#xff08;Throughput&#xff09;和并发量&#xff08;Concurrency&#xff09;是性能测试中常用的两个指标&#xff0c;它们描述了系统处理能力的不同方面。 吞吐量&#xff08;Throughput&#xff09; 是指系统在单位时间内能够处理的请求数量或事务数量。它常用于…...

Fastjson反序列化漏洞

文章目录 一、概念二、Fastjson-历史漏洞三、漏洞原理四、Fastjson特征五、Fastjson1.2.47漏洞复现1.搭建环境2.漏洞验证&#xff08;利用 dnslog&#xff09;3.漏洞利用1)Fastjson反弹shell2)启动HTTP服务器3)启动LDAP服务4)启动shell反弹监听5)Burp发送反弹shell 一、概念 啥…...

AI 帮我写代码——Amazon CodeWhisperer 初体验

文章作者&#xff1a;游凯超 人工智能的突破和变革正在深刻地改变我们的生活。从智能手机到自动驾驶汽车&#xff0c;AI 的应用已经深入到我们生活的方方面面。而在编程领域&#xff0c;AI 的崭新尝试正在开启一场革命。Amazon CodeWhisperer&#xff0c;作为亚马逊云科技的一款…...

实训笔记9.1

实训笔记9.1 9.1笔记一、项目开发流程一共分为七个阶段1.1 数据产生阶段1.2 数据采集存储阶段1.3 数据清洗预处理阶段1.4 数据统计分析阶段1.5 数据迁移导出阶段1.6 数据可视化阶段 二、项目的数据产生阶段三、项目的数据采集存储阶段四、项目数据清洗预处理的实现4.1 清洗预处…...

汽车SOA架构

文章目录 一、汽车SOA架构的基本概念二、汽车SOA架构的优势三、从设计、开发和测试方面介绍汽车SOA架构四、SOA技术在汽车行业的应用 汽车SOA架构是指汽车软件架构采用面向服务的架构&#xff08;Service-Oriented Architecture&#xff0c;简称SOA&#xff09;的设计模式。SOA…...

L1-017 到底有多二 C++解法

题目 一个整数“犯二的程度”定义为该数字中包含2的个数与其位数的比值。如果这个数是负数&#xff0c;则程度增加0.5倍&#xff1b;如果还是个偶数&#xff0c;则再增加1倍。例如数字-13142223336是个11位数&#xff0c;其中有3个2&#xff0c;并且是负数&#xff0c;也是偶数…...

motionface respeak视频一键对口型

语音驱动视频唇部动作和视频对口型是两项不同的技术&#xff0c;但是它们都涉及到将语音转化为视觉效果。 语音驱动视频唇部动作&#xff08;语音唇同步&#xff09;&#xff1a; 语音驱动视频唇部动作是一种人工智能技术&#xff0c;它可以将语音转化为实时视频唇部动作。这…...

LeetCode——顺时针打印矩形

题目地址 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 题目解析 按照顺时针一次遍历&#xff0c;遍历外外层遍历里层。 代码如下 class Solution { public:vector<int> spiralOrder(vector<vector<int>>& matrix) {if(…...

C语言课程作业

本科期间c语言课程作业代码整理&#xff1a; Josephus链表实现 Josephus 层序遍历树 二叉树的恢复 哈夫曼树 链表的合并 中缀表达式 链接&#xff1a;https://pan.baidu.com/s/1Q7d-LONauNLi7nJS_h0jtw?pwdswit 提取码&#xff1a;swit...

Yolov8魔术师:卷积变体大作战,涨点创新对比实验,提供CVPR2023、ICCV2023等改进方案

&#x1f4a1;&#x1f4a1;&#x1f4a1;本文独家改进&#xff1a;提供各种卷积变体DCNV3、DCNV2、ODConv、SCConv、PConv、DynamicSnakeConvolution、DAT&#xff0c;引入CVPR2023、ICCV2023等改进方案&#xff0c;为Yolov8创新保驾护航&#xff0c;提供各种科研对比实验 &am…...

基于小波神经网络的空气质量预测,基于小波神经网络的PM2.5预测,基于ANN的PM2.5预测

目标 背影 BP神经网络的原理 BP神经网络的定义 BP神经网络的基本结构 BP神经网络的神经元 BP神经网络的激活函数, BP神经网络的传递函数 小波神经网络(以小波基为传递函数的BP神经网络) 代码链接:基于小波神经网络的PM2.5预测,ann神经网络pm2.5预测资源-CSDN文库 https:/…...

Vue / Vue CLI / Vue Router / Vuex / Element UI

Vue Vue是一种流行的JavaScript前端框架&#xff0c;用于构建用户界面 它被设计为易于学习和使用&#xff0c;并且具有响应式的数据绑定和组件化的架构 Vue具有简洁的语法和灵活的功能&#xff0c;可以帮助开发人员构建高效、可扩展的Web应用程序 它也有一个大型的生态系统和活…...

Lesson4-2:OpenCV图像特征提取与描述---Harris和Shi-Tomas算法

学习目标 理解Harris和Shi-Tomasi算法的原理能够利用Harris和Shi-Tomasi进行角点检测 1 Harris角点检测 1.1 原理 H a r r i s Harris Harris角点检测的思想是通过图像的局部的小窗口观察图像&#xff0c;角点的特征是窗口沿任意方向移动都会导致图像灰度的明显变化&#xff…...

相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解

【关注我&#xff0c;后续持续新增专题博文&#xff0c;谢谢&#xff01;&#xff01;&#xff01;】 上一篇我们讲了&#xff1a; 这一篇我们开始讲&#xff1a; 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下&#xff1a; 一、场景操作步骤 操作步…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; 目前2025年06月05日更新到&#xff1a; AI炼丹日志-28 - Aud…...

SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)

上一章用到了V2 的概念&#xff0c;其实 Fiori当中还有 V4&#xff0c;咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务)&#xff0c;代理中间件&#xff08;ui5-middleware-simpleproxy&#xff09;-CSDN博客…...

Reasoning over Uncertain Text by Generative Large Language Models

https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829 1. 概述 文本中的不确定性在许多语境中传达,从日常对话到特定领域的文档(例如医学文档)(Heritage 2013;Landmark、Gulbrandsen 和 Svenevei…...

免费数学几何作图web平台

光锐软件免费数学工具&#xff0c;maths,数学制图&#xff0c;数学作图&#xff0c;几何作图&#xff0c;几何&#xff0c;AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...

医疗AI模型可解释性编程研究:基于SHAP、LIME与Anchor

1 医疗树模型与可解释人工智能基础 医疗领域的人工智能应用正迅速从理论研究转向临床实践,在这一过程中,模型可解释性已成为确保AI系统被医疗专业人员接受和信任的关键因素。基于树模型的集成算法(如RandomForest、XGBoost、LightGBM)因其卓越的预测性能和相对良好的解释性…...

代理服务器-LVS的3种模式与调度算法

作者介绍&#xff1a;简历上没有一个精通的运维工程师。请点击上方的蓝色《运维小路》关注我&#xff0c;下面的思维导图也是预计更新的内容和当前进度(不定时更新)。 我们上一章介绍了Web服务器&#xff0c;其中以Nginx为主&#xff0c;本章我们来讲解几个代理软件&#xff1a…...

dvwa11——XSS(Reflected)

LOW 分析源码&#xff1a;无过滤 和上一关一样&#xff0c;这一关在输入框内输入&#xff0c;成功回显 <script>alert(relee);</script> MEDIUM 分析源码&#xff0c;是把<script>替换成了空格&#xff0c;但没有禁用大写 改大写即可&#xff0c;注意函数…...

Xcode 16.2 版本 pod init 报错

Xcode 版本升级到 16.2 后&#xff0c;项目执行 pod init 报错&#xff1b; ### Error RuntimeError - PBXGroup attempted to initialize an object with unknown ISA PBXFileSystemSynchronizedRootGroup from attributes: {"isa">"PBXFileSystemSynchron…...

【自然语言处理】大模型时代的数据标注(主动学习)

文章目录 A 论文出处B 背景B.1 背景介绍B.2 问题提出B.3 创新点 C 模型结构D 实验设计E 个人总结 A 论文出处 论文题目&#xff1a;FreeAL: Towards Human-Free Active Learning in the Era of Large Language Models发表情况&#xff1a;2023-EMNLP作者单位&#xff1a;浙江大…...