当前位置: 首页 > news >正文

深度学习之“制作自定义数据”--torch.utils.data.DataLoader重写构造方法。

深度学习之“制作自定义数据”–torch.utils.data.DataLoader重写构造方法。

前言:

​ 本文讲述重写torch.utils.data.DataLoader类的构造方法对自定义图片制作类似MNIST数据集格式(image, label),用于自己的Pytorch神经网络模型运行,代码已整理打包上传网盘,文末下载。tensor数据格式(N,C,H,W)

  • N:Batch,批处理大小,表示一个batch中的图像数量

  • C:Channel,通道数,表示一张图像中的通道数

  • H:Height,高度,表示图像垂直维度的像素数

  • W:Width,宽度,表示图像水平维度的像素数

  • 例如下图输出一个批次的训练集数据就是一批次64张图片(N),3维通道数(C),一张图片高度32像素(H),一张图片宽度32像素(W)

在这里插入图片描述

步骤一

​ 对图片整理分类(python代码os库进行对文件夹创建和图片的移动到文件夹),以文件夹名为图片的种类名,如下图所示:

在这里插入图片描述

步骤二

​ 对所有种类文件夹进行遍历读入,将每个(图片的文件路径 )和(对应的标签)写入到txt文本中,结果为trian.txt 和 test.txt,作为训练集合测试集的数据准备。代码为CreateDataset01.py

