构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
引言
在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经网络(CNN),并在CIFAR-10数据集上进行训练和评估。通过本文,您将了解到数据预处理、模型定义、训练过程及结果可视化的完整流程。

数据预处理
首先,我们需要加载并预处理CIFAR-10数据集。CIFAR-10包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。我们使用torchvision库来轻松加载这些数据,并应用一些基本的变换,如归一化。
import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1, 1]
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
模型定义
接下来,我们定义一个简单的卷积神经网络。该网络包含三个卷积层,两个池化层,以及两个全连接层。
import torch.nn as nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.fc1 = nn.Linear(64 * 8 * 8, 64) # 考虑到池化层后的尺寸self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.relu(self.conv3(x))x = x.view(-1, 64 * 8 * 8) # flattenx = torch.relu(self.fc1(x))x = self.fc2(x)return xnet = ConvNet()
训练过程
我们使用Adam优化器和交叉熵损失函数来训练模型,并将模型训练10个epoch。训练过程中,我们记录每个epoch的平均损失。
import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)num_epochs = 10
loss_history = [] # 记录每个epoch的平均损失
for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100}')running_loss = 0.0epoch_loss = running_loss / len(trainloader)loss_history.append(epoch_loss)print('Finished Training')
模型评估
训练完成后,我们在测试集上评估模型的性能,并计算准确率。
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()final_accuracy = 100 * correct / totalprint(f'Accuracy of the network on the 10000 test images: {final_accuracy} %')
结果可视化
最后,我们将训练过程中的损失和最终的准确率进行可视化,以便更直观地了解模型的训练效果。
import matplotlib.pyplot as plt# 可视化损失
plt.plot(range(1, num_epochs + 1), loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss History')
plt.show()# 可视化准确率
plt.bar(1, final_accuracy, width=0.4, label='Final Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Final Accuracy on Test Set')
plt.legend()
plt.show()
结论
本文介绍了如何使用PyTorch构建并训练一个简单的卷积神经网络对CIFAR-10数据集进行分类。通过数据预处理、模型定义、训练及结果可视化,我们完整地展示了深度学习项目的流程。希望本文能为您提供一些有用的参考和启发,帮助您在自己的深度学习项目中取得更好的成果。
相关文章:
构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类 引言 在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经…...
flowable 根据xml 字符串生成流程图
//获取xml InputStream stream repositoryService.getProcessModel(processDefinitionId); String result IOUtils.toString(stream, StandardCharsets.UTF_8); // 创建 XMLInputFactory XMLInputFactory factory XMLInputFactory.newInstance(); // 从字符…...
AI建模——AI生成3D内容算法产品介绍与模型免费下载
说明: 记录AI文生3D模型、图生3D模型的相关产品;记录其性能、功能、收费与免费方法 0.AI建模产品 Rodin MeshAnything Meshy 生成效果比较: Rodin效果最好、Meshy其次 1.Rodin 官网:gHyperHuman 支持:文生模型、…...
在Go中迅速使用RabbitMQ
文章目录 1 认识1.1 MQ分类1.2 安装1.3 基本流程 2 [Work模型](https://www.rabbitmq.com/tutorials/tutorial-two-go#preparation)3 交换机3.1 fanout3.2 direct3.3 [topic](https://www.rabbitmq.com/tutorials/tutorial-five-go) 4 Golang创建交换机/队列/Publish/Consume/B…...
Windows JDK安装详细教程
一、关于JDK 1.1 简介 Java是一种广泛使用的计算机编程语言,拥有跨平台、面向对象、泛型编程的特性,广泛应用于企业级Web应用开发和移动应用开发。 JDK(Java Development Kit)是用于开发 Java 应用程序的工具包。它由以下几个主要…...
Ribbon负载均衡底层原理
springcloude服务实例与服务实例之间发送请求,首先根据服务名注册到nacos,然后发送请求,nacos可以根据服务名找到对应的服务实例。 SpringCloudRibbon的底层采用了一个拦截器,拦截了openfeign发出的请求,对地址做了修…...
【C语言可变参数函数的使用与原理分析】
文章目录 1 前言2 实例2.1实例程序2.2程序执行结果2.3 程序分析 3 补充4 总结 1 前言 在编程过程中,有时会遇到需要定义参数数量不固定的函数的情况。 C语言提供了一种灵活的解决方案:变参函数。这种函数能够根据实际调用时的需求,接受任意…...
【笔记】Java EE应用开发环境配置(JDK+Maven+Tomcat+MySQL+IDEA)
一、安装JDK17 1.下载JDK17 https://download.oracle.com/java/17/archive/jdk-17.0.7_windows-x64_bin.zip 2.配置环境变量 下载后,解压到本地(目录中最好不要有中文或特殊字符) 打开【控制面板】-【系统和安全】-【系统】-【高级系统…...
一文讲懂扩散模型
一文讲懂扩散模型 扩散模型(Diffusion Models, DM)是近年来在计算机视觉、自然语言处理等领域取得显著进展的一种生成模型。其思想根源可以追溯到非平衡热力学,通过模拟数据的扩散和去噪过程来生成新的样本。以下将详细阐述扩散模型的基本原理…...
学习笔记八:基于Jenkins+k8s+Git+DockerHub等技术链构建企业级DevOps容器云平台
基于Jenkinsk8sGitDockerHub等技术链构建企业级DevOps容器云平台 测试jenkins的CI/CD在Jenkins中安装kubernetes插件安装blueocean插件配置jenkins连接到我们存在的k8s集群配置pod-template添加自己的dockerhub凭据测试通过Jenkins部署应用发布到k8s开发环境、测试环境、生产环…...
科研绘图系列:R语言柱状图分布(histogram plot)
文章目录 介绍加载R包读取数据画图介绍 柱状图(Bar Chart)是一种常用的数据可视化图表,用于展示和比较不同类别或组的数据。它通过在二维平面上绘制一系列垂直或水平的柱子来表示数据的大小,每个柱子的长度或高度代表一个数据点的数值。柱状图非常适合于展示分类数据的分布…...
vue3+ts封装类似于微信消息的组件
组件代码如下: <template><div:class"[voice-message, { sent: isSent, received: !isSent }]":style"{ backgroundColor: backgroundColor }"click"togglePlayback"><!-- isSent为false在左侧,为true在右…...
ES6 reduce方法详解:示例、应用场景与实用技巧
在JavaScript中,reduce 方法是一个非常强大的数组方法,它允许你将数组中的元素归并(reduce)为单个值。reduce 方法执行一个由你提供的reducer函数(归并函数),将其结果汇总为单一的返回值。 一.…...
java后端保存的本地图片通过ip+端口直接访问
直接上代码吧 package com.ydx.emms.datapro.controller;import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.…...
2024 年高教社杯全国大学生数学建模竞赛B题4小问解题思路(第二版)
原文链接:https://www.cnblogs.com/qimoxuan/articles/18399415 问题 1:抽样检测方案设计 详细解题思路: 确定抽样检测目标:企业需要确定一个可接受的次品率上限(标称值),以及在该次品率下&am…...
docker-nginx数据卷挂载
一、案例1-利用Nginx容器部署静态资源 1.1、需求: 创建Nginx容器, 修改nginx容器内的html目录下的index.html文件,查看变化将静态资源部署到nginx的html目录 1.2、修改html目录下的index.html文件,查看变化 因为docker运用得最小化系统环境,解决办法就…...
项目实战 ---- 商用落地视频搜索系统(8)---优化(2)---查询逻辑层优化
目录 背景 技术衡量与方案 一种可实现方案 可实现方案及设计描述 可能存在的问题 一种创新实现方案 方案的改良设计 策略公式 优化的实现 完整代码 代码解释 异常场景的考量 处理方式 运行注意事项 运行结果 结果优化对比与解释 背景 在项目实战 ---- 商用落地…...
山东大学机试试题合集
🍰🍰🍰高分篇已经涵盖了绝大多数的机试考点,由于临近预推免,各校的机试蜂拥而至,我们接下来先更一些各高校机试题合集,算是对前边学习成果的深入学习,也是对我们代码能力的锻炼。加油…...
餐厅食品留样管理系统小程序的设计
管理员账户功能包括:系统首页,个人中心,窗口负责人管理,窗口员工管理,冰柜管理,排班信息管理,留样食品管理,教育宣传管理,系统管理 微信端账号功能包括:系统…...
亚马逊运营:如何提高亚马逊销量和运营效率?
不少亚马逊卖家们为了扩大业务规模和提高销量,会创建多个卖家账户来同时运营多个亚马逊店铺。问题是,这种多店铺运营模式并非没有风险——亚马逊运营的一个重要方面就是账户的健康管理。一旦某个账户出现问题,亚马逊的算法就可能会启动关联检…...
HTML 语义化
目录 HTML 语义化HTML5 新特性HTML 语义化的好处语义化标签的使用场景最佳实践 HTML 语义化 HTML5 新特性 标准答案: 语义化标签: <header>:页头<nav>:导航<main>:主要内容<article>&#x…...
OpenLayers 可视化之热力图
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 热力图(Heatmap)又叫热点图,是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...
【力扣数据库知识手册笔记】索引
索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...
iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版分享
平时用 iPhone 的时候,难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵,或者买了二手 iPhone 却被原来的 iCloud 账号锁住,这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...
页面渲染流程与性能优化
页面渲染流程与性能优化详解(完整版) 一、现代浏览器渲染流程(详细说明) 1. 构建DOM树 浏览器接收到HTML文档后,会逐步解析并构建DOM(Document Object Model)树。具体过程如下: (…...
关于 WASM:1. WASM 基础原理
一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...
大学生职业发展与就业创业指导教学评价
这里是引用 作为软工2203/2204班的学生,我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要,而您认真负责的教学态度,让课程的每一部分都充满了实用价值。 尤其让我…...
C#学习第29天:表达式树(Expression Trees)
目录 什么是表达式树? 核心概念 1.表达式树的构建 2. 表达式树与Lambda表达式 3.解析和访问表达式树 4.动态条件查询 表达式树的优势 1.动态构建查询 2.LINQ 提供程序支持: 3.性能优化 4.元数据处理 5.代码转换和重写 适用场景 代码复杂性…...
为什么要创建 Vue 实例
核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...
[特殊字符] 手撸 Redis 互斥锁那些坑
📖 手撸 Redis 互斥锁那些坑 最近搞业务遇到高并发下同一个 key 的互斥操作,想实现分布式环境下的互斥锁。于是私下顺手手撸了个基于 Redis 的简单互斥锁,也顺便跟 Redisson 的 RLock 机制对比了下,记录一波,别踩我踩过…...
