当前位置: 首页 > news >正文

F.binary_cross_entropy、nn.BCELoss、nn.BCEWithLogitsLoss与F.kl_div函数详细解读

提示:有关loss损失函数详细解读,并附源码!!!

文章目录

  • 前言
  • 一、F.binary_cross_entropy()函数解读
    • 1.函数表达
    • 2.函数运用
  • 二、nn.BCELoss()函数解读
    • 1.函数表达
    • 2.函数运用
  • 三、nn.BCEWithLogitsLoss()函数解读
    • 1.函数表达
    • 2.函数运用(logit探索)
    • 3.函数运用(pred探索)
  • 四、F.kl_div()函数解读


前言

最近我在构建蒸馏相关模型,我重温了一下交叉熵相关内容,也使用pytorch相关函数接口调用,我将对F.binary_cross_entropy()、nn.BCELoss()与nn.BCEWithLogitsLoss()函数做一个说明,同时也简单介绍相对熵的蒸馏F.kl_div()函数做一个介绍。

一、F.binary_cross_entropy()函数解读

1.函数表达

F.binary_cross_entropy(input: Tensor,  # 预测输入target: Tensor, # 标签weight: Optional[Tensor] = None, # 权重可选项size_average: Optional[bool] = None,  # 可选项,快被弃用了reduce: Optional[bool] = None,reduction: str = "mean",  # 默认均值或求和等形式
) -> Tensor:

该函数实际是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def binary_cross_entropy():input = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# s = nn.Sigmoid()# pred = s(input)target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])weight = torch.tensor([[0.1, 0.9, 0.1],[0.1, 0.1, 0.9]])output_weight = F.binary_cross_entropy(input, target,weight=weight)  # input取值范围[0,1]output = F.binary_cross_entropy(input, target)  # input取值范围[0,1]print('预测数据:',input)print('标签数据:',target)print('\nbinary_cross_entropy-有权重:{}\t无权重:{}\n'.format(output_weight, output))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])binary_cross_entropy-有权重:0.12723307311534882	无权重:0.5912299752235413

二、nn.BCELoss()函数解读

1.函数表达

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')

参数说明:
weight :用于样本加权的权重张量。如果给定,则必须是一维张量,大小等于输入张量的大小。默认值为 None。
reduction :指定如何计算损失值。可选值为 ‘none’、‘mean’ 或 ‘sum’。默认值为 ‘mean’。

此为类,是对F.binary_cross_entropy()函数的调用,也是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bceloss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# pred = s(pred)  # 一般会经过sigmoid或softmax方式将其预测转为[0,1]范围的值target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:',pred)print('标签数据:',target)print('\nbceloss:{}\n'.format(b))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.5912299752235413

可以看出该函数与上面无权重运行结果一致,实际是对上一个函数进行了类包装,其计算方式和上面函数完全一样。

三、nn.BCEWithLogitsLoss()函数解读

1.函数表达

torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)

参数说明:
weight:用于对每个样本的损失值进行加权。默认值为 None。
reduction:指定如何对每个 batch 的损失值进行降维。可选值为 ‘none’、‘mean’ 和 ‘sum’。默认值为 ‘mean’。
pos_weight:用于对正样本的损失值进行加权。可以用于处理样本不平衡的问题。例如,如果正样本比负样本少很多,可以设置 pos_weight 为一个较大的值,以提高正样本的权重。默认值为 None。

2.函数运用(logit探索)

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bce_logit_loss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错pred = s(pred)# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print('\nbceloss:{}\t bce_with_logit:{} \n'.format(b, b_logit))

结果如下:

