当前位置: 首页 > news >正文

构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

引言

在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经网络(CNN),并在CIFAR-10数据集上进行训练和评估。通过本文,您将了解到数据预处理、模型定义、训练过程及结果可视化的完整流程。
在这里插入图片描述

数据预处理

首先,我们需要加载并预处理CIFAR-10数据集。CIFAR-10包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。我们使用torchvision库来轻松加载这些数据,并应用一些基本的变换,如归一化。

import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到[-1, 1]
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
模型定义

接下来,我们定义一个简单的卷积神经网络。该网络包含三个卷积层,两个池化层,以及两个全连接层。

import torch.nn as nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.fc1 = nn.Linear(64 * 8 * 8, 64)  # 考虑到池化层后的尺寸self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.relu(self.conv3(x))x = x.view(-1, 64 * 8 * 8)  # flattenx = torch.relu(self.fc1(x))x = self.fc2(x)return xnet = ConvNet()
训练过程

我们使用Adam优化器和交叉熵损失函数来训练模型,并将模型训练10个epoch。训练过程中,我们记录每个epoch的平均损失。

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)num_epochs = 10
loss_history = []  # 记录每个epoch的平均损失
for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100}')running_loss = 0.0epoch_loss = running_loss / len(trainloader)loss_history.append(epoch_loss)print('Finished Training')
模型评估

训练完成后,我们在测试集上评估模型的性能,并计算准确率。

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()final_accuracy = 100 * correct / totalprint(f'Accuracy of the network on the 10000 test images: {final_accuracy} %')
结果可视化

最后,我们将训练过程中的损失和最终的准确率进行可视化,以便更直观地了解模型的训练效果。

