当前位置: 首页 > news >正文

构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

引言

在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经网络(CNN),并在CIFAR-10数据集上进行训练和评估。通过本文,您将了解到数据预处理、模型定义、训练过程及结果可视化的完整流程。
在这里插入图片描述

数据预处理

首先,我们需要加载并预处理CIFAR-10数据集。CIFAR-10包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。我们使用torchvision库来轻松加载这些数据,并应用一些基本的变换,如归一化。

import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到[-1, 1]
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
模型定义

接下来,我们定义一个简单的卷积神经网络。该网络包含三个卷积层,两个池化层,以及两个全连接层。

import torch.nn as nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.fc1 = nn.Linear(64 * 8 * 8, 64)  # 考虑到池化层后的尺寸self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.relu(self.conv3(x))x = x.view(-1, 64 * 8 * 8)  # flattenx = torch.relu(self.fc1(x))x = self.fc2(x)return xnet = ConvNet()
训练过程

我们使用Adam优化器和交叉熵损失函数来训练模型,并将模型训练10个epoch。训练过程中,我们记录每个epoch的平均损失。

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)num_epochs = 10
loss_history = []  # 记录每个epoch的平均损失
for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100}')running_loss = 0.0epoch_loss = running_loss / len(trainloader)loss_history.append(epoch_loss)print('Finished Training')
模型评估

训练完成后,我们在测试集上评估模型的性能,并计算准确率。

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()final_accuracy = 100 * correct / totalprint(f'Accuracy of the network on the 10000 test images: {final_accuracy} %')
结果可视化

最后,我们将训练过程中的损失和最终的准确率进行可视化,以便更直观地了解模型的训练效果。

