pytorch一致数据增强
分割任务对 image 做(某些)transform 时,要对 label(segmentation mask)也做对应的 transform,如 Resize、RandomRotation 等。如果对 image、label 分别用 transform 处理一遍,则涉及随机操作的可能不一致,如 RandomRotation 将 image 转了 a 度、却将 label 转了 b 度。
MONAI 有个 ArrayDataset 实现了这功能,思路是每次 transform 前都重置一次 random seed 先。对 monai 订制 transform 的方法不熟,torchvision.transforms 的订制接口比较简单,考虑基于 pytorch 实现。要改两个东西:
- 扩展 torchvison.transforms.Compose,使之支持多个输入(image、label);
- 一个 wrapper,扩展 transform,使之支持多输入。
思路也是重置 random seed,参考 [1-4]。
Code
to_multi:将处理单幅图的 transform 扩展成可处理多幅;MultiCompose:扩展 torchvision.transforms.Compose,可输入多幅图。内部调用to_multi扩展传入的 transforms。
import random, os
import numpy as np
import torchdef seed_everything(seed=42):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Truedef to_multi(trfm):"""wrap a transform to extend to multiple input with synchronised random seedInput:trfm: transformation function/object (custom or from torchvision.transforms)Output:_multi_transform: function"""# numpy.random.seed range error:# ValueError: Seed must be between 0 and 2**32 - 1min_seed = 0 # - 0x8000_0000_0000_0000max_seed = min(2**32 - 1, 0xffff_ffff_ffff_ffff)def _multi_transform(*images):"""images: [C, H, W]"""if len(images) == 1:return trfm(images[0])_seed = random.randint(min_seed, max_seed)res = []for img in images:seed_everything(_seed)res.append(trfm(img))return tuple(res)return _multi_transformclass MultiCompose:"""Extension of torchvision.transforms.Compose that accepts multiple input.Usage is the same as torchvision.transforms.Compose. This class will wrap inputtransforms with `to_multi` to support simultaneous multiple transformation.This can be useful when simultaneously transforming images & segmentation masks."""def __init__(self, transforms):"""transforms should be wrapped by `to_multi`"""self.transforms = [to_multi(t) for t in transforms]def __call__(self, *images):for t in self.transforms:images = t(*images)return images
test
测试一致性,用到预处理过的 verse’19 数据集、一些工具函数、一个订制 transform:
- verse’19 数据集及预处理见 iTomxy/data/verse;
digit_sort_key:数据文件排序用;get_palette、color_seg、blend_seg:可视化用;MyDataset:看其中__getitem__的 transform 用法,即同时传入 image 和 label;ResizeZoomPad:一个订制的 transform;
import os, os.path as osp, random
from glob import glob
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as Fdef digit_sort_key(s, num_pattern=re.compile('([0-9]+)')):"""natural sort,数据排序用"""return [int(text) for text in num_pattern.split(s) if text.isdigit()]def get_palette(n_classes, pil_format=True):"""创建调色盘,可视化用"""n = n_classespalette = [0] * (n * 3)for j in range(0, n):lab = jpalette[j * 3 + 0] = 0palette[j * 3 + 1] = 0palette[j * 3 + 2] = 0i = 0while lab:palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))i += 1lab >>= 3if pil_format:return paletteres = []for i in range(0, len(palette), 3):res.append(tuple(palette[i: i+3]))return resdef color_seg(label, n_classes=0):"""segmentation mask 上色,可视化用"""if n_classes < 1:n_classes = math.ceil(np.max(label)) + 1label_rgb = Image.fromarray(label.astype(np.int32)).convert("L")label_rgb.putpalette(get_palette(n_classes))return label_rgb.convert("RGB")def blend_seg(image, label, n_classes=0, alpha=0.7, rescale=False, transparent_bg=True, save_file=""):"""融合 image 和其 segmentation mask,可视化用"""if rescale:denom = image.max() - image.min()if 0 != denom:image = (image - image.min()) / denom * 255image = np.clip(image, 0, 255).astype(np.uint8)img_pil = Image.fromarray(image).convert("RGB")lab_pil = color_seg(label, n_classes)blended_image = Image.blend(img_pil, lab_pil, alpha)if transparent_bg:blended_image = Image.fromarray(np.where((0 == label)[:, :, np.newaxis],np.asarray(img_pil),np.asarray(blended_image)))if save_file:blended_image.save(save_file)return blended_imageclass MyDataset(torch.utils.data.Dataset):"""订制 dataset,看 __getitem__ 处 transform 的调法"""def __init__(self, image_list, label_list, transform=None):assert len(image_list) == len(label_list)self.image_list = image_listself.label_list = label_listself.transform = transformdef __len__(self):return len(self.image_list)def __getitem__(self, index):img = np.load(self.image_list[index]) # [h, w]lab = np.load(self.label_list[index])img = torch.from_numpy(img).unsqueeze(0).float() # -> [c=1, h, w]lab = torch.from_numpy(lab).unsqueeze(0).int()if self.transform is not None:img, lab = self.transform(img, lab) # 同时传入 image、labelreturn img, labclass ResizeZoomPad:"""订制 resize"""def __init__(self, size, interpolation="bilinear"):if isinstance(size, int):assert size > 0self.size = [size, size]elif isinstance(size, (tuple, list)):assert len(size) == 2 and size[0] > 0 and size[1] > 0self.size = sizeif isinstance(interpolation, str):assert interpolation.lower() in {"nearest", "bilinear", "bicubic", "box", "hamming", "lanczos"}interpolation = {"nearest": F.InterpolationMode.NEAREST,"bilinear": F.InterpolationMode.BILINEAR,"bicubic": F.InterpolationMode.BICUBIC,"box": F.InterpolationMode.BOX,"hamming": F.InterpolationMode.HAMMING,"lanczos": F.InterpolationMode.LANCZOS}[interpolation.lower()]self.interpolation = interpolationdef __call__(self, image):"""image: [C, H, W]"""scale_h, scale_w = float(self.size[0]) / image.size(1), float(self.size[1]) / image.size(2)scale = min(scale_h, scale_w)tmp_size = [ # clipping to ensure sizemin(int(image.size(1) * scale), self.size[0]),min(int(image.size(2) * scale), self.size[1])]image = F.resize(image, tmp_size, self.interpolation)assert image.size(1) <= self.size[0] and image.size(2) <= self.size[1]pad_h, pad_w = self.size[0] - image.size(1), self.size[1] - image.size(2)if pad_h > 0 or pad_w > 0:pad_left, pad_right = pad_w // 2, (pad_w + 1) // 2pad_top, pad_bottom = pad_h // 2, (pad_h + 1) // 2image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom))return image# 读数据文件
data_path = os.path.expanduser("~/data/verse/processed-verse19-npy-horizontal")
train_images, train_labels, val_images, val_labels = [], [], [], []
for d in os.listdir(osp.join(data_path, "training")):if d.endswith("_ct"):img_p = osp.join(data_path, "training", d)lab_p = osp.join(data_path, "training", d[:-3]+"_seg-vert_msk")assert osp.isdir(lab_p)train_labels.extend(glob(os.path.join(lab_p, "*.npy")))train_images.extend(glob(os.path.join(img_p, "*.npy")))
for d in os.listdir(osp.join(data_path, "validation")):if d.endswith("_ct"):img_p = osp.join(data_path, "validation", d)lab_p = osp.join(data_path, "validation", d[:-3]+"_seg-vert_msk")assert osp.isdir(lab_p)val_labels.extend(glob(os.path.join(lab_p, "*.npy")))val_images.extend(glob(os.path.join(img_p, "*.npy")))# 数据文件名排序
train_images = sorted(train_images, key=lambda f: digit_sort_key(os.path.basename(f)))
train_labels = sorted(train_labels, key=lambda f: digit_sort_key(os.path.basename(f)))
val_images = sorted(val_images, key=lambda f: digit_sort_key(os.path.basename(f)))
val_labels = sorted(val_labels, key=lambda f: digit_sort_key(os.path.basename(f)))# transform
# 用 MultiCompose,其内部调用 to_multi 将 transforms wrap 成支持多输入的
train_trans = MultiCompose([ResizeZoomPad((224, 256)),transforms.RandomRotation(30),
])# 测试:读数据,可试化 image 和 label
check_ds = MyDataset(train_images, train_labels, train_trans)
check_loader = torch.utils.data.DataLoader(check_ds, batch_size=10, shuffle=True)
for images, labels in check_loader:print(images.size(), labels.size())for i in range(images.size(0)):# print(i, end='\r')img = images[i][0].numpy()lab = labels[i][0].numpy()print(np.unique(lab))seg_img = blend_seg(img, lab)img = (255 * (img - img.min()) / (img.max() - img.min())).astype(np.uint8)img = np.asarray(Image.fromarray(img).convert("RGB"))lab = np.asarray(color_seg(lab))comb = np.concatenate([img, lab, seg_img], axis=1)Image.fromarray(comb).save(f"test-dataset-{i}.png")break
效果:

可见,image 和 label 转了同一个随机角度。
Limits
有些 augmentations 是只对 image 做而不对 label 做的,如 ColorJitter,这里没有考虑怎么处理。
References
- How to Set Random Seeds in PyTorch and Tensorflow
- ihoromi4/seed_everything.py
- Reproducibility
- What is the max seed you can set up?
相关文章:
pytorch一致数据增强
分割任务对 image 做(某些)transform 时,要对 label(segmentation mask)也做对应的 transform,如 Resize、RandomRotation 等。如果对 image、label 分别用 transform 处理一遍,则涉及随机操作的…...
MapReduce
1. MapReduce是什么?请简要说明它的工作原理。 MapReduce是一种编程模型,主要用于处理大规模数据集的并行运算,特别是非结构化数据。这个模型的核心思想是将大数据处理任务分解为两个主要步骤:Map和Reduce。用户只需实现map()和r…...
Spring Boot 快速入门
Spring Boot 快速入门 什么是Spring Boot Spring Boot是一个用于简化Spring应用开发的框架,它基于Spring框架,提供了自动配置、快速开发等特性,使得开发者可以更加便捷地构建独立的、生产级别的Spring应用。 开始使用Spring Boot 步骤一&a…...
什么是神经网络的非线性
大家好啊,我是董董灿。 最近在写《计算机视觉入门与调优》(右键,在新窗口中打开链接)的小册,其中一部分说到激活函数的时候,谈到了神经网络的非线性问题。 今天就一起来看看,为什么神经网络需…...
前端知识(十四)——浅谈用户体验测试的主要功能
用户体验(User Experience,简称UX)在现代软件和产品开发中变得愈发重要。为了确保产品能够满足用户期望,提高用户满意度,用户体验测试成为不可或缺的环节。本文将详细探讨用户体验测试的主要功能,以及它在产品开发过程中的重要性 …...
解决前端跨域问题,后端解决方法
Spring CloudVue前后端分离项目报错:Network Error;net::ERR_FAILED(请求跨越)-CSDN博客记录自用...
【网络奇缘系列】计算机网络|数据通信方式|数据传输方式
🌈个人主页: Aileen_0v0🔥系列专栏: 一见倾心,再见倾城 --- 计算机网络~💫个人格言:"没有罗马,那就自己创造罗马~" 这篇文章是关于计算机网络中数据通信的基础知识点, 从模型,术语再到数据通信方式&#…...
数组 注意事项
1.一维数组的初始化 int a[5]{1,2,3,4,5}; 合法 int a[5]{1,2,3}; 合法 int a[]{1,2,3,4,5}; 合法,后面决定前面的大小 int a[5]{1,2,3,4,5,6}; 不合法! 2.一维数组的定义 int a[5] 合法 int a[11] 合法 int a[1/24] 合法 int x5,a[x…...
day11 滑动窗口中的最大值
class MyQueue { //单调队列(从大到小)public:deque<int> que; // 使用deque来实现单调队列// 每次弹出的时候,比较当前要弹出的数值是否等于队列出口元素的数值,如果相等则弹出。// 同时pop之前判断队列当前是否为空。void…...
viple模拟器使用(五):Web 2D模拟器中实现两距离局部最优迷宫算法
关于两距离局部最优迷宫算法的原理本文不再赘述,详情请参考:viple模拟器使用(四),归纳总结为: 前方有路,则直行; 前方无路,则右转90度,标记右转完成ÿ…...
每日一道算法题 3(2023-12-11)
题目描述: VLAN是一种对局域网设备进行逻辑划分的技术,为了标识不同的VLAN,引入VLAN ID(1-4094之间的整数)的概念。 定义一个VLAN ID的资源池(下称VLAN资源池),资源池中连续的VLAN用开始VLAN-结束VLAN表示,不连续的用单…...
【Android】查看keystore的公钥和私钥
前言: 查看前准备好.keystore文件,安装并配置openssl、keytool。文件路径中不要有中文。 一、查看keystore的公钥: 1.从keystore中获取MD5证书 keytool -list -v -keystore gamekeyold.keystore 2.导出公钥文件 keytool -export -alias …...
ChatGPT的常识
什么是ChatGPT? ChatGPT是一个基于GPT模型的聊天机器人,GPT即“Generative Pre-training Transformer”,是一种预训练的语言模型。ChatGPT使用GPT-2和GPT-3两种模型来生成自然语言响应,从而与人类进行真实的对话。 ChatGPT的设计…...
Spring Boot中的事务是如何实现的?懂吗?
SpringBoot中的事务管理,用得好,能确保数据的一致性和完整性;用得不好,可能会给性能带来不小的影响哦。 基本使用 在SpringBoot中,事务的使用非常简洁。首先,得感谢Spring框架提供的Transactional注解&am…...
应用安全:JAVA反序列化漏洞之殇
应用安全:JAVA反序列化漏洞之殇 概述 序列化是让Java对象脱离Java运行环境的一种手段,可以有效的实现多平台之间的通信、对象持久化存储。Java 序列化是指把 Java 对象转换为字节序列的过程便于保存在内存、文件、数据库中,ObjectOutputStream类的 wri…...
基于以太坊的智能合约开发Solidity(函数继承篇)
参考教程:【实战篇】1、函数重载_哔哩哔哩_bilibili 1、函数重载: pragma solidity ^0.5.17;contract overLoadTest {//不带参数function test() public{}//带一个参数function test(address account) public{}//参数类型不同,虽然uint160可…...
【论文极速读】LVM,视觉大模型的GPT时刻?
【论文极速读】LVM,视觉大模型的GPT时刻? FesianXu 20231210 at Baidu Search Team 前言 这一周,LVM在arxiv上刚挂出不久,就被众多自媒体宣传为『视觉大模型的GPT时刻』,笔者抱着强烈的好奇心,在繁忙工作之…...
TS基础语法
前言: 因为在写前端的时候,发现很多UI组件的语法都已经开始使用TS语法,不学习TS根本看不到懂,所以简单的学一下TS语法。为了看UI组件的简单代码,不至于一脸懵。 一、安装node 对于windows来讲,node版本高…...
【基于NLP的微博情感分析:从数据爬取到情感洞察】
基于NLP的微博情感分析:从数据爬取到情感洞察 背景数据集技术选型功能实现创新点 今天我将分享一个基于NLP的微博情感分析项目,通过Python技术、NLP模型和Flask框架,对微博数据进行清洗、分词、可视化,并利用NLP和贝叶斯进行情感分…...
Ubuntu 18.04使用Qemu和GDB搭建运行内核的环境
安装busybox 参考博客: 使用GDBQEMU调试Linux内核环境搭建 一文教你如何使用GDBQemu调试Linux内核 ubuntu22.04搭建qemu环境测试内核 交叉编译busybox 编译busybox出现Library m is needed, can’t exclude it (yet)的解释 S3C2440 制作最新busybox文件系统 https:…...
基于FPGA的PID算法学习———实现PID比例控制算法
基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容:参考网站: PID算法控制 PID即:Proportional(比例)、Integral(积分&…...
使用分级同态加密防御梯度泄漏
抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...
【机器视觉】单目测距——运动结构恢复
ps:图是随便找的,为了凑个封面 前言 在前面对光流法进行进一步改进,希望将2D光流推广至3D场景流时,发现2D转3D过程中存在尺度歧义问题,需要补全摄像头拍摄图像中缺失的深度信息,否则解空间不收敛…...
【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力
引言: 在人工智能快速发展的浪潮中,快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型(LLM)。该模型代表着该领域的重大突破,通过独特方式融合思考与非思考…...
React19源码系列之 事件插件系统
事件类别 事件类型 定义 文档 Event Event 接口表示在 EventTarget 上出现的事件。 Event - Web API | MDN UIEvent UIEvent 接口表示简单的用户界面事件。 UIEvent - Web API | MDN KeyboardEvent KeyboardEvent 对象描述了用户与键盘的交互。 KeyboardEvent - Web…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
【数据分析】R版IntelliGenes用于生物标志物发现的可解释机器学习
禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍流程步骤1. 输入数据2. 特征选择3. 模型训练4. I-Genes 评分计算5. 输出结果 IntelliGenesR 安装包1. 特征选择2. 模型训练和评估3. I-Genes 评分计…...
uniapp 开发ios, xcode 提交app store connect 和 testflight内测
uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...
Python 实现 Web 静态服务器(HTTP 协议)
目录 一、在本地启动 HTTP 服务器1. Windows 下安装 node.js1)下载安装包2)配置环境变量3)安装镜像4)node.js 的常用命令 2. 安装 http-server 服务3. 使用 http-server 开启服务1)使用 http-server2)详解 …...
《Docker》架构
文章目录 架构模式单机架构应用数据分离架构应用服务器集群架构读写分离/主从分离架构冷热分离架构垂直分库架构微服务架构容器编排架构什么是容器,docker,镜像,k8s 架构模式 单机架构 单机架构其实就是应用服务器和单机服务器都部署在同一…...
