当前位置: 首页 > news >正文

分类网络-类别不均衡问题之FocalLoss

在这里插入图片描述
有训练和测代码如下:(完整代码来自CNN从搭建到部署实战)
train.py

import torch
import torchvision
import time
import argparse
import importlib
from loss import FocalLossdef parse_args():parser = argparse.ArgumentParser('training')parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')return parser.parse_args()if __name__ == '__main__':args = parse_args()batch_size = args.batch_sizenum_epochs = args.num_epochsmodel = importlib.import_module('models.'+args.model) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = model.net.to(device)loss = torch.nn.CrossEntropyLoss()if args.model == 'mlp':optimizer = torch.optim.SGD(net.parameters(), lr=0.5)else:optimizer = torch.optim.Adam(net.parameters(), lr=0.001)train_path = r'./Datasets/mnist_png/training'test_path = r'./Datasets/mnist_png/testing'transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]if args.model == 'alexnet' or args.model == 'vgg':transform_list.append(torchvision.transforms.Resize(size=224))if args.model == 'googlenet' or args.model == 'resnet':transform_list.append(torchvision.transforms.Resize(size=96))transform = torchvision.transforms.Compose(transform_list)train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)for epoch in range(num_epochs):train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()for X, y in train_iter:X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l += l.cpu().item()train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()m += y.shape[0]batch_count += 1with torch.no_grad():for X, y in test_iter:net.eval() # 评估模式test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train() # 改回训练模式n += y.shape[0]print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))torch.save(net, args.model+".pth")

test.py

import cv2
import torch
import argparse
import importlib
from pathlib import Path
import torchvision.transforms.functionaldef parse_args():parser = argparse.ArgumentParser('testing')parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')return parser.parse_args()if __name__ == '__main__':args = parse_args()model = importlib.import_module('models.' + args.model) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = model.net.to(device)net = torch.load(args.model+'.pth')net.eval()with torch.no_grad():imgs_path = Path(r"./Datasets/mnist_png/testing/6/").glob("*")acc = 0count = 0for img_path in imgs_path:img = cv2.imread(str(img_path), 0)if args.model == 'alexnet' or args.model == 'vgg':  img = cv2.resize(img, (224,224))if args.model == 'googlenet' or args.model == 'resnet':img = cv2.resize(img, (96,96))img_tensor = torchvision.transforms.functional.to_tensor(img)img_tensor = torch.unsqueeze(img_tensor, 0)#print(net(img_tensor.to(device)).argmax(dim=1).item())if(net(img_tensor.to(device)).argmax(dim=1).item()==6):acc += 1count+=1print(acc/count)

数据集为mnist手写数字识别,其中训练集中数字0~9的数量分别为:0(5923张),1(6472张),2(5985张),3(6131张),4(5842张),5(5421张),6(5918张),7(6265张),8(5851张),9(5949张), 测试集中数字0~9的数量分别为:0(980张),1(1135张),2(1032张),3(1010张),4(982张),5(892张),6(958张),7(1028张),8(974张),9(1009张)。可见各个类别的数量基本上平衡。测试代码仅测试数字6的准确率,因为后面我们要改变训练集中数字6的数量来进行对比。为了节省时间,仅训练5个epoch。
训练结果:

epoch 0, loss 1.443379, train acc 0.529, test acc 0.877, time 23.4s
epoch 1, loss 0.314123, train acc 0.913, test acc 0.939, time 22.1s
epoch 2, loss 0.174050, train acc 0.949, test acc 0.960, time 21.9s
epoch 3, loss 0.122714, train acc 0.963, test acc 0.971, time 21.8s
epoch 4, loss 0.096798, train acc 0.971, test acc 0.975, time 21.8s

测试结果:

0.9780793319415448

现在将训练集中数字6的数量减少到59张(原来的1/100),来模拟某个类别的数据不平衡的情况。
训练结果:

epoch 0, loss 2.200247, train acc 0.131, test acc 0.373, time 20.8s
epoch 1, loss 0.579792, train acc 0.840, test acc 0.855, time 20.5s
epoch 2, loss 0.177890, train acc 0.950, test acc 0.872, time 20.3s
epoch 3, loss 0.128251, train acc 0.963, test acc 0.880, time 20.5s
epoch 4, loss 0.103937, train acc 0.969, test acc 0.888, time 20.7s

测试结果:

0.04801670146137787

可以看到,训练的准确率下降9%,而测试集直接下降了93%惨不忍睹。

引入FocalLoss模块:(参考https://github.com/QunBB/DeepLearning/blob/main/trick/unbalance/loss_pt.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Unionclass FocalLoss(nn.Module):def __init__(self, alpha: Union[List[float], float], gamma: Optional[int] = 2, with_logits: Optional[bool] = True):""":param alpha: 每个类别的权重:param gamma::param with_logits: 是否经过softmax或者sigmoid"""super(FocalLoss, self).__init__()self.gamma = gammaself.alpha = torch.FloatTensor([alpha]) if isinstance(alpha, float) else torch.FloatTensor(alpha)self.smooth = 1e-8self.with_logits = with_logitsdef _binary_class(self, input, target):prob = torch.sigmoid(input) if self.with_logits else inputprob += self.smoothalpha = self.alpha.to(target.device)loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob)return lossdef _multiple_class(self, input, target):prob = F.softmax(input, dim=1) if self.with_logits else inputalpha = self.alpha.to(target.device)alpha = alpha.gather(0, target)target = target.view(-1, 1)prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nanlogpt = torch.log(prob)loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * logptreturn lossdef forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:""":param input: 维度为[bs, num_classes]:param target: 维度为[bs]:return:"""if len(input.shape) > 1 and input.shape[-1] != 1:loss = self._multiple_class(input, target)else:loss = self._binary_class(input, target)return loss.mean()

并将train.py的第26行修改成

    loss = FocalLoss([1, 1, 1, 1, 1, 1, 100, 1, 1, 1])

其中列表的数字代表10个类别的权重值。
训练结果:

epoch 0, loss 2.045273, train acc 0.137, test acc 0.467, time 20.7s
epoch 1, loss 0.510476, train acc 0.810, test acc 0.907, time 21.3s
epoch 2, loss 0.148246, train acc 0.922, test acc 0.941, time 21.1s
epoch 3, loss 0.099026, train acc 0.944, test acc 0.953, time 21.2s
epoch 4, loss 0.075481, train acc 0.954, test acc 0.959, time 21.3s

测试结果:

0.9196242171189979

对比看出,FocalLoss可以有效缓解类别不均衡问题(当然并不能完全消除,有足够平衡的高质量数据集肯定更好啦~)。

相关文章:

分类网络-类别不均衡问题之FocalLoss

有训练和测代码如下:(完整代码来自CNN从搭建到部署实战) train.py import torch import torchvision import time import argparse import importlib from loss import FocalLossdef parse_args():parser argparse.ArgumentParser(training)parser.add_argument(-…...

记录一下ComboBox在listview中的问题,后面再解决。

在listview的ComboBox,ViewModel类得不到ComboBox的 SelectedModeIndex 和 SelectionChanged事件。 问题描述: 1. 在listview中有ComboBox 2. 数据源类 InspectionInfo ,其中有ComboBox的绑定数据源 ModelList,代码如下&#…...

手写一个PrattParser基本运算解析器1: 编译原理概述

点击查看 基于Swift的PrattParser项目 编译原理概述 编译原理是我们每一个程序猿必须要了解的技能, 编译原理实际上并没有啥高深的技术, 我们如果在做业务开发, 也很少会用到编译开发的知识, 但是编译原理又是我们必备的基础知识之一. 所以我们需要对编译原理的内容有一个大概的…...

ZKP3.2 Programming ZKPs (Arkworks Zokrates)

ZKP学习笔记 ZK-Learning MOOC课程笔记 Lecture 3: Programming ZKPs (Guest Lecturers: Pratyush Mishra and Alex Ozdemir) 3.3 Using a library ( tutorial) R1CS Libraries A library in a host language (Eg: Rust, OCaml, C, Go, …)Key type: constraint system Mai…...

mysqld: File ‘./binlog.index‘ not found (OS errno 13 - Permission denied) 问题解决

问题背景 Centos7 安装Mysql 8后启动时遇到的问题,看了好几个博客方案无效,搞了半小时才找到正解,在此次进行记录。 在此假设你已经修改了对应目录的权限,比如配置的mysql data目录初始化后已经执行了chown -R mysql:mysql /XXX/…...

Python 环境构建最佳实践:Mamba + Conda + PIP

此前,我们单独介绍过 PIP 和 Conda,在后续的实际应用中,还是遇到了不少 Python 环境构建的问题,特别是在 Windows 系统上,最突出的表现是:虽然PIP的包依赖解析和下载都很快,但在 Windows 上经常会因为缺失底层依赖的程序库(例如某些dll文件)而导致 Python 程序启动时报…...

华为OD 最多团队(100分)【java】A卷+B卷

华为OD统一考试A卷+B卷 新题库说明 你收到的链接上面会标注A卷还是B卷。目前大部分收到的都是B卷。 B卷对应20022部分考题以及新出的题目,A卷对应的是新出的题目。 我将持续更新最新题目 获取更多免费题目可前往夸克网盘下载,请点击以下链接进入: 我用夸克网盘分享了「华为O…...

2023“龙芯杯”信创攻防赛 | 赛宁网安技术支持

2023年10月19日,为深入贯彻国家网络强国战略思想,宣传国家网络安全顶层设计,落实《网络安全法》《数据安全法》等法律法规。由大学生网络安全尖锋训练营主办,龙芯中科技术股份有限公司承办,山石网科通信技术股份有限公…...

代码随想录算法训练营第五十八天| 583. 两个字符串的删除操作 72. 编辑距离

今日学习的文章链接和视频链接 两个字符串的删除操作 https://programmercarl.com/0583.%E4%B8%A4%E4%B8%AA%E5%AD%97%E7%AC%A6%E4%B8%B2%E7%9A%84%E5%88%A0%E9%99%A4%E6%93%8D%E4%BD%9C.html 编辑距离 https://programmercarl.com/0072.%E7%BC%96%E8%BE%91%E8%B7%9D%E7%A6%BB…...

leetcode做题笔记191. 位1的个数

编写一个函数,输入是一个无符号整数(以二进制串的形式),返回其二进制表达式中数字位数为 1 的个数(也被称为汉明重量)。 提示: 请注意,在某些语言(如 Java)中…...

Git基本命令和使用

文章目录 1、Git本地库命令1.1、初始化本地库1.2、设置用户签名1.3、查看本地库状态1.4、将工作区的修改添加到暂存区1.5、将暂存区的修改提交到本地库1.6、历史版本1.7、取消commit1.8、取消暂存文件 2、分支操作2.1、查看分支2.2、创建分支2.3、分支合并时产生冲突 3、Gitee远…...

50springboot私人健身与教练预约管理系统

大家好✌!我是CZ淡陌。一名专注以理论为基础实战为主的技术博主,将再这里为大家分享优质的实战项目,本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目,希望你能有所收获,少走一些弯路…...

测试Android webview 加载本地html

最近开发一个需要未联网功能的App, 不熟悉使用Java原生开发界面,于是想使用本地H5做界面,本文测试了使用本地html加载远程数据。直接上代码: MainActivity.java package com.alex.webviewlocal;import androidx.appcompat.app.AppCompatAct…...

ubuntu安装pgsql

ubuntu安装postgresSQL 官网地址: https://www.postgresql.org/download/ 1.安装 # 添加源 sudo sh -c echo "deb https://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list # 安装数字签名 w…...

利用ArcGIS获取每一个冰川的中心位置经纬度坐标:要素转点和要素折点转点的区别

问题概述:下图是天山地区的冰川的分布,我们可以看到每一条冰川是一个面要素,要求得到每一个冰川(面要素)的中心经纬度坐标。 1.采用要素转点功能 选择工具箱的【数据管理工具】-【要素】-【要素转点】。完成之后再采用…...

数据结构中的七大排序(Java实现)

目录 一、直接插入排序 二、希尔排序 三、直接选择排序 四、堆排序 五、冒泡排序 六、快速排序 七、归并排序 一、直接插入排序 思想: 定义i下标之前的元素全部已经有序,遍历一遍要排序的数组,把i下标前的元素全部进行排序&#xff0…...

深度学习基础算法

算法 1.K近邻算法 机器学习--K-近邻算法(KNN)_k近邻-CSDN博客 2. 数据库样本: CIFAR-10 CIFAR-10数据集(介绍、下载读取、可视化显示、另存为图片)_cifar10数据集-CSDN博客...

LuatOS-SOC接口文档(air780E)-- ir - 红外遥控

ir.sendNEC(pin, addr, cmd, repeat, disablePWM)# 发送NEC数据 参数 传入值类型 解释 int 使用的GPIO引脚编号 int 用户码(大于0xff则采用Extended NEC模式) int 数据码 int 可选,引导码发送次数(110ms一次&#xff0…...

Java虚拟机常见面试题总结

梳理Java虚拟机相关的面试题,主要参考《深入理解Java虚拟机 JVM高级特性与最佳实践》(第2版, 周志明 著)一书,其余部分整合网络相关内容。注意,关于Java并发编程的面试题因为内容较多,单独整理。Java基础相关的面试题可以参考Java…...

NVIDIA NCCL 源码学习(十一)- ring allreduce

之前的章节里我们看到了nccl send/recv通信的过程,本节我们以ring allreduce为例看下集合通信的过程。整体执行流程和send/recv很像,所以对于相似的流程只做简单介绍,主要介绍ring allreduce自己特有内容。 单机 搜索ring 在nccl初始化的过…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装(Encapsulation) 定义:将数据(属性)和操作数据的方法绑定在一起,通过访问控制符(private、protected、public)隐藏内部实现细节。示例: public …...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

FFmpeg 低延迟同屏方案

引言 在实时互动需求激增的当下,无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作,还是游戏直播的画面实时传输,低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架,凭借其灵活的编解码、数据…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

服务器硬防的应用场景都有哪些?

服务器硬防是指一种通过硬件设备层面的安全措施来防御服务器系统受到网络攻击的方式,避免服务器受到各种恶意攻击和网络威胁,那么,服务器硬防通常都会应用在哪些场景当中呢? 硬防服务器中一般会配备入侵检测系统和预防系统&#x…...

定时器任务——若依源码分析

分析util包下面的工具类schedule utils: ScheduleUtils 是若依中用于与 Quartz 框架交互的工具类,封装了定时任务的 创建、更新、暂停、删除等核心逻辑。 createScheduleJob createScheduleJob 用于将任务注册到 Quartz,先构建任务的 JobD…...

如何为服务器生成TLS证书

TLS(Transport Layer Security)证书是确保网络通信安全的重要手段,它通过加密技术保护传输的数据不被窃听和篡改。在服务器上配置TLS证书,可以使用户通过HTTPS协议安全地访问您的网站。本文将详细介绍如何在服务器上生成一个TLS证…...

css的定位(position)详解:相对定位 绝对定位 固定定位

在 CSS 中,元素的定位通过 position 属性控制,共有 5 种定位模式:static(静态定位)、relative(相对定位)、absolute(绝对定位)、fixed(固定定位)和…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...

Unity | AmplifyShaderEditor插件基础(第七集:平面波动shader)

目录 一、👋🏻前言 二、😈sinx波动的基本原理 三、😈波动起来 1.sinx节点介绍 2.vertexPosition 3.集成Vector3 a.节点Append b.连起来 4.波动起来 a.波动的原理 b.时间节点 c.sinx的处理 四、🌊波动优化…...