当前位置: 首页 > news >正文

OhemCrossEntropyLoss

1. Ohem Cross Entropy Loss 的定义

OhemCrossEntropyLoss 是一种用于深度学习中目标检测任务的损失函数,它是针对不平衡数据分布和困难样本训练的一种改进版本的交叉熵损失函数。Ohem 表示 “Online Hard Example Mining”,意为在线困难样本挖掘。在目标检测任务中,由于背景类样本通常远远多于目标类样本,导致了数据分布的不平衡问题,而且一些困难的样本对于网络的训练很有挑战性。OhemCrossEntropyLoss 就是为了解决这些问题而设计的。

这个损失函数的核心思想是在训练过程中只选择那些具有较高损失值的困难样本进行梯度更新,从而更加关注于难以分类的样本,有助于网络更好地适应这些样本,提高模型的性能。

数学上,OhemCrossEntropyLoss 的定义可以用以下公式表示:

OhemCrossEntropyLoss = − 1 N ∑ i = 1 N { log ( p target ) if  y target = 1 (目标类样本) log ( 1 − p target ) if  y target = 0 (背景类样本且损失高于阈值) 0 otherwise \text{OhemCrossEntropyLoss} = - \frac{1}{N} \sum_{i=1}^{N} \begin{cases} \text{log}(p_{\text{target}}) & \text{if } y_{\text{target}} = 1 \text{ (目标类样本)} \\ \text{log}(1 - p_{\text{target}}) & \text{if } y_{\text{target}} = 0 \text{ (背景类样本且损失高于阈值)} \\ 0 & \text{otherwise} \end{cases} OhemCrossEntropyLoss=N1i=1N log(ptarget)log(1ptarget)0if ytarget=1 (目标类样本)if ytarget=0 (背景类样本且损失高于阈值)otherwise

其中, N N N 是 Batch 中样本的数量, p target p_{\text{target}} ptarget 是模型预测目标类的概率, y target y_{\text{target}} ytarget 是真实标签(1 表示目标类,0 表示背景类),损失计算根据标签的情况进行不同的处理。背景类样本中损失值高于一个预定义的阈值的样本会被选中进行梯度更新,这样网络更关注于难以分类的样本,有助于提高性能。

需要注意的是,OhemCrossEntropyLoss 需要在训练过程中动态地筛选困难样本,所以相比于传统的交叉熵损失,它的计算相对复杂。但在处理不平衡数据和困难样本时,它能够提升模型的鲁棒性和泛化能力。

2. OHEM 步骤流程

  1. 给 OhemCE Loss 取一个阈值 thresh

    • 那么该像素点的预测概率 > 0.7,则该像素点可以看成是简单样本,不参与损失计算
    • 那么该像素点的预测概率 < 0.7,则该像素点可以看成是困难样本,参与损失计算
  2. 确定忽略的像素点值 lb_ignore:一般我们将背景的值设置为 255,即如果像素点值的大小是 255,那么就不参与损失计算。

  3. 设置最少计算的像素点个数 n_min:至少有 n_num 个像素点参与损失计算(不然网络有可能停止更新了)。

简单来说:OHEM CrossEntropy Loss 的目的是:挖掘困难样本;忽略简单样本

3. 代码实现

import random
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nndef setup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)class OhemCELoss(nn.Module):def __init__(self, thresh, lb_ignore=255, ignore_simple_sample_factor=16):"""Args:thresh: 阈值,超过该值则被算法简单样本 -> 不参与Loss计算lb_ignore: 忽略的像素值(一般255代表背景), 不参与损失的计算ignore_simple_sample_factor: 忽略简单样本的系数该系数越大,最少计算的像素点个数越少该系数越小,最少计算的像素点个数越多"""super(OhemCELoss, self).__init__()"""这里的 thresh 和 self.thresh 不是一回儿事儿①预测概率 > thresh -> 简单样本①预测概率 < thresh -> 困难样本②损失值 > self.thresh -> 困难样本②损失值 < self.thresh -> 简单①和②其实是一回儿事儿,但 thresh 和 self.thresh 不是一回儿事儿"""self.thresh = -torch.log(input=torch.tensor(thresh, requires_grad=False, dtype=torch.float))self.lb_ignore = lb_ignoreself.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none')self.ignore_simple_sample_factor = ignore_simple_sample_factor"""reduction 参数用于控制损失的计算方式和输出形式。它有三种可选的取值:1. 'none':当设置为 'none' 时,损失将会逐个样本计算,返回一个与输入张量相同形状的损失张量。这意味着输出的损失张量的形状与输入的标签张量相同,每个位置对应一个样本的损失值。2. 'mean':当设置为 'mean' 时,损失会对逐个样本计算的损失进行求均值,得到一个标量值。即计算所有样本的损失值的平均值。3. 'sum' : 当设置为 'sum'  时,损失会对逐个样本计算的损失进行求和,得到一个标量值。即计算所有样本的损失值的总和。在语义分割任务中,通常使用 ignore_index 参数来忽略某些特定标签,例如背景类别。当计算损失时,将会忽略这些特定标签的损失计算,以避免这些标签对损失的影响。如果设置了 ignore_index 参数,'none' 的 reduction 参数会很有用,因为它可以让你获取每个样本的损失,包括被忽略的样本。总之,reduction 参数允许在计算损失时控制输出形式,以满足不同的需求。"""def forward(self, logits, labels):# 1. 计算 n_min(至少算多少个像素点)n_min = labels[labels != self.lb_ignore].numel() // self.ignore_simple_sample_factor# 2. 使用 CrossEntropy 计算损失, 之后再将其展平loss = self.criteria(logits, labels).view(-1)# 3. 选出所有loss中大于self.thresh的像素点 -> 困难样本loss_hard = loss[loss > self.thresh]# 4. 如果总数小于 n_min, 那么肯定要保证有 n_min 个像素点的 lossif loss_hard.numel() < n_min:loss_hard, _ = loss.topk(n_min)# 5. 如果参与的像素点的个数 > n_min 个,那么这些点都参与计算loss_hard_mean = torch.mean(loss_hard)# 6. 返回损失的均值return loss_hard_meanif __name__ == "__main__":setup_seed(20)# 1. 生成预测值(假设我们有两个样本,每个样本有 3 个类别,高度和宽度均为 4)logits = Variable(torch.randn(2, 3, 4, 4))  # [N, C, H, W], s.t. C <-> num_classes# 2. 生成真实标签(每个样本的标签是一个 4x4 的图像)labels = Variable(torch.randint(low=0, high=3, size=(2, 4, 4)))  # [N, H, W]# 3. 初始化:创建 OhemCELoss 的实例,阈值设置为 0.7ohem_criterion = OhemCELoss(thresh=0.7, lb_ignore=255, ignore_simple_sample_factor=16)# 4. 计算 Ohem 损失loss = ohem_criterion(logits, labels)print(f"Ohem Loss: {loss.item()}")  # Ohem Loss: 1.3310734033584595

