当前位置: 首页 > article >正文

每天五分钟,跟学pytorch_day3:构建和训练图像分类器

目标给模型输入一张纯色的图片模型识别并输出其颜色一、数据准备这里我们将使用经典的 CIFAR10 数据集它包含 10 个类别的彩色图像每个类别有 6000 张图像图像大小为 32x32 像素。①使用 torchvision 加载 CIFAR10 数据集import torch import torchvision import torchvision.transforms as transforms ## 数据预处理将图像转换为张量并进行标准化 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) ## 下载训练集和测试集 trainset torchvision.datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtransform) trainloader torch.utils.data.DataLoader(trainset, batch_size4, shuffleTrue, num_workers2) testset torchvision.datasets.CIFAR10(root./data, trainFalse, downloadTrue, transformtransform) testloader torch.utils.data.DataLoader(testset, batch_size4, shuffleFalse, num_workers2) ## 定义类别名称 classes (plane, car, bird, cat, deer, dog, frog, horse, ship, truck)1数据预处理作用通过一系列操作将图像数据转换为适合神经网络输入的形式。transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])组合多个图像变换操作2数据集下载与加载训练集与测试集划分CIFAR-10 数据集被划分为训练集和测试集训练集用于训练模型测试集用于评估模型性能。torchvision.datasets.CIFAR10用于加载 CIFAR-10 数据集。root./data指定数据集的存储路径若数据集不存在则会下载到该目录。trainTrue表示加载训练集。downloadTrue若数据集不存在则自动下载。transformtransform对数据集应用之前定义的变换操作。torch.utils.data.DataLoader用于创建数据加载器将数据集打包成小批量数据方便模型训练。trainset和testset分别是训练集和测试集。batch_size4每次迭代加载 4 个样本。shuffleTrue在训练集中打乱数据顺序使模型在训练时看到更多样的数据分布。shuffleFalse在测试集中不打乱数据顺序便于评估模型性能。num_workers2指定使用 2 个子进程加载数据可加快数据加载速度。3类别名称定义classes (plane, car, bird, cat, deer, dog, frog, horse, ship, truck)定义了 CIFAR-10 数据集中 10 个类别的名称用于后续对预测结果的解释和分析。②数据可视化import matplotlib.pyplot as plt import numpy as np ## 定义一个函数用于显示图像 def imshow(img): img img / 2 0.5 # 反标准化 npimg img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() ## 获取一批训练数据 dataiter iter(trainloader) images, labels next(dataiter) ## 显示图像 imshow(torchvision.utils.make_grid(images)) ## 打印标签 print( .join(f{classes[labels[j]]:5s} for j in range(4)))二、定义卷积神经网络模型颜色分类器核心①定义 CNN 架构熟悉流程以及神经网络细节可见往期文章点击链接即可跳转import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() # 定义卷积层和池化层 self.conv1 nn.Conv2d(3, 6, 5) # 输入通道3输出通道6卷积核大小5 self.pool nn.MaxPool2d(2, 2) # 最大池化层窗口大小2步长2 self.conv2 nn.Conv2d(6, 16, 5) # 输入通道6输出通道16卷积核大小5 # 定义全连接层 self.fc1 nn.Linear(16 * 5 * 5, 120) self.fc2 nn.Linear(120, 84) self.fc3 nn.Linear(84, 10) def forward(self, x): # 前向传播过程 x self.pool(F.relu(self.conv1(x))) # 卷积 激活 池化 x self.pool(F.relu(self.conv2(x))) # 卷积 激活 池化 x x.view(-1, 16 * 5 * 5) # 展平操作 x F.relu(self.fc1(x)) # 全连接 激活 x F.relu(self.fc2(x)) # 全连接 激活 x self.fc3(x) # 输出层 return x net Net() print(net)卷积层self.conv1 nn.Conv2d(3, 6, 5)第一个卷积层输入通道数为 3对应 RGB 三通道图像输出通道数为 6卷积核大小为 5。该层用于提取图像的局部特征。self.pool nn.MaxPool2d(2, 2)最大池化层窗口大小为 2步长为 2用于降低特征图的空间维度减少计算量同时保留重要特征。self.conv2 nn.Conv2d(6, 16, 5)第二个卷积层输入通道数为 6输出通道数为 16卷积核大小为 5。进一步提取更抽象的特征。全连接层self.fc1 nn.Linear(16 * 5 * 5, 120)第一个全连接层输入大小为 16 * 5 * 5输出大小为 120。将卷积层提取的特征映射到更高维的空间进行更复杂的数据表示。self.fc2 nn.Linear(120, 84)第二个全连接层输入大小为 120输出大小为 84。继续对特征进行变换和整合。self.fc3 nn.Linear(84, 10)输出层输入大小为 84输出大小为 10对应 CIFAR - 10 数据集的 10 个类别。前向传播在forward方法中输入图像先经过第一个卷积层应用ReLU激活函数后进行最大池化然后重复该过程经过第二个卷积层。之后将特征图展平并依次通过三个全连接层前两个全连接层后都应用ReLU激活函数最后一层直接输出分类结果。三、定义损失函数和优化器模型训练的指引①损失函数 优化器损失函数用于衡量模型预测结果与真实标签之间的差距优化器则负责根据损失函数的梯度信息更新模型参数。损失函数优化器详解见往期文章import torch.optim as optim # 使用交叉熵损失函数和随机梯度下降优化器 criterion nn.CrossEntropyLoss() optimizer optim.SGD(net.parameters(), lr0.001, momentum0.9)criterion nn.CrossEntropyLoss()定义损失函数 使用交叉熵损失函数衡量模型输出与真实标签之间的差异。optimizer optim.SGD(net.parameters(), lr0.001, momentum0.9)定义优化器 使用随机梯度下降SGD优化器lr0.001表示学习率控制参数更新步长momentum0.9表示动量因子加速梯度下降过程。tips交叉熵损失函数是一种用于衡量两个概率分布之间差异的函数主要用于分类问题中的损失计算。详细解释可见大佬文章四、训练网络提升模型性能训练过程是模型学习数据特征、优化参数的关键环节。我们需要多次迭代训练数据逐步调整模型参数以降低损失函数的值。for epoch in range(2): # 遍历数据集多次 running_loss 0.0 for i, data in enumerate(trainloader, 0): # 获取输入数据和标签 inputs, labels data # 清空梯度缓存 optimizer.zero_grad() # 前向传播 反向传播 优化 outputs net(inputs) loss criterion(outputs, labels) loss.backward() optimizer.step() # 打印统计信息 running_loss loss.item() if i % 2000 1999: # 每 2000 个小批量打印一次 print(f[{epoch 1}, {i 1:5d}] loss: {running_loss / 2000:.3f}) running_loss 0.0 print(Finished Training)这段代码是训练神经网络的核心部分它实现了模型的训练过程在每个训练周期epoch中依次处理训练集中的每个小批量数据通过前向传播计算输出和损失再通过反向传播更新模型参数。遍历数据集多次-初始化损失统计变量-遍历每个小批量数据-获取输入数据和标签-清空梯度缓存-前向传播 反向传播 优化-打印统计信息打印统计信息running_loss loss.item()将当前小批量的损失值累加到running_loss中。if i % 2000 1999: # 每 2000 个小批量打印一次每处理 2000 个小批量后打印一次平均损失信息。print(f[{epoch 1}, {i 1:5d}] loss: {running_loss / 2000:.3f})打印 当前训练周期和小批量索引以及过去 2000 个小批量的平均损失值。running_loss 0.0重置running_loss以便开始累计下一个 2000 小批量的损失值。tips:如果电脑配备 GPU可以利用 GPU 加速模型训练过程显著提升训练速度。device torch.device(cuda:0 if torch.cuda.is_available() else cpu) net.to(device) ## 将输入数据和标签移动到 GPU 上 inputs, labels inputs.to(device), labels.to(device)五、保存和加载模型模型持久化与复用训练完成后我们可以将模型参数保存到文件中以便后续加载和使用。PATH ./cifar_net.pth torch.save(net.state_dict(), PATH)加载模型参数net Net() net.load_state_dict(torch.load(PATH))六、测试网络评估模型性能correct 0 total 0 with torch.no_grad(): # 在测试阶段不需要计算梯度 for data in testloader: images, labels data outputs net(images) _, predicted torch.max(outputs.data, 1) total labels.size(0) correct (predicted labels).sum().item() print(fAccuracy of the network on the 10000 test images: {100 * correct / total:.2f}%)禁用梯度计算在测试阶段模型参数已经训练完成不需要再计算梯度或更新参数。使用torch.no_grad()可以临时关闭梯度计算节省内存并加快计算速度。遍历测试数据遍历测试集testloader中的每个小批量数据。每个小批量数据包含images输入图像数据。labels图像对应的真实标签。模型预测outputs net(images)将输入图像images通过训练好的模型net进行前向传播得到模型的输出outputs。torch.max(outputs.data, 1)对模型输出进行处理返回每个样本预测概率最大的类别索引_表示最大概率值我们不关心。predicted表示模型预测的类别索引。统计预测结果total labels.size(0)累计当前小批量的样本数量到总样本数total中。(predicted labels)比较预测类别与真实标签返回一个布尔张量。.sum().item()统计预测正确的样本数量并转换为 Python 数字累加到correct中。打印准确率总结本文介绍了PyTorch训练图像分类器的核心步骤包括数据准备、网络定义、模型训练、性能评估以及GPU 加速等关键技术。引用文献资料PyTorch 神经网络_w3cschoolhttps://www.w3cschool.cn/pytorch/pytorch-w18e3be1.html交叉熵损失函数解析-CSDN博客https://blog.csdn.net/b1055077005/article/details/100152102?ops_request_miscelastic_search_miscrequest_id32f6dd1f5f7a13d1801a3dbf24560931biz_id0utm_mediumdistribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-100152102-null-null.142^v102^pc_search_result_base3utm_term%E4%BA%A4%E5%8F%89%E7%86%B5%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0spm1018.2226.3001.4187