import matplotlib.pyplot as plt# 可视化损失
plt.plot(range(1, num_epochs + 1), loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss History')
plt.show()# 可视化准确率
plt.bar(1, final_accuracy, width=0.4, label='Final Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Final Accuracy on Test Set')
plt.legend()
plt.show()
结论

本文介绍了如何使用PyTorch构建并训练一个简单的卷积神经网络对CIFAR-10数据集进行分类。通过数据预处理、模型定义、训练及结果可视化,我们完整地展示了深度学习项目的流程。希望本文能为您提供一些有用的参考和启发,帮助您在自己的深度学习项目中取得更好的成果。

相关文章:

构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类 引言 在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经…...

flowable 根据xml 字符串生成流程图

//获取xml InputStream stream repositoryService.getProcessModel(processDefinitionId); String result IOUtils.toString(stream, StandardCharsets.UTF_8); // 创建 XMLInputFactory XMLInputFactory factory XMLInputFactory.newInstance(); // 从字符…...

AI建模——AI生成3D内容算法产品介绍与模型免费下载

说明: 记录AI文生3D模型、图生3D模型的相关产品;记录其性能、功能、收费与免费方法 0.AI建模产品 Rodin MeshAnything Meshy 生成效果比较: Rodin效果最好、Meshy其次 1.Rodin 官网:gHyperHuman 支持:文生模型、…...

在Go中迅速使用RabbitMQ

文章目录 1 认识1.1 MQ分类1.2 安装1.3 基本流程 2 [Work模型](https://www.rabbitmq.com/tutorials/tutorial-two-go#preparation)3 交换机3.1 fanout3.2 direct3.3 [topic](https://www.rabbitmq.com/tutorials/tutorial-five-go) 4 Golang创建交换机/队列/Publish/Consume/B…...

Windows JDK安装详细教程

一、关于JDK 1.1 简介 Java是一种广泛使用的计算机编程语言,拥有跨平台、面向对象、泛型编程的特性,广泛应用于企业级Web应用开发和移动应用开发。 JDK(Java Development Kit)是用于开发 Java 应用程序的工具包。它由以下几个主要…...

Ribbon负载均衡底层原理

springcloude服务实例与服务实例之间发送请求,首先根据服务名注册到nacos,然后发送请求,nacos可以根据服务名找到对应的服务实例。 SpringCloudRibbon的底层采用了一个拦截器,拦截了openfeign发出的请求,对地址做了修…...

【C语言可变参数函数的使用与原理分析】

文章目录 1 前言2 实例2.1实例程序2.2程序执行结果2.3 程序分析 3 补充4 总结 1 前言 在编程过程中,有时会遇到需要定义参数数量不固定的函数的情况。 C语言提供了一种灵活的解决方案:变参函数。这种函数能够根据实际调用时的需求,接受任意…...

【笔记】Java EE应用开发环境配置(JDK+Maven+Tomcat+MySQL+IDEA)

一、安装JDK17 1.下载JDK17 https://download.oracle.com/java/17/archive/jdk-17.0.7_windows-x64_bin.zip 2.配置环境变量 下载后,解压到本地(目录中最好不要有中文或特殊字符) 打开【控制面板】-【系统和安全】-【系统】-【高级系统…...

一文讲懂扩散模型

一文讲懂扩散模型 扩散模型(Diffusion Models, DM)是近年来在计算机视觉、自然语言处理等领域取得显著进展的一种生成模型。其思想根源可以追溯到非平衡热力学,通过模拟数据的扩散和去噪过程来生成新的样本。以下将详细阐述扩散模型的基本原理…...

学习笔记八:基于Jenkins+k8s+Git+DockerHub等技术链构建企业级DevOps容器云平台

基于Jenkinsk8sGitDockerHub等技术链构建企业级DevOps容器云平台 测试jenkins的CI/CD在Jenkins中安装kubernetes插件安装blueocean插件配置jenkins连接到我们存在的k8s集群配置pod-template添加自己的dockerhub凭据测试通过Jenkins部署应用发布到k8s开发环境、测试环境、生产环…...

科研绘图系列:R语言柱状图分布(histogram plot)

文章目录 介绍加载R包读取数据画图介绍 柱状图(Bar Chart)是一种常用的数据可视化图表,用于展示和比较不同类别或组的数据。它通过在二维平面上绘制一系列垂直或水平的柱子来表示数据的大小,每个柱子的长度或高度代表一个数据点的数值。柱状图非常适合于展示分类数据的分布…...

vue3+ts封装类似于微信消息的组件

组件代码如下&#xff1a; <template><div:class"[voice-message, { sent: isSent, received: !isSent }]":style"{ backgroundColor: backgroundColor }"click"togglePlayback"><!-- isSent为false在左侧&#xff0c;为true在右…...

ES6 reduce方法详解:示例、应用场景与实用技巧

在JavaScript中&#xff0c;reduce 方法是一个非常强大的数组方法&#xff0c;它允许你将数组中的元素归并&#xff08;reduce&#xff09;为单个值。reduce 方法执行一个由你提供的reducer函数&#xff08;归并函数&#xff09;&#xff0c;将其结果汇总为单一的返回值。 一.…...

java后端保存的本地图片通过ip+端口直接访问

直接上代码吧 package com.ydx.emms.datapro.controller;import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.…...

2024 年高教社杯全国大学生数学建模竞赛B题4小问解题思路(第二版)

原文链接&#xff1a;https://www.cnblogs.com/qimoxuan/articles/18399415 问题 1&#xff1a;抽样检测方案设计 详细解题思路&#xff1a; 确定抽样检测目标&#xff1a;企业需要确定一个可接受的次品率上限&#xff08;标称值&#xff09;&#xff0c;以及在该次品率下&am…...

docker-nginx数据卷挂载

一、案例1-利用Nginx容器部署静态资源 1.1、需求: 创建Nginx容器&#xff0c; 修改nginx容器内的html目录下的index.html文件,查看变化将静态资源部署到nginx的html目录 1.2、修改html目录下的index.html文件,查看变化 因为docker运用得最小化系统环境&#xff0c;解决办法就…...

项目实战 ---- 商用落地视频搜索系统(8)---优化(2)---查询逻辑层优化

目录 背景 技术衡量与方案 一种可实现方案 可实现方案及设计描述 可能存在的问题 一种创新实现方案 方案的改良设计 策略公式 优化的实现 完整代码 代码解释 异常场景的考量 处理方式 运行注意事项 运行结果 结果优化对比与解释 背景 在项目实战 ---- 商用落地…...

山东大学机试试题合集

&#x1f370;&#x1f370;&#x1f370;高分篇已经涵盖了绝大多数的机试考点&#xff0c;由于临近预推免&#xff0c;各校的机试蜂拥而至&#xff0c;我们接下来先更一些各高校机试题合集&#xff0c;算是对前边学习成果的深入学习&#xff0c;也是对我们代码能力的锻炼。加油…...

餐厅食品留样管理系统小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;窗口负责人管理&#xff0c;窗口员工管理&#xff0c;冰柜管理&#xff0c;排班信息管理&#xff0c;留样食品管理&#xff0c;教育宣传管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统…...

亚马逊运营:如何提高亚马逊销量和运营效率?

不少亚马逊卖家们为了扩大业务规模和提高销量&#xff0c;会创建多个卖家账户来同时运营多个亚马逊店铺。问题是&#xff0c;这种多店铺运营模式并非没有风险——亚马逊运营的一个重要方面就是账户的健康管理。一旦某个账户出现问题&#xff0c;亚马逊的算法就可能会启动关联检…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”&#xff0c;无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息&#xff1a; 关注测试号&#xff1a;扫二维码关注测试号。 发送模版消息&#xff1a; import requests da…...

Linux 文件类型,目录与路径,文件与目录管理

文件类型 后面的字符表示文件类型标志 普通文件&#xff1a;-&#xff08;纯文本文件&#xff0c;二进制文件&#xff0c;数据格式文件&#xff09; 如文本文件、图片、程序文件等。 目录文件&#xff1a;d&#xff08;directory&#xff09; 用来存放其他文件或子目录。 设备…...

React Native 导航系统实战(React Navigation)

导航系统实战&#xff08;React Navigation&#xff09; React Navigation 是 React Native 应用中最常用的导航库之一&#xff0c;它提供了多种导航模式&#xff0c;如堆栈导航&#xff08;Stack Navigator&#xff09;、标签导航&#xff08;Tab Navigator&#xff09;和抽屉…...

YSYX学习记录(八)

C语言&#xff0c;练习0&#xff1a; 先创建一个文件夹&#xff0c;我用的是物理机&#xff1a; 安装build-essential 练习1&#xff1a; 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件&#xff0c;随机修改或删除一部分&#xff0c;之后…...

Spring Boot面试题精选汇总

&#x1f91f;致敬读者 &#x1f7e9;感谢阅读&#x1f7e6;笑口常开&#x1f7ea;生日快乐⬛早点睡觉 &#x1f4d8;博主相关 &#x1f7e7;博主信息&#x1f7e8;博客首页&#x1f7eb;专栏推荐&#x1f7e5;活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...

Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)

参考官方文档&#xff1a;https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java&#xff08;供 Kotlin 使用&#xff09; 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...

Reasoning over Uncertain Text by Generative Large Language Models

https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829 1. 概述 文本中的不确定性在许多语境中传达,从日常对话到特定领域的文档(例如医学文档)(Heritage 2013;Landmark、Gulbrandsen 和 Svenevei…...

面向无人机海岸带生态系统监测的语义分割基准数据集

描述&#xff1a;海岸带生态系统的监测是维护生态平衡和可持续发展的重要任务。语义分割技术在遥感影像中的应用为海岸带生态系统的精准监测提供了有效手段。然而&#xff0c;目前该领域仍面临一个挑战&#xff0c;即缺乏公开的专门面向海岸带生态系统的语义分割基准数据集。受…...

JavaScript 数据类型详解

JavaScript 数据类型详解 JavaScript 数据类型分为 原始类型&#xff08;Primitive&#xff09; 和 对象类型&#xff08;Object&#xff09; 两大类&#xff0c;共 8 种&#xff08;ES11&#xff09;&#xff1a; 一、原始类型&#xff08;7种&#xff09; 1. undefined 定…...

【 java 虚拟机知识 第一篇 】

目录 1.内存模型 1.1.JVM内存模型的介绍 1.2.堆和栈的区别 1.3.栈的存储细节 1.4.堆的部分 1.5.程序计数器的作用 1.6.方法区的内容 1.7.字符串池 1.8.引用类型 1.9.内存泄漏与内存溢出 1.10.会出现内存溢出的结构 1.内存模型 1.1.JVM内存模型的介绍 内存模型主要分…...