知识来源

  1. https://www.bilibili.com/video/BV12841117yo
  2. https://www.bilibili.com/video/BV1Um4y1L753

相关文章:

OhemCrossEntropyLoss

1. Ohem Cross Entropy Loss 的定义 OhemCrossEntropyLoss 是一种用于深度学习中目标检测任务的损失函数&#xff0c;它是针对不平衡数据分布和困难样本训练的一种改进版本的交叉熵损失函数。Ohem 表示 “Online Hard Example Mining”&#xff0c;意为在线困难样本挖掘。在目…...

prometheusalert区分告警到不同钉钉群

方法一 修改告警规则 - alert: cpu使用率大于88%expr: instance:node_cpu_utilization:ratio * 100 > 88for: 5mlabels:severity: criticallevel: 3kind: CpuUsageannotations:summary: "cpu使用率大于85%"description: "主机 {{ $labels.hostname }} 的cp…...

AUTOSAR规范与ECU软件开发(实践篇)3.2 ETAS AUTOSAR系统解决方案介绍(上)

1、ETAS AUTOSAR系统解决方案介绍 博世集团ETAS公司基于其强大的研发实力为用户提供了一套高效、 可靠的AUTOSAR系统解决方案&#xff0c; 该方案覆盖了软件架构设计、 应用层模型设计、 基础软件开发、 软件虚拟验证等各个方面&#xff0c; 如图3.5所示&#xff0c; 其中深色…...

【leetcode】第三章 哈希表part02