相关文章:

每天五分钟,跟学pytorch_day3:构建和训练图像分类器

目标:给模型输入一张纯色的图片,模型识别并输出其颜色 一、数据准备: 这里我们将使用经典的 CIFAR10 数据集,它包含 10 个类别的彩色图像,每个类别有 6000 张图像,图像大小为 32x32 像素。 ①使用 torch…...

MySQL 三层 B+ 树能存多少数据?

这是一个非常经典且常被问到的 MySQL 面试题。要计算 MySQL 三层 B 树能存多少数据,我们需要拆解 B 树的结构、页(Page)的大小、索引项的大小以及数据行的平均大小。 结论先行: 在默认配置下(页大小 16KB,主…...

军工领域OA系统怎样高效转存Word图文到网页端?

企业网站Word/公众号内容导入功能集成方案 一、需求分析与技术调研 1.1 需求分解 作为浙江某软件公司的前端工程师,我近期接到一个企业后台管理系统的功能升级需求,主要包含两个核心功能: Word粘贴功能:从Word直接复制内容到编…...

RPA-Python与Dependabot集成:依赖更新自动化的完整指南

RPA-Python与Dependabot集成:依赖更新自动化的完整指南 【免费下载链接】RPA-Python Python package for doing RPA 项目地址: https://gitcode.com/gh_mirrors/rp/RPA-Python 在Python机器人流程自动化(RPA)领域,RPA-Pyth…...

如何实现网页编辑器无缝导入Word文档内容?

河南软件工程大三狗的CMS升级记:从Word粘贴到Latex公式,99元预算的极限操作! 一、项目背景:穷学生的倔强 作为一枚即将毕业的大三狗,自己撸了个CMS新闻管理系统,但后台编辑器太挫——从Word复制内容粘贴进…...

学之思xzs系统核心代码解析:试卷生成模块的设计与实现

学之思xzs系统核心代码解析:试卷生成模块的设计与实现 【免费下载链接】xzs 在线考试系统 项目地址: https://gitcode.com/gh_mirrors/xz/xzs 学之思xzs在线考试系统是一个功能强大的开源考试平台,其核心功能之一就是智能试卷生成模块。本文将深入…...

MangoHud项目管理指南:如何高效使用GitHub Projects进行协作开发

MangoHud项目管理指南:如何高效使用GitHub Projects进行协作开发 【免费下载链接】MangoHud A Vulkan and OpenGL overlay for monitoring FPS, temperatures, CPU/GPU load and more. Discord: https://discordapp.com/invite/Gj5YmBb 项目地址: https://gitcode…...

Python实战:用LDA模型分析文本主题演化(附完整代码与避坑指南)

Python实战:用LDA模型追踪文本主题演化全流程 文本数据中隐藏的主题演化规律往往蕴含着宝贵的信息价值。作为数据分析师和Python开发者,掌握LDA主题建模技术并能够分析主题随时间的演变趋势,是一项极具实用价值的技能。本文将完整呈现从数据…...

Terraform工作流自动化:使用Terratest实现完整测试

Terraform工作流自动化:使用Terratest实现完整测试 【免费下载链接】terratest Terratest is a Go library that makes it easier to write automated tests for your infrastructure code. 项目地址: https://gitcode.com/gh_mirrors/te/terratest 在现代D…...

保姆级教程:用YOLOv8n搞定数字仪表盘检测,附390张数据集与完整代码

工业视觉实战:YOLOv8n数字仪表盘检测全流程解析 数字仪表盘在电力、化工、制造等行业中广泛应用,传统人工读数方式效率低下且容易出错。本文将手把手教你从零开始构建一个基于YOLOv8n的数字仪表盘检测系统,包含390张标注数据集的处理技巧和完…...

机械狗在复杂环境中的SLAM导航突破:从实验室到现实世界的跨越

1. 机械狗SLAM导航的技术挑战与现实痛点 第一次带着机械狗去建筑工地测试时,我亲眼看着这个价值几十万的"高科技产物"在碎石堆前突然死机——激光雷达被扬尘干扰,视觉系统因强光过曝,四条腿僵在原地不断发出错误警报。这个尴尬场景…...

BootstrapBlazor水波纹按钮:打造令人惊艳的点击交互效果

BootstrapBlazor水波纹按钮:打造令人惊艳的点击交互效果 【免费下载链接】BootstrapBlazor 项目地址: https://gitcode.com/gh_mirrors/bo/BootstrapBlazor BootstrapBlazor是一款功能强大的Blazor UI组件库,提供了丰富的界面元素和交互效果。其…...

军工嵌入式C固件逆向攻防全景图(2024最新版):从符号剥离到IR层语义混淆,92%的商用工具已失效

第一章:军工嵌入式C固件逆向攻防态势总览军工嵌入式系统普遍采用高度定制化的C语言固件,运行于ARM Cortex-M、PowerPC 405/74xx或SPARC LEON等专用处理器平台,其二进制分发形态(如裸机BIN、SREC、Intel HEX)与封闭调试…...

SwinIR智能安全:公共安全图像的目标识别优化

SwinIR智能安全:公共安全图像的目标识别优化 【免费下载链接】SwinIR SwinIR: Image Restoration Using Swin Transformer (official repository) 项目地址: https://gitcode.com/gh_mirrors/sw/SwinIR 在公共安全领域,图像的清晰度直接影响目标识…...

Splitflap传感器PCB设计与制造:从原理图到PCB布局最佳实践

Splitflap传感器PCB设计与制造:从原理图到PCB布局最佳实践 【免费下载链接】splitflap DIY split-flap display 项目地址: https://gitcode.com/gh_mirrors/sp/splitflap DIY split-flap显示器的传感器PCB设计是实现精确位置检测的关键技术。霍尔效应传感器P…...

云计算基础Day07:计划任务、软件包管理、本地YUM仓库

Linux核心操作知识总结(计划任务、软件包管理、本地YUM仓库) 本文基于Red Hat/RockyLinux系统,详细讲解了计划任务crontab、RPM包基础管理、本地YUM仓库搭建与使用三大核心操作,同时修正实操细节偏差、补充企业级运维场景的注意事…...

guacamole-server核心架构解析:深入理解libguac库和guacd守护进程

guacamole-server核心架构解析:深入理解libguac库和guacd守护进程 【免费下载链接】guacamole-server Mirror of Apache Guacamole Server 项目地址: https://gitcode.com/gh_mirrors/gu/guacamole-server guacamole-server是Apache Guacamole项目的核心组件…...

阿里小云KWS模型在AR/VR设备中的语音交互方案

阿里小云KWS模型在AR/VR设备中的语音交互方案 1. 引言 戴上AR眼镜或VR头显,眼前是令人惊叹的虚拟世界,但当你想要切换场景或调整设置时,却不得不摘下设备去找按钮或手柄——这样的体验是不是很熟悉?传统的AR/VR交互方式&#xf…...

深入go-json内部:操作码序列与虚拟机的完美结合

深入go-json内部:操作码序列与虚拟机的完美结合 【免费下载链接】go-json Fast JSON encoder/decoder compatible with encoding/json for Go 项目地址: https://gitcode.com/gh_mirrors/go/go-json go-json作为一款高性能的JSON编解码库,其核心优…...

特征值可视化指南:用Matplotlib动态演示PCA降维全过程

特征值可视化指南:用Matplotlib动态演示PCA降维全过程 在数据科学领域,理解高维数据的结构是一项基础但关键的能力。主成分分析(PCA)作为最常用的降维技术之一,其核心数学原理却常常让初学者望而生畏——特征值、特征向…...

如何通过API批量重命名ONLYOFFICE Docs文档标签:终极指南

如何通过API批量重命名ONLYOFFICE Docs文档标签:终极指南 【免费下载链接】DocumentServer ONLYOFFICE Docs is a free collaborative online office suite comprising viewers and editors for texts, spreadsheets and presentations, forms and PDF, fully compa…...

Transformer在图像恢复中的实战应用:AdaIR频率挖掘与调制技术解析

Transformer在图像恢复中的实战突破:频率域自适应修复技术详解 1. 频率域视角下的图像退化本质 当我们用手机在雨天拍摄照片时,那些恼人的雨滴条纹;在雾天远眺时,景物仿佛被蒙上了一层薄纱;或是夜间拍摄时画面出现的颗…...

多 agents 飞书群内通讯配置实战,根因 + 可复现配置 + 防坑清单

如果你也在用下龙虾openclaw,添加多个机器人到一个群里,统一指挥和调度,那么你大概率遇到过这个极其典型的线上诡异现象: 结果却是:A 机器人正常收消息、正常回复B 机器人像完全“失明”,毫无反应 很多人第一反应会怀…...

Flexprice订阅管理详解:如何处理升级、降级和暂停的完整流程

Flexprice订阅管理详解:如何处理升级、降级和暂停的完整流程 【免费下载链接】flexprice 🌟Open source pricing and billing infrastructure to support any pricing model, from usage-based to subscription and everything in between.👨…...

5分钟掌握TIDAL音乐下载:tidal-dl-ng完整使用指南

5分钟掌握TIDAL音乐下载:tidal-dl-ng完整使用指南 【免费下载链接】tidal-dl-ng TIDAL Media Downloader Next Generation! Up to HiRes / TIDAL MAX 24-bit, 192 kHz. 项目地址: https://gitcode.com/gh_mirrors/ti/tidal-dl-ng tidal-dl-ng是一款强大的TID…...

Mapus企业级应用场景:从团队协作到商业决策支持的完整指南

Mapus企业级应用场景:从团队协作到商业决策支持的完整指南 【免费下载链接】mapus A map tool with real-time collaboration 🗺️ 项目地址: https://gitcode.com/gh_mirrors/ma/mapus Mapus是一款开源的实时协作地图工具,专为团队协…...

隐私计算实践:OpenClaw本地化Qwen3-32B处理加密数据

隐私计算实践:OpenClaw本地化Qwen3-32B处理加密数据 1. 为什么需要本地化隐私计算 去年我在处理一批医疗调研数据时遇到了一个棘手问题:数据包含敏感个人信息,但需要AI辅助进行统计分析。当时尝试过几个云端方案,要么无法满足合…...

C#数据持久化新思路:除了Json和XML,试试康耐视CogSerializer存对象到文件

C#数据持久化新思路:探索CogSerializer在复杂对象序列化中的独特价值 在C#开发中,数据持久化是一个永恒的话题。当我们谈论序列化时,Json和XML往往是开发者最先想到的方案。Json.NET和XmlSerializer确实能解决大部分场景下的需求,…...

【真能降AI】速降AIGC,降重!标价即卖价,全网最低!维普、知网、万方等一键降AIGC率,逻辑清晰,语义通顺,只需稍改错别字和标点。

【真能降AI】速降AIGC,降重!标价即卖价,全网最低!维普、知网、万方等一键降AIGC率,逻辑清晰,语义通顺,只需稍改错别字和标点。 降AI人工服务,维普、知网专用,不限字数。依…...

MangoHud与AI游戏助手:性能优化建议生成

MangoHud与AI游戏助手:性能优化建议生成 【免费下载链接】MangoHud A Vulkan and OpenGL overlay for monitoring FPS, temperatures, CPU/GPU load and more. Discord: https://discordapp.com/invite/Gj5YmBb 项目地址: https://gitcode.com/gh_mirrors/ma/Mang…...