当前位置: 首页 > news >正文

F.binary_cross_entropy、nn.BCELoss、nn.BCEWithLogitsLoss与F.kl_div函数详细解读

提示:有关loss损失函数详细解读,并附源码!!!

文章目录

  • 前言
  • 一、F.binary_cross_entropy()函数解读
    • 1.函数表达
    • 2.函数运用
  • 二、nn.BCELoss()函数解读
    • 1.函数表达
    • 2.函数运用
  • 三、nn.BCEWithLogitsLoss()函数解读
    • 1.函数表达
    • 2.函数运用(logit探索)
    • 3.函数运用(pred探索)
  • 四、F.kl_div()函数解读


前言

最近我在构建蒸馏相关模型,我重温了一下交叉熵相关内容,也使用pytorch相关函数接口调用,我将对F.binary_cross_entropy()、nn.BCELoss()与nn.BCEWithLogitsLoss()函数做一个说明,同时也简单介绍相对熵的蒸馏F.kl_div()函数做一个介绍。

一、F.binary_cross_entropy()函数解读

1.函数表达

F.binary_cross_entropy(input: Tensor,  # 预测输入target: Tensor, # 标签weight: Optional[Tensor] = None, # 权重可选项size_average: Optional[bool] = None,  # 可选项,快被弃用了reduce: Optional[bool] = None,reduction: str = "mean",  # 默认均值或求和等形式
) -> Tensor:

该函数实际是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def binary_cross_entropy():input = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# s = nn.Sigmoid()# pred = s(input)target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])weight = torch.tensor([[0.1, 0.9, 0.1],[0.1, 0.1, 0.9]])output_weight = F.binary_cross_entropy(input, target,weight=weight)  # input取值范围[0,1]output = F.binary_cross_entropy(input, target)  # input取值范围[0,1]print('预测数据:',input)print('标签数据:',target)print('\nbinary_cross_entropy-有权重:{}\t无权重:{}\n'.format(output_weight, output))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])binary_cross_entropy-有权重:0.12723307311534882	无权重:0.5912299752235413

二、nn.BCELoss()函数解读

1.函数表达

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')

参数说明:
weight :用于样本加权的权重张量。如果给定,则必须是一维张量,大小等于输入张量的大小。默认值为 None。
reduction :指定如何计算损失值。可选值为 ‘none’、‘mean’ 或 ‘sum’。默认值为 ‘mean’。

此为类,是对F.binary_cross_entropy()函数的调用,也是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bceloss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# pred = s(pred)  # 一般会经过sigmoid或softmax方式将其预测转为[0,1]范围的值target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:',pred)print('标签数据:',target)print('\nbceloss:{}\n'.format(b))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.5912299752235413

可以看出该函数与上面无权重运行结果一致,实际是对上一个函数进行了类包装,其计算方式和上面函数完全一样。

三、nn.BCEWithLogitsLoss()函数解读

1.函数表达

torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)

参数说明:
weight:用于对每个样本的损失值进行加权。默认值为 None。
reduction:指定如何对每个 batch 的损失值进行降维。可选值为 ‘none’、‘mean’ 和 ‘sum’。默认值为 ‘mean’。
pos_weight:用于对正样本的损失值进行加权。可以用于处理样本不平衡的问题。例如,如果正样本比负样本少很多,可以设置 pos_weight 为一个较大的值,以提高正样本的权重。默认值为 None。

2.函数运用(logit探索)

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bce_logit_loss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错pred = s(pred)# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print('\nbceloss:{}\t bce_with_logit:{} \n'.format(b, b_logit))

结果如下:

