当前位置: 首页 > news >正文

F.binary_cross_entropy、nn.BCELoss、nn.BCEWithLogitsLoss与F.kl_div函数详细解读

提示:有关loss损失函数详细解读,并附源码!!!

文章目录

  • 前言
  • 一、F.binary_cross_entropy()函数解读
    • 1.函数表达
    • 2.函数运用
  • 二、nn.BCELoss()函数解读
    • 1.函数表达
    • 2.函数运用
  • 三、nn.BCEWithLogitsLoss()函数解读
    • 1.函数表达
    • 2.函数运用(logit探索)
    • 3.函数运用(pred探索)
  • 四、F.kl_div()函数解读


前言

最近我在构建蒸馏相关模型,我重温了一下交叉熵相关内容,也使用pytorch相关函数接口调用,我将对F.binary_cross_entropy()、nn.BCELoss()与nn.BCEWithLogitsLoss()函数做一个说明,同时也简单介绍相对熵的蒸馏F.kl_div()函数做一个介绍。

一、F.binary_cross_entropy()函数解读

1.函数表达

F.binary_cross_entropy(input: Tensor,  # 预测输入target: Tensor, # 标签weight: Optional[Tensor] = None, # 权重可选项size_average: Optional[bool] = None,  # 可选项,快被弃用了reduce: Optional[bool] = None,reduction: str = "mean",  # 默认均值或求和等形式
) -> Tensor:

该函数实际是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def binary_cross_entropy():input = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# s = nn.Sigmoid()# pred = s(input)target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])weight = torch.tensor([[0.1, 0.9, 0.1],[0.1, 0.1, 0.9]])output_weight = F.binary_cross_entropy(input, target,weight=weight)  # input取值范围[0,1]output = F.binary_cross_entropy(input, target)  # input取值范围[0,1]print('预测数据:',input)print('标签数据:',target)print('\nbinary_cross_entropy-有权重:{}\t无权重:{}\n'.format(output_weight, output))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])binary_cross_entropy-有权重:0.12723307311534882	无权重:0.5912299752235413

二、nn.BCELoss()函数解读

1.函数表达

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')

参数说明:
weight :用于样本加权的权重张量。如果给定,则必须是一维张量,大小等于输入张量的大小。默认值为 None。
reduction :指定如何计算损失值。可选值为 ‘none’、‘mean’ 或 ‘sum’。默认值为 ‘mean’。

此为类,是对F.binary_cross_entropy()函数的调用,也是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bceloss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# pred = s(pred)  # 一般会经过sigmoid或softmax方式将其预测转为[0,1]范围的值target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:',pred)print('标签数据:',target)print('\nbceloss:{}\n'.format(b))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.5912299752235413

可以看出该函数与上面无权重运行结果一致,实际是对上一个函数进行了类包装,其计算方式和上面函数完全一样。

三、nn.BCEWithLogitsLoss()函数解读

1.函数表达

torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)

参数说明:
weight:用于对每个样本的损失值进行加权。默认值为 None。
reduction:指定如何对每个 batch 的损失值进行降维。可选值为 ‘none’、‘mean’ 和 ‘sum’。默认值为 ‘mean’。
pos_weight:用于对正样本的损失值进行加权。可以用于处理样本不平衡的问题。例如,如果正样本比负样本少很多,可以设置 pos_weight 为一个较大的值,以提高正样本的权重。默认值为 None。

2.函数运用(logit探索)

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bce_logit_loss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错pred = s(pred)# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print('\nbceloss:{}\t bce_with_logit:{} \n'.format(b, b_logit))

结果如下:

