当前位置: 首页 > news >正文

pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同

1.标签为long类型,BCE标签为float类型
2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。
在这里插入图片描述

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction=‘mean’,label_smoothing=0.0)

增加加权的CE_loss代码实现

# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as npclass CrossEntropyLoss2d(nn.Module):def __init__(self, weight=None):super(CrossEntropyLoss2d, self).__init__()self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')def forward(self, preds, targets):return self.nll_loss(preds, targets)

语义分割类别计算

class CE_w_loss(nn.Module):def __init__(self,ignore_index=255):super(CE_w_loss, self).__init__()self.ignore_index = ignore_index# self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)def forward(self, outputs, targets):class_num = outputs.shape[1]# print("class_num :",class_num )# # 计算每个类别在整个 batch 中的像素数占比class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num)  # 假设有class_num个类别class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)# # 根据类别占比计算权重class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double()  # 使用对数变换平衡权重# # print("class_weights :",class_weights)## 定义交叉熵损失函数,并使用动态计算的类别权重criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)# 计算损失loss = criterion(outputs, targets)print(loss.item())  # 打印损失值return lossnp.random.seed(666)pred = np.ones((2, 5, 256,256))seg = np.ones((2, 5, 256, 256)) # 灰度label = np.ones((2, 256, 256))  # 灰度pred = torch.from_numpy(pred)seg = torch.from_numpy(seg).int()  # 灰度label = torch.from_numpy(label).long()ce = CE_w_loss()loss = ce(pred, label)print("loss:",loss.item())

调用库(手动设置权重)

import torch
import torch.nn as nn# 假设有一些模型输出和目标标签
model_output = torch.randn(3, 5)  # 假设有5个类别
target = torch.empty(3, dtype=torch.long).random_(5)# 定义权重
weights = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])# 定义交叉熵损失函数,并设置权重
criterion = nn.CrossEntropyLoss(weight=weights)# 计算损失
loss = criterion(model_output, target)
print(loss)

自适应计算权重

import torch
import torch.nn as nn
import numpy as np# 假设我们有一个包含10个样本的批次,每个样本属于4个类别之一
batch_size = 10
num_classes = 4# 随机生成未经过 softmax 的logits输出(网络的最后一层输出)
logits = torch.randn(batch_size, num_classes, requires_grad=True)# 真实的标签(每个样本的类别索引),例如 [0, 2, 1, 3, 0, 0, 1, 2, 3, 3]
labels = torch.tensor([0, 2, 1, 3, 0, 0, 1, 2, 3, 3])# 统计每个类别的频率
class_counts = torch.bincount(labels, minlength=num_classes).float()# 计算每个类别的权重,权重可以为类别频率的倒数
# 为了防止分母为零,这里加一个小的常数epsilon
epsilon = 1e-6
class_weights = 1.0 / (class_counts + epsilon)# 归一化权重,使其和为1
class_weights /= class_weights.sum()print('Class Counts:', class_counts)
print('Class Weights:', class_weights)# 创建带权重的交叉熵损失函数
criterion = nn.CrossEntropyLoss(weight=class_weights)# 计算损失值
loss = criterion(logits, labels)print('Logits:\n', logits)
print('Labels:\n', labels)
print('Weighted Cross-Entropy Loss:', loss.item())# 反向传播梯度
loss.backward()

报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确

参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226
[2] https://blog.csdn.net/qq_40306845/article/details/137651442
[3] https://www.zhihu.com/question/400443029/answer/2477658229

相关文章:

pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同 1.标签为long类型,BCE标签为float类型 2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。 参数配置torch.nn.CrossEntropyLoss(weightNone,…...

【iOS】UI——关于UIAlertController类(警告对话框)

目录 前言关于UIAlertController具体操作及代码实现总结 前言 在UI的警告对话框的学习中,我们发现UIAlertView在iOS 9中已经被废弃,我们找到UIAlertController来代替UIAlertView实现弹出框的功能,从而有了这篇关于UIAlertController的学习笔记…...

django支持https

测试环境,可以用django自带的证书 安装模块 sudo pip3 install django_sslserver服务端https启动 python3 manage.py runsslserver 127.0.0.1:8001https访问 https://127.0.0.1:8001/quota/api/XXX...

算法题day41(补5.27日卡:动态规划01)

一、动态规划基础知识:在动态规划中每一个状态一定是由上一个状态推导出来的。 动态规划五部曲: 1.确定dp数组 以及下标的含义 2.确定递推公式 3.dp数组如何初始化 4.确定遍历顺序 5.举例推导dp数组 debug方式:打印 二、刷题&#xf…...

【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo

系列文章目录 【附带源码】机械臂MoveIt2极简教程(一)、moveit2安装 【附带源码】机械臂MoveIt2极简教程(二)、move_group交互 【附带源码】机械臂MoveIt2极简教程(三)、URDF/SRDF介绍 【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo 目录 系列文章目录1. 创…...

基于蚁群算法的二维路径规划算法(matlab)

微♥关注“电击小子程高兴的MATLAB小屋”获得资料 一、理论基础 1、路径规划算法 路径规划算法是指在有障碍物的工作环境中寻找一条从起点到终点、无碰撞地绕过所有障碍物的运动路径。路径规划算法较多,大体上可分为全局路径规划算法和局部路径规划算法两大类。其…...

政务云参考技术架构

行业优势 总体架构 政务云平台技术框架图,由机房环境、基础设施层、支撑软件层及业务应用层组成,在运维、安全和运营体系的保障下,为政务云使用单位提供统一服务支撑。 功能架构 标准双区隔离 参照国家电子政务规范,打造符合标准的…...

android 13 aosp 预置so库

展讯对应的main.mk配置 device/sprd/qogirn**/ums***/product/***_native/main.mk $(call inherit-product-if-exists, vendor/***/build.mk)vendor/***/build.mk PRODUCT_PACKAGES \libtestvendor///Android.bp cc_prebuilt_library_shared{name:"libtest",srcs:…...

mongo篇---mongoDB Compass连接数据库

mongo篇—mongoDB Compass连接数据库 mongoDB笔记 – 第一条 一、mongoDB Compass连接远程数据库,配置URL。 URL: mongodb://username:passwordhost:port点击connect即可。 注意:host最好使用名称,防止出错连接超时。...

基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真,输出收敛曲线以及三维曲面最高点搜索结果。 2.测试软件版本以及运行结果展示 MATLAB2022A版本…...

前端js解析websocket推送的gzip压缩json的Blob数据

主要依赖插件pako https://www.npmjs.com/package/pako 1、安装 npm install pako 2、使用, pako.inflate(reader.result, {to: "string"}) 解压后的string 对象,需要JSON.parse转成json this.ws.onmessage (evt) > {console.log("…...

【wiki知识库】06.文档管理接口的实现--SpringBoot后端部分

目录 一、🔥今日目标 二、🎈SpringBoot部分类的添加 1.调用MybatisGenerator 2.添加DocSaveParam 3.添加DocQueryVo 三、🚆后端新增接口 3.1添加DocController 3.1.1 /all/{ebokId} 3.1.2 /doc/save 3.1.3 /doc/delete/{idStr} …...

c,c++,go语言字符串的演进

#include <stdio.h> #include <string.h> int main() {char str[] {a,b,c,\0,d,d,d};printf("string:[%s], len:%d \n", str, strlen(str) );return 0; } string:[abc], len:3 c语言只有数组的概念&#xff0c;数组本身没有长度的概念&#xff0c;需…...

vue-cli 快速入门

vue-cli &#xff08;目前向Vite发展&#xff09; 介绍&#xff1a;Vue-cli 是Vue官方提供一个脚手架&#xff0c;用于快速生成一个Vue的项目模板。 Vue-cli提供了如下功能&#xff1a; 统一的目录结构 本地调试 热部署 单元测试 集成打包上线 依赖环境&#xff1a;NodeJ…...

机器人--矩阵运算

两个矩阵相乘的含义 P点在坐标系B中的坐标系PB&#xff0c;需要乘以B到A到变换矩阵TAB。 M点在B坐标系中的位姿MB&#xff0c;怎么计算M在A中的坐标系&#xff1f; 两个矩阵相乘 一个矩阵*另一个矩阵的逆矩阵...

期末复习【汇总】