预测数据: tensor([[0.6225, 0.7311, 0.6900],[0.5498, 0.5987, 0.6457]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.7678468823432922	 bce_with_logit:0.7678468823432922

可以看出,nn.BCELoss只需多一个nn.Sigmoid()得到的结果和nn.BCEWithLogitsLoss是一致的,说明该类只是多了一个logit过程。

3.函数运用(pred探索)

import torch
import torch.nn.functional as F
def bce_logit_loss():pred = torch.tensor([[5, 1, 8.0], [2, 4, 6.0]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print(' bce_with_logit:{} \n'.format( b_logit))

结果如下:

预测数据: tensor([[5., 1., 8.],[2., 4., 6.]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bce_with_logit:3.2446444034576416 

可以看出nn.BCEWithLogitsLoss的输入是可以为实数,它先进行sigmoid处理,将其输入变为[0,1]范围,在进行交叉熵运算,然上面nn.BCELoss与F.binary_cross_entropy则不行。

四、F.kl_div()函数解读

该函数为蒸馏模型使用的函数,我直接给出示列,如下:

def kl_func():logits = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])probs = torch.nn.functional.softmax(logits, dim=1)  # 预测学生模型target_probs = torch.tensor([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]])  # 教师模型loss = F.kl_div(torch.log(probs), target_probs, reduction='batchmean')print('模型输出数据:', logits)print('预测数据:',probs)print('标签数据:',target_probs)print('\nkl_loss:{}\n'.format(loss))

输出结果:

模型输出数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
预测数据: tensor([[0.2501, 0.4123, 0.3376],[0.2693, 0.3289, 0.4018]])
标签数据: tensor([[0.3000, 0.4000, 0.3000],[0.1000, 0.5000, 0.4000]])kl_loss:0.057796258479356766

参考文章:点击这里

相关文章:

F.binary_cross_entropy、nn.BCELoss、nn.BCEWithLogitsLoss与F.kl_div函数详细解读

提示:有关loss损失函数详细解读,并附源码!!! 文章目录 前言一、F.binary_cross_entropy()函数解读1.函数表达2.函数运用 二、nn.BCELoss()函数解读1.函数表达2.函数运用 三、nn.BCEWithLogitsLoss()函数解读1.函数表达…...

后端接口性能优化分析

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码🔥如果感觉博主的文章还不错的话,请👍三连支持&…...

【ceph】ceph集群中使用多路径(Multipath)方法

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》:python零基础入门学习 《python运维脚本》: python运维脚本实践 《shell》:shell学习 《terraform》持续更新中:terraform_Aws学习零基础入门到最佳实战 《k8…...

Xshell+Xftp通过代理的方式访问局域网内网服务器

最近在部署项目时遇到只有1台服务器拥有公网ip,其它服务器只有局域网ip,当然其它服务器可以正常访问网络,例如如下模型。之前访问其它几台服务器,都是先通过登录公网IP服务器,然后在Xshell里面执行ssh远程连接&#xf…...

对盒子中的材料进行计数

背景 在做AI算法分析项目的时候,有时候需要我们使用影像分析结合机器学习算法对某些材料盒中的材料进行数目计数,通过自己的分析,给出以下两种解决问题的思路。 1.图像处理方法对材料计数 要使用图像处理方式对盒子中的材料进行数目分析&a…...

科技驱动固定资产管理变革:RFID技术的前沿应用

在当今激烈竞争的商业环境中,企业固定资产管理面临挑战,而RFID技术正以其独特特性和功能性彻底改变资产管理方式。本文将深入探讨RFID技术在固定资产管理中的革命性作用,并解析其应用带来的创新和便利。 RFID技术概述: RFID系统作…...

Django路由层之有名分组和无名分组、反向解析、路由分发、伪静态的概念、名称空间、虚拟环境、Django1和Django2的区别

【1】无名分组 无名分组:就是把正则中小括号里噩匹配到的内容以位置参数的形式传递给视图函数 url(r^test/(\d)$,view.text) get请求的第一种方式: http://127.0.0.1:8000/test/?a1&b2 get请求的第二种方式: http://127.0.0.1:8000/test…...

【nlp】2.5 人名分类器实战项目(对比RNN、LSTM、GRU模型)

人名分类器实战项目 0 项目说明1 案例介绍2 案例步骤2.1 导入必备的工具包2.2 数据预处理2.2.1 获取常用的字符数量2.2.2 国家名种类数和个数2.2.3 读数据到python环境中2.2.4 构建数据源NameClassDataset2.2.5 构建迭代器遍历数据2.3 构建RNN及其变体模型2.3.1 构建RNN模型2.3…...

海康Visionmaster-环境配置:MFC 二次开发环境配置方法

1 新建 MFC 工程,拷贝 DLL:VM\VisionMaster4.0.0\Development\V4.0.0 \ComControl\bin\x64 下的所有拷贝到项目工程输出目录下,如下图所示,项目的输出路径是 Dll 文件夹。 2 通过配置 C目录和链接器的方式配置 VM 环境 2.1 C目录下添加附加…...

利用EXCEL中的VBA对同一文件夹下的多个数据文件进行特定提取

Sub CopyFilesBasedOnCriteria()Dim fso As ObjectDim sourceFolder As StringDim destinationFolder As String 设置源文件夹路径和目标文件夹路径sourceFolder "C:\\test\\全波段模拟_Nimbostratus cloud - 副本"destinationFolder "C:\\Desktop\\MOD02数据…...

FPGA时序约束(七)文献时序约束实验测试

系列文章目录 文章目录 系列文章目录前言文献1:时钟移位LogiclockDesign Partition封装用户编写的程序停掉singletap抓取单端口RAM的数据文献2:SRAM约束前言 之前学习了一些基本时序约束的类别,包括主时钟约束、虚拟时钟约束、输入输出约束、多周期约束等等,但大多都是纸上…...

【数据库开发】DataX开发环境的安装部署(Python、Java)

文章目录 1、简介1.1 DataX简介1.2 DataX功能1.3 支持的数据通道 2、DataX安装配置2.1 DataX2.2 Java2.3 Python 3、DataX Web安装配置3.1 mysql3.2 DataX Web3.2.1 简介3.2.2 架构图3.2.3 依赖环境3.2.4 安装 4、入门使用4.1 DataX自带打印示例测试4.2 DataX生成任务模板文件4…...

Flutter实践一:package组织

1.架构概览 为了降低Flutter工程里lib的复杂度,应尽量拆分一些代码成为独立的package。如图: 我们将通用的组件、领域模型、API、features、存储、repository等抽取成了单独的package。这时lib只剩下多国语言、基本的页面、路由等代码了: 这…...

SpringCloud微服务:Ribbon负载均衡

目录 负载均衡策略: 负载均衡的两种方式: 饥饿加载 1. Ribbon负载均衡规则 规则接口是IRule 默认实现是ZoneAvoidanceRule,根据zone选择服务列表,然后轮询 2.负载均衡自定义方式 代码方式:配置灵活,但修…...

【教程】大气化学在线耦合模式WRF/Chem

查看原文>>>区域气象-大气化学在线耦合模式(WRF/Chem)在大气环境领域实践 随着我国经济快速发展,我国面临着日益严重的大气污染问题。近年来,严重的大气污染问题已经明显影响国计民生,引起政府、学界和人们越…...

GDS 命令的使用 srvctl service TAF application continuity

文档中prim and stdy在同一台机器上,不同机器需要添加address list TAF ENABLED GLOBAL SERVICE in GDS ENVIRONMNET 12C. (Doc ID 2283193.1)​编辑To Bottom In this Document Goal Solution APPLIES TO: Oracle Database - Enterprise Edition - Version 12.1.…...

go 语言之 select

在 Go 语言中&#xff0c;select 是一种用于处理多个通道操作的控制结构。它可以用于在多个通道之间进行非阻塞的选择操作&#xff0c;从而实现并发控制和通信。 select 语句的基本语法如下&#xff1a; go select { case <-channel1:// 当 channel1 可读时执行的代码 cas…...

23款奔驰GLC260L升级小柏林音响 全新15个扬声器

2023年款奔驰GLC260 GLC300升级小柏林之声 3D音效系统 升级小柏林之声音响之后&#xff0c;全车一共有15个喇叭&#xff0c;1台功放&#xff0c;每一首音乐都能在车内掀起激情的音浪&#xff0c;感受纯粹的音乐享受&#xff0c;低频震撼澎湃&#xff0c;让你的心跳与音乐完美契…...

ssh 免密码登录

ssh 免密码登录 1. 原理 1.1 密码登录的通俗解释 把服务器当作一个凤凰社&#xff0c;每次进社公干都需要拿特别的门票入场&#xff0c;门票便是服务器上的账户密码&#xff1b; 1.2 免密登录 对于凤凰社的高级会员&#xff0c;会在社内存储一张高级会员身份&#xff08;id_rsa…...

小程序使用腾讯位置插件获取当前位置

1.小程序后台 设置-第三方设置-插件管理-添加插件 2.进入网站 腾讯位置服务 设置对应的额度 mapPickerPlugin(res) {const key ""; //使用在腾讯位置服务申请的keyconst referer "铅锂运营"; //调用插件的app的名称const category "生活服务,娱…...

C++_核心编程_多态案例二-制作饮品

#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为&#xff1a;煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例&#xff0c;提供抽象制作饮品基类&#xff0c;提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

【Linux】shell脚本忽略错误继续执行

在 shell 脚本中&#xff0c;可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行&#xff0c;可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令&#xff0c;并忽略错误 rm somefile…...

RocketMQ延迟消息机制

两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数&#xff0c;对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后&#xf…...

C++.OpenGL (10/64)基础光照(Basic Lighting)

基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...

JDK 17 新特性

#JDK 17 新特性 /**************** 文本块 *****************/ python/scala中早就支持&#xff0c;不稀奇 String json “”" { “name”: “Java”, “version”: 17 } “”"; /**************** Switch 语句 -> 表达式 *****************/ 挺好的&#xff…...

安宝特方案丨船舶智造的“AR+AI+作业标准化管理解决方案”(装配)

船舶制造装配管理现状&#xff1a;装配工作依赖人工经验&#xff0c;装配工人凭借长期实践积累的操作技巧完成零部件组装。企业通常制定了装配作业指导书&#xff0c;但在实际执行中&#xff0c;工人对指导书的理解和遵循程度参差不齐。 船舶装配过程中的挑战与需求 挑战 (1…...

【分享】推荐一些办公小工具

1、PDF 在线转换 https://smallpdf.com/cn/pdf-tools 推荐理由&#xff1a;大部分的转换软件需要收费&#xff0c;要么功能不齐全&#xff0c;而开会员又用不了几次浪费钱&#xff0c;借用别人的又不安全。 这个网站它不需要登录或下载安装。而且提供的免费功能就能满足日常…...

手机平板能效生态设计指令EU 2023/1670标准解读

手机平板能效生态设计指令EU 2023/1670标准解读 以下是针对欧盟《手机和平板电脑生态设计法规》(EU) 2023/1670 的核心解读&#xff0c;综合法规核心要求、最新修正及企业合规要点&#xff1a; 一、法规背景与目标 生效与强制时间 发布于2023年8月31日&#xff08;OJ公报&…...

十九、【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建

【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建 前言准备工作第一部分:回顾 Django 内置的 `User` 模型第二部分:设计并创建 `Role` 和 `UserProfile` 模型第三部分:创建 Serializers第四部分:创建 ViewSets第五部分:注册 API 路由第六部分:后端初步测…...

消防一体化安全管控平台:构建消防“一张图”和APP统一管理

在城市的某个角落&#xff0c;一场突如其来的火灾打破了平静。熊熊烈火迅速蔓延&#xff0c;滚滚浓烟弥漫开来&#xff0c;周围群众的生命财产安全受到严重威胁。就在这千钧一发之际&#xff0c;消防救援队伍迅速行动&#xff0c;而豪越科技消防一体化安全管控平台构建的消防“…...