如何通过卷积神经网络(CNN)有效地提取图像的局部特征,并在CIFAR-10数据集上实现高精度的分类?
目录
1. CNN 提取图像局部特征的原理
2. 在 CIFAR - 10 数据集上实现高精度分类的步骤
2.1 数据准备
2.2 构建 CNN 模型
2.3 定义损失函数和优化器
2.4 训练模型
2.5 测试模型
3. 提高分类精度的技巧
卷积神经网络(Convolutional Neural Network, CNN)是专门为处理具有网格结构数据(如图像)而设计的深度学习模型,能够有效地提取图像的局部特征。下面将详细介绍如何通过 CNN 提取图像局部特征,并在 CIFAR - 10 数据集上实现高精度分类,同时给出基于 PyTorch 的示例代码。
1. CNN 提取图像局部特征的原理
- 卷积层:卷积层是 CNN 的核心组件,它通过使用多个卷积核(滤波器)在图像上滑动进行卷积操作。每个卷积核可以看作是一个小的矩阵,用于检测图像中的特定局部特征,如边缘、纹理等。卷积操作会生成一个特征图,特征图上的每个元素表示卷积核在对应位置检测到的特征强度。
- 局部连接:CNN 中的神经元只与输入图像的局部区域相连,而不是像全连接网络那样与所有输入神经元相连。这种局部连接方式使得网络能够专注于提取图像的局部特征,减少了参数数量,提高了计算效率。
- 权值共享:在卷积层中,同一个卷积核在整个图像上共享一组权重。这意味着卷积核在不同位置检测到的特征是相同的,进一步减少了参数数量,同时增强了网络对平移不变性的学习能力。
- 池化层:池化层通常紧跟在卷积层之后,用于对特征图进行下采样,减少特征图的尺寸,降低计算量,同时增强特征的鲁棒性。常见的池化操作有最大池化和平均池化。
2. 在 CIFAR - 10 数据集上实现高精度分类的步骤
2.1 数据准备
CIFAR - 10 数据集包含 10 个不同类别的 60000 张 32x32 彩色图像,其中训练集 50000 张,测试集 10000 张。可以使用 PyTorch 的torchvision库来加载和预处理数据。
import torch
import torchvision
import torchvision.transforms as transforms# 定义数据预处理步骤
transform = transforms.Compose([transforms.RandomCrop(32, padding=4), # 随机裁剪transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(), # 转换为张量transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
2.2 构建 CNN 模型
可以构建一个简单的 CNN 模型,包含卷积层、池化层和全连接层。
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)self.fc1 = nn.Linear(128 * 8 * 8, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.relu(self.conv2(x))x = self.pool(x)x = F.relu(self.conv3(x))x = F.relu(self.conv4(x))x = self.pool(x)x = x.view(-1, 128 * 8 * 8)x = F.relu(self.fc1(x))x = self.fc2(x)return xnet = Net()
2.3 定义损失函数和优化器
使用交叉熵损失函数和随机梯度下降(SGD)优化器。
import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
2.4 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)for epoch in range(20): # 训练20个epochrunning_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 199:print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}')running_loss = 0.0print('Finished Training')
2.5 测试模型
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
3. 提高分类精度的技巧
- 数据增强:通过随机裁剪、翻转、旋转等操作增加训练数据的多样性,提高模型的泛化能力。
- 更深的网络结构:可以使用更复杂的 CNN 架构,如 ResNet、VGG 等,这些网络通过引入残差连接、批量归一化等技术,能够更好地学习图像特征。
- 学习率调整:在训练过程中动态调整学习率,如使用学习率衰减策略,使模型在训练初期快速收敛,后期更精细地调整参数。
- 正则化:使用 L1 或 L2 正则化、Dropout 等技术防止模型过拟合。
通过以上步骤和技巧,可以有效地利用 CNN 提取图像的局部特征,并在 CIFAR - 10 数据集上实现高精度的分类。
相关文章:
如何通过卷积神经网络(CNN)有效地提取图像的局部特征,并在CIFAR-10数据集上实现高精度的分类?
目录 1. CNN 提取图像局部特征的原理 2. 在 CIFAR - 10 数据集上实现高精度分类的步骤 2.1 数据准备 2.2 构建 CNN 模型 2.3 定义损失函数和优化器 2.4 训练模型 2.5 测试模型 3. 提高分类精度的技巧 卷积神经网络(Convolutional Neural Network, CNN&#…...
Redis的持久化-RDBAOF
文章目录 一、 RDB1. 触发机制2. 流程说明3. RDB 文件的处理4. RDB 的优缺点 二、AOF1. 使用 AOF2. 命令写⼊3. 文件同步4. 重写机制5 启动时数据恢复 一、 RDB RDB 持久化是把当前进程数据生成快照保存到硬盘的过程,触发 RDB 持久化过程分为手动触发和自动触发。 …...
Redis 的几个热点知识
前言 Redis 是一款内存级的数据库,凭借其卓越的性能,几乎成为每位开发者的标配工具。 虽然 Redis 包含大量需要掌握的知识,但其中的热点知识并不多。今天,『知行』就和大家分享一些 Redis 中的热点知识。 Redis 数据结构 Redis…...
Rust WebAssembly 入门教程
一、开发环境搭建 1. 基础工具安装 # 安装 Rust curl --proto https --tlsv1.2 -sSf https://sh.rustup.rs | sh# 安装 wasm-pack cargo install wasm-pack# 安装开发服务器 cargo install basic-http-server# 安装文件监听工具 cargo install cargo-watch2. VSCode 插件安装…...
靶场之路-VulnHub-DC-6 nmap提权、kali爆破、shell反连
靶场之路-VulnHub-DC-6 一、信息收集 1、扫描靶机ip 2、指纹扫描 这里扫的我有点懵,这里只有两个端口,感觉是要扫扫目录了 nmap -sS -sV 192.168.122.128 PORT STATE SERVICE VERSION 22/tcp open ssh OpenSSH 7.4p1 Debian 10deb9u6 (protoc…...
机器视觉开发教程——封装Halcon通用模板匹配工具【含免费教程源码】
目录 引言前期准备Step1 设计可序列化的输入输出集合【不支持多线程】Step2 设计程序框架1、抽象层【IProcess】2、父类【HAlgorithm】3、子类【HFindModelTool】 Step3 设计UI结果展示 引言 通过仿照VisionPro软件二次开发Halcon的模板匹配工具,便于在客户端软件中…...
Android 中 ConstrantLayout 与 RelativeLayout 区别
ConstraintLayout 和 RelativeLayout 都是 Android 开发中常用的布局容器,它们都可以用于构建复杂的用户界面,但在功能、性能、使用方式等方面存在一些区别,下面为你详细介绍: 1. 布局原理 RelativeLayout:RelativeL…...
【3DMAX室内设计】2D转3D平面图插件2Dto3D使用方法
【一键筑梦】革新性2Dto3D插件,轻松实现2D平面图向3D空间的华丽蜕变。这款专为3DMAX室内设计师设计的神器,集一键式墙体、门、窗自动生成功能于一身,能够将2D图形无缝转化为3D网格对象(3D平面图、鸟瞰图),一…...
vscode 查看3d
目录 1. vscode-3d-preview obj查看ok 2. vscode-obj-viewer 没找到这个插件: 3. 3D Viewer for Vscode 查看obj失败 1. vscode-3d-preview obj查看ok 可以查看obj 显示过程:开始是绿屏,过了1到2秒,后来就正常看了。 2. vsc…...
自动驾驶---不依赖地图的大模型轨迹预测
1 前言 早期传统自动驾驶方案通常依赖高精地图(HD Map)提供道路结构、车道线、交通规则等信息,可参考博客《自动驾驶---方案从有图迈进无图》,本质上还是存在问题: 数据依赖性高:地图构建成本昂贵…...
perl初试
我手头有一个脚本,用于从blastp序列比对的结果文件中,进行文本处理, 获取序列比对最优的hit记录 #!/usr/bin/perl -w use strict;my ($blast_out) ARGV; my $usage "This script is to get the best hit from blast output file wit…...
VS Code C++ 开发环境配置
VS Code 是当前非常流行的开发工具. 本文讲述如何配置 VS Code 作为 C开发环境. 本文将按照如下步骤来介绍如何配置 VS Code 作为 C开发环境. 安装编译器安装插件配置工作区 第一个步骤的具体操作会因为系统不同或者方案不同而有不同的选择. 环境要求 首先需要立即 VS Code…...
Web Snapshot 网页截图 模块代码详解
本文将详细解析 Web Snapshot 模块的实现原理和关键代码。这个模块主要用于捕获网页完整截图,特别优化了对动态加载内容的处理。 1. 模块概述 snapshot.py 是一个功能完整的网页截图工具,它使用 Selenium 和 Chrome WebDriver 来模拟真实浏览器行为&am…...
Java TCP 通信:实现简单的 Echo 服务器与客户端
TCP(Transmission Control Protocol)是一种面向连接的、可靠的传输层协议。与 UDP 不同,TCP 保证了数据的顺序、可靠性和完整性,适用于需要可靠传输的应用场景,如文件传输、网页浏览等。本文将基于 Java 实现一个简单的…...
Windows 10 下 SIBR Core (i.e. 3DGS SIBR Viewers) 的编译
本文针对在 Windows 10 上从源码编译安装3DGS (3D Gaussian Splatting)的Viewers 即SIBR Core及外部依赖库extlibs(预编译的版本直接在页面https://sibr.gitlabpages.inria.fr/download.html下载) ,参考SIBR 的官方网站…...
JavaWeb-HttpServletRequest请求域接口
文章目录 HttpServletRequest请求域接口HttpServletRequest请求域接口简介关于请求域和应用域的区别 请求域接口中的相关方法获取前端请求参数(getParameter系列方法)存储请求域名参数(Attribute系列方法)获取客户端的相关地址信息获取项目的根路径 关于转发和重定向的细致剖析…...
【C++】switch 语句编译报错:error: jump to case label
/home/share/mcrockit_3588/prj_linux/../source/rkvpss.cpp: In member function ‘virtual u32 CRkVpss::Control(u32, void*, u32)’: /home/share/mcrockit_3588/prj_linux/../source/rkvpss.cpp:242:8: error: jump to case label242 | case emRkComCmd_DBG_SaveInput:|…...
防火墙虚拟系统实验
拓扑图 需求一 安全策略要求: 1、只存在一个公网IP地址,公司内网所有部门都需要借用同一个接口访问外网 2、财务部禁止访问Internet,研发部门只有部分员工可以访问Internet,行政部门全部可以访问互联网 3、为三个部门的虚拟系统分…...
点云滤波方法:特点、作用及使用场景
点云滤波是点云数据预处理的重要步骤,目的是去除噪声点、离群点等异常数据,平滑点云或提取特定频段特征,为后续的特征提取、配准、曲面重建、可视化等高阶应用打下良好基础。以下是点云中几种常见滤波方法的特点、作用及使用场景:…...
Gradle 配置 Lombok 项目并发布到私有 Maven 仓库的完整指南
Gradle 配置 Lombok 项目并发布到私有 Maven 仓库的完整指南 在 Java 项目开发中,使用 Lombok 可以极大地减少样板代码(如 getter/setter 方法、构造器等),提高开发效率。然而,当使用 Gradle 构建工具并将项目发布到私…...
ArcGIS Pro 基于基站数据生成基站扇区地图
在当今数字化的时代,地理信息系统(GIS)在各个领域都发挥着至关重要的作用。 ArcGIS Pro作为一款功能强大的GIS软件,为用户提供了丰富的工具和功能,使得数据处理、地图制作和空间分析变得更加高效和便捷。 本文将为您…...
【Python · Pytorch】Conda介绍 DGL-cuda安装
本文仅涉及DGL库介绍与cuda配置,不包含神经网络及其训练测试。 起因:博主电脑安装了 CUDA 12.4 版本,但DGL疑似没有版本支持该CUDA版本。随即想到可利用Conda创建CUDA12.1版本的虚拟环境。 1. Conda环境 1.1 Conda环境简介 Conda࿱…...
Spring AI:开启Java开发的智能新时代
目录 一、引言二、什么是 Spring AI2.1 Spring AI 的背景2.2 Spring AI 的目标 三、Spring AI 的核心组件3.1 数据处理3.2 模型训练3.3 模型部署3.4 模型监控 四、Spring AI 的核心功能4.1 支持的模型提供商与类型4.2 便携 API 与同步、流式 API 选项4.3 将 AI 模型输出映射到 …...
leetcode:2965. 找出缺失和重复的数字(python3解法)
难度:简单 给你一个下标从 0 开始的二维整数矩阵 grid,大小为 n * n ,其中的值在 [1, n2] 范围内。除了 a 出现 两次,b 缺失 之外,每个整数都 恰好出现一次 。 任务是找出重复的数字a 和缺失的数字 b 。 返回一个下标从…...
Android U 分屏——SystemUI侧处理
WMShell相关的dump命令 手机分屏启动应用后运行命令:adb shell dumpsys activity service SystemUIService WMShell 我们可以找到其中分屏的部分,如下图所示: 分屏的组成 简图 分屏是由上分屏(SideStage)、下分屏(MainStage)以及分割线组…...
面试基础---MySQL 事务隔离级别与 MVCC 深度解析
MySQL 事务隔离级别与 MVCC 深度解析:原理、实践与源码分析 引言 在高并发的互联网应用中,数据库事务的隔离级别是保证数据一致性和并发性能的关键。MySQL 通过多版本并发控制(MVCC)机制实现了不同的事务隔离级别。本文将深入探…...
第十二届蓝桥杯大学A组java省赛答案整理
货物摆放 题目描述 小蓝有一个超大的仓库,可以摆放很多货物。 现在,小蓝有 nn 箱货物要摆放在仓库,每箱货物都是规则的正方体。小蓝规定了长、宽、高三个互相垂直的方向,每箱货物的边都必须严格平行于长、宽、高。 小蓝希望所…...
浅浅初识AI、AI大模型、AGI
前记:这里只是简单了解,后面有时间会专门来扩展和深入。 当前,人工智能(AI)及其细分领域(如AI算法工程师、自然语言处理NLP、通用人工智能AGI)的就业前景呈现高速增长态势,市场需求…...
flink集成tidb cdc
Flink TiDB CDC 详解 1. TiDB CDC 简介 1.1 TiDB CDC 的核心概念 TiDB CDC 是 TiDB 提供的变更数据捕获工具,能够实时捕获 TiDB 集群中的数据变更(如 INSERT、UPDATE、DELETE 操作),并将这些变更以事件流的形式输出。TiDB CDC 的…...
【flutter】TextField输入框工具栏文本为英文解决(不用安装插件版本
输入框长按选项菜单复制、粘贴、剪切、全选部分默认为英文,对于只需要对此部分做中文本地化,不需要考虑其他语言及全局本地化的项目,可以直接自定义一个本地化代理方法进行覆盖,不需要额外下载插件 // 自定义本地化代理 class _C…...