期末复习【汇总】 前言版权推荐期末复习【汇总】最后 前言 2024-5-12 20:52:17 截止到今天&#xff0c;所有期末复习的汇总 以下内容源自《【创作模板】》 仅供学习交流使用 版权 禁止其他平台发布时删除以下此话 本文首次发布于CSDN平台 作者是CSDN日星月云 博客主页是ht…...

【IM即时通讯】MQTT协议的详解(3)- CONNACK Packet

【IM即时通讯】MQTT协议的详解&#xff08;3&#xff09;- CONNACK Packet 文章目录 【IM即时通讯】MQTT协议的详解&#xff08;3&#xff09;- CONNACK Packet前言说明一、固定同步详解、可变头部详解总结 前言 关于所有的类型的数据示例已经在上面一篇博客说完&#xff1a; …...

Linux - 深入理解/proc虚拟文件系统:从基础到高级

文章目录 Linux /proc虚拟文件系统/proc/self使用 /proc/self 的优势/proc/self 的使用案例案例1&#xff1a;获取当前进程的状态信息案例2&#xff1a;获取当前进程的命令行参数案例3&#xff1a;获取当前进程的内存映射案例4&#xff1a;获取当前进程的文件描述符 /proc中进程…...

Django DeleteView视图

Django 的 DeleteView 是一个基于类的视图&#xff0c;用于处理对象的删除操作。 1&#xff0c;添加视图函数 Test/app3/views.py from django.shortcuts import render# Create your views here. from .models import Bookfrom django.views.generic import ListView class B…...

代码杂谈 之 pyspark如何做相似度计算

在 PySpark 中&#xff0c;计算 DataFrame 两列向量的差可以通过使用 UDF&#xff08;用户自定义函数&#xff09;和 Vector 类型完成。这里有一个示例&#xff0c;展示了如何使用 PySpark 的 pyspark.ml.linalg.Vectorspyspark.sql.functions.udf 来实现这一功能&#xff1a…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

Vim 调用外部命令学习笔记

Vim 外部命令集成完全指南 文章目录 Vim 外部命令集成完全指南核心概念理解命令语法解析语法对比 常用外部命令详解文本排序与去重文本筛选与搜索高级 grep 搜索技巧文本替换与编辑字符处理高级文本处理编程语言处理其他实用命令 范围操作示例指定行范围处理复合命令示例 实用技…...

谷歌浏览器插件

项目中有时候会用到插件 sync-cookie-extension1.0.0&#xff1a;开发环境同步测试 cookie 至 localhost&#xff0c;便于本地请求服务携带 cookie 参考地址&#xff1a;https://juejin.cn/post/7139354571712757767 里面有源码下载下来&#xff0c;加在到扩展即可使用FeHelp…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式&#xff0c;可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

手游刚开服就被攻击怎么办?如何防御DDoS?

开服初期是手游最脆弱的阶段&#xff0c;极易成为DDoS攻击的目标。一旦遭遇攻击&#xff0c;可能导致服务器瘫痪、玩家流失&#xff0c;甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案&#xff0c;帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...

基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容

基于 ​UniApp + WebSocket​实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配​微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

&#x1f31f; 什么是 MCP&#xff1f; 模型控制协议 (MCP) 是一种创新的协议&#xff0c;旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议&#xff0c;它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用&#xff0c;可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器&#xff0c;能够帮助开发者更好地管理复杂的依赖关系&#xff0c;而 GraphQL 则是一种用于 API 的查询语言&#xff0c;能够提…...

React19源码系列之 事件插件系统

事件类别 事件类型 定义 文档 Event Event 接口表示在 EventTarget 上出现的事件。 Event - Web API | MDN UIEvent UIEvent 接口表示简单的用户界面事件。 UIEvent - Web API | MDN KeyboardEvent KeyboardEvent 对象描述了用户与键盘的交互。 KeyboardEvent - Web…...

Linux-07 ubuntu 的 chrome 启动不了

文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了&#xff0c;报错如下四、启动不了&#xff0c;解决如下 总结 问题原因 在应用中可以看到chrome&#xff0c;但是打不开(说明&#xff1a;原来的ubuntu系统出问题了&#xff0c;这个是备用的硬盘&a…...