当前位置: 首页 > news >正文

pytorch实现半监督学习

半监督学习(Semi-Supervised Learning,SSL)结合了有监督学习和无监督学习的特点,通常用于部分数据有标签、部分数据无标签的场景。其主要步骤如下:

1. 数据准备

  • 有标签数据(Labeled Data):数据集的一部分带有真实的类别标签。
  • 无标签数据(Unlabeled Data):数据集的另一部分没有标签,仅有特征信息。
  • 数据预处理:对数据进行清理、标准化、特征工程等处理,以保证数据质量。

2. 选择半监督学习方法

常见的半监督学习方法包括:

  • 基于生成模型(Generative Models):如高斯混合模型(GMM)、变分自编码器(VAE)。
  • 基于一致性正则化(Consistency Regularization):如 MixMatch、FixMatch,利用数据增强来约束模型预测一致性。
  • 基于伪标签(Pseudo-Labeling):先用模型预测无标签数据的类别,然后将高置信度的预测作为新标签加入训练。
  • 图神经网络(Graph-Based Methods):如 Label Propagation,通过构造数据之间的图结构传播标签信息。

3. 训练初始模型

  • 仅使用有标签数据训练一个初始模型。
  • 选择合适的损失函数,如交叉熵损失(Cross-Entropy Loss)或均方误差(MSE Loss)。
  • 训练过程中可以使用数据增强、正则化等优化策略。

4. 利用无标签数据增强训练

  • 伪标签方法:用初始模型对无标签数据进行预测,筛选高置信度样本,加入有标签数据训练。
  • 一致性正则化:对无标签数据进行不同变换,要求模型的预测结果一致。
  • 联合训练:构造有监督损失(Supervised Loss)和无监督损失(Unsupervised Loss),综合优化。

5. 模型迭代更新

  • 重新利用训练后的模型预测无标签数据,产生新的伪标签或调整模型参数。
  • 通过半监督策略不断优化模型,使其对无标签数据的预测更加稳定。

6. 评估和测试

  • 使用测试集(通常是有标签的数据)评估模型性能。
  • 选择合适的评估指标,如准确率(Accuracy)、F1-score、AUC-ROC 等。

7. 调优和部署

  • 根据实验结果调整超参数,如伪标签置信度阈值、学习率等。
  • 结合业务需求,将最终模型部署到实际应用中。

关键步骤:

  1. 初始化模型:首先使用有标签数据训练模型。
  2. 生成伪标签:用训练好的模型对无标签数据进行预测,生成伪标签。
  3. 结合有标签和伪标签数据进行训练:用带有标签和无标签(伪标签)数据一起训练模型。
  4. 迭代训练:不断迭代,使用更新的模型生成新的伪标签,进一步优化模型。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt# 简化的神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 8, kernel_size=3)  # 缩小卷积层的输出通道self.fc1 = nn.Linear(8 * 26 * 26, 10)  # 调整全连接层的输入和输出尺寸def forward(self, x):x = F.relu(self.conv1(x))x = x.view(x.size(0), -1)  # 展平x = self.fc1(x)return x# 自定义数据集
