当前位置: 首页 > news >正文

构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

引言

在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经网络(CNN),并在CIFAR-10数据集上进行训练和评估。通过本文,您将了解到数据预处理、模型定义、训练过程及结果可视化的完整流程。
在这里插入图片描述

数据预处理

首先,我们需要加载并预处理CIFAR-10数据集。CIFAR-10包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。我们使用torchvision库来轻松加载这些数据,并应用一些基本的变换,如归一化。

import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到[-1, 1]
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
模型定义

接下来,我们定义一个简单的卷积神经网络。该网络包含三个卷积层,两个池化层,以及两个全连接层。

import torch.nn as nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.fc1 = nn.Linear(64 * 8 * 8, 64)  # 考虑到池化层后的尺寸self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.relu(self.conv3(x))x = x.view(-1, 64 * 8 * 8)  # flattenx = torch.relu(self.fc1(x))x = self.fc2(x)return xnet = ConvNet()
训练过程

我们使用Adam优化器和交叉熵损失函数来训练模型,并将模型训练10个epoch。训练过程中,我们记录每个epoch的平均损失。

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)num_epochs = 10
loss_history = []  # 记录每个epoch的平均损失
for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100}')running_loss = 0.0epoch_loss = running_loss / len(trainloader)loss_history.append(epoch_loss)print('Finished Training')
模型评估

训练完成后,我们在测试集上评估模型的性能,并计算准确率。

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()final_accuracy = 100 * correct / totalprint(f'Accuracy of the network on the 10000 test images: {final_accuracy} %')
结果可视化

最后,我们将训练过程中的损失和最终的准确率进行可视化,以便更直观地了解模型的训练效果。

