当前位置: 首页 > news >正文

【pytorch】使用mixup技术扩充数据集进行训练

目录

  • 1.mixup技术简介
  • 2.pytorch实现代码,以图片分类为例

1.mixup技术简介

在这里插入图片描述

mixup是一种数据增强技术,它可以通过将多组不同数据集的样本进行线性组合,生成新的样本,从而扩充数据集。mixup的核心原理是将两个不同的图片按照一定的比例进行线性组合,生成新的样本,新样本的标签也是进行线性组合得到。比如,对于两个样本x1和x2,它们的标签分别为y1和y2,那么mixup生成的新样本x’和标签y’如下:

x’ = λx1 + (1-λ)x2
y’ = λy1 + (1-λ)y2

其中,λ为0到1之间的一个随机数,它表示x1和x2在新样本中的权重。

本文中,使用mixup扩充数据集后的损失函数,为:
loss = λ * criterion(outputs, targets_a) + (1 - λ) * criterion(outputs, targets_b)

即由两张图片融合后新图片的损失为,分别计算原先两图片与各自标签的损失值之和。

mixup也可以增加数据集的多样性,从而降低模型的方差,提高模型的鲁棒性。
总之,mixup是一种非常实用的数据增强技术,它可以用于各种机器学习任务中,可以有效地防止过拟合,并且可以提高模型的泛化能力。在实际应用中,mixup可以帮助我们更好地利用有限的数据集,并且提高模型的性能。

2.pytorch实现代码,以图片分类为例