class CustomDataset(Dataset):def __init__(self, data, labels=None):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):if self.labels is not None:return self.data[idx], self.labels[idx]else:return self.data[idx], -1  # 无标签数据# 半监督训练函数
def pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device, threshold=0.95):model.train()labeled_loss_value = 0pseudo_loss_value = 0for (labeled_data, labeled_labels), (unlabeled_data, _) in zip(labeled_loader, unlabeled_loader):labeled_data, labeled_labels = labeled_data.to(device), labeled_labels.to(device)unlabeled_data = unlabeled_data.to(device)# 1. 有标签数据训练optimizer.zero_grad()labeled_output = model(labeled_data)labeled_loss = F.cross_entropy(labeled_output, labeled_labels)labeled_loss.backward()# 2. 无标签数据伪标签生成unlabeled_output = model(unlabeled_data)probs = F.softmax(unlabeled_output, dim=1)max_probs, pseudo_labels = torch.max(probs, dim=1)# 伪标签置信度筛选pseudo_mask = max_probs > threshold  # 置信度大于阈值的数据作为伪标签if pseudo_mask.sum() > 0:pseudo_labels = pseudo_labels[pseudo_mask]unlabeled_data_pseudo = unlabeled_data[pseudo_mask]# 3. 使用伪标签数据进行训练(确保无标签数据参与反向传播)optimizer.zero_grad()  # 清除之前的梯度pseudo_output = model(unlabeled_data_pseudo)pseudo_loss = F.cross_entropy(pseudo_output, pseudo_labels)pseudo_loss.backward()  # 计算反向梯度optimizer.step()  # 更新模型参数# 累加损失用于展示labeled_loss_value += labeled_loss.item()if pseudo_mask.sum() > 0:pseudo_loss_value += pseudo_loss.item()return labeled_loss_value / len(labeled_loader), pseudo_loss_value / len(unlabeled_loader)# 模拟数据
num_labeled = 1000
num_unlabeled = 5000
data_dim = (1, 28, 28)  # 28x28 灰度图像
num_classes = 10labeled_data = torch.randn(num_labeled, *data_dim)
labeled_labels = torch.randint(0, num_classes, (num_labeled,))
unlabeled_data = torch.randn(num_unlabeled, *data_dim)labeled_dataset = CustomDataset(labeled_data, labeled_labels)
unlabeled_dataset = CustomDataset(unlabeled_data)labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小# 模型、优化器和设备设置
device = torch.device("cpu")  # 临时使用 CPU
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练过程并记录损失
num_epochs = 10
labeled_losses = []
pseudo_losses = []for epoch in range(num_epochs):labeled_loss, pseudo_loss = pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device)labeled_losses.append(labeled_loss)pseudo_losses.append(pseudo_loss)print(f"Epoch [{epoch + 1}/{num_epochs}] | Labeled Loss: {labeled_loss:.4f} | Pseudo Loss: {pseudo_loss:.4f}")# 绘制损失曲线
plt.plot(range(num_epochs), labeled_losses, label='Labeled Loss')
plt.plot(range(num_epochs), pseudo_losses, label='Pseudo Label Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Losses Over Epochs')
plt.show()# 展示伪标签生成效果(可视化一些样本的伪标签预测结果)
model.eval()
with torch.no_grad():sample_unlabeled_data = unlabeled_data[:10].to(device)output = model(sample_unlabeled_data)probs = F.softmax(output, dim=1)_, predicted_labels = torch.max(probs, dim=1)# 展示预测的标签print("Generated Pseudo Labels for Samples:")print(predicted_labels)# 假设这些是伪标签预测的图片fig, axes = plt.subplots(2, 5, figsize=(12, 5))for i, ax in enumerate(axes.flat):# 将tensor转换为NumPy数组img = sample_unlabeled_data[i].cpu().numpy().squeeze()  # 转为NumPy数组ax.imshow(img, cmap='gray')  # 使用灰度显示图像ax.set_title(f"Pred: {predicted_labels[i].item()}")ax.axis('off')plt.show()

相关文章:

pytorch实现半监督学习

半监督学习(Semi-Supervised Learning,SSL)结合了有监督学习和无监督学习的特点,通常用于部分数据有标签、部分数据无标签的场景。其主要步骤如下: 1. 数据准备 有标签数据(Labeled Data)&…...

我的毕设之路:(2)系统类型的论文写法

一般先进行毕设的设计与实现,再在现成毕设基础上进行描述形成文档,那么论文也就成形了。 1 需求分析:毕业设计根据开题报告和要求进行需求分析和功能确定,区分贴合主题的主要功能和拓展功能能,删除偏离无关紧要的功能…...

LosslessScaling-学习版[steam价值30元的游戏无损放大/补帧工具]

LosslessScaling 链接:https://pan.xunlei.com/s/VOHc-yZBgwBOoqtdZAv114ZTA1?pwdxiih# 解压后运行"A-绿化-解压后运行我.cmd"...

concurrent.futures.Future对象详解:利用线程池与进程池实现异步操作

concurrent.futures.Future对象详解:利用线程池与进程池实现异步操作 一、前言二、使用线程池三、使用进程池四、注意事项五、结语 一、前言 在现代编程中,异步操作已成为提升程序性能和响应速度的关键手段。Python的concurrent.futures模块为此提供了强…...

StarRocks 安装部署

StarRocks 安装部署 StarRocks端口: 官方《配置检查》有服务端口详细描述: https://docs.starrocks.io/zh/docs/deployment/environment_configurations/ StarRocks架构:https://docs.starrocks.io/zh/docs/introduction/Architecture/ Sta…...

Python Matplotlib库:从入门到精通

Python Matplotlib库:从入门到精通 在数据分析和科学计算领域,可视化是一项至关重要的技能。Matplotlib作为Python中最流行的绘图库之一,为我们提供了强大的绘图功能。本文将带你从Matplotlib的基础开始,逐步掌握其高级用法&…...

