当前位置: 首页 > news >正文

分类网络-类别不均衡问题之FocalLoss

在这里插入图片描述
有训练和测代码如下:(完整代码来自CNN从搭建到部署实战)
train.py

import torch
import torchvision
import time
import argparse
import importlib
from loss import FocalLossdef parse_args():parser = argparse.ArgumentParser('training')parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')return parser.parse_args()if __name__ == '__main__':args = parse_args()batch_size = args.batch_sizenum_epochs = args.num_epochsmodel = importlib.import_module('models.'+args.model) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = model.net.to(device)loss = torch.nn.CrossEntropyLoss()if args.model == 'mlp':optimizer = torch.optim.SGD(net.parameters(), lr=0.5)else:optimizer = torch.optim.Adam(net.parameters(), lr=0.001)train_path = r'./Datasets/mnist_png/training'test_path = r'./Datasets/mnist_png/testing'transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]if args.model == 'alexnet' or args.model == 'vgg':transform_list.append(torchvision.transforms.Resize(size=224))if args.model == 'googlenet' or args.model == 'resnet':transform_list.append(torchvision.transforms.Resize(size=96))transform = torchvision.transforms.Compose(transform_list)train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)for epoch in range(num_epochs):train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()for X, y in train_iter:X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l += l.cpu().item()train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()m += y.shape[0]batch_count += 1with torch.no_grad():for X, y in test_iter:net.eval() # 评估模式test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train() # 改回训练模式n += y.shape[0]print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))torch.save(net, args.model+".pth")

test.py

import cv2
import torch
import argparse
import importlib
from pathlib import Path
import torchvision.transforms.functionaldef parse_args():parser = argparse.ArgumentParser('testing')parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')return parser.parse_args()if __name__ == '__main__':args = parse_args()model = importlib.import_module('models.' + args.model) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = model.net.to(device)net = torch.load(args.model+'.pth')net.eval()with torch.no_grad():imgs_path = Path(r"./Datasets/mnist_png/testing/6/").glob("*")acc = 0count = 0for img_path in imgs_path:img = cv2.imread(str(img_path), 0)if args.model == 'alexnet' or args.model == 'vgg':  img = cv2.resize(img, (224,224))if args.model == 'googlenet' or args.model == 'resnet':img = cv2.resize(img, (96,96))img_tensor = torchvision.transforms.functional.to_tensor(img)img_tensor = torch.unsqueeze(img_tensor, 0)#print(net(img_tensor.to(device)).argmax(dim=1).item())if(net(img_tensor.to(device)).argmax(dim=1).item()==6):acc += 1count+=1print(acc/count)

数据集为mnist手写数字识别,其中训练集中数字0~9的数量分别为:0(5923张),1(6472张),2(5985张),3(6131张),4(5842张),5(5421张),6(5918张),7(6265张),8(5851张),9(5949张), 测试集中数字0~9的数量分别为:0(980张),1(1135张),2(1032张),3(1010张),4(982张),5(892张),6(958张),7(1028张),8(974张),9(1009张)。可见各个类别的数量基本上平衡。测试代码仅测试数字6的准确率,因为后面我们要改变训练集中数字6的数量来进行对比。为了节省时间,仅训练5个epoch。
训练结果:

epoch 0, loss 1.443379, train acc 0.529, test acc 0.877, time 23.4s
epoch 1, loss 0.314123, train acc 0.913, test acc 0.939, time 22.1s
epoch 2, loss 0.174050, train acc 0.949, test acc 0.960, time 21.9s
epoch 3, loss 0.122714, train acc 0.963, test acc 0.971, time 21.8s
epoch 4, loss 0.096798, train acc 0.971, test acc 0.975, time 21.8s

测试结果:

0.9780793319415448

现在将训练集中数字6的数量减少到59张(原来的1/100),来模拟某个类别的数据不平衡的情况。
训练结果:

epoch 0, loss 2.200247, train acc 0.131, test acc 0.373, time 20.8s
epoch 1, loss 0.579792, train acc 0.840, test acc 0.855, time 20.5s
epoch 2, loss 0.177890, train acc 0.950, test acc 0.872, time 20.3s
epoch 3, loss 0.128251, train acc 0.963, test acc 0.880, time 20.5s
epoch 4, loss 0.103937, train acc 0.969, test acc 0.888, time 20.7s

