pytorch 加权CE_loss实现(语义分割中的类不平衡使用)
加权CE_loss和BCE_loss稍有不同
1.标签为long类型,BCE标签为float类型
2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction=‘mean’,label_smoothing=0.0)
增加加权的CE_loss代码实现
# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as npclass CrossEntropyLoss2d(nn.Module):def __init__(self, weight=None):super(CrossEntropyLoss2d, self).__init__()self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')def forward(self, preds, targets):return self.nll_loss(preds, targets)
语义分割类别计算
class CE_w_loss(nn.Module):def __init__(self,ignore_index=255):super(CE_w_loss, self).__init__()self.ignore_index = ignore_index# self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)def forward(self, outputs, targets):class_num = outputs.shape[1]# print("class_num :",class_num )# # 计算每个类别在整个 batch 中的像素数占比class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num) # 假设有class_num个类别class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)# # 根据类别占比计算权重class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double() # 使用对数变换平衡权重# # print("class_weights :",class_weights)## 定义交叉熵损失函数,并使用动态计算的类别权重criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)# 计算损失loss = criterion(outputs, targets)print(loss.item()) # 打印损失值return lossnp.random.seed(666)pred = np.ones((2, 5, 256,256))seg = np.ones((2, 5, 256, 256)) # 灰度label = np.ones((2, 256, 256)) # 灰度pred = torch.from_numpy(pred)seg = torch.from_numpy(seg).int() # 灰度label = torch.from_numpy(label).long()ce = CE_w_loss()loss = ce(pred, label)print("loss:",loss.item())
调用库(手动设置权重)
import torch
import torch.nn as nn# 假设有一些模型输出和目标标签
model_output = torch.randn(3, 5) # 假设有5个类别
target = torch.empty(3, dtype=torch.long).random_(5)# 定义权重
weights = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])# 定义交叉熵损失函数,并设置权重
criterion = nn.CrossEntropyLoss(weight=weights)# 计算损失
loss = criterion(model_output, target)
print(loss)
自适应计算权重
import torch
import torch.nn as nn
import numpy as np# 假设我们有一个包含10个样本的批次,每个样本属于4个类别之一
batch_size = 10
num_classes = 4# 随机生成未经过 softmax 的logits输出(网络的最后一层输出)
logits = torch.randn(batch_size, num_classes, requires_grad=True)# 真实的标签(每个样本的类别索引),例如 [0, 2, 1, 3, 0, 0, 1, 2, 3, 3]
labels = torch.tensor([0, 2, 1, 3, 0, 0, 1, 2, 3, 3])# 统计每个类别的频率
class_counts = torch.bincount(labels, minlength=num_classes).float()# 计算每个类别的权重,权重可以为类别频率的倒数
# 为了防止分母为零,这里加一个小的常数epsilon
epsilon = 1e-6
class_weights = 1.0 / (class_counts + epsilon)# 归一化权重,使其和为1
class_weights /= class_weights.sum()print('Class Counts:', class_counts)
print('Class Weights:', class_weights)# 创建带权重的交叉熵损失函数
criterion = nn.CrossEntropyLoss(weight=class_weights)# 计算损失值
loss = criterion(logits, labels)print('Logits:\n', logits)
print('Labels:\n', labels)
print('Weighted Cross-Entropy Loss:', loss.item())# 反向传播梯度
loss.backward()
报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确
参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226
[2] https://blog.csdn.net/qq_40306845/article/details/137651442
[3] https://www.zhihu.com/question/400443029/answer/2477658229
相关文章:
pytorch 加权CE_loss实现(语义分割中的类不平衡使用)
加权CE_loss和BCE_loss稍有不同 1.标签为long类型,BCE标签为float类型 2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。 参数配置torch.nn.CrossEntropyLoss(weightNone,…...
【iOS】UI——关于UIAlertController类(警告对话框)
目录 前言关于UIAlertController具体操作及代码实现总结 前言 在UI的警告对话框的学习中,我们发现UIAlertView在iOS 9中已经被废弃,我们找到UIAlertController来代替UIAlertView实现弹出框的功能,从而有了这篇关于UIAlertController的学习笔记…...
django支持https
测试环境,可以用django自带的证书 安装模块 sudo pip3 install django_sslserver服务端https启动 python3 manage.py runsslserver 127.0.0.1:8001https访问 https://127.0.0.1:8001/quota/api/XXX...
算法题day41(补5.27日卡:动态规划01)
一、动态规划基础知识:在动态规划中每一个状态一定是由上一个状态推导出来的。 动态规划五部曲: 1.确定dp数组 以及下标的含义 2.确定递推公式 3.dp数组如何初始化 4.确定遍历顺序 5.举例推导dp数组 debug方式:打印 二、刷题…...
【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo
系列文章目录 【附带源码】机械臂MoveIt2极简教程(一)、moveit2安装 【附带源码】机械臂MoveIt2极简教程(二)、move_group交互 【附带源码】机械臂MoveIt2极简教程(三)、URDF/SRDF介绍 【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo 目录 系列文章目录1. 创…...
基于蚁群算法的二维路径规划算法(matlab)
微♥关注“电击小子程高兴的MATLAB小屋”获得资料 一、理论基础 1、路径规划算法 路径规划算法是指在有障碍物的工作环境中寻找一条从起点到终点、无碰撞地绕过所有障碍物的运动路径。路径规划算法较多,大体上可分为全局路径规划算法和局部路径规划算法两大类。其…...
政务云参考技术架构
行业优势 总体架构 政务云平台技术框架图,由机房环境、基础设施层、支撑软件层及业务应用层组成,在运维、安全和运营体系的保障下,为政务云使用单位提供统一服务支撑。 功能架构 标准双区隔离 参照国家电子政务规范,打造符合标准的…...
android 13 aosp 预置so库
展讯对应的main.mk配置 device/sprd/qogirn**/ums***/product/***_native/main.mk $(call inherit-product-if-exists, vendor/***/build.mk)vendor/***/build.mk PRODUCT_PACKAGES \libtestvendor///Android.bp cc_prebuilt_library_shared{name:"libtest",srcs:…...
mongo篇---mongoDB Compass连接数据库
mongo篇—mongoDB Compass连接数据库 mongoDB笔记 – 第一条 一、mongoDB Compass连接远程数据库,配置URL。 URL: mongodb://username:passwordhost:port点击connect即可。 注意:host最好使用名称,防止出错连接超时。...
基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真,输出收敛曲线以及三维曲面最高点搜索结果。 2.测试软件版本以及运行结果展示 MATLAB2022A版本…...
前端js解析websocket推送的gzip压缩json的Blob数据
主要依赖插件pako https://www.npmjs.com/package/pako 1、安装 npm install pako 2、使用, pako.inflate(reader.result, {to: "string"}) 解压后的string 对象,需要JSON.parse转成json this.ws.onmessage (evt) > {console.log("…...
【wiki知识库】06.文档管理接口的实现--SpringBoot后端部分
目录 一、🔥今日目标 二、🎈SpringBoot部分类的添加 1.调用MybatisGenerator 2.添加DocSaveParam 3.添加DocQueryVo 三、🚆后端新增接口 3.1添加DocController 3.1.1 /all/{ebokId} 3.1.2 /doc/save 3.1.3 /doc/delete/{idStr} …...
c,c++,go语言字符串的演进
#include <stdio.h> #include <string.h> int main() {char str[] {a,b,c,\0,d,d,d};printf("string:[%s], len:%d \n", str, strlen(str) );return 0; } string:[abc], len:3 c语言只有数组的概念,数组本身没有长度的概念,需…...
vue-cli 快速入门
vue-cli (目前向Vite发展) 介绍:Vue-cli 是Vue官方提供一个脚手架,用于快速生成一个Vue的项目模板。 Vue-cli提供了如下功能: 统一的目录结构 本地调试 热部署 单元测试 集成打包上线 依赖环境:NodeJ…...
机器人--矩阵运算
两个矩阵相乘的含义 P点在坐标系B中的坐标系PB,需要乘以B到A到变换矩阵TAB。 M点在B坐标系中的位姿MB,怎么计算M在A中的坐标系? 两个矩阵相乘 一个矩阵*另一个矩阵的逆矩阵...
期末复习【汇总】
期末复习【汇总】 前言版权推荐期末复习【汇总】最后 前言 2024-5-12 20:52:17 截止到今天,所有期末复习的汇总 以下内容源自《【创作模板】》 仅供学习交流使用 版权 禁止其他平台发布时删除以下此话 本文首次发布于CSDN平台 作者是CSDN日星月云 博客主页是ht…...
【IM即时通讯】MQTT协议的详解(3)- CONNACK Packet
【IM即时通讯】MQTT协议的详解(3)- CONNACK Packet 文章目录 【IM即时通讯】MQTT协议的详解(3)- CONNACK Packet前言说明一、固定同步详解、可变头部详解总结 前言 关于所有的类型的数据示例已经在上面一篇博客说完: …...
Linux - 深入理解/proc虚拟文件系统:从基础到高级
文章目录 Linux /proc虚拟文件系统/proc/self使用 /proc/self 的优势/proc/self 的使用案例案例1:获取当前进程的状态信息案例2:获取当前进程的命令行参数案例3:获取当前进程的内存映射案例4:获取当前进程的文件描述符 /proc中进程…...
Django DeleteView视图
Django 的 DeleteView 是一个基于类的视图,用于处理对象的删除操作。 1,添加视图函数 Test/app3/views.py from django.shortcuts import render# Create your views here. from .models import Bookfrom django.views.generic import ListView class B…...
代码杂谈 之 pyspark如何做相似度计算
在 PySpark 中,计算 DataFrame 两列向量的差可以通过使用 UDF(用户自定义函数)和 Vector 类型完成。这里有一个示例,展示了如何使用 PySpark 的 pyspark.ml.linalg.Vectorspyspark.sql.functions.udf 来实现这一功能:…...
如何用LDBlockShow高效绘制连锁不平衡热图:从入门到精通的完整指南
如何用LDBlockShow高效绘制连锁不平衡热图:从入门到精通的完整指南 【免费下载链接】LDBlockShow LDBlockShow: a fast and convenient tool for visualizing linkage disequilibrium and haplotype blocks based on VCF files 项目地址: https://gitcode.com/gh_…...
小波散射网络:从理论优势到小样本图像分类实践
1. 小波散射网络为什么值得关注 第一次听说小波散射网络时,我和大多数搞机器学习的朋友反应一样:"这玩意儿和普通卷积神经网络(CNN)有什么区别?"直到去年接手一个医疗影像项目,手头只有200张标注…...
长期使用Token Plan套餐在Taotoken平台带来的月度成本控制体验
🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 长期使用Token Plan套餐在Taotoken平台带来的月度成本控制体验 对于个人开发者或小型团队而言,在探索和集成大模型能力…...
3个步骤让AMD显卡也能运行CUDA程序:ZLUDA终极指南
3个步骤让AMD显卡也能运行CUDA程序:ZLUDA终极指南 【免费下载链接】ZLUDA CUDA on non-NVIDIA GPUs 项目地址: https://gitcode.com/GitHub_Trending/zl/ZLUDA 你是否曾经因为手头只有AMD显卡,却想运行那些需要CUDA加速的深度学习框架而感到无奈&…...
Wi-Fi卸载技术解析:从运营商策略到用户体验的深度实践
1. 项目概述:当“大哥”开始管理你的Wi-Fi十年前,一篇发表在EE Times上的文章提出了一个在今天看来依然尖锐的问题:智能手机用户使用Wi-Fi是件好事吗?这甚至上升到了“人权”层面——每个有手机的人是否都应该有权访问Wi-Fi&#…...
Latte文本到视频生成实战:打造个性化AI视频的终极指南
Latte文本到视频生成实战:打造个性化AI视频的终极指南 【免费下载链接】Latte [TMLR 2025] Latte: Latent Diffusion Transformer for Video Generation. 项目地址: https://gitcode.com/gh_mirrors/la/Latte Latte是一款基于TMLR 2025研究成果的文本到视频…...
AI提示词工程:用Claude+Cursor构建高效创意工作流
1. 项目概述:当创意遇上AI,一个提示词库如何改变工作流如果你是一位创意工作者——无论是设计师、插画师、文案策划还是视频创作者,最近几个月,你的工作流里可能多了一个新伙伴:Claude。这个由Anthropic推出的AI助手&a…...
【实战指南】Ubuntu SSH服务配置与XShell/Xftp高效连接全解析
1. 为什么需要SSH远程连接Ubuntu? 作为开发者或运维人员,我们经常需要管理远程服务器。想象一下,你正在咖啡馆用Windows笔记本,突然需要紧急修改线上Ubuntu服务器的配置——这时候SSH就是你的救命稻草。它就像一把安全钥匙&#x…...
智能网联单轨捷运编组协同控制【附仿真】
✨ 长期致力于跨座式单轨车辆、单轨捷运系统、智能编组运行、协同避撞、协同控制研究工作,擅长数据搜集与处理、建模仿真、程序编写、仿真设计。 ✅ 专业定制毕设、代码 ✅ 如需沟通交流,点击《获取方式》 (1)融合车距与速度的多层…...
使用Helm Chart在Kubernetes部署高可用authentik身份认证中心
1. 项目概述:为什么我们需要一个身份认证的“中央厨房”?在云原生和微服务架构大行其道的今天,一个典型的应用系统可能由几十甚至上百个独立的服务组成。每个服务都需要处理用户登录、权限验证、单点登录(SSO)这些基础…...