454.四数相加II public int fourSumCount(int[] nums1, int[] nums2, int[] nums3, int[] nums4) {HashMap<Integer,Integer> map new HashMap<>();// 统计频率for (int i 0; i < nums1.length; i) {for (int j 0; j < nums2.length; j) {int num nums1…...

【C语言】memset()函数

一.memset()函数简介 我们先来看一下cplusplus.com - The C Resources Network网站上memset()函数的基本信息&#xff1a; 1.函数功能 memset()函数的功能是:将一块内存空间的每个字节都设置为指定的值。 这个函数通常用于初始化一个内存空间&#xff0c;或者清空一个内存空间…...

C++中重载(overload)、重写(override,也叫做“覆盖”)和重定义(redefine,也叫作“隐藏”)的区别?

在C中&#xff0c;允许在同一作用域中的某个函数和运算符指定多个定义&#xff0c;分别称为函数重载和运算符重载。 重载声明是指一个与之前已经在该作用域内声明过的函数或方法具有相同名称的声明&#xff0c;但是它们的参数列表和定义&#xff08;实现&#xff09;不相同。 …...

将非受信数据作为参数传入,可能引起xml 注入,引起数据覆盖,这个问题咋解决

目录 1 解决 1 解决 当将非受信数据作为参数传入时&#xff0c;确实存在XML注入&#xff08;XML Injection&#xff09;的风险&#xff0c;攻击者可以通过构造恶意的XML数据来修改XML文档结构或执行意外的操作。为了解决这个问题&#xff0c;你可以采取以下措施&#xff1a; 输…...

设计模式-简单工厂模式

简单工厂模式又称为静态工厂模式&#xff0c;其实就是根据传入参数创建对应具体类的实例并返回实例对象&#xff0c;这些类通常继承至同一个父类&#xff0c;该模式专门定义了一个类来负责创建其他类的实例。 using System.Collections; using System.Collections.Generic; us…...

Maven框架SpringBootWeb简单入门

一、Maven ★ Maven:是Apache旗下的一个开源项目,是一款用于管理和构建java项目的工具。 官网:https://maven.apache.org/ ★ Maven的作用: 1. 依赖管理:方便快捷的管理项目依赖的资源(jar包),避免版本冲突问题。 2. 统一项目结构:提供标准、统一的项目结构。 …...

关于2023年8月19日PMP认证考试准考信下载通知

各位考生: 为保证参加2023年8月19日PMI项目管理资格认证考试的每位考生都能顺利进入考场参加考试&#xff0c;请完整阅读本通知内容。 一、关于准考信下载 为确保您顺利进入考场参加8月份考试&#xff0c;请及时登录本网站&#xff08;https://event.chinapmp.cn/&#xff09…...

html实现iphone同款开关

一、背景 想实现一个开关的按钮&#xff0c;来触发一些操作&#xff0c;网上找了总感觉看着别扭&#xff0c;忽然想到iphone的开关挺好&#xff0c;搞一个 二、代码实现 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8&qu…...

使用Vue和jsmind如何实现思维导图的历史版本控制和撤销/重做功能?

思维导图是一种流行的知识图谱工具&#xff0c;可以帮助我们更好地组织和理解复杂的思维关系。在开发基于Vue的思维导图应用时&#xff0c;实现历史版本控制和撤销/重做功能是非常有用的。以下为您介绍如何使用Vue和jsmind插件来实现这些功能。 安装依赖 首先&#xff0c;我们…...

【Vue-Router】路由元信息

路由元信息&#xff08;Route Meta Information&#xff09;是在路由配置中为每个路由定义的一组自定义数据。这些数据可以包含任何你希望在路由中传递和使用的信息&#xff0c;比如权限、页面标题、布局设置等。Vue Router 允许你在路由配置中定义元信息&#xff0c;然后在组件…...

vue 控件的四个角设置 父视图position:relative

父视图relative&#xff0c;子视图 absolute <div class"bg1"> <i class"topL"></i> <i class"topR"></i> <i class"bottomL"></i> <i class"bottomR"></i> <di…...

VM中linux虚拟机配置桥接模式(虚拟机与宿主机网络互通)

VM虚拟机配置桥接模式&#xff0c;可以让虚拟机和物理主机一样存在于局域网中&#xff0c;可以和主机相通&#xff0c;和互联网相通&#xff0c;和局域网中其它主机相通。 vmware为我们提供了三种网络工作模式&#xff0c;它们分别是&#xff1a;Bridged&#xff08;桥接模式&…...

7.Eclipse中改变编码方式及解决部分乱码问题

1、改变整个工作空间的编码方式&#xff1a; 点击Window->Preference->General->workplace&#xff0c;然后选择默认编码方式 2、改变某个项目的编码方式&#xff1a; 右键点击项目名->Properties>Resource&#xff0c;然后选择默认编码方式。 问题&#xff…...

grafana 的 ws websocket 连接不上的解决方式

使用了多层的代理方式&#xff0c;一层没有此问题 错误 WebSocket connection to ‘wss://ip地址/grafana01/api/live/ws’ failed: 日志报错 msg“Request Completed” methodGET path/api/live/ws status403 解决方式 # allowed_origins is a comma-separated list of o…...

多环境_部署项目

多环境&#xff1a; 指同一套项目代码在不同的阶段需要根据实际情况来调整配置并且部署到不同的机器上。 为什么需要&#xff1f; 1. 每个环境互不影响 2. 区分不同的阶段&#xff1a;开发 / 测试 / 生产 3. 对项目进行优化&#xff1a; 1. 本地日志级别 2. 精简依赖&a…...

go web框架 gin-gonic源码解读02————router

go web框架 gin-gonic源码解读02————router 本来想先写context&#xff0c;但是发现context能简单讲讲的东西不多&#xff0c;就准备直接和router合在一起讲好了 router是web服务的路由&#xff0c;是指讲来自客户端的http请求与服务器端的处理逻辑或者资源相映射的机制。&…...

【Java后端封装数据】常见后端封装数据的格式,用于返回给前端使用(109)

数据格式一&#xff1a;包装 List Map 返回&#xff0c;常用于数据展示&#xff1b; // Controller&#xff1a;public Result selectRegConfig(RequestBody String param) {try {Map<String, Object> paramMap JsonUtils.readValue(param, Map.class);return Result.su…...

零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?

一、核心优势&#xff1a;专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发&#xff0c;是一款收费低廉但功能全面的Windows NAS工具&#xff0c;主打“无学习成本部署” 。与其他NAS软件相比&#xff0c;其优势在于&#xff1a; 无需硬件改造&#xff1a;将任意W…...

简易版抽奖活动的设计技术方案

1.前言 本技术方案旨在设计一套完整且可靠的抽奖活动逻辑,确保抽奖活动能够公平、公正、公开地进行,同时满足高并发访问、数据安全存储与高效处理等需求,为用户提供流畅的抽奖体验,助力业务顺利开展。本方案将涵盖抽奖活动的整体架构设计、核心流程逻辑、关键功能实现以及…...

python/java环境配置

环境变量放一起 python&#xff1a; 1.首先下载Python Python下载地址&#xff1a;Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个&#xff0c;然后自定义&#xff0c;全选 可以把前4个选上 3.环境配置 1&#xff09;搜高级系统设置 2…...

【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密

在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...

Linux相关概念和易错知识点(42)(TCP的连接管理、可靠性、面临复杂网络的处理)

目录 1.TCP的连接管理机制&#xff08;1&#xff09;三次握手①握手过程②对握手过程的理解 &#xff08;2&#xff09;四次挥手&#xff08;3&#xff09;握手和挥手的触发&#xff08;4&#xff09;状态切换①挥手过程中状态的切换②握手过程中状态的切换 2.TCP的可靠性&…...

Leetcode 3577. Count the Number of Computer Unlocking Permutations

Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接&#xff1a;3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯&#xff0c;要想要能够将所有的电脑解锁&#x…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计&#xff0c;提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合&#xff1a;各模块职责清晰&#xff0c;便于独立开发…...

鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南

1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发&#xff0c;使用DevEco Studio作为开发工具&#xff0c;采用Java语言实现&#xff0c;包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...

SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)

上一章用到了V2 的概念&#xff0c;其实 Fiori当中还有 V4&#xff0c;咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务)&#xff0c;代理中间件&#xff08;ui5-middleware-simpleproxy&#xff09;-CSDN博客…...