import matplotlib.pyplot as plt# 可视化损失
plt.plot(range(1, num_epochs + 1), loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss History')
plt.show()# 可视化准确率
plt.bar(1, final_accuracy, width=0.4, label='Final Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Final Accuracy on Test Set')
plt.legend()
plt.show()
结论

本文介绍了如何使用PyTorch构建并训练一个简单的卷积神经网络对CIFAR-10数据集进行分类。通过数据预处理、模型定义、训练及结果可视化,我们完整地展示了深度学习项目的流程。希望本文能为您提供一些有用的参考和启发,帮助您在自己的深度学习项目中取得更好的成果。

相关文章:

构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类 引言 在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经…...

flowable 根据xml 字符串生成流程图

//获取xml InputStream stream repositoryService.getProcessModel(processDefinitionId); String result IOUtils.toString(stream, StandardCharsets.UTF_8); // 创建 XMLInputFactory XMLInputFactory factory XMLInputFactory.newInstance(); // 从字符…...

AI建模——AI生成3D内容算法产品介绍与模型免费下载

说明: 记录AI文生3D模型、图生3D模型的相关产品;记录其性能、功能、收费与免费方法 0.AI建模产品 Rodin MeshAnything Meshy 生成效果比较: Rodin效果最好、Meshy其次 1.Rodin 官网:gHyperHuman 支持:文生模型、…...

在Go中迅速使用RabbitMQ

文章目录 1 认识1.1 MQ分类1.2 安装1.3 基本流程 2 [Work模型](https://www.rabbitmq.com/tutorials/tutorial-two-go#preparation)3 交换机3.1 fanout3.2 direct3.3 [topic](https://www.rabbitmq.com/tutorials/tutorial-five-go) 4 Golang创建交换机/队列/Publish/Consume/B…...

Windows JDK安装详细教程

一、关于JDK 1.1 简介 Java是一种广泛使用的计算机编程语言,拥有跨平台、面向对象、泛型编程的特性,广泛应用于企业级Web应用开发和移动应用开发。 JDK(Java Development Kit)是用于开发 Java 应用程序的工具包。它由以下几个主要…...

Ribbon负载均衡底层原理

springcloude服务实例与服务实例之间发送请求,首先根据服务名注册到nacos,然后发送请求,nacos可以根据服务名找到对应的服务实例。 SpringCloudRibbon的底层采用了一个拦截器,拦截了openfeign发出的请求,对地址做了修…...

【C语言可变参数函数的使用与原理分析】

文章目录 1 前言2 实例2.1实例程序2.2程序执行结果2.3 程序分析 3 补充4 总结 1 前言 在编程过程中,有时会遇到需要定义参数数量不固定的函数的情况。 C语言提供了一种灵活的解决方案:变参函数。这种函数能够根据实际调用时的需求,接受任意…...

【笔记】Java EE应用开发环境配置(JDK+Maven+Tomcat+MySQL+IDEA)

一、安装JDK17 1.下载JDK17 https://download.oracle.com/java/17/archive/jdk-17.0.7_windows-x64_bin.zip 2.配置环境变量 下载后,解压到本地(目录中最好不要有中文或特殊字符) 打开【控制面板】-【系统和安全】-【系统】-【高级系统…...

一文讲懂扩散模型

一文讲懂扩散模型 扩散模型(Diffusion Models, DM)是近年来在计算机视觉、自然语言处理等领域取得显著进展的一种生成模型。其思想根源可以追溯到非平衡热力学,通过模拟数据的扩散和去噪过程来生成新的样本。以下将详细阐述扩散模型的基本原理…...

学习笔记八:基于Jenkins+k8s+Git+DockerHub等技术链构建企业级DevOps容器云平台

基于Jenkinsk8sGitDockerHub等技术链构建企业级DevOps容器云平台 测试jenkins的CI/CD在Jenkins中安装kubernetes插件安装blueocean插件配置jenkins连接到我们存在的k8s集群配置pod-template添加自己的dockerhub凭据测试通过Jenkins部署应用发布到k8s开发环境、测试环境、生产环…...

科研绘图系列:R语言柱状图分布(histogram plot)

文章目录 介绍加载R包读取数据画图介绍 柱状图(Bar Chart)是一种常用的数据可视化图表,用于展示和比较不同类别或组的数据。它通过在二维平面上绘制一系列垂直或水平的柱子来表示数据的大小,每个柱子的长度或高度代表一个数据点的数值。柱状图非常适合于展示分类数据的分布…...

vue3+ts封装类似于微信消息的组件

组件代码如下&#xff1a; <template><div:class"[voice-message, { sent: isSent, received: !isSent }]":style"{ backgroundColor: backgroundColor }"click"togglePlayback"><!-- isSent为false在左侧&#xff0c;为true在右…...

ES6 reduce方法详解:示例、应用场景与实用技巧

在JavaScript中&#xff0c;reduce 方法是一个非常强大的数组方法&#xff0c;它允许你将数组中的元素归并&#xff08;reduce&#xff09;为单个值。reduce 方法执行一个由你提供的reducer函数&#xff08;归并函数&#xff09;&#xff0c;将其结果汇总为单一的返回值。 一.…...

java后端保存的本地图片通过ip+端口直接访问

直接上代码吧 package com.ydx.emms.datapro.controller;import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.…...

2024 年高教社杯全国大学生数学建模竞赛B题4小问解题思路(第二版)

原文链接&#xff1a;https://www.cnblogs.com/qimoxuan/articles/18399415 问题 1&#xff1a;抽样检测方案设计 详细解题思路&#xff1a; 确定抽样检测目标&#xff1a;企业需要确定一个可接受的次品率上限&#xff08;标称值&#xff09;&#xff0c;以及在该次品率下&am…...

docker-nginx数据卷挂载

一、案例1-利用Nginx容器部署静态资源 1.1、需求: 创建Nginx容器&#xff0c; 修改nginx容器内的html目录下的index.html文件,查看变化将静态资源部署到nginx的html目录 1.2、修改html目录下的index.html文件,查看变化 因为docker运用得最小化系统环境&#xff0c;解决办法就…...

项目实战 ---- 商用落地视频搜索系统(8)---优化(2)---查询逻辑层优化

目录 背景 技术衡量与方案 一种可实现方案 可实现方案及设计描述 可能存在的问题 一种创新实现方案 方案的改良设计 策略公式 优化的实现 完整代码 代码解释 异常场景的考量 处理方式 运行注意事项 运行结果 结果优化对比与解释 背景 在项目实战 ---- 商用落地…...

山东大学机试试题合集

&#x1f370;&#x1f370;&#x1f370;高分篇已经涵盖了绝大多数的机试考点&#xff0c;由于临近预推免&#xff0c;各校的机试蜂拥而至&#xff0c;我们接下来先更一些各高校机试题合集&#xff0c;算是对前边学习成果的深入学习&#xff0c;也是对我们代码能力的锻炼。加油…...

餐厅食品留样管理系统小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;窗口负责人管理&#xff0c;窗口员工管理&#xff0c;冰柜管理&#xff0c;排班信息管理&#xff0c;留样食品管理&#xff0c;教育宣传管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统…...

亚马逊运营:如何提高亚马逊销量和运营效率?

不少亚马逊卖家们为了扩大业务规模和提高销量&#xff0c;会创建多个卖家账户来同时运营多个亚马逊店铺。问题是&#xff0c;这种多店铺运营模式并非没有风险——亚马逊运营的一个重要方面就是账户的健康管理。一旦某个账户出现问题&#xff0c;亚马逊的算法就可能会启动关联检…...

Docker 离线安装指南

参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性&#xff0c;不同版本的Docker对内核版本有不同要求。例如&#xff0c;Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本&#xff0c;Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

linux之kylin系统nginx的安装

一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源&#xff08;HTML/CSS/图片等&#xff09;&#xff0c;响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址&#xff0c;提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

Admin.Net中的消息通信SignalR解释

定义集线器接口 IOnlineUserHub public interface IOnlineUserHub {/// 在线用户列表Task OnlineUserList(OnlineUserList context);/// 强制下线Task ForceOffline(object context);/// 发布站内消息Task PublicNotice(SysNotice context);/// 接收消息Task ReceiveMessage(…...

【JVM】- 内存结构

引言 JVM&#xff1a;Java Virtual Machine 定义&#xff1a;Java虚拟机&#xff0c;Java二进制字节码的运行环境好处&#xff1a; 一次编写&#xff0c;到处运行自动内存管理&#xff0c;垃圾回收的功能数组下标越界检查&#xff08;会抛异常&#xff0c;不会覆盖到其他代码…...

【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力

引言&#xff1a; 在人工智能快速发展的浪潮中&#xff0c;快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型&#xff08;LLM&#xff09;。该模型代表着该领域的重大突破&#xff0c;通过独特方式融合思考与非思考…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 &#xff08;一&#xff09;多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如&#xff0c;当用户上传一张“蓝色连衣裙”的图片时&#xff0c;接口可自动提取图像中的颜色&#xff08;RGB值&…...

Linux云原生安全:零信任架构与机密计算

Linux云原生安全&#xff1a;零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言&#xff1a;云原生安全的范式革命 随着云原生技术的普及&#xff0c;安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测&#xff0c;到2025年&#xff0c;零信任架构将成为超…...

3403. 从盒子中找出字典序最大的字符串 I

3403. 从盒子中找出字典序最大的字符串 I 题目链接&#xff1a;3403. 从盒子中找出字典序最大的字符串 I 代码如下&#xff1a; class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...