预测数据: tensor([[0.6225, 0.7311, 0.6900],[0.5498, 0.5987, 0.6457]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.7678468823432922	 bce_with_logit:0.7678468823432922

可以看出,nn.BCELoss只需多一个nn.Sigmoid()得到的结果和nn.BCEWithLogitsLoss是一致的,说明该类只是多了一个logit过程。

3.函数运用(pred探索)

import torch
import torch.nn.functional as F
def bce_logit_loss():pred = torch.tensor([[5, 1, 8.0], [2, 4, 6.0]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print(' bce_with_logit:{} \n'.format( b_logit))

结果如下:

预测数据: tensor([[5., 1., 8.],[2., 4., 6.]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bce_with_logit:3.2446444034576416 

可以看出nn.BCEWithLogitsLoss的输入是可以为实数,它先进行sigmoid处理,将其输入变为[0,1]范围,在进行交叉熵运算,然上面nn.BCELoss与F.binary_cross_entropy则不行。

四、F.kl_div()函数解读

该函数为蒸馏模型使用的函数,我直接给出示列,如下:

def kl_func():logits = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])probs = torch.nn.functional.softmax(logits, dim=1)  # 预测学生模型target_probs = torch.tensor([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]])  # 教师模型loss = F.kl_div(torch.log(probs), target_probs, reduction='batchmean')print('模型输出数据:', logits)print('预测数据:',probs)print('标签数据:',target_probs)print('\nkl_loss:{}\n'.format(loss))

输出结果:

模型输出数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
预测数据: tensor([[0.2501, 0.4123, 0.3376],[0.2693, 0.3289, 0.4018]])
标签数据: tensor([[0.3000, 0.4000, 0.3000],[0.1000, 0.5000, 0.4000]])kl_loss:0.057796258479356766

参考文章:点击这里

相关文章:

F.binary_cross_entropy、nn.BCELoss、nn.BCEWithLogitsLoss与F.kl_div函数详细解读

提示:有关loss损失函数详细解读,并附源码!!! 文章目录 前言一、F.binary_cross_entropy()函数解读1.函数表达2.函数运用 二、nn.BCELoss()函数解读1.函数表达2.函数运用 三、nn.BCEWithLogitsLoss()函数解读1.函数表达…...

后端接口性能优化分析

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码🔥如果感觉博主的文章还不错的话,请👍三连支持&…...

【ceph】ceph集群中使用多路径(Multipath)方法

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》:python零基础入门学习 《python运维脚本》: python运维脚本实践 《shell》:shell学习 《terraform》持续更新中:terraform_Aws学习零基础入门到最佳实战 《k8…...

Xshell+Xftp通过代理的方式访问局域网内网服务器

最近在部署项目时遇到只有1台服务器拥有公网ip,其它服务器只有局域网ip,当然其它服务器可以正常访问网络,例如如下模型。之前访问其它几台服务器,都是先通过登录公网IP服务器,然后在Xshell里面执行ssh远程连接&#xf…...

对盒子中的材料进行计数

背景 在做AI算法分析项目的时候,有时候需要我们使用影像分析结合机器学习算法对某些材料盒中的材料进行数目计数,通过自己的分析,给出以下两种解决问题的思路。 1.图像处理方法对材料计数 要使用图像处理方式对盒子中的材料进行数目分析&a…...

科技驱动固定资产管理变革:RFID技术的前沿应用

在当今激烈竞争的商业环境中,企业固定资产管理面临挑战,而RFID技术正以其独特特性和功能性彻底改变资产管理方式。本文将深入探讨RFID技术在固定资产管理中的革命性作用,并解析其应用带来的创新和便利。 RFID技术概述: RFID系统作…...

Django路由层之有名分组和无名分组、反向解析、路由分发、伪静态的概念、名称空间、虚拟环境、Django1和Django2的区别

【1】无名分组 无名分组:就是把正则中小括号里噩匹配到的内容以位置参数的形式传递给视图函数 url(r^test/(\d)$,view.text) get请求的第一种方式: http://127.0.0.1:8000/test/?a1&b2 get请求的第二种方式: http://127.0.0.1:8000/test…...

【nlp】2.5 人名分类器实战项目(对比RNN、LSTM、GRU模型)

人名分类器实战项目 0 项目说明1 案例介绍2 案例步骤2.1 导入必备的工具包2.2 数据预处理2.2.1 获取常用的字符数量2.2.2 国家名种类数和个数2.2.3 读数据到python环境中2.2.4 构建数据源NameClassDataset2.2.5 构建迭代器遍历数据2.3 构建RNN及其变体模型2.3.1 构建RNN模型2.3…...

海康Visionmaster-环境配置:MFC 二次开发环境配置方法

1 新建 MFC 工程,拷贝 DLL:VM\VisionMaster4.0.0\Development\V4.0.0 \ComControl\bin\x64 下的所有拷贝到项目工程输出目录下,如下图所示,项目的输出路径是 Dll 文件夹。 2 通过配置 C目录和链接器的方式配置 VM 环境 2.1 C目录下添加附加…...

利用EXCEL中的VBA对同一文件夹下的多个数据文件进行特定提取

Sub CopyFilesBasedOnCriteria()Dim fso As ObjectDim sourceFolder As StringDim destinationFolder As String 设置源文件夹路径和目标文件夹路径sourceFolder "C:\\test\\全波段模拟_Nimbostratus cloud - 副本"destinationFolder "C:\\Desktop\\MOD02数据…...

FPGA时序约束(七)文献时序约束实验测试

系列文章目录 文章目录 系列文章目录前言文献1:时钟移位LogiclockDesign Partition封装用户编写的程序停掉singletap抓取单端口RAM的数据文献2:SRAM约束前言 之前学习了一些基本时序约束的类别,包括主时钟约束、虚拟时钟约束、输入输出约束、多周期约束等等,但大多都是纸上…...

【数据库开发】DataX开发环境的安装部署(Python、Java)

文章目录 1、简介1.1 DataX简介1.2 DataX功能1.3 支持的数据通道 2、DataX安装配置2.1 DataX2.2 Java2.3 Python 3、DataX Web安装配置3.1 mysql3.2 DataX Web3.2.1 简介3.2.2 架构图3.2.3 依赖环境3.2.4 安装 4、入门使用4.1 DataX自带打印示例测试4.2 DataX生成任务模板文件4…...

Flutter实践一:package组织

1.架构概览 为了降低Flutter工程里lib的复杂度,应尽量拆分一些代码成为独立的package。如图: 我们将通用的组件、领域模型、API、features、存储、repository等抽取成了单独的package。这时lib只剩下多国语言、基本的页面、路由等代码了: 这…...

SpringCloud微服务:Ribbon负载均衡

目录 负载均衡策略: 负载均衡的两种方式: 饥饿加载 1. Ribbon负载均衡规则 规则接口是IRule 默认实现是ZoneAvoidanceRule,根据zone选择服务列表,然后轮询 2.负载均衡自定义方式 代码方式:配置灵活,但修…...

【教程】大气化学在线耦合模式WRF/Chem

查看原文>>>区域气象-大气化学在线耦合模式(WRF/Chem)在大气环境领域实践 随着我国经济快速发展,我国面临着日益严重的大气污染问题。近年来,严重的大气污染问题已经明显影响国计民生,引起政府、学界和人们越…...

GDS 命令的使用 srvctl service TAF application continuity

文档中prim and stdy在同一台机器上,不同机器需要添加address list TAF ENABLED GLOBAL SERVICE in GDS ENVIRONMNET 12C. (Doc ID 2283193.1)​编辑To Bottom In this Document Goal Solution APPLIES TO: Oracle Database - Enterprise Edition - Version 12.1.…...

go 语言之 select

在 Go 语言中&#xff0c;select 是一种用于处理多个通道操作的控制结构。它可以用于在多个通道之间进行非阻塞的选择操作&#xff0c;从而实现并发控制和通信。 select 语句的基本语法如下&#xff1a; go select { case <-channel1:// 当 channel1 可读时执行的代码 cas…...

23款奔驰GLC260L升级小柏林音响 全新15个扬声器

2023年款奔驰GLC260 GLC300升级小柏林之声 3D音效系统 升级小柏林之声音响之后&#xff0c;全车一共有15个喇叭&#xff0c;1台功放&#xff0c;每一首音乐都能在车内掀起激情的音浪&#xff0c;感受纯粹的音乐享受&#xff0c;低频震撼澎湃&#xff0c;让你的心跳与音乐完美契…...

ssh 免密码登录

ssh 免密码登录 1. 原理 1.1 密码登录的通俗解释 把服务器当作一个凤凰社&#xff0c;每次进社公干都需要拿特别的门票入场&#xff0c;门票便是服务器上的账户密码&#xff1b; 1.2 免密登录 对于凤凰社的高级会员&#xff0c;会在社内存储一张高级会员身份&#xff08;id_rsa…...

小程序使用腾讯位置插件获取当前位置

1.小程序后台 设置-第三方设置-插件管理-添加插件 2.进入网站 腾讯位置服务 设置对应的额度 mapPickerPlugin(res) {const key ""; //使用在腾讯位置服务申请的keyconst referer "铅锂运营"; //调用插件的app的名称const category "生活服务,娱…...

基于算法竞赛的c++编程(28)结构体的进阶应用

结构体的嵌套与复杂数据组织 在C中&#xff0c;结构体可以嵌套使用&#xff0c;形成更复杂的数据结构。例如&#xff0c;可以通过嵌套结构体描述多层级数据关系&#xff1a; struct Address {string city;string street;int zipCode; };struct Employee {string name;int id;…...

AI-调查研究-01-正念冥想有用吗?对健康的影响及科学指南

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; 目前2025年06月05日更新到&#xff1a; AI炼丹日志-28 - Aud…...

CVPR 2025 MIMO: 支持视觉指代和像素grounding 的医学视觉语言模型

CVPR 2025 | MIMO&#xff1a;支持视觉指代和像素对齐的医学视觉语言模型 论文信息 标题&#xff1a;MIMO: A medical vision language model with visual referring multimodal input and pixel grounding multimodal output作者&#xff1a;Yanyuan Chen, Dexuan Xu, Yu Hu…...

04-初识css

一、css样式引入 1.1.内部样式 <div style"width: 100px;"></div>1.2.外部样式 1.2.1.外部样式1 <style>.aa {width: 100px;} </style> <div class"aa"></div>1.2.2.外部样式2 <!-- rel内表面引入的是style样…...

c#开发AI模型对话

AI模型 前面已经介绍了一般AI模型本地部署&#xff0c;直接调用现成的模型数据。这里主要讲述讲接口集成到我们自己的程序中使用方式。 微软提供了ML.NET来开发和使用AI模型&#xff0c;但是目前国内可能使用不多&#xff0c;至少实践例子很少看见。开发训练模型就不介绍了&am…...

Redis的发布订阅模式与专业的 MQ(如 Kafka, RabbitMQ)相比,优缺点是什么?适用于哪些场景?

Redis 的发布订阅&#xff08;Pub/Sub&#xff09;模式与专业的 MQ&#xff08;Message Queue&#xff09;如 Kafka、RabbitMQ 进行比较&#xff0c;核心的权衡点在于&#xff1a;简单与速度 vs. 可靠与功能。 下面我们详细展开对比。 Redis Pub/Sub 的核心特点 它是一个发后…...

招商蛇口 | 执笔CID,启幕低密生活新境

作为中国城市生长的力量&#xff0c;招商蛇口以“美好生活承载者”为使命&#xff0c;深耕全球111座城市&#xff0c;以央企担当匠造时代理想人居。从深圳湾的开拓基因到西安高新CID的战略落子&#xff0c;招商蛇口始终与城市发展同频共振&#xff0c;以建筑诠释对土地与生活的…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...

CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝

目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为&#xff1a;一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...

Golang——7、包与接口详解

包与接口详解 1、Golang包详解1.1、Golang中包的定义和介绍1.2、Golang包管理工具go mod1.3、Golang中自定义包1.4、Golang中使用第三包1.5、init函数 2、接口详解2.1、接口的定义2.2、空接口2.3、类型断言2.4、结构体值接收者和指针接收者实现接口的区别2.5、一个结构体实现多…...