小白的进阶之路系列之四----人工智能从初步到精通pytorch自定义数据集下
本篇涵盖的内容
在之前的文章中,我们已经讨论了如何获取数据,转换数据以及如何准备自定义数据集,本篇文章将涵盖更加深入的问题,希望通过详细的代码示例,帮助大家了解PyTorch自定义数据集是如何应对各种复杂实际情况中,数据处理的。
更加详细的,我们将讨论下面一些内容:
主题 | 内容 |
---|---|
7 Model 0:没有数据增强的TinyVGG | 到这个阶段,我们已经准备好了数据,让我们建立一个能够拟合数据的模型。我们还将创建一些训练和测试函数来训练和评估我们的模型。 |
8 探索损失曲线 | 损失曲线是观察你的模型如何训练/改进的好方法。它们也是一种很好的方法来判断你的模型是过拟合还是欠拟合。 |
9 Model 1:带数据增强功能的TinyVGG | 到目前为止,我们已经尝试了一个没有数据增强的模型? |
10 比较模型结果 | 让我们比较不同模型的损失曲线,看看哪个表现更好,并讨论一些改进性能的选项。 |
11 对自定义图像进行预测 | 我们的模型是在披萨、牛排和寿司图像的数据集上训练的。在本节中,我们将介绍如何使用我们训练好的模型来预测现有数据集之外的图像。 |
7 Model 0:没有数据增强的TinyVGG
好了,我们已经看到了如何把数据从文件夹里的图像变成变换后的张量。
现在让我们构建一个计算机视觉模型,看看我们是否可以将图像分类为披萨、牛排或寿司。
首先,我们将从一个简单的变换开始,仅将图像大小调整为(64,64)并将它们转换为张量。
7.1 为模型0创建转换和加载数据
# Create simple transform
simple_transform = transforms.Compose([ transforms.Resize((64, 64)),transforms.ToTensor(),
])
很好,现在我们有了一个简单的变换,让我们
-
加载数据,首先使用torchvision.datasets.ImageFolder()将每个训练和测试文件夹转换为Dataset
-
然后使用torch.utils.data.DataLoader())转换为数据加载器。
-
我们将把batch_size=32和num_workers设置为机器上尽可能多的cpu(这取决于您使用的机器)。
# 1. Load and transform data
from torchvision import datasets
train_data_simple = datasets.ImageFolder(root=train_dir, transform=simple_transform)
test_data_simple = datasets.ImageFolder(root=test_dir, transform=simple_transform)# 2. Turn data into DataLoaders
import os
from torch.utils.data import DataLoader# Setup batch size and number of workers
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKERS} workers.")# Create DataLoader's
train_dataloader_simple = DataLoader(train_data_simple, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)test_dataloader_simple = DataLoader(test_data_simple, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)print(train_dataloader_simple, test_dataloader_simple)
输出为:
Creating DataLoader's with batch size 32 and 16 workers.
<torch.utils.data.dataloader.DataLoader object at 0x0000024974F734D0> <torch.utils.data.dataloader.DataLoader object at 0x0000024974F07A80>
很好dataloader已经创建好了,现在让我们设立模型。
7.2创建TinyVGG模型类
在上一篇文章中,我们使用了来自CNN解释器网站的TinyVGG模型。
让我们重新创建相同的模型,只不过这次我们将使用彩色图像而不是灰度图像(对于RGB像素,in_channels=3而不是in_channels=1)。
class TinyVGG(nn.Module):"""Model architecture copying TinyVGG from: https://poloclub.github.io/cnn-explainer/"""def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:super().__init__()self.conv_block_1 = nn.Sequential(nn.Conv2d(in_channels=input_shape, out_channels=hidden_units, kernel_size=3, # how big is the square that's going over the image?stride=1, # defaultpadding=1), # options = "valid" (no padding) or "same" (output has same shape as input) or int for specific number nn.ReLU(),nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units,kernel_size=3,stride=1,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2,stride=2) # default stride value is same as kernel_size)self.conv_block_2 = nn.Sequential(nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2))self.classifier = nn.Sequential(nn.Flatten(),# Where did this in_features shape come from? # It's because each layer of our network compresses and changes the shape of our input data.nn.Linear(in_features=hidden_units*16*16,out_features=output_shape))def forward(self, x: torch.Tensor):x = self.conv_block_1(x)# print(x.shape)x = self.conv_block_2(x)# print(x.shape)x = self.classifier(x)# print(x.shape)return x# return self.classifier(self.conv_block_2(self.conv_block_1(x))) # <- leverage the benefits of operator fusiontorch.manual_seed(42)
model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB) hidden_units=10, output_shape=len(train_data.classes)).to(device)
print(model_0)
输出为:
TinyVGG((conv_block_1): Sequential((0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU()(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(conv_block_2): Sequential((0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU()(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(classifier): Sequential((0): Flatten(start_dim=1, end_dim=-1)(1): Linear(in_features=2560, out_features=3, bias=True)</
相关文章:
小白的进阶之路系列之四----人工智能从初步到精通pytorch自定义数据集下
本篇涵盖的内容 在之前的文章中,我们已经讨论了如何获取数据,转换数据以及如何准备自定义数据集,本篇文章将涵盖更加深入的问题,希望通过详细的代码示例,帮助大家了解PyTorch自定义数据集是如何应对各种复杂实际情况中,数据处理的。 更加详细的,我们将讨论下面一些内容…...
安卓添加设备节点权限和selinux访问权限
# 1 修改设备节点权限及配置属性设置节点值 ## 1.1 修改设备节点权限 ### 1.1.1 不会手动卸载的节点 在system/core/rootdir/init.rc中添加节点权限 在on boot下面添加 chown system system /sys/kernel/usb/host chmod 0664 /sys/kernel/usb/host ### 1.1.2 支持热插拔的…...

谷歌Stitch:AI赋能UI设计,免费高效新利器
在AI技术日新月异的今天,各大科技巨头都在不断刷新我们对智能工具的认知。最近,谷歌在其年度I/O开发者大会期间,除了那些聚光灯下的重磅发布,还悄然上线了一款令人惊喜的AI工具——Stitch。这是一款全新的、完全免费的AI驱动UI&am…...

运营商地址和ip属地一样吗?怎么样更改ip属地地址
在互联网时代,IP属地和运营商地址是两个经常被提及的概念,但它们是否相同?如何更改IP属地地址?这些问题困扰着许多网民。本文将深入探讨这两个概念的区别,并详细介绍更改IP属地地址的方法。 一、运营商地址和IP属地一…...

在QT中,利用charts库绘制FFT图形
第1章 添加charts库 1.1 .pro工程添加chart库 1.1.1 在.pro工程里面添加charts库 1.1.2 在需要使用的地方添加这两个库函数,顺序一点不要搞错,先添加.pro,否则编译器会找不到这两个.h文件。 第2章 Charts关键绘图函数 2.1 QChart 类 QChart 是…...
ChatGPT + 知网 + 知乎,如何高效整合信息写出一篇专业内容?
——写作,不是闭门造车,而是高效聚合 🧠 为什么“信息整合力”才是AI时代的核心写作能力? 现在的写作,不缺工具,也不缺资料,缺的是: 把 scattered info 变成 structured idea 的能力…...

流媒体协议分析:流媒体传输的基石
在流媒体传输过程中,协议的选择至关重要,它决定了数据如何封装、传输和解析,直接影响着视频的播放质量和用户体验。本文将深入分析几种常见的流媒体传输协议,探讨它们的特点、应用场景及优缺点。 协议分类概述 流媒体传输协议根据…...

vscode中让文件夹一直保持展开不折叠
vscode中让文件夹一直保持展开不折叠 问题 很多小伙伴使用vscode发现空文件夹会折叠显示, 让人看起来非常难受, 如下图 解决办法 首先打开设置->setting, 搜索compact Folders, 去掉勾选即可, 如下图所示 效果如下 看起来非常爽 ! ! !...

JAVA-springboot整合Mybatis
SpringBoot从入门到精通-第15章 MyBatis框架 学习MyBatis心路历程 2022年学习java基础时候,想着怎么使用java代码操作数据库,咨询了项目上开发W同事,没有引用框架,操作数据库很麻烦,就帮我写好多行代码,就…...

深度学习pycharm debug
深度学习中,Debug 是定位并解决代码逻辑错误(如张量维度不匹配)、训练异常(如 Loss 波动)、数据问题(如标签错误)的关键手段,通过打印维度、可视化梯度等方法确保模型正常运行、优化…...

MicroPython+L298N+ESP32控制电机转速
要使用MicroPython控制L298N电机驱动板来控制电机的转速,你可以通过PWM(脉冲宽度调制)信号来调节电机速度。L298N是一个双H桥驱动器,可以同时控制两个电机的正反转和速度。 硬件准备: 1. L298N 电机控制板 2. ESP32…...
Hive的存储格式如何优化?
Hive的存储格式对查询性能、存储成本和数据处理效率有显著影响。以下是主流存储格式的特点、选择标准和优化方法: 一、主流存储格式对比 特性ORC(Optimized Row Columnar)ParquetTextFile(默认)SequenceFile数据布局…...

在部署了一台mysql5.7的机器上部署mysql8.0.35
在已部署 MySQL 5.7 的机器上部署 MySQL 8.0.35 的完整指南 在同一台服务器上部署多个 MySQL 版本需要谨慎规划,避免端口冲突和数据混淆。以下是详细的部署步骤: 一、规划配置 端口分配 MySQL 5.7:使用默认端口 3306MySQL 8.0.35࿱…...
OpenCV CUDA模块结构分析与形状描述符------在 GPU 上计算图像的原始矩(spatial moments)函数spatialMoments()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 该函数用于在 GPU 上计算图像的原始矩(spatial moments)。这些矩可用于描述图像中物体的形状特征,如面积、质…...

QT入门学习(一)---新建工程与、信号与槽
一: 新建QT项目 二:QT文件构成 2.1 first.pro 项目管理文件,下面来看代码解析 QT core guigreaterThan(QT_MAJOR_VERSION, 4): QT widgetsCONFIG c11TARGET main# The following define makes your compiler emit warnings if you use # any Qt feature …...

UE5.4.4+Rider2024.3.7开发环境配置
文章目录 一、UE5安装 安装有两种方式一种的源码编译安装、一种是EPIC安装,推荐后者,只需要注册一个EPIC账号就可以一键安装。 二、C环境安装 1.下载VisualStudioSetup 下载链接如下下载 Visual Studio Tools - 免费安装 Windows、Mac、Linux 选择社…...

Windows环境下PHP,在PowerShell控制台输出中文乱码
解决方法: 以管理员运行PowerShell , 输入: chcp 65001 重启控制台;然后就正常输出中文;...
第2篇:数据库连接池原理与自定义连接池开发实践
2.1 什么是数据库连接池? 数据库连接池(Connection Pool)是一种用于管理数据库连接对象的复用机制。它的主要目标是: 减少频繁创建/销毁连接的开销 提高系统对数据库资源的使用效率 支持连接复用、并发控制和连接健康检查 连接…...

性能优化 - 理论篇:性能优化的七类技术手段
文章目录 Pre引言性能优化的七类技术手段性能优化策略一览表1. 复用优化2. 计算优化2.1 并行执行2.2 变同步为异步2.3 惰性加载 3. 结果集优化3.1 数据格式与协议选择3.2 字段精简与按需返回3.3 批量处理与分页3.4 索引与位图加速 4. 资源冲突优化4.1 锁的分类与特点4.2 无锁与…...

华为IP(7)
端口隔离技术 产生的背景 1.以太交换网络中为了实现报文之间的二层隔离,用户通常将不同的端口加入不同的VLAN,实现二层广播域的隔离。 2.大型网络中,业务需求种类繁多,只通过VLAN实现二层隔离,会浪费有限的VLAN资源…...

AIGC与影视制作:技术革命、产业重构与未来图景
文章目录 一、AIGC技术全景:从算法突破到产业赋能1. **技术底座:多模态大模型的进化路径**2. **核心算法:从生成对抗网络到扩散模型的迭代** 二、AIGC在影视制作全流程中的深度应用1. **剧本创作:从“灵感枯竭”到“创意井喷”**2…...
spring-cloud-alibaba-sentinel-gateway
Spring Cloud Alibaba Sentinel Gateway 是阿里巴巴开源组件 Sentinel 与 Spring Cloud Gateway 的整合模块,主要用于在微服务架构中对网关层的流量进行控制、保护和监控。以下是它的详细说明: 一. 核心用途 网关层流量治理:在 API 网关&…...

Cursor 玩转 腾讯地图 MCP Server
腾讯地图WebService API 服务简介 腾讯地图WebService API 是基于HTTPS/HTTP协议构建的标准化地理数据服务接口。该接口支持跨平台调用,开发者可使用任意客户端、服务器端技术及编程语言,遵循API规范发起HTTPS请求,获取地理信息服务…...
【HarmonyOS 5】 ArkUI-X开发中的常见问题及解决方案
一、跨平台编译与适配问题 1. 平台特定API不兼容 问题现象:使用Router模块的replaceUrl或startAbility等鸿蒙专属API时,编译跨平台工程报错cant support crossplatform application。 解决方案: 改用ohos.router的跨平台封装API&a…...

2025年中国电商618年中大促策略分析:存量博弈与生态重构
图片来源:Photo by Samuel Regan-Asante on Unsplash 中国电商行业正经历一场从「增量扩张」到「存量深耕」的深刻转型。 随着网络购物用户规模突破9.74亿、线上消费渗透率逼近30%的临界点,传统流量红利逐渐消退,行业竞争已从「切蛋糕」转向…...

Deepseek给出的8255显示例程
#include <stdio.h> #include <conio.h> #include <dos.h>// 定义8255端口地址 (根据原理图译码确定) #define PORT_8255_A 0x8000 // PA端口地址 #define PORT_8255_B 0x8001 // PB端口地址 #define PORT_8255_C 0x8002 // PC端口地址 #define PORT_8255…...
React+Antd全局加载遮罩工具
下面是全局加载遮罩工具,功能:提供show和showWithDelay/hide方法用于显示/延时显示/隐藏遮罩,它还提供loading属性返回是否正在loading。通常用于耗时较长的操作,比如远端api调用。 如何用它,下面是个例子,…...
Qt OpenGL 光照实现
Qt 中使用 OpenGL 实现光照效果主要基于 OpenGL 的光照模型和着色器编程。以下是 Qt OpenGL 光照实现的核心原理: 一. 光照模型基础 OpenGL 使用 Phong 光照模型,包含三个主要光照分量: 环境光(Ambient):场景中的基础光照,没有方向性 漫反射光(Diffuse):与表面法线和光…...

智汇云舟携最新无人机2D地图快速重建技术亮相广西国际矿业展览会
5月22至25日,广西国际矿业展览会(以下简称 “矿业展”)在南宁国际会展中心成功举办。智汇云舟与合作伙伴广西空驭数智信息技术有限公司携无人机 2D地图快速重建技术,以及视频孪生智慧矿山解决方案参会,为矿山行业数字化…...
Rust: CString、CStr和String、str
在FFI与C交互中,少不了与C中字符串交互。在Rust中,有 各种String存在的意义: OsString:因为要与操作系统等复杂的世界交互; 因为Rust世界中的Strings 始终是有效的 UTF-8。对于非 UTF-8 字符串,可以用到OsString。 CSt…...