线程概念、操作

一、背景知识 1、地址空间进一步理解 在父子进程对同一变量进行修改时发生写时拷贝,这时候拷贝的基本单位是4KB,会将该变量所在的页框全拷贝一份,这是因为修改该变量很有可能会修改其周围的变量(局部性原理)&#xf…...

【PySide6拓展】QSoundEffect

文章目录 【PySide6拓展】QSoundEffect 音效播放类**基本概念****什么是 QSoundEffect?****QSoundEffect 的特点****安装 PySide6** **如何使用 QSoundEffect?****1. 播放音效****示例代码:播放音效** **代码解析****QSoundEffect 的高级用法…...

33【脚本解析语言】

脚本语言也叫解析语言 脚本一词,相信很多人都听过,那么什么是脚本语言,我们在开发时有一个调试功能,但是发布版是需要编译执行的,体积比较大,同时这使得我们每次更新都需要重新编译,客户再…...

【Unity】 HTFramework框架(五十九)快速开发编辑器工具(Assembly Viewer + ILSpy)

更新日期:2025年1月23日。 Github源码:[点我获取源码] Gitee源码:[点我获取源码] 索引 开发编辑器工具MouseRayTarget焦点视角Collider线框Assembly Viewer搜索程序集ILSpy反编译程序集搜索GizmosElement类找到Gizmos菜单找到Gizmos窗口分析A…...

如何解决TikTok网络不稳定的问题

TikTok是目前全球最受欢迎的短视频平台之一,凭借其丰富多彩的内容和社交功能吸引了数以亿计的用户。然而,尽管TikTok在世界范围内的使用情况不断增长,但不少用户在使用过程中仍然会遇到网络不稳定的问题。无论是在观看视频时遇到缓冲&#xf…...

告别页面刷新!如何使用AJAX和FormData优化Web表单提交

系列文章目录 01-从零开始学 HTML:构建网页的基本框架与技巧 02-HTML常见文本标签解析:从基础到进阶的全面指南 03-HTML从入门到精通:链接与图像标签全解析 04-HTML 列表标签全解析:无序与有序列表的深度应用 05-HTML表格标签全面…...

WireShark4.4.2浏览器网络调试指南:数据统计(八)

概述 Wireshark 是一款功能强大的开源网络协议分析软件,被广泛应用于网络调试和数据分析。随着互联网的发展,以及网络安全问题日益严峻,了解如何使用 Wireshark进行浏览器网络调试显得尤为重要。最新的 Wireshark4.4.2 提供了更加强大的功能…...

Hypium+python鸿蒙原生自动化安装配置

Hypiumpython自动化搭建 文章目录 Python安装pip源配置HDC安装Hypium安装DevEco Testing Hypium插件安装及使用方法​​​​​插件安装工程创建区域 Python安装 推荐从官网获取3.10版本,其他版本可能出现兼容性问题 Python下载地址 下载64/32bitwindows安装文件&am…...

2025创业思路和方向有哪些?

创业思路和方向是决定创业成功与否的关键因素。以下是一些基于找到的参考内容的创业思路和方向,旨在激发创业灵感: 一、技术创新与融合: 1、智能手机与云电视结合:开发集成智能手机功能的云电视,提供通讯、娱乐一体化体…...

实验五---控制系统的稳定性分析---自动控制原理实验课

一 实验目的 1、理解控制系统稳定性的概念 2、掌握多种判定系统稳定性的原理及方法 3、掌握使用Matlab软件进行控制系统的稳定性分析 二 实验仪器 计算机,MATLAB仿真软件 三 实验内容及步骤 1.计算系统闭环特征根,判别系统稳定性; 2.绘制系统…...

AttributeError: can‘t set attribute ‘lines‘

报错: ax p3.Axes3D(fig) ax.lines [] AttributeError: cant set attribute lines 总结下来,解决方案应包括: 1. 使用ax.clear()方法清除所有内容。 2. 逐个移除lines中的元素。 3. 检查matplotlib版本,确保没有已知的bug。…...

Day07:缓存-数据淘汰策略

Redis的数据淘汰策略有哪些 ? (key过期导致的) 在redis中提供了两种数据过期删除策略 第一种是惰性删除,在设置该key过期时间后,我们不去管它,当需要该key时,我们再检查其是否过期,如果过期&…...

基于聚类与相关性分析对马来西亚房价数据进行分析

