构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
引言
在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经网络(CNN),并在CIFAR-10数据集上进行训练和评估。通过本文,您将了解到数据预处理、模型定义、训练过程及结果可视化的完整流程。
数据预处理
首先,我们需要加载并预处理CIFAR-10数据集。CIFAR-10包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。我们使用torchvision
库来轻松加载这些数据,并应用一些基本的变换,如归一化。
import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1, 1]
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
模型定义
接下来,我们定义一个简单的卷积神经网络。该网络包含三个卷积层,两个池化层,以及两个全连接层。
import torch.nn as nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.fc1 = nn.Linear(64 * 8 * 8, 64) # 考虑到池化层后的尺寸self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.relu(self.conv3(x))x = x.view(-1, 64 * 8 * 8) # flattenx = torch.relu(self.fc1(x))x = self.fc2(x)return xnet = ConvNet()
训练过程
我们使用Adam优化器和交叉熵损失函数来训练模型,并将模型训练10个epoch。训练过程中,我们记录每个epoch的平均损失。
import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)num_epochs = 10
loss_history = [] # 记录每个epoch的平均损失
for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100}')running_loss = 0.0epoch_loss = running_loss / len(trainloader)loss_history.append(epoch_loss)print('Finished Training')
模型评估
训练完成后,我们在测试集上评估模型的性能,并计算准确率。
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()final_accuracy = 100 * correct / totalprint(f'Accuracy of the network on the 10000 test images: {final_accuracy} %')
结果可视化
最后,我们将训练过程中的损失和最终的准确率进行可视化,以便更直观地了解模型的训练效果。
import matplotlib.pyplot as plt# 可视化损失
plt.plot(range(1, num_epochs + 1), loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss History')
plt.show()# 可视化准确率
plt.bar(1, final_accuracy, width=0.4, label='Final Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Final Accuracy on Test Set')
plt.legend()
plt.show()
结论
本文介绍了如何使用PyTorch构建并训练一个简单的卷积神经网络对CIFAR-10数据集进行分类。通过数据预处理、模型定义、训练及结果可视化,我们完整地展示了深度学习项目的流程。希望本文能为您提供一些有用的参考和启发,帮助您在自己的深度学习项目中取得更好的成果。
相关文章:

构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类 引言 在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经…...
flowable 根据xml 字符串生成流程图
//获取xml InputStream stream repositoryService.getProcessModel(processDefinitionId); String result IOUtils.toString(stream, StandardCharsets.UTF_8); // 创建 XMLInputFactory XMLInputFactory factory XMLInputFactory.newInstance(); // 从字符…...

AI建模——AI生成3D内容算法产品介绍与模型免费下载
说明: 记录AI文生3D模型、图生3D模型的相关产品;记录其性能、功能、收费与免费方法 0.AI建模产品 Rodin MeshAnything Meshy 生成效果比较: Rodin效果最好、Meshy其次 1.Rodin 官网:gHyperHuman 支持:文生模型、…...

在Go中迅速使用RabbitMQ
文章目录 1 认识1.1 MQ分类1.2 安装1.3 基本流程 2 [Work模型](https://www.rabbitmq.com/tutorials/tutorial-two-go#preparation)3 交换机3.1 fanout3.2 direct3.3 [topic](https://www.rabbitmq.com/tutorials/tutorial-five-go) 4 Golang创建交换机/队列/Publish/Consume/B…...

Windows JDK安装详细教程
一、关于JDK 1.1 简介 Java是一种广泛使用的计算机编程语言,拥有跨平台、面向对象、泛型编程的特性,广泛应用于企业级Web应用开发和移动应用开发。 JDK(Java Development Kit)是用于开发 Java 应用程序的工具包。它由以下几个主要…...

Ribbon负载均衡底层原理
springcloude服务实例与服务实例之间发送请求,首先根据服务名注册到nacos,然后发送请求,nacos可以根据服务名找到对应的服务实例。 SpringCloudRibbon的底层采用了一个拦截器,拦截了openfeign发出的请求,对地址做了修…...

【C语言可变参数函数的使用与原理分析】
文章目录 1 前言2 实例2.1实例程序2.2程序执行结果2.3 程序分析 3 补充4 总结 1 前言 在编程过程中,有时会遇到需要定义参数数量不固定的函数的情况。 C语言提供了一种灵活的解决方案:变参函数。这种函数能够根据实际调用时的需求,接受任意…...

【笔记】Java EE应用开发环境配置(JDK+Maven+Tomcat+MySQL+IDEA)
一、安装JDK17 1.下载JDK17 https://download.oracle.com/java/17/archive/jdk-17.0.7_windows-x64_bin.zip 2.配置环境变量 下载后,解压到本地(目录中最好不要有中文或特殊字符) 打开【控制面板】-【系统和安全】-【系统】-【高级系统…...

一文讲懂扩散模型
一文讲懂扩散模型 扩散模型(Diffusion Models, DM)是近年来在计算机视觉、自然语言处理等领域取得显著进展的一种生成模型。其思想根源可以追溯到非平衡热力学,通过模拟数据的扩散和去噪过程来生成新的样本。以下将详细阐述扩散模型的基本原理…...

学习笔记八:基于Jenkins+k8s+Git+DockerHub等技术链构建企业级DevOps容器云平台
基于Jenkinsk8sGitDockerHub等技术链构建企业级DevOps容器云平台 测试jenkins的CI/CD在Jenkins中安装kubernetes插件安装blueocean插件配置jenkins连接到我们存在的k8s集群配置pod-template添加自己的dockerhub凭据测试通过Jenkins部署应用发布到k8s开发环境、测试环境、生产环…...

科研绘图系列:R语言柱状图分布(histogram plot)
文章目录 介绍加载R包读取数据画图介绍 柱状图(Bar Chart)是一种常用的数据可视化图表,用于展示和比较不同类别或组的数据。它通过在二维平面上绘制一系列垂直或水平的柱子来表示数据的大小,每个柱子的长度或高度代表一个数据点的数值。柱状图非常适合于展示分类数据的分布…...

vue3+ts封装类似于微信消息的组件
组件代码如下: <template><div:class"[voice-message, { sent: isSent, received: !isSent }]":style"{ backgroundColor: backgroundColor }"click"togglePlayback"><!-- isSent为false在左侧,为true在右…...
ES6 reduce方法详解:示例、应用场景与实用技巧
在JavaScript中,reduce 方法是一个非常强大的数组方法,它允许你将数组中的元素归并(reduce)为单个值。reduce 方法执行一个由你提供的reducer函数(归并函数),将其结果汇总为单一的返回值。 一.…...

java后端保存的本地图片通过ip+端口直接访问
直接上代码吧 package com.ydx.emms.datapro.controller;import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.…...

2024 年高教社杯全国大学生数学建模竞赛B题4小问解题思路(第二版)
原文链接:https://www.cnblogs.com/qimoxuan/articles/18399415 问题 1:抽样检测方案设计 详细解题思路: 确定抽样检测目标:企业需要确定一个可接受的次品率上限(标称值),以及在该次品率下&am…...

docker-nginx数据卷挂载
一、案例1-利用Nginx容器部署静态资源 1.1、需求: 创建Nginx容器, 修改nginx容器内的html目录下的index.html文件,查看变化将静态资源部署到nginx的html目录 1.2、修改html目录下的index.html文件,查看变化 因为docker运用得最小化系统环境,解决办法就…...
项目实战 ---- 商用落地视频搜索系统(8)---优化(2)---查询逻辑层优化
目录 背景 技术衡量与方案 一种可实现方案 可实现方案及设计描述 可能存在的问题 一种创新实现方案 方案的改良设计 策略公式 优化的实现 完整代码 代码解释 异常场景的考量 处理方式 运行注意事项 运行结果 结果优化对比与解释 背景 在项目实战 ---- 商用落地…...

山东大学机试试题合集
🍰🍰🍰高分篇已经涵盖了绝大多数的机试考点,由于临近预推免,各校的机试蜂拥而至,我们接下来先更一些各高校机试题合集,算是对前边学习成果的深入学习,也是对我们代码能力的锻炼。加油…...

餐厅食品留样管理系统小程序的设计
管理员账户功能包括:系统首页,个人中心,窗口负责人管理,窗口员工管理,冰柜管理,排班信息管理,留样食品管理,教育宣传管理,系统管理 微信端账号功能包括:系统…...
亚马逊运营:如何提高亚马逊销量和运营效率?
不少亚马逊卖家们为了扩大业务规模和提高销量,会创建多个卖家账户来同时运营多个亚马逊店铺。问题是,这种多店铺运营模式并非没有风险——亚马逊运营的一个重要方面就是账户的健康管理。一旦某个账户出现问题,亚马逊的算法就可能会启动关联检…...

使用分级同态加密防御梯度泄漏
抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...

【大模型RAG】Docker 一键部署 Milvus 完整攻略
本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装;只需暴露 19530(gRPC)与 9091(HTTP/WebUI)两个端口,即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...
基于Uniapp开发HarmonyOS 5.0旅游应用技术实践
一、技术选型背景 1.跨平台优势 Uniapp采用Vue.js框架,支持"一次开发,多端部署",可同步生成HarmonyOS、iOS、Android等多平台应用。 2.鸿蒙特性融合 HarmonyOS 5.0的分布式能力与原子化服务,为旅游应用带来…...
电脑插入多块移动硬盘后经常出现卡顿和蓝屏
当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

ios苹果系统,js 滑动屏幕、锚定无效
现象:window.addEventListener监听touch无效,划不动屏幕,但是代码逻辑都有执行到。 scrollIntoView也无效。 原因:这是因为 iOS 的触摸事件处理机制和 touch-action: none 的设置有关。ios有太多得交互动作,从而会影响…...

论文笔记——相干体技术在裂缝预测中的应用研究
目录 相关地震知识补充地震数据的认识地震几何属性 相干体算法定义基本原理第一代相干体技术:基于互相关的相干体技术(Correlation)第二代相干体技术:基于相似的相干体技术(Semblance)基于多道相似的相干体…...

论文阅读:LLM4Drive: A Survey of Large Language Models for Autonomous Driving
地址:LLM4Drive: A Survey of Large Language Models for Autonomous Driving 摘要翻译 自动驾驶技术作为推动交通和城市出行变革的催化剂,正从基于规则的系统向数据驱动策略转变。传统的模块化系统受限于级联模块间的累积误差和缺乏灵活性的预设规则。…...
加密通信 + 行为分析:运营商行业安全防御体系重构
在数字经济蓬勃发展的时代,运营商作为信息通信网络的核心枢纽,承载着海量用户数据与关键业务传输,其安全防御体系的可靠性直接关乎国家安全、社会稳定与企业发展。随着网络攻击手段的不断升级,传统安全防护体系逐渐暴露出局限性&a…...

PLC入门【4】基本指令2(SET RST)
04 基本指令2 PLC编程第四课基本指令(2) 1、运用上接课所学的基本指令完成个简单的实例编程。 2、学习SET--置位指令 3、RST--复位指令 打开软件(FX-TRN-BEG-C),从 文件 - 主画面,“B: 让我们学习基本的”- “B-3.控制优先程序”。 点击“梯形图编辑”…...

RKNN开发环境搭建2-RKNN Model Zoo 环境搭建
目录 1.简介2.环境搭建2.1 启动 docker 环境2.2 安装依赖工具2.3 下载 RKNN Model Zoo2.4 RKNN模型转化2.5编译C++1.简介 RKNN Model Zoo基于 RKNPU SDK 工具链开发, 提供了目前主流算法的部署例程. 例程包含导出RKNN模型, 使用 Python API, CAPI 推理 RKNN 模型的流程. 本…...