import matplotlib.pyplot as plt
import torchvision
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from PIL import Image
import os
import cv2
# 加载resnet18模型
resnet18 = models.resnet18(pretrained=False)# 获取resnet18最后一层输出,输出为512维,最后一层本来是用作 分类的,原始网络分为1000类
# 用 softmax函数或者 fully connected 函数,但是用 nn.identtiy() 函数把最后一层替换掉,相当于得到分类之前的特征!
#Identity模块,它将输入直接传递给输出,而不会对输入进行任何变换。
resnet18.fc = nn.Identity()
# 构建新的网络,将resnet18的输出作为输入
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.resnet18 = resnet18self.fc1 = nn.Linear(512, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 10)self.fc4 = nn.Linear(10, 2)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.resnet18(x)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.relu(self.fc3(x))x = F.relu(self.fc4(x))x = self.softmax(x)x=x.view(-1,2)return x
#使用mixup时的数据集融合器,输入数据集的输入(一批图片),以及对应标签,返回线性相加后的图片
def mixup_data(x, y, alpha=1.0):#随机生成一个 beta 分布的参数 lam,用于生成随机的线性组合,以实现 mixup 数据扩充。lam = np.random.beta(alpha, alpha)#生成一个随机的序列,用于将输入数据进行 shuffle。batch_size = x.size()[0]index = torch.randperm(batch_size)#得到混合后的新图片mixed_x = lam * x + (1 - lam) * x[index, :]#得到混图对应的两类标签y_a, y_b = y, y[index]return mixed_x, y_a, y_b, lam# 实例化网络
net = Net()
# 将模型放入GPU
net = net.cuda()# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器,添加l2正则化
optimizer = torch.optim.Adam(net.parameters(), lr=0.0005,weight_decay=0)# 加载数据集
# 创建一个transform对象
def rgb2bgr(image):image = np.array(image)[:, :, ::-1]image=Image.fromarray(np.uint8(image))return imagetransform = transforms.Compose([transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=0.5),# rgb转bgrtorchvision.transforms.Lambda(rgb2bgr),torchvision.transforms.Resize(112),# 入的图片为PIL image 或者 numpy.nadrry格式的图片,其shape为(HxWxC)数值范围在[0,255],转换之后shape为(CxHxw),数值范围在[0,1]transforms.ToTensor(),# 进行归一化和标准化,Imagenet数据集的均值和方差为:mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),# 因为这是在百万张图像上计算而得的,所以我们通常见到在训练过程中使用它们做标准化。transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]),# #这行代码表示使用transforms.RandomErasing函数,以概率p=1,在图像上随机选择一个尺寸为scale=(0.02, 0.33),长宽比为ratio=(1, 1)的区域,# #进行随机像素值的遮盖,只能对tensor操作:transforms.RandomErasing(p=0.1, scale=(0.02, 0.2), ratio=(1, 1), value='random')
])train_dataset = torchvision.datasets.ImageFolder(r'D:\eyeDataSet\train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataset = torchvision.datasets.ImageFolder(r'D:\eyeDataSet\test', transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True)## 绘制训练及测试集迭代曲线
# 记录训练集准确率
train_acc = []
# 记录测试集准确率
test_acc = []
for epoch in range(50):running_loss = 0.0#[(0, data1), (1, data2), (2, data3), ...]for i, data in enumerate(train_loader, 0):# 获取输入inputs, labels = datainputs, labels = inputs.cuda(), labels.cuda()# mixup扩充数据集inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha=1.0)# 梯度清零optimizer.zero_grad()# forward + backwardoutputs = net(inputs)#这里对应调整了使用mixup后的数据集的lossloss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)loss.backward()# 更新参数optimizer.step()# 打印log信息# loss 是一个scalar,需要使用loss.item()来获取数值,不能使用loss[0]running_loss += loss.item()if i % 10 == 9: # 每200个batch打印一下训练状态print('[%d, %5d] loss: %.3f' \% (epoch+1, i+1, running_loss / 2000))running_loss = 0.0# 在每次训练完成后,使用测试集进行测试correct = 0total = 0with torch.no_grad():for i2,data2 in enumerate(test_loader):#控制测试集的数量if i2>5:breakimages, labels = data2images, labels = images.cuda(), labels.cuda()outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_acc.append(100 * correct / total)print('Accuracy of the network on the test images: %.3f,now max acc is %.3f' % (100 * correct / total,max(test_acc)))# 保存测试集上准确率最高的模型if 100 * correct / total == max(test_acc):if not os.path.exists(r'./result'):os.makedirs(r'./result')if max(test_acc)>93:savename="./result/bestmodel"+"%.3f"%max(test_acc)+".pth"torch.save(net.state_dict(), savename)print("最大准确度:",max(test_acc))

相关文章:

【pytorch】使用mixup技术扩充数据集进行训练

目录1.mixup技术简介2.pytorch实现代码,以图片分类为例1.mixup技术简介 mixup是一种数据增强技术,它可以通过将多组不同数据集的样本进行线性组合,生成新的样本,从而扩充数据集。mixup的核心原理是将两个不同的图片按照一定的比例…...

面向对象设计模式:创建型模式之单例模式

1. 单例模式,Singleton Pattern 1.1 Definition 定义 单例模式是确保类有且仅有一个实例的创建型模式,其提供了获取类唯一实例(全局指针)的方法。 单例模式类提供了一种访问其唯一的对象的方式,可以直接访问&#xf…...

IsADirectoryError: [Errno 21] Is a directory: ‘.‘

项目场景: 基于YOLOv5的室内场景识别 工具:colab 问题描述 Traceback (most recent call last): File “train.py”, line 630, in main(opt) File “train.py”, line 494, in main d torch.load(last, map_location‘cpu’)[‘opt’] File “/usr/…...

判断三角面片与空间中球体是否相交

文章目录一、问题描述二、解题思路​ 在做项目时遇到了一个数学问题,即,如何判断给定一个三角面片与空间中某个球体有相交部分?这个问题看似简单,实际处理起来需要一些方法和手段。一、问题描述 已知空间中球体的球心位置center&a…...

继承下的缺省参数值和访问说明符

前言 本文将介绍 C 继承体系下,函数缺省参数的绑定和函数访问说明符的绑定。这些奇怪的问题实际上不应在我们的代码中出现,但它们能帮助我们理解 C 的动态绑定和静态绑定,也能帮助我们更好的通过面试。 缺省参数值 先来看一段代码&#xf…...

Spring核心模块—— BeanFactoryPostProcessorBeanPostProcessor(后处理器)

后置处理器前言Spring的后处理器BeanFactoryPostProcessor(工厂后处理器)执行节点作用基本信息经典场景子接口——BeanDefinitiRegistryPostProcessor基本介绍用途具体原理例子——注册BeanDefinition使用Spring的BeanFactoryPostProcessor扩展点完成自定…...

产品新人如何培养产品思维?

什么是产品思维?其实很难定义,不同人有不同的定义。有的人定义为以用户为中心打磨一个完美体验的产品;有的定义为从需求调研到需求上线各个步骤需要思考的点,等等。本文想讨论的产品思维是:怎么去发现问题,…...

「兔了个兔」CSS如此之美,看我如何实现可爱兔兔LOADING页面(万字详解附源码)

💂作者简介: THUNDER王,一名热爱财税和SAP ABAP编程以及热爱分享的博主。目前于江西师范大学会计学专业大二本科在读,同时任汉硕云(广东)科技有限公司ABAP开发顾问。在学习工作中,我通常使用偏后…...

【Java】阻塞队列 BlcokingQueue 原理、与等待唤醒机制condition/await/singal的关系、多线程安全总结

在实习过程中使用阻塞队列对while sleep 轮询机制进行了改造,提升了发送接收的效率,这里做一点点总结。 自从Java 1.5之后,在java.util.concurrent包下提供了若干个阻塞队列,BlcokingQueue继承了Queue接口,是线程安全…...

【水下图像增强】Enhancing Underwater Imagery using Generative Adversarial Networks

原始题目Enhancing Underwater Imagery using Generative Adversarial Networks中文名称使用 GAN 增强水下图像发表时间2018年1月11日平台ICRA 2018来源University of Minnesota, Minneapolis MN文章链接https://arxiv.org/abs/1801.04011开源代码官方:https://gith…...

Maven专题总结—详细版

第一章 为什么使用Maven 获取jar包 使用Maven之前,自行在网络中下载jar包,效率较低。如【谷歌、百度、CSDN…】使用Maven之后,统一在一个地址下载资源jar包【阿里云镜像服务器等…】 添加jar包 使用Maven之前,将jar复制到项目工程…...

华为OD机试真题Java实现【字符串加密】真题+解题思路+代码(20222023)

字符串加密 题目 给你一串未加密的字符串str, 通过对字符串的每一个字母进行改变来实现加密, 加密方式是在每一个字母str[i]偏移特定数组元素a[i]的量, 数组a前三位已经赋值:a[0]=1,a[1]=2,a[2]=4。 当i>=3时,数组元素a[i]=a[i-1]+a[i-2]+a[i-3], 例如:原文 abcde …...

「Python 基础」函数与高阶函数

文章目录1. 函数调用函数定义函数函数的参数递归函数2. 高阶函数map/reducefiltersorted3. 函数式编程返回函数匿名函数装饰器偏函数1. 函数 函数是一种重复代码的抽象方式,Python 内建支持的一种封装; 调用函数 调用一个函数,需要知道函数…...

DIV内容滚动,文字符滚动标签marquee兼容稳定不卡

marquee(文字滚动)标签 marquee简介 <marquee>标签,是成对出现的标签,首标签<marquee>和尾标签</marquee>之间的内容就是滚动内容。 <marquee>标签的属性主要有behavior、bgcolor、direction、width、height、hspace、vspace、loop、scrollamount、scr…...

SpringBoot_第五章(Web和原理分析)

目录 1&#xff1a;静态资源 1.1&#xff1a;静态资源访问 1.2&#xff1a;静态资源源码解析-到WebMvcAutoConfiguration 2&#xff1a;Rest请求绑定&#xff08;设置put和delete&#xff09; 2.1&#xff1a;代码实例 2.2&#xff1a;源码分析到-WebMvcAutoConfiguratio…...

4-2 Linux进程和内存概念

文章目录前言进程状态进程优先级内存模型进程内存关系前言 进程是一个其中运行着一个或多个线程的地址空间和这些线程所需要的系统资源。一般来说&#xff0c;Linux系统会在进程之间共享程序代码和系统函数库&#xff0c;所以在任何时刻内存中都只有代码的一份拷贝。 进程状态…...

【微信小程序】计算器案例

&#x1f3c6;今日学习目标&#xff1a;第二十一期——计算器案例 ✨个人主页&#xff1a;颜颜yan_的个人主页 ⏰预计时间&#xff1a;30分钟 &#x1f389;专栏系列&#xff1a;我的第一个微信小程序 计算器前言实现效果实现步骤wxmlwxssjs数字按钮事件处理函数计算按钮处理事…...

408 计算机基础复试笔记 —— 更新中

计算机组成原理 计算机系统概述 问题一、冯诺依曼机基本思想 存储程序&#xff1a;程序和数据都存储在同一个内存中&#xff0c;计算机可以根据指令集执行存储在内存中的程序。这使得程序具有高度灵活性和可重用性。指令流水线&#xff1a;将指令分成若干阶段&#xff0c;每…...

找出最大数-课后程序(Python程序开发案例教程-黑马程序员编著-第二章-课后作业)

实例6&#xff1a;找出最大数 “脑力大乱斗”休闲益智游戏的关卡中&#xff0c;有一个题目是找出最大数。本实例要求编写程序&#xff0c;实现从输入的任意三个数中找出最大数的功能。 实例分析 对于3个数比较大小&#xff0c;我们可以首先先对两个数的大小进行比较&#xff…...

Java——N叉树的层序遍历

题目链接 leetcode在线oj题——N叉树的层序遍历 题目描述 给定一个 N 叉树&#xff0c;返回其节点值的层序遍历。&#xff08;即从左到右&#xff0c;逐层遍历&#xff09;。 树的序列化输入是用层序遍历&#xff0c;每组子节点都由 null 值分隔&#xff08;参见示例&…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

(二)TensorRT-LLM | 模型导出(v0.20.0rc3)

0. 概述 上一节 对安装和使用有个基本介绍。根据这个 issue 的描述&#xff0c;后续 TensorRT-LLM 团队可能更专注于更新和维护 pytorch backend。但 tensorrt backend 作为先前一直开发的工作&#xff0c;其中包含了大量可以学习的地方。本文主要看看它导出模型的部分&#x…...

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

从深圳崛起的“机器之眼”:赴港乐动机器人的万亿赛道赶考路

进入2025年以来&#xff0c;尽管围绕人形机器人、具身智能等机器人赛道的质疑声不断&#xff0c;但全球市场热度依然高涨&#xff0c;入局者持续增加。 以国内市场为例&#xff0c;天眼查专业版数据显示&#xff0c;截至5月底&#xff0c;我国现存在业、存续状态的机器人相关企…...

ESP32读取DHT11温湿度数据

芯片&#xff1a;ESP32 环境&#xff1a;Arduino 一、安装DHT11传感器库 红框的库&#xff0c;别安装错了 二、代码 注意&#xff0c;DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异&#xff0c;它们的数据同步要求既要保持数据的准确性和一致性&#xff0c;又要处理好性能问题。以下是一些主要的技术要点&#xff1a; 数据结构差异 数据类型差异&#xff…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年&#xff0c;截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始&#xff0c;将英文题库免费公布出来&#xff0c;并进行解析&#xff0c;帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战

在现代战争中&#xff0c;电磁频谱已成为继陆、海、空、天之后的 “第五维战场”&#xff0c;雷达作为电磁频谱领域的关键装备&#xff0c;其干扰与抗干扰能力的较量&#xff0c;直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器&#xff0c;凭借数字射…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件&#xff0c;用于在原生应用中加载 HTML 页面&#xff1a; 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...