【论文复现】偏标记学习+图像分类

📝个人主页🌹:Eternity._
🌹🌹期待您的关注 🌹🌹


❀ 偏标记学习+图像分类
- 概述
- 算法原理
- 核心逻辑
- 效果演示
- 使用方式
- 参考文献
概述
本文复现论文 Progressive Identification of True Labels for Partial-Label Learning[1] 提出的偏标记学习方法。
随着深度神经网络的发展,机器学习任务对标注数据的需求不断增加。然而,大量的标注数据十分依赖人力资源与标注者的专业知识。弱监督学习可以有效缓解这一问题,因其不需要完全且准确的标注数据。该论文关注一个重要的弱监督学习问题——偏标记学习(Partial Label Learning),其中每个训练实例与一组候选标签相关联,但仅有一个标签是真实的。

该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关。
本文所涉及的所有资源的获取方式:这里
算法原理
传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下:

其中, x 表示样本特征; [ y = [ y 1 , y 2 , … , y c ] ] [ \mathbf{y} = [y_1, y_2, \ldots, y_c] ] [y=[y1,y2,…,yc]]表示样本标签,其为独热码,即除了真实标签对应维度值为 1,其余为零; [ f i ( x ; θ ) ] [ f_i(x; \theta) ] [fi(x;θ)]表示模型预测样本 x 标签为 i 的概率。
该论文提出的方法使用一个软标签 [ y ^ = [ y ^ 1 , y ^ 2 , … , y ^ c ] ] [ \hat{y} = [\hat{y}_1, \hat{y}_2, \ldots, \hat{y}_c] ] [y^=[y^1,y^2,…,y^c]],其对任意 [ i ∈ [ 0 , c ] ] [ i \in [0, c] ] [i∈[0,c]]满足 [ ∑ i y ^ i = 1 且 0 ≤ y ^ i ≤ 1 ] [ \sum_{i} \hat{y}_i = 1 \quad \text{且} \quad 0 \leq \hat{y}_i \leq 1 ] [∑iy^i=1且0≤y^i≤1]为了使用该软标签,论文根据候选标签集 s 对软标签进行初始化:

为了渐进式地识别真实标签,算法在每次更新参数之前,根据预测结果为下轮训练使用的软标签赋值:

其中, [ I ( j ∈ s ) = { 1 当且仅当 j ∈ s 为真 0 否则 ] [ I(j \in s) = \begin{cases} 1 & \text{当且仅当 } j \in s \text{ 为真} \\ 0 & \text{否则} \end{cases} ] [I(j∈s)={10当且仅当 j∈s 为真否则]
核心逻辑
具体的核心逻辑如下所示:
import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdmdef CE_loss(probs, targets):"""交叉熵损失函数"""loss = -torch.sum(targets * torch.log(probs), dim = -1)loss_avg = torch.sum(loss)/probs.shape[0]return loss_avgclass Proden:def __init__(self, configs):self.configs = configsdef train(self, save = False):configs = self.configs# 读取数据集dataset_path = configs['dataset path']if configs['dataset'] == 'CIFAR-10':train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)train_dataset = datasets.Cifar(train_data, train_labels)test_dataset = datasets.Cifar(test_data, test_labels)output_dimension = 10elif configs['dataset'] == 'CIFAR-100':train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)train_dataset = datasets.Cifar(train_data, train_labels)test_dataset = datasets.Cifar(test_data, test_labels)output_dimension = 100# 生成偏标记partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])train_dataset.load_partial_labels(partial_labels)# 计算数据的均值和方差,用于模型输入的标准化mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]std = [np.std(train_data[:, i, :, :]) for i in range(3)]normalize = transforms.Normalize(mean, std)# 设备:GPU或CPUdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载模型if configs['model'] == 'ResNet18':model = models.ResNet18(output_dimension = output_dimension).to(device)elif configs['model'] == 'ConvNet':model = models.ConvNet(output_dimension = output_dimension).to(device)# 设置学习率等超参数lr = configs['learning rate']weight_decay = configs['weight decay']momentum = configs['momentum']optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)lr_step = configs['learning rate decay step']lr_decay = configs['learning rate decay rate']lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)for epoch_id in range(configs['epoch count']):# 训练模型train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)model.train()for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):ids = batch['ids']# 标准化输入data = normalize(batch['data'].to(device))partial_labels = batch['partial_labels'].to(device)targets = batch['targets'].to(device)optimizer.zero_grad()# 计算预测概率logits = model(data)probs = F.softmax(logits, dim=-1)# 更新软标签with torch.no_grad():new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)train_dataset.targets[ids] = new_targets.cpu().numpy()# 计算交叉熵损失loss = CE_loss(probs, targets)loss.backward()# 更新模型参数optimizer.step()# 调整学习率lr_scheduler.step()
以上代码仅作展示,更详细的代码文件请参见附件。
效果演示
我提前在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下:

由图可见,该算法在测试集上获得了 89.8% 的准确率。
进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 “Ship”:



网站提供了在线演示功能,使用者请输入一张小于1MB、类别为上述十个类别之一、长宽尽可能相等的JPG图像。
使用方式
解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:
unzip Proden-implemention.zip
cd Proden-implemention
代码的运行环境可通过如下命令进行配置:
pip install -r requirements.txt
运行如下命令以下载并解压数据集
bash download.sh
如果希望在本地训练模型,请运行如下命令:
python main.py -c [你的配置文件路径] -r [选择下者之一:"train"、"test"、"infer"]
如果希望在线部署,请运行如下命令:
python main-flask.py
参考文献
[1] Lv J, Xu M, Feng L, et al. Progressive identification of true labels for partial-label learning[C]//International conference on machine learning. PMLR, 2020: 6500-6510.
[2] Krizhevsky A, Hinton G. Learning multiple layers of features from tiny images[J]. 2009.
[3] Laine S, Aila T. Temporal ensembling for semi-supervised learning[J]. arXiv preprint arXiv:1610.02242, 2016.
编程未来,从这里启航!解锁无限创意,让每一行代码都成为你通往成功的阶梯,帮助更多人欣赏与学习!
更多内容详见:这里
相关文章:
【论文复现】偏标记学习+图像分类
📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀ 偏标记学习图像分类 概述算法原理核心逻辑效果演示使用方式参考文献 概述 本文复现论文 Progressive Identification of True Labels for Pa…...
C嘎嘎探索篇:栈与队列的交响:C++中的结构艺术
C嘎嘎探索篇:栈与队列的交响:C中的结构艺术 前言: 小编在之前刚完成了C中栈和队列(stack和queue)的讲解,忘记的小伙伴可以去我上一篇文章看一眼的,今天小编将会带领大家吹奏栈和队列的交响&am…...
AIGC-----AIGC在虚拟现实中的应用前景
AIGC在虚拟现实中的应用前景 引言 随着人工智能生成内容(AIGC)的快速发展,虚拟现实(VR)技术的应用也迎来了新的契机。AIGC与VR的结合为创造沉浸式体验带来了全新的可能性,这种组合不仅极大地降低了VR内容的…...
Django 路由层
1. 路由基础概念 URLconf (URL 配置):Django 的路由系统是基于 urls.py 文件定义的。路径匹配:通过模式匹配 URL,并将请求传递给对应的视图处理函数。命名路由:每个路由可以定义一个名称,用于反向解析。 2. 基本路由配…...
《硬件架构的艺术》笔记(八):消抖技术
简介 在电子设备中两个金属触点随着触点的断开闭合便产生了多个信号,这就是抖动。 消抖是用来确保每一次断开或闭合触点时只有一个信号起作用的硬件设备或软件。(就是每次断开闭合只对应一个操作)。 抖动在某些模拟和逻辑电路中可能产生问…...
Spring 与 Spring MVC 与 Spring Boot三者之间的区别与联系
一.什么是Spring?它解决了什么问题? 1.1什么是Spring? Spring,一般指代的是Spring Framework 它是一个开源的应用程序框架,提供了一个简易的开发方式,通过这种开发方式,将避免那些可能致使代码…...
【算法】连通块问题(C/C++)
目录 连通块问题 解决思路 步骤: 初始化: DFS函数: 复杂度分析 代码实现(C) 题目链接:2060. 奶牛选美 - AcWing题库 解题思路: AC代码: 题目链接:687. 扫雷 -…...
如何选择黑白相机和彩色相机
我们在选择成像解决方案时黑白相机很容易被忽略,因为许多新相机提供鲜艳的颜色,鲜明的对比度和改进的弱光性能。然而,有许多应用,选择黑白相机将是更好的选择,因为他们产生更清晰的图像,更好的分辨率&#…...
Rust 力扣 - 740. 删除并获得点数
文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 首先对于这题我们如果将所有点数装入一个切片f中,该切片f中的i号下标表示所有点数为i的点数之和 那么这题就转换成了打家劫舍这道题,也就是求选择了切片中某个下标的元素后,该…...
OpenCV从入门到精通实战(七)——探索图像处理:自定义滤波与OpenCV卷积核
本文主要介绍如何使用Python和OpenCV库通过卷积操作来应用不同的图像滤波效果。主要分为几个步骤:图像的读取与处理、自定义卷积函数的实现、不同卷积核的应用,以及结果的展示。 卷积 在图像处理中,卷积是一种重要的操作,它通过…...
Docker核心概念总结
本文只是对 Docker 的概念做了较为详细的介绍,并不涉及一些像 Docker 环境的安装以及 Docker 的一些常见操作和命令。 容器介绍 Docker 是世界领先的软件容器平台,所以想要搞懂 Docker 的概念我们必须先从容器开始说起。 什么是容器? 先来看看容器较为…...
环形缓冲区
什么是环形缓冲区 环形缓冲区,也称为循环缓冲区或环形队列,是一种特殊的FIFO(先进先出)数据结构。它使用一块固定大小的内存空间来缓存数据,并通过两个指针(读指针和写指针)来管理数据的读写。当任意一个指针到达缓冲区末尾时,会自动回绕到缓冲区开头,形成一个"环"。…...
jQuery-Word-Export 使用记录及完整修正文件下载 jquery.wordexport.js
参考资料: jQuery-Word-Export导出word_jquery.wordexport.js下载-CSDN博客 近期又需要自己做个 Html2Doc 的解决方案,因为客户又不想要 Html2pdf 的下载了,当初还给我费尽心思解决Html转pdf时中文输出的问题(html转pdf文件下载之…...
云服务器部署WebSocket项目
WebSocket是一种在单个TCP连接上进行全双工通信的协议,其设计的目的是在Web浏览器和Web服务器之间进行实时通信(实时Web) WebSocket协议的优点包括: 1. 更高效的网络利用率:与HTTP相比,WebSocket的握手只…...
C#+数据库 实现动态权限设置
将权限信息存储在数据库中,支持动态调整。根据用户所属的角色、特定的功能模块,动态加载权限” 1. 数据库设计 根据这种需求,可以通过以下表设计: 用户表 (Users):存储用户信息。角色表 (Roles):存储角色…...
(原创)Android Studio新老界面UI切换及老版本下载地址
前言 这两天下载了一个新版的Android Studio,发现整个界面都发生了很大改动: 新的界面的一些设置可参考一些博客: Android Studio新版UI常用设置 但是对于一些急着开发的小伙伴来说,没有时间去适应,那么怎么办呢&am…...
Ubuntu24虚拟机-gnome-boxes
推荐使用gnome-boxes, virtualbox构建失败,multipass需要开启防火墙 sudo apt install gnome-boxes创建完毕~...
k8s rainbond centos7/win10 -20241124
参考 https://www.rainbond.com/ 国内一站式云原生平台 对centos7环境支持不太行 [lighthouseVM-16-5-centos ~]$ curl -o install.sh https://get.rainbond.com && bash ./install.sh 2024-11-24 09:56:57 ERROR: Ops! Docker daemon is not running. Start docke…...
SpringBoot+Vue滑雪社区网站设计与实现
【1】系统介绍 研究背景 随着互联网技术的快速发展和冰雪运动的普及,滑雪作为一种受欢迎的冬季运动项目,吸引了越来越多的爱好者。与此同时,社交媒体和在线社区平台的兴起为滑雪爱好者提供了一个交流经验、分享心得、获取信息的重要渠道。滑…...
MySql.2
sql查询语句执行过程 SQL 查询语句的执行过程是一个复杂的过程,涉及多个步骤。以下是典型的关系数据库管理系统 (RDBMS) 中 SQL 查询语句的执行过程概述: 1. 客户端发送查询 用户通过 SQL 客户端或应用程序发送 SQL 查询语句给数据库服务器。 2. …...
华硕笔记本性能困境突破:G-Helper工具的全方位优化方案
华硕笔记本性能困境突破:G-Helper工具的全方位优化方案 【免费下载链接】g-helper Lightweight Armoury Crate alternative for Asus laptops. Control tool for ROG Zephyrus G14, G15, G16, M16, Flow X13, Flow X16, TUF, Strix, Scar and other models 项目地…...
家里装了 OpenClaw,在公司也能随时管理——Shield CLI 远程访问方案
家里装了 OpenClaw,在公司也能随时管理 OpenClaw 火到不用介绍了——GitHub 25 万 Star,一个能真正帮你干活的 AI Agent。很多人装在家里的 Windows 电脑上,配好了 API Key 和各种插件,用着很爽。但一到公司或者出门在外ÿ…...
ESP32 IDF环境下DHT11温湿度读取避坑指南:从时序图到数据拼接的完整解析
ESP32 IDF环境下DHT11温湿度读取避坑指南:从时序图到数据拼接的完整解析 在物联网设备开发中,温湿度传感器是最基础也最常用的环境感知元件之一。DHT11作为一款低成本、单总线数字输出的温湿度传感器,被广泛应用于各类嵌入式项目中。然而&…...
GB28181协议实战:WVP开源项目+ZLM流媒体服务联调配置详解
GB28181协议实战:WVP开源项目ZLM流媒体服务联调配置详解 在视频监控领域,GB28181协议作为国家标准协议,已经成为设备互联互通的重要基础。而将WVP(Web Video Platform)开源项目与ZLM(ZLMediaKit)…...
VLSI设计实战:手把手教你用SPICE模型搭建9种基础电路(附完整代码)
VLSI设计实战:手把手教你用SPICE模型搭建9种基础电路(附完整代码) 在集成电路设计的浩瀚宇宙中,SPICE模型就像工程师手中的瑞士军刀。我第一次接触SPICE仿真时,面对密密麻麻的网表文件完全不知所措——直到导师扔给我一…...
别只盯着显卡!CES上英伟达那个能装进口袋的AI超算,普通人怎么玩?
口袋里的AI革命:如何用英伟达Project DIGITS打造个人智能工作站 当大多数人还在讨论RTX 50系列显卡的游戏性能时,英伟达在CES 2025上悄悄展示了一个可能改变未来的小玩意——Project DIGITS。这个能装进口袋的AI超算,搭载GB10芯片,…...
3款工业调试开源工具让Modbus通讯诊断效率提升80%
3款工业调试开源工具让Modbus通讯诊断效率提升80% 【免费下载链接】OpenModScan Open ModScan is a Free Modbus Master (Client) Utility 项目地址: https://gitcode.com/gh_mirrors/op/OpenModScan 在工业自动化领域,Modbus协议作为设备间通讯的"通用…...
ollama-QwQ-32B量化部署:在4GB内存设备运行OpenClaw的配置
ollama-QwQ-32B量化部署:在4GB内存设备运行OpenClaw的配置 1. 为什么要在低配设备上折腾大模型? 去年冬天,我在树莓派上第一次尝试部署OpenClaw时,被现实狠狠教育了一顿——32GB内存的笔记本跑得飞起,换到4GB的树莓派…...
3步搞定Windows 11优化:用Win11Debloat让你的电脑更快更干净
3步搞定Windows 11优化:用Win11Debloat让你的电脑更快更干净 【免费下载链接】Win11Debloat 一个简单的PowerShell脚本,用于从Windows中移除预装的无用软件,禁用遥测,从Windows搜索中移除Bing,以及执行各种其他更改以简…...
Vue 3 + hls.js 实战:手把手教你打造一个能‘续命’的安防监控播放器
Vue 3 hls.js 打造安防级视频流播放器的"续命"秘籍 在安防监控、智慧城市等实时视频流应用场景中,网络抖动、服务中断、页面切换等问题常常导致视频播放中断,严重影响监控效果。本文将深入探讨如何基于Vue 3和hls.js构建一个具备"续命&q…...