# -*- coding: utf-8 -*-
# @Time : 2023/1/26/026 18:48
# @Author : LeeSheel
# @File : CreateDataset01.py
# @Project : 深度学习'''
生成训练集和测试集,保存在txt文件中本地电脑,只选取出3000张图片为训练集进行模型运行数据
'''import os
import random
train_ratio = 0.6
test_ratio = 1-train_ratio
train_list, test_list = [],[]  #创建两个个列表,里面存放  图片路径+‘\t’+图片标签
data_list = []rootdata = r"D:\FreeDesk\大创项目\手写藏文字母识别\手写藏文字母数据\总数据"for root,dirs,files in os.walk(rootdata):# print(root)# print(dirs)# print(files)#拼接每个图片的绝对文件路径:for i in range(int(len(files)*train_ratio)):# print(files[i])#输出的是每个图片的名称# print(root+"---"+files[i])  #shu输出每个每个图片的文件夹路径----图片名称# print(os.path.join(root, files[i]))  #拼接路径,# print(str(root).split("/")[-1])   #dui对root进行字符串切割,获得最后一个元素,代表每个图片的标签。class_flag = str(root).split("\\")[-1]  #biaoqain标签data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'train_list.append(data)for i in range(int(len(files) * train_ratio),len(files)):# print(i)class_flag = str(root).split("\\")[-1]  # biaoqain标签# print(class_flag)# print(files[i])data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'test_list.append(data)# print(train_list)
random.shuffle(train_list)
random.shuffle(test_list)with open('train.txt','w',encoding='UTF-8') as f:for train_img in train_list:f.write(str(train_img))with open('test.txt','w',encoding='UTF-8') as f:for test_img in test_list:f.write(test_img)## 随机抽取3000个作为本地train.txt   以及1000个作为本地test.txt# from random import sample
#
# print(sample(train_list, 30000)) # 随机抽取5个元素
# local_train_list = sample(train_list, 30000)
# print("dsdfsdfs")
# print(len(local_train_list))
# local_test_list = sample(test_list, 10000)
#
# with open('localtrain.txt','w',encoding='UTF-8') as f:
#     for train_img in local_train_list:
#             f.write(str(train_img))
#
# with open('localtest.txt','w',encoding='UTF-8') as f:
#     for test_img in local_test_list:
#         f.write(test_img)

得到txt结果:(文件路径与标签以空格隔开):

在这里插入图片描述

步骤三

​ 将步骤二得到的train.txt 和 test.txt 转化为train_loader 和 test_loader,重写LoadData类的构造方法,将train.txt文本转为train_dataset ,将test.txt转为test_dataset,最后再使用torch.utils.data.DataLoader()进行转为train_loader 和 test_loader: 就可以用于调用模型训练了。

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=64,shuffle=True)

重写LoadData类的构造方法代码(这里的transforms.Normalize()图像标准化,可以使用下文的python代码求出mean和std,填入标准化数值。),步骤三代码为 CreateDataloader02.py

# -*- coding: utf-8 -*-
# @Time : 2023/1/26/026 18:56
# @Author : LeeSheel
# @File : CreateDataloader02.py
# @Project : 深度学习
import torch
from PIL import Image
import torchvision.transforms as transforms
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.utils.data import Datasetclass LoadData(Dataset):def __init__(self, txt_path, train_flag=True):self.imgs_info = self.get_images(txt_path)self.train_flag = train_flagself.train_tf = transforms.Compose([# 随机旋转图片transforms.RandomHorizontalFlip(),# 将图片尺寸resize到32x32transforms.Resize((32, 32)),# 将图片转化为Tensor格式transforms.ToTensor(),# 正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)transforms.Normalize((0.96934927, 0.9696228, 0.9695143), (0.124204025, 0.12326231, 0.12356147))  # 图像标准化])self.val_tf = transforms.Compose([# 将图片尺寸resize到32x32transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.96934927, 0.9696228, 0.9695143), (0.124204025, 0.12326231, 0.12356147))])def get_images(self, txt_path):with open(txt_path, 'r', encoding='utf-8') as f:imgs_info = f.readlines()imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))return imgs_infodef __getitem__(self, index):img_path, label = self.imgs_info[index]img = Image.open(img_path)img = img.convert('RGB')if self.train_flag:img = self.train_tf(img)else:img = self.val_tf(img)label = int(label)return img, labeldef __len__(self):return len(self.imgs_info)train_dataset = LoadData("train.txt", True)print("训练接数据个数:", len(train_dataset))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
for image, label in train_loader:print(image.shape)print(image)# img = transform_BZ(image)# print(img)print(label)breaktest_dataset = LoadData("test.txt", False)
print("测试集数据个数:", len(test_dataset))
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=64,shuffle=True)

求图片标准化transforms.Normalize()参数 代码

# -*- coding: utf-8 -*-
# @Time : 2023/1/31/031 18:18
# @Author : LeeSheel
# @File : 计算std和mea.py
# @Project : 深度学习
import numpy as np
import cv2
import os# img_h, img_w = 32, 32
img_h, img_w = 32, 32  # 经过处理后你的图片的尺寸大小
means, stdevs = [], []
img_list = []imgs_path = "D:\\0"  # 数据集的路径采用绝对引用
imgs_path_list = os.listdir(imgs_path)len_ = len(imgs_path_list)
i = 0
for item in imgs_path_list:img = cv2.imread(os.path.join(imgs_path, item))img = cv2.resize(img, (img_w, img_h))img = img[:, :, :, np.newaxis]img_list.append(img)i += 1print(i, '/', len_)imgs = np.concatenate(img_list, axis=3)
imgs = imgs.astype(np.float32) / 255.for i in range(3):pixels = imgs[:, :, i, :].ravel()  # 拉成一行means.append(np.mean(pixels))stdevs.append(np.std(pixels))# BGR --> RGB , CV读取的需要转换,PIL读取的不用转换
means.reverse()
stdevs.reverse()print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))

代码下载:

链接:https://pan.baidu.com/s/1fa_gdLYXagu65P2uYpepqA?pwd=xx78
提取码:xx78

在这里插入图片描述

相关文章:

深度学习之“制作自定义数据”--torch.utils.data.DataLoader重写构造方法。

深度学习之“制作自定义数据”–torch.utils.data.DataLoader重写构造方法。 前言: ​ 本文讲述重写torch.utils.data.DataLoader类的构造方法,对自定义图片制作类似MNIST数据集格式(image, label),用于自己的Pytorc…...

#G. 求约数个数之六

我们先求到区间[1..b]之间的所有约数之和于是结果就等于 [1..b]之间的所有约数之和减去[1..a-1]之间的约数之和很明显这两个问题是同性质的问题,只是右端点不同罢了.明显对于1到N之间的数字,其约数范围也为1到N这个范围内。于是我们可以枚举约数L,当然这…...

如何为Java文件代码签名及添加时间戳?

Java是一种流行的编程语言,大多数组织都使用它来开发业务应用程序。由于其高使用率,攻击者总是试图找到其中的漏洞并基于它利用软件。为了防止此类攻击, 为 Java 文件(.jar)进行代码签名并添加时间戳,可以防…...

Xamarin.Forsm for Android 显示 PDF

背景 某些情况下,需要让用户阅读下发的文件,特别是红头文件,这些文件一般都是使用PDF格式下发,这种文件有很重要的一点就是不能更改。这时候就需要使用原文件进行展示。 Xamarin.Forms Android 中的 WebView 控件是不能直接显示的…...

RK3399平台开发系列讲解(LED子系统篇)LED子系统详解

🚀返回专栏总目录 文章目录 一、设备树编写二、LED子系统2.1、用户态2.2、内核驱动三、驱动代码3.1、平台设备驱动的注册3.2、平台设备驱动的probe四、使用方法沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇将详细介绍LED子系统。 一、设备树编写 节点属性添加…...

LeetCode 432. 全 O(1) 的数据结构

LeetCode 432. 全 O(1) 的数据结构 难度:hard\color{red}{hard}hard 题目描述 请你设计一个用于存储字符串计数的数据结构,并能够返回计数最小和最大的字符串。 实现 AllOneAllOneAllOne 类: AllOne()AllOne()AllOne() 初始化数据结构的对…...

再析jvm

前言 希望自己每一次学习都有不同的理解 文章目录前言1. jvm的组成取消永久代使用元空间原因2. 运行时数据区3. 堆栈区别队列和栈,队列先进先出,栈先进后出从栈顶弹出4. GC、内存溢出、垃圾回收4.1 如何确定引用是否会被回收4.1.1 Java中的引用类型4.1.…...

社招前端二面面试题总结

代码输出结果 var A {n: 4399}; var B function(){this.n 9999}; var C function(){var n 8888}; B.prototype A; C.prototype A; var b new B(); var c new C(); A.n console.log(b.n); console.log(c.n);输出结果:9999 4400 解析: conso…...

人人能读懂redux原理剖析

一、Redux是什么? 众所周知,Redux最早运用于React框架中,是一个全局状态管理器。Redux解决了在开发过程中数据无限层层传递而引发的一系列问题,因此我们有必要来了解一下Redux到底是如何实现的? 二、Redux的核心思想…...

uniCloud云开发----7、uniapp通过uni-swiper-dot实现轮播图

uniapp通过uni-swiper-dot实现轮播图前言效果图1、官网实现的效果2、需求中使用到的效果图官网提供的效果图源码1、html部分2、js部分3、css部分根据需求调整轮播图前言 uni-swiper-dot.文档 uni-swiper-dot 轮播图指示点 - DCloud 插件市场 本次展示根据需求制作的和官网用到…...

IM即时通讯构建企业协同生态链

在当今互联网信息飞速发展的时代,随着企业对协同办公要求的提高,协同办公的定义提升到了智能化办公的范畴。大多企业都非常重视构建连接用户、员工和合作伙伴的生态平台,利用即时通讯软件解决企业内部的工作沟通、信息传递和知识共享等问题。…...

Python实现构建gan模型, 输入一个矩阵和两个参数值,输出一个矩阵

构建一个GAN模型,使用Python实现,该模型将接受一个矩阵和两个参数值作为输入,并输出另一个矩阵。GAN(生成对抗网络)是一种深度学习模型,由生成器和判别器两部分组成,可以用于生成具有一定规律性的数据,如图像或音频。 # 定义生成器 def make_generator(noise_dim, dat…...

开学准备哪些电容笔?ipad触控笔推荐平价

在现代,数码产品的发展受到高技术的驱动。不管是在工作上,还是在学习上,大的显示屏可以使图像更加清晰。Ipad将成为我们日常生活中不可或缺的一部分,无论现在或将来。如果ipad配上一款方便操作的电容笔,将极大地提高我…...

放下和拿起 解放自己

放下太难,从过去中解放自己 工作这么久了,第一次不拿包上班,真爽 人的成长都是在碰撞和摸索中产生的,通过摸索,知道自己能力的边界和欲望的边界以及身体的边界,这三个决定了 你能做什么 你能享受什么&…...

100%BIM学员的疑惑:不会CAD可以学Revit吗?

在新一轮科技创新和产业变革中,信息化与建筑业的融合发展已成为建筑业发展的方向,将对建筑业发展带来战略性和全局性的影响。 建筑业是传统产业,推动建筑业科技创新,加快推进信息化发展,激发创新活力,培育…...

经常会采坑的javascript原型应试题

一. 前言 原型和原型链在面试中历来备受重视,经常被提及。说难可能也不太难,但要真正完全理解,吃透它,还是要多下功夫的。 下面为大家简单阐述我对原型和原型链的理解,若是觉得有说的不对的地方&#xff…...

完全背包—动态规划

一、背包问题概述 如图,完全背包与01背包的区别只有一点:01背包中每个物品只能取一个而完全背包中每个物品可以取无数个。解决完全背包问题必须首先弄明白01背包,不清楚的可以看我的这篇文章01背包—动态规划。 二、例题 重量价值物品0115物…...

消息队列MQ介绍

消息队列技术是分布式应用间交换信息的一种技术。消息队列可驻留在内存或磁盘上,队列存储消息直到它们被应用程序读走。通过消息队列,应用程序可独立地执行--它们不需要知道彼此的位置、或在继续执行前不需要等待接收程序接收此消息。 消息中间件概述 消息队列技术是…...

C语言进阶(八)—— 链表

1. 链表基本概念1.1 什么是链表链表是一种常用的数据结构,它通过指针将一些列数据结点,连接成一个数据链。相对于数组,链表具有更好的动态性(非顺序存储)。数据域用来存储数据,指针域用于建立与下一个结点的…...

手工测试用例就是自动化测试脚本——使用ruby 1.9新特性进行自动化脚本的编写

昨天因为要装watir-webdriver的原因将用了快一年的ruby1.8.6升级到了1.9。由于1.9是原生支持unicode编码,所以我们可以使用中文进行自动化脚本的编写工作。 做了简单的封装后,我们可以实现如下的自动化测试代码。请注意,这些代码是可以正确运…...

云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?

大家好,欢迎来到《云原生核心技术》系列的第七篇! 在上一篇,我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在,我们就像一个拥有了一块崭新数字土地的农场主,是时…...

【WiFi帧结构】

文章目录 帧结构MAC头部管理帧 帧结构 Wi-Fi的帧分为三部分组成:MAC头部frame bodyFCS,其中MAC是固定格式的,frame body是可变长度。 MAC头部有frame control,duration,address1,address2,addre…...

IT供电系统绝缘监测及故障定位解决方案

随着新能源的快速发展,光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域,IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选,但在长期运行中,例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...

学习STC51单片机32(芯片为STC89C52RCRC)OLED显示屏2

每日一言 今天的每一份坚持,都是在为未来积攒底气。 案例:OLED显示一个A 这边观察到一个点,怎么雪花了就是都是乱七八糟的占满了屏幕。。 解释 : 如果代码里信号切换太快(比如 SDA 刚变,SCL 立刻变&#…...

【数据分析】R版IntelliGenes用于生物标志物发现的可解释机器学习

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍流程步骤1. 输入数据2. 特征选择3. 模型训练4. I-Genes 评分计算5. 输出结果 IntelliGenesR 安装包1. 特征选择2. 模型训练和评估3. I-Genes 评分计…...

LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》

这段 Python 代码是一个完整的 知识库数据库操作模块,用于对本地知识库系统中的知识库进行增删改查(CRUD)操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 📘 一、整体功能概述 该模块…...

基于SpringBoot在线拍卖系统的设计和实现

摘 要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 在线拍卖系统,主要的模块包括管理员;首页、个人中心、用户管理、商品类型管理、拍卖商品管理、历史竞拍管理、竞拍订单…...

Web中间件--tomcat学习

Web中间件–tomcat Java虚拟机详解 什么是JAVA虚拟机 Java虚拟机是一个抽象的计算机,它可以执行Java字节码。Java虚拟机是Java平台的一部分,Java平台由Java语言、Java API和Java虚拟机组成。Java虚拟机的主要作用是将Java字节码转换为机器代码&#x…...

AI语音助手的Python实现

引言 语音助手(如小爱同学、Siri)通过语音识别、自然语言处理(NLP)和语音合成技术,为用户提供直观、高效的交互体验。随着人工智能的普及,Python开发者可以利用开源库和AI模型,快速构建自定义语音助手。本文由浅入深,详细介绍如何使用Python开发AI语音助手,涵盖基础功…...

如何通过git命令查看项目连接的仓库地址?

要通过 Git 命令查看项目连接的仓库地址,您可以使用以下几种方法: 1. 查看所有远程仓库地址 使用 git remote -v 命令,它会显示项目中配置的所有远程仓库及其对应的 URL: git remote -v输出示例: origin https://…...