import matplotlib.pyplot as plt# 可视化损失
plt.plot(range(1, num_epochs + 1), loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss History')
plt.show()# 可视化准确率
plt.bar(1, final_accuracy, width=0.4, label='Final Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Final Accuracy on Test Set')
plt.legend()
plt.show()
结论

本文介绍了如何使用PyTorch构建并训练一个简单的卷积神经网络对CIFAR-10数据集进行分类。通过数据预处理、模型定义、训练及结果可视化,我们完整地展示了深度学习项目的流程。希望本文能为您提供一些有用的参考和启发,帮助您在自己的深度学习项目中取得更好的成果。

相关文章:

构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类 引言 在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经…...

flowable 根据xml 字符串生成流程图

//获取xml InputStream stream repositoryService.getProcessModel(processDefinitionId); String result IOUtils.toString(stream, StandardCharsets.UTF_8); // 创建 XMLInputFactory XMLInputFactory factory XMLInputFactory.newInstance(); // 从字符…...

AI建模——AI生成3D内容算法产品介绍与模型免费下载

说明: 记录AI文生3D模型、图生3D模型的相关产品;记录其性能、功能、收费与免费方法 0.AI建模产品 Rodin MeshAnything Meshy 生成效果比较: Rodin效果最好、Meshy其次 1.Rodin 官网:gHyperHuman 支持:文生模型、…...

在Go中迅速使用RabbitMQ

文章目录 1 认识1.1 MQ分类1.2 安装1.3 基本流程 2 [Work模型](https://www.rabbitmq.com/tutorials/tutorial-two-go#preparation)3 交换机3.1 fanout3.2 direct3.3 [topic](https://www.rabbitmq.com/tutorials/tutorial-five-go) 4 Golang创建交换机/队列/Publish/Consume/B…...

Windows JDK安装详细教程

一、关于JDK 1.1 简介 Java是一种广泛使用的计算机编程语言,拥有跨平台、面向对象、泛型编程的特性,广泛应用于企业级Web应用开发和移动应用开发。 JDK(Java Development Kit)是用于开发 Java 应用程序的工具包。它由以下几个主要…...

Ribbon负载均衡底层原理

springcloude服务实例与服务实例之间发送请求,首先根据服务名注册到nacos,然后发送请求,nacos可以根据服务名找到对应的服务实例。 SpringCloudRibbon的底层采用了一个拦截器,拦截了openfeign发出的请求,对地址做了修…...

【C语言可变参数函数的使用与原理分析】

文章目录 1 前言2 实例2.1实例程序2.2程序执行结果2.3 程序分析 3 补充4 总结 1 前言 在编程过程中,有时会遇到需要定义参数数量不固定的函数的情况。 C语言提供了一种灵活的解决方案:变参函数。这种函数能够根据实际调用时的需求,接受任意…...

【笔记】Java EE应用开发环境配置(JDK+Maven+Tomcat+MySQL+IDEA)

一、安装JDK17 1.下载JDK17 https://download.oracle.com/java/17/archive/jdk-17.0.7_windows-x64_bin.zip 2.配置环境变量 下载后,解压到本地(目录中最好不要有中文或特殊字符) 打开【控制面板】-【系统和安全】-【系统】-【高级系统…...

一文讲懂扩散模型

一文讲懂扩散模型 扩散模型(Diffusion Models, DM)是近年来在计算机视觉、自然语言处理等领域取得显著进展的一种生成模型。其思想根源可以追溯到非平衡热力学,通过模拟数据的扩散和去噪过程来生成新的样本。以下将详细阐述扩散模型的基本原理…...

学习笔记八:基于Jenkins+k8s+Git+DockerHub等技术链构建企业级DevOps容器云平台

基于Jenkinsk8sGitDockerHub等技术链构建企业级DevOps容器云平台 测试jenkins的CI/CD在Jenkins中安装kubernetes插件安装blueocean插件配置jenkins连接到我们存在的k8s集群配置pod-template添加自己的dockerhub凭据测试通过Jenkins部署应用发布到k8s开发环境、测试环境、生产环…...

科研绘图系列:R语言柱状图分布(histogram plot)

文章目录 介绍加载R包读取数据画图介绍 柱状图(Bar Chart)是一种常用的数据可视化图表,用于展示和比较不同类别或组的数据。它通过在二维平面上绘制一系列垂直或水平的柱子来表示数据的大小,每个柱子的长度或高度代表一个数据点的数值。柱状图非常适合于展示分类数据的分布…...

vue3+ts封装类似于微信消息的组件

组件代码如下&#xff1a; <template><div:class"[voice-message, { sent: isSent, received: !isSent }]":style"{ backgroundColor: backgroundColor }"click"togglePlayback"><!-- isSent为false在左侧&#xff0c;为true在右…...

ES6 reduce方法详解:示例、应用场景与实用技巧

在JavaScript中&#xff0c;reduce 方法是一个非常强大的数组方法&#xff0c;它允许你将数组中的元素归并&#xff08;reduce&#xff09;为单个值。reduce 方法执行一个由你提供的reducer函数&#xff08;归并函数&#xff09;&#xff0c;将其结果汇总为单一的返回值。 一.…...

java后端保存的本地图片通过ip+端口直接访问

直接上代码吧 package com.ydx.emms.datapro.controller;import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.…...

2024 年高教社杯全国大学生数学建模竞赛B题4小问解题思路(第二版)

原文链接&#xff1a;https://www.cnblogs.com/qimoxuan/articles/18399415 问题 1&#xff1a;抽样检测方案设计 详细解题思路&#xff1a; 确定抽样检测目标&#xff1a;企业需要确定一个可接受的次品率上限&#xff08;标称值&#xff09;&#xff0c;以及在该次品率下&am…...

docker-nginx数据卷挂载

一、案例1-利用Nginx容器部署静态资源 1.1、需求: 创建Nginx容器&#xff0c; 修改nginx容器内的html目录下的index.html文件,查看变化将静态资源部署到nginx的html目录 1.2、修改html目录下的index.html文件,查看变化 因为docker运用得最小化系统环境&#xff0c;解决办法就…...

项目实战 ---- 商用落地视频搜索系统(8)---优化(2)---查询逻辑层优化

目录 背景 技术衡量与方案 一种可实现方案 可实现方案及设计描述 可能存在的问题 一种创新实现方案 方案的改良设计 策略公式 优化的实现 完整代码 代码解释 异常场景的考量 处理方式 运行注意事项 运行结果 结果优化对比与解释 背景 在项目实战 ---- 商用落地…...

山东大学机试试题合集

&#x1f370;&#x1f370;&#x1f370;高分篇已经涵盖了绝大多数的机试考点&#xff0c;由于临近预推免&#xff0c;各校的机试蜂拥而至&#xff0c;我们接下来先更一些各高校机试题合集&#xff0c;算是对前边学习成果的深入学习&#xff0c;也是对我们代码能力的锻炼。加油…...

餐厅食品留样管理系统小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;窗口负责人管理&#xff0c;窗口员工管理&#xff0c;冰柜管理&#xff0c;排班信息管理&#xff0c;留样食品管理&#xff0c;教育宣传管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统…...

亚马逊运营:如何提高亚马逊销量和运营效率?

不少亚马逊卖家们为了扩大业务规模和提高销量&#xff0c;会创建多个卖家账户来同时运营多个亚马逊店铺。问题是&#xff0c;这种多店铺运营模式并非没有风险——亚马逊运营的一个重要方面就是账户的健康管理。一旦某个账户出现问题&#xff0c;亚马逊的算法就可能会启动关联检…...

Qwen3.5-9B-AWQ-4bit图文问答进阶:结合上下文的多图对比分析方法

Qwen3.5-9B-AWQ-4bit图文问答进阶&#xff1a;结合上下文的多图对比分析方法 1. 多图对比分析的价值与应用场景 在日常工作和生活中&#xff0c;我们经常需要比较和分析多张图片之间的异同。传统的人工对比方法耗时耗力&#xff0c;而借助Qwen3.5-9B-AWQ-4bit这样的多模态模型…...

快速入门:5步掌握OCR文字识别镜像,轻松提取图片文字

快速入门&#xff1a;5步掌握OCR文字识别镜像&#xff0c;轻松提取图片文字 1. 为什么选择这个OCR镜像 在日常工作和生活中&#xff0c;我们经常遇到需要从图片中提取文字的场景&#xff1a;扫描的文档、手机拍摄的发票、路牌标识等。传统手动输入不仅效率低下&#xff0c;还…...

位运算的技巧和演示

尝试理解并去总结...

终极Windows系统维护指南:使用Dism++轻松管理你的操作系统

终极Windows系统维护指南&#xff1a;使用Dism轻松管理你的操作系统 【免费下载链接】Dism-Multi-language Dism Multi-language Support & BUG Report 项目地址: https://gitcode.com/gh_mirrors/di/Dism-Multi-language Dism是一款强大的Windows系统维护工具&…...

OpenClaw 核心概念关系与配置指南

文章目录&#x1f3d7;️ 一、核心概念关系图&#x1f504; 二、核心概念关系详解1. Gateway&#xff08;网关&#xff09;- 控制中枢2. Agent&#xff08;智能体&#xff09;- 执行单元3. Skills&#xff08;技能&#xff09;- 功能模块4. Tools&#xff08;工具&#xff09;-…...

大模型学习第5天--python基础(练习题)

# 作业三&#xff1a;类型转换练习# 任务描述&#xff1a;# 编写一个程序&#xff0c;实现以下功能&#xff1a;# 1. 定义以下变量&#xff08;初始值都是字符串&#xff09;&#xff1a;# - 学号&#xff1a;"2024001"# - 数学成绩&#xff1a;"85"…...

Qwen-Audio歌唱语音识别效果展示

Qwen-Audio歌唱语音识别效果展示 1. 歌唱语音识别的独特挑战与突破 当我们在听一首歌时&#xff0c;大脑会自动分离出旋律、节奏、歌词和情感表达。但对AI模型来说&#xff0c;这却是个复杂得多的任务——它需要同时处理音高变化、节奏韵律、人声谐波特征&#xff0c;还要准确…...

图论(16)匈牙利算法与最优匹配算法实战解析

1. 匈牙利算法&#xff1a;偶图匹配的魔法棒 第一次听说匈牙利算法时&#xff0c;我误以为它和匈牙利这个国家有什么关系。后来才知道&#xff0c;这个算法之所以叫这个名字&#xff0c;是因为它基于匈牙利数学家Dnes Kőnig的定理。不过名字不重要&#xff0c;重要的是它确实像…...

ooderAgent 龙虾时代的统一认证体系

当 Agent 从"工具"进化为"伙伴"&#xff0c;账户体系如何重新定义人机协作的信任边界&#xff1f; ​ 协议版本&#xff1a;ooderAgent v1.0.0 | 发布日期&#xff1a;2026-04-08 | 维护团队&#xff1a;ooderAgent Team 一、引言&#xff1a;从 0.7.3 到 …...

OpenClaw+千问3.5-9B学习助手:自动整理笔记与生成习题

OpenClaw千问3.5-9B学习助手&#xff1a;自动整理笔记与生成习题 1. 为什么需要AI学习助手&#xff1f; 去年备考PMP证书时&#xff0c;我每天要处理上百页PDF讲义。最痛苦的莫过于手动整理重点和制作复习卡片——复制粘贴到半夜&#xff0c;第二天发现漏了关键图表&#xff…...