测试结果:

0.04801670146137787

可以看到,训练的准确率下降9%,而测试集直接下降了93%惨不忍睹。

引入FocalLoss模块:(参考https://github.com/QunBB/DeepLearning/blob/main/trick/unbalance/loss_pt.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Unionclass FocalLoss(nn.Module):def __init__(self, alpha: Union[List[float], float], gamma: Optional[int] = 2, with_logits: Optional[bool] = True):""":param alpha: 每个类别的权重:param gamma::param with_logits: 是否经过softmax或者sigmoid"""super(FocalLoss, self).__init__()self.gamma = gammaself.alpha = torch.FloatTensor([alpha]) if isinstance(alpha, float) else torch.FloatTensor(alpha)self.smooth = 1e-8self.with_logits = with_logitsdef _binary_class(self, input, target):prob = torch.sigmoid(input) if self.with_logits else inputprob += self.smoothalpha = self.alpha.to(target.device)loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob)return lossdef _multiple_class(self, input, target):prob = F.softmax(input, dim=1) if self.with_logits else inputalpha = self.alpha.to(target.device)alpha = alpha.gather(0, target)target = target.view(-1, 1)prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nanlogpt = torch.log(prob)loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * logptreturn lossdef forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:""":param input: 维度为[bs, num_classes]:param target: 维度为[bs]:return:"""if len(input.shape) > 1 and input.shape[-1] != 1:loss = self._multiple_class(input, target)else:loss = self._binary_class(input, target)return loss.mean()

并将train.py的第26行修改成

    loss = FocalLoss([1, 1, 1, 1, 1, 1, 100, 1, 1, 1])

其中列表的数字代表10个类别的权重值。
训练结果:

epoch 0, loss 2.045273, train acc 0.137, test acc 0.467, time 20.7s
epoch 1, loss 0.510476, train acc 0.810, test acc 0.907, time 21.3s
epoch 2, loss 0.148246, train acc 0.922, test acc 0.941, time 21.1s
epoch 3, loss 0.099026, train acc 0.944, test acc 0.953, time 21.2s
epoch 4, loss 0.075481, train acc 0.954, test acc 0.959, time 21.3s

测试结果:

0.9196242171189979

对比看出,FocalLoss可以有效缓解类别不均衡问题(当然并不能完全消除,有足够平衡的高质量数据集肯定更好啦~)。

相关文章:

分类网络-类别不均衡问题之FocalLoss

有训练和测代码如下:(完整代码来自CNN从搭建到部署实战) train.py import torch import torchvision import time import argparse import importlib from loss import FocalLossdef parse_args():parser argparse.ArgumentParser(training)parser.add_argument(-…...

记录一下ComboBox在listview中的问题,后面再解决。

在listview的ComboBox,ViewModel类得不到ComboBox的 SelectedModeIndex 和 SelectionChanged事件。 问题描述: 1. 在listview中有ComboBox 2. 数据源类 InspectionInfo ,其中有ComboBox的绑定数据源 ModelList,代码如下&#…...

手写一个PrattParser基本运算解析器1: 编译原理概述

点击查看 基于Swift的PrattParser项目 编译原理概述 编译原理是我们每一个程序猿必须要了解的技能, 编译原理实际上并没有啥高深的技术, 我们如果在做业务开发, 也很少会用到编译开发的知识, 但是编译原理又是我们必备的基础知识之一. 所以我们需要对编译原理的内容有一个大概的…...

ZKP3.2 Programming ZKPs (Arkworks Zokrates)

ZKP学习笔记 ZK-Learning MOOC课程笔记 Lecture 3: Programming ZKPs (Guest Lecturers: Pratyush Mishra and Alex Ozdemir) 3.3 Using a library ( tutorial) R1CS Libraries A library in a host language (Eg: Rust, OCaml, C, Go, …)Key type: constraint system Mai…...

mysqld: File ‘./binlog.index‘ not found (OS errno 13 - Permission denied) 问题解决

问题背景 Centos7 安装Mysql 8后启动时遇到的问题,看了好几个博客方案无效,搞了半小时才找到正解,在此次进行记录。 在此假设你已经修改了对应目录的权限,比如配置的mysql data目录初始化后已经执行了chown -R mysql:mysql /XXX/…...

Python 环境构建最佳实践:Mamba + Conda + PIP

此前,我们单独介绍过 PIP 和 Conda,在后续的实际应用中,还是遇到了不少 Python 环境构建的问题,特别是在 Windows 系统上,最突出的表现是:虽然PIP的包依赖解析和下载都很快,但在 Windows 上经常会因为缺失底层依赖的程序库(例如某些dll文件)而导致 Python 程序启动时报…...

华为OD 最多团队(100分)【java】A卷+B卷

华为OD统一考试A卷+B卷 新题库说明 你收到的链接上面会标注A卷还是B卷。目前大部分收到的都是B卷。 B卷对应20022部分考题以及新出的题目,A卷对应的是新出的题目。 我将持续更新最新题目 获取更多免费题目可前往夸克网盘下载,请点击以下链接进入: 我用夸克网盘分享了「华为O…...

2023“龙芯杯”信创攻防赛 | 赛宁网安技术支持

2023年10月19日,为深入贯彻国家网络强国战略思想,宣传国家网络安全顶层设计,落实《网络安全法》《数据安全法》等法律法规。由大学生网络安全尖锋训练营主办,龙芯中科技术股份有限公司承办,山石网科通信技术股份有限公…...

代码随想录算法训练营第五十八天| 583. 两个字符串的删除操作 72. 编辑距离

今日学习的文章链接和视频链接 两个字符串的删除操作 https://programmercarl.com/0583.%E4%B8%A4%E4%B8%AA%E5%AD%97%E7%AC%A6%E4%B8%B2%E7%9A%84%E5%88%A0%E9%99%A4%E6%93%8D%E4%BD%9C.html 编辑距离 https://programmercarl.com/0072.%E7%BC%96%E8%BE%91%E8%B7%9D%E7%A6%BB…...

leetcode做题笔记191. 位1的个数

编写一个函数,输入是一个无符号整数(以二进制串的形式),返回其二进制表达式中数字位数为 1 的个数(也被称为汉明重量)。 提示: 请注意,在某些语言(如 Java)中…...

Git基本命令和使用

文章目录 1、Git本地库命令1.1、初始化本地库1.2、设置用户签名1.3、查看本地库状态1.4、将工作区的修改添加到暂存区1.5、将暂存区的修改提交到本地库1.6、历史版本1.7、取消commit1.8、取消暂存文件 2、分支操作2.1、查看分支2.2、创建分支2.3、分支合并时产生冲突 3、Gitee远…...

50springboot私人健身与教练预约管理系统

大家好✌!我是CZ淡陌。一名专注以理论为基础实战为主的技术博主,将再这里为大家分享优质的实战项目,本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目,希望你能有所收获,少走一些弯路…...

测试Android webview 加载本地html

最近开发一个需要未联网功能的App, 不熟悉使用Java原生开发界面,于是想使用本地H5做界面,本文测试了使用本地html加载远程数据。直接上代码: MainActivity.java package com.alex.webviewlocal;import androidx.appcompat.app.AppCompatAct…...

ubuntu安装pgsql

ubuntu安装postgresSQL 官网地址: https://www.postgresql.org/download/ 1.安装 # 添加源 sudo sh -c echo "deb https://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list # 安装数字签名 w…...

利用ArcGIS获取每一个冰川的中心位置经纬度坐标:要素转点和要素折点转点的区别

问题概述:下图是天山地区的冰川的分布,我们可以看到每一条冰川是一个面要素,要求得到每一个冰川(面要素)的中心经纬度坐标。 1.采用要素转点功能 选择工具箱的【数据管理工具】-【要素】-【要素转点】。完成之后再采用…...

数据结构中的七大排序(Java实现)

目录 一、直接插入排序 二、希尔排序 三、直接选择排序 四、堆排序 五、冒泡排序 六、快速排序 七、归并排序 一、直接插入排序 思想: 定义i下标之前的元素全部已经有序,遍历一遍要排序的数组,把i下标前的元素全部进行排序&#xff0…...

深度学习基础算法

算法 1.K近邻算法 机器学习--K-近邻算法(KNN)_k近邻-CSDN博客 2. 数据库样本: CIFAR-10 CIFAR-10数据集(介绍、下载读取、可视化显示、另存为图片)_cifar10数据集-CSDN博客...

LuatOS-SOC接口文档(air780E)-- ir - 红外遥控

ir.sendNEC(pin, addr, cmd, repeat, disablePWM)# 发送NEC数据 参数 传入值类型 解释 int 使用的GPIO引脚编号 int 用户码(大于0xff则采用Extended NEC模式) int 数据码 int 可选,引导码发送次数(110ms一次&#xff0…...

Java虚拟机常见面试题总结

梳理Java虚拟机相关的面试题,主要参考《深入理解Java虚拟机 JVM高级特性与最佳实践》(第2版, 周志明 著)一书,其余部分整合网络相关内容。注意,关于Java并发编程的面试题因为内容较多,单独整理。Java基础相关的面试题可以参考Java…...

NVIDIA NCCL 源码学习(十一)- ring allreduce

之前的章节里我们看到了nccl send/recv通信的过程,本节我们以ring allreduce为例看下集合通信的过程。整体执行流程和send/recv很像,所以对于相似的流程只做简单介绍,主要介绍ring allreduce自己特有内容。 单机 搜索ring 在nccl初始化的过…...

C# SqlSugar:依赖注入与仓储模式实践

C# SqlSugar:依赖注入与仓储模式实践 在 C# 的应用开发中,数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护,许多开发者会选择成熟的 ORM(对象关系映射)框架,SqlSugar 就是其中备受…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中,UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

C++.OpenGL (14/64)多光源(Multiple Lights)

多光源(Multiple Lights) 多光源渲染技术概览 #mermaid-svg-3L5e5gGn76TNh7Lq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-3L5e5gGn76TNh7Lq .error-icon{fill:#552222;}#mermaid-svg-3L5e5gGn76TNh7Lq .erro…...

深度学习水论文:mamba+图像增强

🧀当前视觉领域对高效长序列建模需求激增,对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模,以及动态计算优势,在图像质量提升和细节恢复方面有难以替代的作用。 🧀因此短时间内,就有不…...

Golang——6、指针和结构体

指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...

LRU 缓存机制详解与实现(Java版) + 力扣解决

📌 LRU 缓存机制详解与实现(Java版) 一、📖 问题背景 在日常开发中,我们经常会使用 缓存(Cache) 来提升性能。但由于内存有限,缓存不可能无限增长,于是需要策略决定&am…...

HTML前端开发:JavaScript 获取元素方法详解

作为前端开发者,高效获取 DOM 元素是必备技能。以下是 JS 中核心的获取元素方法,分为两大系列: 一、getElementBy... 系列 传统方法,直接通过 DOM 接口访问,返回动态集合(元素变化会实时更新)。…...

0x-3-Oracle 23 ai-sqlcl 25.1 集成安装-配置和优化

是不是受够了安装了oracle database之后sqlplus的简陋,无法删除无法上下翻页的苦恼。 可以安装readline和rlwrap插件的话,配置.bahs_profile后也能解决上下翻页这些,但是很多生产环境无法安装rpm包。 oracle提供了sqlcl免费许可&#xff0c…...

动态规划-1035.不相交的线-力扣(LeetCode)

一、题目解析 光看题目要求和例图,感觉这题好麻烦,直线不能相交啊,每个数字只属于一条连线啊等等,但我们结合题目所给的信息和例图的内容,这不就是最长公共子序列吗?,我们把最长公共子序列连线起…...

Netty自定义协议解析

目录 自定义协议设计 实现消息解码器 实现消息编码器 自定义消息对象 配置ChannelPipeline Netty提供了强大的编解码器抽象基类,这些基类能够帮助开发者快速实现自定义协议的解析。 自定义协议设计 在实现自定义协议解析之前,需要明确协议的具体格式。例如,一个简单的…...