碎碎念:由于最近太忙了,更新的比较慢,提前祝大家新春快乐,万事如意!本数据集的下载地址,读者可以自行下载。 1.项目背景 本项目旨在对马来西亚房地产市场进行初步的数据分析,探索各州的房产市…...

Java—工具类类使用

工具类的调用:工具类名.方法名 工具类的书写: 示例: 写一个遍历数组的工具类 import java.util.Arrays;public class ArrayUtil {private ArrayUtil() {} //用私有化构造方法不让外界创建关于它的对象//定义static静态方法,因…...

SDMatte高清人像抠图作品集:影视级海报与创意合成的幕后利器

SDMatte高清人像抠图作品集:影视级海报与创意合成的幕后利器 1. 开篇:当AI遇见专业级人像抠图 想象一下这样的场景:电影海报需要将主演从绿幕背景中完美剥离,电商广告要把模特无缝融入不同场景,艺术创作需要将人物与…...

【喜报】义翘神州再获CNAS认可,全面对标2025版药典新标准

义翘神州生物安全检测实验室近日成功通过中国合格评定国家认可委员会(CNAS)的扩项评审及定期监督评审,并已完成全部能力附表更新!这标志着实验室技术能力与质量管理体系持续符合ISO/IEC 17025:2017国际标准的严苛要求,…...

AI绘画模型训练完全指南:3大核心优势与零代码实践

AI绘画模型训练完全指南:3大核心优势与零代码实践 【免费下载链接】sd-trainer 项目地址: https://gitcode.com/gh_mirrors/sd/sd-trainer Stable Diffusion训练技术已成为AI绘画领域的核心能力,但传统训练流程复杂、配置繁琐,让许多…...

高效学挖漏洞!全网最全平台汇总 + 零基础到精通指南,一篇搞定所有

一、众测平台(国内) 名称网址漏洞盒子https://www.vulbox.com/火线安全平台https://www.huoxian.cn/漏洞银行https://www.bugbank.cn/360漏洞众包响应平台https://src.360.net/补天平台(奇安信)https://www.butian.net/春秋云测https://zhongce.ichunqi…...

KingbaseES V008R006C008B0014物理备份实战:sys_rman从配置到自动化的完整避坑指南

KingbaseES物理备份实战:从sys_rman配置到自动化运维的深度解析 凌晨三点,数据库告警铃声突然响起——某核心业务系统的KingbaseES实例因磁盘故障导致数据丢失。此时,一个配置得当的sys_rman物理备份系统将成为最后的救命稻草。不同于简单的操…...

拯救数字青春:GetQzonehistory让QQ空间记忆永久安家

拯救数字青春:GetQzonehistory让QQ空间记忆永久安家 【免费下载链接】GetQzonehistory 获取QQ空间发布的历史说说 项目地址: https://gitcode.com/GitHub_Trending/ge/GetQzonehistory 在这个信息爆炸的时代,我们的青春记忆正以数据形式储存在各大…...

4步完成Axure本地化设置:让新手轻松上手的中文界面方案

4步完成Axure本地化设置:让新手轻松上手的中文界面方案 【免费下载链接】axure-cn Chinese language file for Axure RP. Axure RP 简体中文语言包,不定期更新。支持 Axure 9、Axure 10。 项目地址: https://gitcode.com/gh_mirrors/ax/axure-cn …...

RevokeMsgPatcher 2.1:实用高效的微信QQ防撤回完整解决方案

RevokeMsgPatcher 2.1:实用高效的微信QQ防撤回完整解决方案 【免费下载链接】RevokeMsgPatcher :trollface: A hex editor for WeChat/QQ/TIM - PC版微信/QQ/TIM防撤回补丁(我已经看到了,撤回也没用了) 项目地址: https://gitco…...

无需配置环境!MinerU镜像一键部署,即刻体验智能文档解析

无需配置环境!MinerU镜像一键部署,即刻体验智能文档解析 1. 为什么选择智能文档解析? 在日常办公和学习中,我们经常需要处理各种文档资料:PDF报告、扫描合同、学术论文、财务报表等。传统方式要么需要手动输入&#…...

解决Word中MathType功能失效的VBA与注册表修复指南

1. 遇到MathType罢工?先别急着重装Office 最近帮同事处理Word文档时,发现他的MathType菜单全灰了,公式编辑功能完全瘫痪。这种情况在科研论文写作高峰期特别要命——你正赶着投稿 deadline,突然发现公式编辑器失灵了,…...