【论文复现】偏标记学习+图像分类

📝个人主页🌹:Eternity._
🌹🌹期待您的关注 🌹🌹


❀ 偏标记学习+图像分类
- 概述
- 算法原理
- 核心逻辑
- 效果演示
- 使用方式
- 参考文献
概述
本文复现论文 Progressive Identification of True Labels for Partial-Label Learning[1] 提出的偏标记学习方法。
随着深度神经网络的发展,机器学习任务对标注数据的需求不断增加。然而,大量的标注数据十分依赖人力资源与标注者的专业知识。弱监督学习可以有效缓解这一问题,因其不需要完全且准确的标注数据。该论文关注一个重要的弱监督学习问题——偏标记学习(Partial Label Learning),其中每个训练实例与一组候选标签相关联,但仅有一个标签是真实的。

该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关。
本文所涉及的所有资源的获取方式:这里
算法原理
传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下:

其中, x 表示样本特征; [ y = [ y 1 , y 2 , … , y c ] ] [ \mathbf{y} = [y_1, y_2, \ldots, y_c] ] [y=[y1,y2,…,yc]]表示样本标签,其为独热码,即除了真实标签对应维度值为 1,其余为零; [ f i ( x ; θ ) ] [ f_i(x; \theta) ] [fi(x;θ)]表示模型预测样本 x 标签为 i 的概率。
该论文提出的方法使用一个软标签 [ y ^ = [ y ^ 1 , y ^ 2 , … , y ^ c ] ] [ \hat{y} = [\hat{y}_1, \hat{y}_2, \ldots, \hat{y}_c] ] [y^=[y^1,y^2,…,y^c]],其对任意 [ i ∈ [ 0 , c ] ] [ i \in [0, c] ] [i∈[0,c]]满足 [ ∑ i y ^ i = 1 且 0 ≤ y ^ i ≤ 1 ] [ \sum_{i} \hat{y}_i = 1 \quad \text{且} \quad 0 \leq \hat{y}_i \leq 1 ] [∑iy^i=1且0≤y^i≤1]为了使用该软标签,论文根据候选标签集 s 对软标签进行初始化:

为了渐进式地识别真实标签,算法在每次更新参数之前,根据预测结果为下轮训练使用的软标签赋值:

其中, [ I ( j ∈ s ) = { 1 当且仅当 j ∈ s 为真 0 否则 ] [ I(j \in s) = \begin{cases} 1 & \text{当且仅当 } j \in s \text{ 为真} \\ 0 & \text{否则} \end{cases} ] [I(j∈s)={10当且仅当 j∈s 为真否则]
核心逻辑
具体的核心逻辑如下所示:
import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdmdef CE_loss(probs, targets):"""交叉熵损失函数"""loss = -torch.sum(targets * torch.log(probs), dim = -1)loss_avg = torch.sum(loss)/probs.shape[0]return loss_avgclass Proden:def __init__(self, configs):self.configs = configsdef train(self, save = False):configs = self.configs# 读取数据集dataset_path = configs['dataset path']if configs['dataset'] == 'CIFAR-10':train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)train_dataset = datasets.Cifar(train_data, train_labels)test_dataset = datasets.Cifar(test_data, test_labels)output_dimension = 10elif configs['dataset'] == 'CIFAR-100':train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)train_dataset = datasets.Cifar(train_data, train_labels)test_dataset = datasets.Cifar(test_data, test_labels)output_dimension = 100# 生成偏标记partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])train_dataset.load_partial_labels(partial_labels)# 计算数据的均值和方差,用于模型输入的标准化mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]std = [np.std(train_data[:, i, :, :]) for i in range(3)]normalize = transforms.Normalize(mean, std)# 设备:GPU或CPUdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载模型if configs['model'] == 'ResNet18':model = models.ResNet18(output_dimension = output_dimension).to(device)elif configs['model'] == 'ConvNet':model = models.ConvNet(output_dimension = output_dimension).to(device)# 设置学习率等超参数lr = configs['learning rate']weight_decay = configs['weight decay']momentum = configs['momentum']optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)lr_step = configs['learning rate decay step']lr_decay = configs['learning rate decay rate']lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)for epoch_id in range(configs['epoch count']):# 训练模型train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)model.train()for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):ids = batch['ids']# 标准化输入data = normalize(batch['data'].to(device))partial_labels = batch['partial_labels'].to(device)targets = batch['targets'].to(device)optimizer.zero_grad()# 计算预测概率logits = model(data)probs = F.softmax(logits, dim=-1)# 更新软标签with torch.no_grad():new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)train_dataset.targets[ids] = new_targets.cpu().numpy()# 计算交叉熵损失loss = CE_loss(probs, targets)loss.backward()# 更新模型参数optimizer.step()# 调整学习率lr_scheduler.step()
以上代码仅作展示,更详细的代码文件请参见附件。
效果演示
我提前在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下:

由图可见,该算法在测试集上获得了 89.8% 的准确率。
进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 “Ship”:



网站提供了在线演示功能,使用者请输入一张小于1MB、类别为上述十个类别之一、长宽尽可能相等的JPG图像。
使用方式
解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:
unzip Proden-implemention.zip
cd Proden-implemention
代码的运行环境可通过如下命令进行配置:
pip install -r requirements.txt
运行如下命令以下载并解压数据集
bash download.sh
如果希望在本地训练模型,请运行如下命令:
python main.py -c [你的配置文件路径] -r [选择下者之一:"train"、"test"、"infer"]
如果希望在线部署,请运行如下命令:
python main-flask.py
参考文献
[1] Lv J, Xu M, Feng L, et al. Progressive identification of true labels for partial-label learning[C]//International conference on machine learning. PMLR, 2020: 6500-6510.
[2] Krizhevsky A, Hinton G. Learning multiple layers of features from tiny images[J]. 2009.
[3] Laine S, Aila T. Temporal ensembling for semi-supervised learning[J]. arXiv preprint arXiv:1610.02242, 2016.
编程未来,从这里启航!解锁无限创意,让每一行代码都成为你通往成功的阶梯,帮助更多人欣赏与学习!
更多内容详见:这里
相关文章:
【论文复现】偏标记学习+图像分类
📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀ 偏标记学习图像分类 概述算法原理核心逻辑效果演示使用方式参考文献 概述 本文复现论文 Progressive Identification of True Labels for Pa…...
C嘎嘎探索篇:栈与队列的交响:C++中的结构艺术
C嘎嘎探索篇:栈与队列的交响:C中的结构艺术 前言: 小编在之前刚完成了C中栈和队列(stack和queue)的讲解,忘记的小伙伴可以去我上一篇文章看一眼的,今天小编将会带领大家吹奏栈和队列的交响&am…...
AIGC-----AIGC在虚拟现实中的应用前景
AIGC在虚拟现实中的应用前景 引言 随着人工智能生成内容(AIGC)的快速发展,虚拟现实(VR)技术的应用也迎来了新的契机。AIGC与VR的结合为创造沉浸式体验带来了全新的可能性,这种组合不仅极大地降低了VR内容的…...
Django 路由层
1. 路由基础概念 URLconf (URL 配置):Django 的路由系统是基于 urls.py 文件定义的。路径匹配:通过模式匹配 URL,并将请求传递给对应的视图处理函数。命名路由:每个路由可以定义一个名称,用于反向解析。 2. 基本路由配…...
《硬件架构的艺术》笔记(八):消抖技术
简介 在电子设备中两个金属触点随着触点的断开闭合便产生了多个信号,这就是抖动。 消抖是用来确保每一次断开或闭合触点时只有一个信号起作用的硬件设备或软件。(就是每次断开闭合只对应一个操作)。 抖动在某些模拟和逻辑电路中可能产生问…...
Spring 与 Spring MVC 与 Spring Boot三者之间的区别与联系
一.什么是Spring?它解决了什么问题? 1.1什么是Spring? Spring,一般指代的是Spring Framework 它是一个开源的应用程序框架,提供了一个简易的开发方式,通过这种开发方式,将避免那些可能致使代码…...
【算法】连通块问题(C/C++)
目录 连通块问题 解决思路 步骤: 初始化: DFS函数: 复杂度分析 代码实现(C) 题目链接:2060. 奶牛选美 - AcWing题库 解题思路: AC代码: 题目链接:687. 扫雷 -…...
如何选择黑白相机和彩色相机
我们在选择成像解决方案时黑白相机很容易被忽略,因为许多新相机提供鲜艳的颜色,鲜明的对比度和改进的弱光性能。然而,有许多应用,选择黑白相机将是更好的选择,因为他们产生更清晰的图像,更好的分辨率&#…...
Rust 力扣 - 740. 删除并获得点数
文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 首先对于这题我们如果将所有点数装入一个切片f中,该切片f中的i号下标表示所有点数为i的点数之和 那么这题就转换成了打家劫舍这道题,也就是求选择了切片中某个下标的元素后,该…...
OpenCV从入门到精通实战(七)——探索图像处理:自定义滤波与OpenCV卷积核
本文主要介绍如何使用Python和OpenCV库通过卷积操作来应用不同的图像滤波效果。主要分为几个步骤:图像的读取与处理、自定义卷积函数的实现、不同卷积核的应用,以及结果的展示。 卷积 在图像处理中,卷积是一种重要的操作,它通过…...
Docker核心概念总结
本文只是对 Docker 的概念做了较为详细的介绍,并不涉及一些像 Docker 环境的安装以及 Docker 的一些常见操作和命令。 容器介绍 Docker 是世界领先的软件容器平台,所以想要搞懂 Docker 的概念我们必须先从容器开始说起。 什么是容器? 先来看看容器较为…...
环形缓冲区
什么是环形缓冲区 环形缓冲区,也称为循环缓冲区或环形队列,是一种特殊的FIFO(先进先出)数据结构。它使用一块固定大小的内存空间来缓存数据,并通过两个指针(读指针和写指针)来管理数据的读写。当任意一个指针到达缓冲区末尾时,会自动回绕到缓冲区开头,形成一个"环"。…...
jQuery-Word-Export 使用记录及完整修正文件下载 jquery.wordexport.js
参考资料: jQuery-Word-Export导出word_jquery.wordexport.js下载-CSDN博客 近期又需要自己做个 Html2Doc 的解决方案,因为客户又不想要 Html2pdf 的下载了,当初还给我费尽心思解决Html转pdf时中文输出的问题(html转pdf文件下载之…...
云服务器部署WebSocket项目
WebSocket是一种在单个TCP连接上进行全双工通信的协议,其设计的目的是在Web浏览器和Web服务器之间进行实时通信(实时Web) WebSocket协议的优点包括: 1. 更高效的网络利用率:与HTTP相比,WebSocket的握手只…...
C#+数据库 实现动态权限设置
将权限信息存储在数据库中,支持动态调整。根据用户所属的角色、特定的功能模块,动态加载权限” 1. 数据库设计 根据这种需求,可以通过以下表设计: 用户表 (Users):存储用户信息。角色表 (Roles):存储角色…...
(原创)Android Studio新老界面UI切换及老版本下载地址
前言 这两天下载了一个新版的Android Studio,发现整个界面都发生了很大改动: 新的界面的一些设置可参考一些博客: Android Studio新版UI常用设置 但是对于一些急着开发的小伙伴来说,没有时间去适应,那么怎么办呢&am…...
Ubuntu24虚拟机-gnome-boxes
推荐使用gnome-boxes, virtualbox构建失败,multipass需要开启防火墙 sudo apt install gnome-boxes创建完毕~...
k8s rainbond centos7/win10 -20241124
参考 https://www.rainbond.com/ 国内一站式云原生平台 对centos7环境支持不太行 [lighthouseVM-16-5-centos ~]$ curl -o install.sh https://get.rainbond.com && bash ./install.sh 2024-11-24 09:56:57 ERROR: Ops! Docker daemon is not running. Start docke…...
SpringBoot+Vue滑雪社区网站设计与实现
【1】系统介绍 研究背景 随着互联网技术的快速发展和冰雪运动的普及,滑雪作为一种受欢迎的冬季运动项目,吸引了越来越多的爱好者。与此同时,社交媒体和在线社区平台的兴起为滑雪爱好者提供了一个交流经验、分享心得、获取信息的重要渠道。滑…...
MySql.2
sql查询语句执行过程 SQL 查询语句的执行过程是一个复杂的过程,涉及多个步骤。以下是典型的关系数据库管理系统 (RDBMS) 中 SQL 查询语句的执行过程概述: 1. 客户端发送查询 用户通过 SQL 客户端或应用程序发送 SQL 查询语句给数据库服务器。 2. …...
【力扣数据库知识手册笔记】索引
索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...
【大模型RAG】Docker 一键部署 Milvus 完整攻略
本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装;只需暴露 19530(gRPC)与 9091(HTTP/WebUI)两个端口,即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...
条件运算符
C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...
基础测试工具使用经验
背景 vtune,perf, nsight system等基础测试工具,都是用过的,但是没有记录,都逐渐忘了。所以写这篇博客总结记录一下,只要以后发现新的用法,就记得来编辑补充一下 perf 比较基础的用法: 先改这…...
CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云
目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...
MySQL中【正则表达式】用法
MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现(两者等价),用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例: 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...
uniapp中使用aixos 报错
问题: 在uniapp中使用aixos,运行后报如下错误: AxiosError: There is no suitable adapter to dispatch the request since : - adapter xhr is not supported by the environment - adapter http is not available in the build 解决方案&…...
使用Matplotlib创建炫酷的3D散点图:数据可视化的新维度
文章目录 基础实现代码代码解析进阶技巧1. 自定义点的大小和颜色2. 添加图例和样式美化3. 真实数据应用示例实用技巧与注意事项完整示例(带样式)应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...
Yolov8 目标检测蒸馏学习记录
yolov8系列模型蒸馏基本流程,代码下载:这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中,**知识蒸馏(Knowledge Distillation)**被广泛应用,作为提升模型…...
基于TurtleBot3在Gazebo地图实现机器人远程控制
1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...