预测数据: tensor([[0.6225, 0.7311, 0.6900],[0.5498, 0.5987, 0.6457]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.7678468823432922	 bce_with_logit:0.7678468823432922

可以看出,nn.BCELoss只需多一个nn.Sigmoid()得到的结果和nn.BCEWithLogitsLoss是一致的,说明该类只是多了一个logit过程。

3.函数运用(pred探索)

import torch
import torch.nn.functional as F
def bce_logit_loss():pred = torch.tensor([[5, 1, 8.0], [2, 4, 6.0]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print(' bce_with_logit:{} \n'.format( b_logit))

结果如下:

预测数据: tensor([[5., 1., 8.],[2., 4., 6.]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bce_with_logit:3.2446444034576416 

可以看出nn.BCEWithLogitsLoss的输入是可以为实数,它先进行sigmoid处理,将其输入变为[0,1]范围,在进行交叉熵运算,然上面nn.BCELoss与F.binary_cross_entropy则不行。

四、F.kl_div()函数解读

该函数为蒸馏模型使用的函数,我直接给出示列,如下:

def kl_func():logits = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])probs = torch.nn.functional.softmax(logits, dim=1)  # 预测学生模型target_probs = torch.tensor([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]])  # 教师模型loss = F.kl_div(torch.log(probs), target_probs, reduction='batchmean')print('模型输出数据:', logits)print('预测数据:',probs)print('标签数据:',target_probs)print('\nkl_loss:{}\n'.format(loss))

输出结果:

模型输出数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
预测数据: tensor([[0.2501, 0.4123, 0.3376],[0.2693, 0.3289, 0.4018]])
标签数据: tensor([[0.3000, 0.4000, 0.3000],[0.1000, 0.5000, 0.4000]])kl_loss:0.057796258479356766

参考文章:点击这里

相关文章:

F.binary_cross_entropy、nn.BCELoss、nn.BCEWithLogitsLoss与F.kl_div函数详细解读

提示:有关loss损失函数详细解读,并附源码!!! 文章目录 前言一、F.binary_cross_entropy()函数解读1.函数表达2.函数运用 二、nn.BCELoss()函数解读1.函数表达2.函数运用 三、nn.BCEWithLogitsLoss()函数解读1.函数表达…...

后端接口性能优化分析

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码🔥如果感觉博主的文章还不错的话,请👍三连支持&…...

【ceph】ceph集群中使用多路径(Multipath)方法

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》:python零基础入门学习 《python运维脚本》: python运维脚本实践 《shell》:shell学习 《terraform》持续更新中:terraform_Aws学习零基础入门到最佳实战 《k8…...

Xshell+Xftp通过代理的方式访问局域网内网服务器

最近在部署项目时遇到只有1台服务器拥有公网ip,其它服务器只有局域网ip,当然其它服务器可以正常访问网络,例如如下模型。之前访问其它几台服务器,都是先通过登录公网IP服务器,然后在Xshell里面执行ssh远程连接&#xf…...

对盒子中的材料进行计数

背景 在做AI算法分析项目的时候,有时候需要我们使用影像分析结合机器学习算法对某些材料盒中的材料进行数目计数,通过自己的分析,给出以下两种解决问题的思路。 1.图像处理方法对材料计数 要使用图像处理方式对盒子中的材料进行数目分析&a…...

科技驱动固定资产管理变革:RFID技术的前沿应用

在当今激烈竞争的商业环境中,企业固定资产管理面临挑战,而RFID技术正以其独特特性和功能性彻底改变资产管理方式。本文将深入探讨RFID技术在固定资产管理中的革命性作用,并解析其应用带来的创新和便利。 RFID技术概述: RFID系统作…...

Django路由层之有名分组和无名分组、反向解析、路由分发、伪静态的概念、名称空间、虚拟环境、Django1和Django2的区别

【1】无名分组 无名分组:就是把正则中小括号里噩匹配到的内容以位置参数的形式传递给视图函数 url(r^test/(\d)$,view.text) get请求的第一种方式: http://127.0.0.1:8000/test/?a1&b2 get请求的第二种方式: http://127.0.0.1:8000/test…...

【nlp】2.5 人名分类器实战项目(对比RNN、LSTM、GRU模型)

人名分类器实战项目 0 项目说明1 案例介绍2 案例步骤2.1 导入必备的工具包2.2 数据预处理2.2.1 获取常用的字符数量2.2.2 国家名种类数和个数2.2.3 读数据到python环境中2.2.4 构建数据源NameClassDataset2.2.5 构建迭代器遍历数据2.3 构建RNN及其变体模型2.3.1 构建RNN模型2.3…...

海康Visionmaster-环境配置:MFC 二次开发环境配置方法

1 新建 MFC 工程,拷贝 DLL:VM\VisionMaster4.0.0\Development\V4.0.0 \ComControl\bin\x64 下的所有拷贝到项目工程输出目录下,如下图所示,项目的输出路径是 Dll 文件夹。 2 通过配置 C目录和链接器的方式配置 VM 环境 2.1 C目录下添加附加…...

利用EXCEL中的VBA对同一文件夹下的多个数据文件进行特定提取

Sub CopyFilesBasedOnCriteria()Dim fso As ObjectDim sourceFolder As StringDim destinationFolder As String 设置源文件夹路径和目标文件夹路径sourceFolder "C:\\test\\全波段模拟_Nimbostratus cloud - 副本"destinationFolder "C:\\Desktop\\MOD02数据…...

FPGA时序约束(七)文献时序约束实验测试

系列文章目录 文章目录 系列文章目录前言文献1:时钟移位LogiclockDesign Partition封装用户编写的程序停掉singletap抓取单端口RAM的数据文献2:SRAM约束前言 之前学习了一些基本时序约束的类别,包括主时钟约束、虚拟时钟约束、输入输出约束、多周期约束等等,但大多都是纸上…...

【数据库开发】DataX开发环境的安装部署(Python、Java)

文章目录 1、简介1.1 DataX简介1.2 DataX功能1.3 支持的数据通道 2、DataX安装配置2.1 DataX2.2 Java2.3 Python 3、DataX Web安装配置3.1 mysql3.2 DataX Web3.2.1 简介3.2.2 架构图3.2.3 依赖环境3.2.4 安装 4、入门使用4.1 DataX自带打印示例测试4.2 DataX生成任务模板文件4…...

Flutter实践一:package组织

1.架构概览 为了降低Flutter工程里lib的复杂度,应尽量拆分一些代码成为独立的package。如图: 我们将通用的组件、领域模型、API、features、存储、repository等抽取成了单独的package。这时lib只剩下多国语言、基本的页面、路由等代码了: 这…...

SpringCloud微服务:Ribbon负载均衡

目录 负载均衡策略: 负载均衡的两种方式: 饥饿加载 1. Ribbon负载均衡规则 规则接口是IRule 默认实现是ZoneAvoidanceRule,根据zone选择服务列表,然后轮询 2.负载均衡自定义方式 代码方式:配置灵活,但修…...

【教程】大气化学在线耦合模式WRF/Chem

查看原文>>>区域气象-大气化学在线耦合模式(WRF/Chem)在大气环境领域实践 随着我国经济快速发展,我国面临着日益严重的大气污染问题。近年来,严重的大气污染问题已经明显影响国计民生,引起政府、学界和人们越…...

GDS 命令的使用 srvctl service TAF application continuity

文档中prim and stdy在同一台机器上,不同机器需要添加address list TAF ENABLED GLOBAL SERVICE in GDS ENVIRONMNET 12C. (Doc ID 2283193.1)​编辑To Bottom In this Document Goal Solution APPLIES TO: Oracle Database - Enterprise Edition - Version 12.1.…...

go 语言之 select

在 Go 语言中&#xff0c;select 是一种用于处理多个通道操作的控制结构。它可以用于在多个通道之间进行非阻塞的选择操作&#xff0c;从而实现并发控制和通信。 select 语句的基本语法如下&#xff1a; go select { case <-channel1:// 当 channel1 可读时执行的代码 cas…...

23款奔驰GLC260L升级小柏林音响 全新15个扬声器

2023年款奔驰GLC260 GLC300升级小柏林之声 3D音效系统 升级小柏林之声音响之后&#xff0c;全车一共有15个喇叭&#xff0c;1台功放&#xff0c;每一首音乐都能在车内掀起激情的音浪&#xff0c;感受纯粹的音乐享受&#xff0c;低频震撼澎湃&#xff0c;让你的心跳与音乐完美契…...

ssh 免密码登录

ssh 免密码登录 1. 原理 1.1 密码登录的通俗解释 把服务器当作一个凤凰社&#xff0c;每次进社公干都需要拿特别的门票入场&#xff0c;门票便是服务器上的账户密码&#xff1b; 1.2 免密登录 对于凤凰社的高级会员&#xff0c;会在社内存储一张高级会员身份&#xff08;id_rsa…...

小程序使用腾讯位置插件获取当前位置

1.小程序后台 设置-第三方设置-插件管理-添加插件 2.进入网站 腾讯位置服务 设置对应的额度 mapPickerPlugin(res) {const key ""; //使用在腾讯位置服务申请的keyconst referer "铅锂运营"; //调用插件的app的名称const category "生活服务,娱…...

零基础学Python怎么学习?我来告诉你

对于IT新手来说&#xff0c;零基础学Python的话&#xff0c;之后可选择的职业方向非常多。Python全栈和爬虫一直以来都是市场的最火的就业岗位之一&#xff0c;它们的薪资回报也算是开发岗里面的顶级了。而且随着大数据和人工智能时代的到来&#xff0c;数据处理和人工智能行业…...

开源软件 FFmpeg 生成模型使用图片数据集

本篇文章聊聊&#xff0c;成就了无数视频软件公司、无数在线视频网站、无数 CDN 云服务厂商的开源软件 ffmpeg。 分享下如何使用它将各种视频或电影文件&#xff0c;转换成上万张图片数据集、壁纸集合&#xff0c;来让下一篇文章中的模型程序“有米下锅”&#xff0c;这个方法…...

Linux Shell 通配符 / glob 模式

1、概念 glob 模式&#xff08;globbing&#xff09;也被称之为 shell 通配符&#xff0c;名字的起源来自于 Unix V6 中的 /etc/glob &#xff08;详见 man 文档&#xff09;。glob 是一种特殊的模式匹配&#xff0c;最常见的是通配符拓展&#xff0c;也可以将 glob 模式设为精…...

深入了解域名与SSL证书的关系

在如今数字化的世界里&#xff0c;网络安全成为我们关注的重要议题之一。为了确保数据在网络上传输的安全性&#xff0c;我们通常会采取各种安全措施&#xff0c;其中最常用的就是SSL证书。然而&#xff0c;很多人并不了解SSL证书是如何与域名相互关联的。 首先&#xff0c;我…...

计算属性与watch的区别,fetch与axios在vue中的异步请求,单文本组件使用,使用vite创建vue项目,组件的使用方法

7.计算属性 7-1计算属性-有缓存 模板中的表达式虽然很方便,但是只能做简单的逻辑操作,如果在模版中写太多的js逻辑,会使得模板过于臃肿,不利于维护,因此我们推荐使用计算属性来解决复杂的逻辑 <!DOCTYPE html> <html lang"en"> <head><meta …...

2023.11.14 hivesql的容器,数组与映射

目录 https://blog.csdn.net/m0_49956154/article/details/134365327?spm1001.2014.3001.5501https://blog.csdn.net/m0_49956154/article/details/134365327?spm1001.2014.3001.5501 8.hive的复杂类型 9.array类型: 又叫数组类型,存储同类型的单数据的集合 10.struct类型…...

Android Glide照片宫格RecyclerView,点击SharedElement共享元素动画查看大图,Kotlin(1)

Android Glide照片宫格RecyclerView&#xff0c;点击SharedElement共享元素动画查看大图&#xff0c;Kotlin&#xff08;1&#xff09; <uses-permission android:name"android.permission.READ_EXTERNAL_STORAGE" /><uses-permission android:name"an…...

SELinux零知识学习八、SELinux策略语言之客体类别和许可(2)

接前一篇文章&#xff1a;SELinux零知识学习七、SELinux策略语言之客体类别和许可&#xff08;1&#xff09; 一、SELinux策略语言之客体类别和许可 2. 在SELinux策略中定义客体类别 SELinux策略中必须包括所有SELinux内核支持的客体类别和许可的声明&#xff0c;以及其它客体…...

deepstream-测试发送AMQP

1. 安装库 * glib 2.0 ---------- sudo apt-get install libglib2.0 libglib2.0-dev Install rabbitmq-c library -------------------------- sudo apt-get install librabbitmq-dev If you plan to have AMQP broker installed on your local machine ------------------…...

LLMs可以遵循简单的规则吗?

由于大型语言模型在现实世界中的责任越来越大&#xff0c;因此如何以可靠的方式指定和约束这些系统的行为很重要。一些开发人员希望为模型设置显式规则&#xff0c;例如“不生成滥用内容”&#xff0c;但这种方式可能会被特殊技术规避。评估LLM在面对对抗性输入时遵循开发人员提…...