PyTorch 系列教程:使用CNN实现图像分类
图像分类是计算机视觉领域的一项基本任务,也是深度学习技术的一个常见应用。近年来,卷积神经网络(cnn)和PyTorch库的结合由于其易用性和鲁棒性已经成为执行图像分类的流行选择。
理解卷积神经网络(cnn)
卷积神经网络是一类深度神经网络,对分析视觉图像特别有效。他们利用多层构建一个可以直接从图像中识别模式的模型。这些模型对于图像识别和分类等任务特别有用,因为它们不需要手动提取特征。
cnn的关键组成部分
- 卷积层:这些层对输入应用卷积操作,将结果传递给下一层。每个过滤器(或核)可以捕获不同的特征,如边缘、角或其他模式。
- 池化层:这些层减少了表示的空间大小,以减少参数的数量并加快计算速度。池化层简化了后续层的处理。
- 完全连接层:在这些层中,神经元与前一层的所有激活具有完全连接,就像传统的神经网络一样。它们有助于对前一层识别的对象进行分类。

使用PyTorch进行图像分类
PyTorch是开源的深度学习库,提供了极大的灵活性和多功能性。研究人员和从业人员广泛使用它来轻松有效地实现尖端的机器学习模型。
设置PyTorch
首先,确保在开发环境中安装了PyTorch。你可以通过pip安装它:
pip install torch torchvision
用PyTorch创建简单的CNN示例
下面是如何定义简单的CNN来使用PyTorch对图像进行分类的示例。
import torch
import torch.nn as nn
import torch.nn.functional as F# 定义CNN模型(修复了变量引用问题)
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5) # 第一个卷积层:3输入通道,6输出通道,5x5卷积核self.pool = nn.MaxPool2d(2, 2) # 最大池化层:2x2窗口,步长2self.conv2 = nn.Conv2d(6, 16, 5) # 第二个卷积层:6输入通道,16输出通道,5x5卷积核self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全连接层1:400输入 -> 120输出self.fc2 = nn.Linear(120, 84) # 全连接层2:120输入 -> 84输出self.fc3 = nn.Linear(84, 10) # 输出层:84输入 -> 10类 logitsdef forward(self, x):# 输入形状:[batch_size, 3, 32, 32]x = self.pool(F.relu(self.conv1(x))) # -> [batch, 6, 14, 14](池化后尺寸减半)x = self.pool(F.relu(self.conv2(x))) # -> [batch, 16, 5, 5] x = x.view(-1, 16 * 5 * 5) # 展平为一维向量:16 * 5 * 5=400x = F.relu(self.fc1(x)) # -> [batch, 120]x = F.relu(self.fc2(x)) # -> [batch, 84]x = self.fc3(x) # -> [batch, 10](未应用softmax,配合CrossEntropyLoss使用)return x
这个特殊的网络接受一个输入图像,通过两组卷积和池化层,然后是三个完全连接的层。根据数据集的复杂性和大小调整网络的架构和超参数。
模型定义:
SimpleCNN继承自nn.Module- 使用两个卷积层提取特征,三个全连接层进行分类
- 最终输出未应用 softmax,而是直接输出 logits(与
CrossEntropyLoss配合使用)
训练网络
对于训练,你需要一个数据集。PyTorch通过torchvision包提供了用于数据加载和预处理的实用程序。
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader# 初始化模型、损失函数和优化器
net = SimpleCNN() # 实例化模型
criterion = nn.CrossEntropyLoss() # 使用交叉熵损失函数(自动处理softmax)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, # 学习率momentum=0.9) # 动量参数# 数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 加载CIFAR-10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, # 自动下载数据集transform=transform
)trainloader = DataLoader(trainset, batch_size=4, # 每个batch包含4张图像shuffle=True) # 打乱数据顺序
模型配置:
- 损失函数:
CrossEntropyLoss(自动包含 softmax 和 log_softmax) - 优化器:SGD with momentum,学习率 0.001
数据加载:
-
使用
torchvision.datasets.CIFAR10加载数据集 -
batch_size:4(根据 GPU 内存调整,CIFAR-10 建议 batch size ≥ 32)
-
transforms.Compose定义数据预处理流程:ToTensor():将图像转换为 PyTorch TensorNormalize():标准化图像像素值到 [-1, 1]
加载数据后,训练过程包括通过数据集进行多次迭代,使用反向传播和合适的损失函数:
# 训练循环
for epoch in range(2): # 进行2个epoch的训练running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data# 前向传播outputs = net(inputs)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad() # 清空梯度loss.backward() # 计算梯度optimizer.step() # 更新参数running_loss += loss.item()# 每2000个batch打印一次if i % 2000 == 1999:avg_loss = running_loss / 2000print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')running_loss = 0.0print("训练完成!")
训练循环:
- epoch:完整遍历数据集一次
- batch:数据加载器中的一个批次
- 梯度清零:每次反向传播前需要清空梯度
- 损失计算:
outputs的形状为[batch_size, 10],labels为整数标签
完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader# 定义CNN模型(修复了变量引用问题)
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5) # 第一个卷积层:3输入通道,6输出通道,5x5卷积核self.pool = nn.MaxPool2d(2, 2) # 最大池化层:2x2窗口,步长2self.conv2 = nn.Conv2d(6, 16, 5) # 第二个卷积层:6输入通道,16输出通道,5x5卷积核self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全连接层1:400输入 -> 120输出self.fc2 = nn.Linear(120, 84) # 全连接层2:120输入 -> 84输出self.fc3 = nn.Linear(84, 10) # 输出层:84输入 -> 10类 logitsdef forward(self, x):# 输入形状:[batch_size, 3, 32, 32]x = self.pool(F.relu(self.conv1(x))) # -> [batch, 6, 14, 14](池化后尺寸减半)x = self.pool(F.relu(self.conv2(x))) # -> [batch, 16, 5, 5] x = x.view(-1, 16 * 5 * 5) # 展平为一维向量:16 * 5 * 5=400x = F.relu(self.fc1(x)) # -> [batch, 120]x = F.relu(self.fc2(x)) # -> [batch, 84]x = self.fc3(x) # -> [batch, 10](未应用softmax,配合CrossEntropyLoss使用)return x# 初始化模型、损失函数和优化器
net = SimpleCNN() # 实例化模型
criterion = nn.CrossEntropyLoss() # 使用交叉熵损失函数(自动处理softmax)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, # 学习率momentum=0.9) # 动量参数# 数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, # 自动下载数据集transform=transform
)
trainloader = DataLoader(trainset, batch_size=4, # 每个batch包含4张图像shuffle=True) # 打乱数据顺序# 训练循环
for epoch in range(2): # 进行2个epoch的训练running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data# 前向传播outputs = net(inputs)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad() # 清空梯度loss.backward() # 计算梯度optimizer.step() # 更新参数running_loss += loss.item()# 每2000个batch打印一次if i % 2000 == 1999:avg_loss = running_loss / 2000print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')running_loss = 0.0print("训练完成!")
最后总结
通过PyTorch和卷积神经网络,你可以有效地处理图像分类任务。借助PyTorch的灵活性,可以根据特定的数据集和应用程序构建、训练和微调模型。示例代码仅为理论过程,实际项目中还有大量优化空间。
相关文章:
PyTorch 系列教程:使用CNN实现图像分类
图像分类是计算机视觉领域的一项基本任务,也是深度学习技术的一个常见应用。近年来,卷积神经网络(cnn)和PyTorch库的结合由于其易用性和鲁棒性已经成为执行图像分类的流行选择。 理解卷积神经网络(cnn) 卷…...
Docker下ARM64架构的源码编译Qt5.15.1,并移植到开发板上
Docker下ARM64架构的源码编译Qt5.15.1,并移植到开发板上 1、环境介绍 QT版本:5.15.1 待移植环境: jetson nano 系列开发板 aarch64架构(arm64) 编译环境: 虚拟机Ubuntu18.04(x86_64) 2、…...
Java 大视界 -- Java 大数据中的数据可视化大屏设计与开发实战(127)
💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…...
starrocks批量启停脚本
#!/bin/bash # 定义 StarRocks 安装目录 STARROCKS_HOME"/path/to/starrocks" # 定义 FE 和 BE 节点列表 FE_NODES("fe_node1_ip" "fe_node2_ip" "fe_node3_ip") BE_NODES("be_node1_ip" "be_node2_ip" "be_…...
「Unity3D」UGUI将元素固定在,距离屏幕边缘的某个比例,以及保持元素自身比例
在不同分辨率的屏幕下,UI元素按照自身像素大小,会发生位置与比例的变化,本文仅利用锚点(Anchors)使用,来实现UI元素,固定在某个比例距离的屏幕边缘。 首先,将元素的锚点设置为中心&…...
4.3 数组和集合的初始及赋值
版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的 版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商…...
Deep research深度研究:ChatGPT/ Gemini/ Perplexity/ Grok哪家最强?(实测对比分析)
目前推出深度研究和深度检索的AI大模型有四家: OpenAI和Gemini 的deep research,以及Perplexity 和Grok的deep search,都能生成带参考文献引用的主题报告。 致力于“几分钟之内生成一份完整的主题调研报告,解决人力几小时甚至几天…...
关于sqlalchemy的ORM的使用
关于sqlalchemy的ORM的使用 二、创建表三、使用数据表、查询记录三、批量插入数据四、关于with...as...:的使用 二、创建表 使用Mapped来映射字段 from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker,Mapped,mapped_columnBa…...
【leetcode hot 100 148】排序序列
解法一:(双重循环)第一个循环head,逐步将head的node加入有序列表;第二个循环在有序列表中找到合适的位置,插入node。 /*** Definition for singly-linked list.* public class ListNode {* int val;* …...
3-001:MySQL 中的回表是什么?
1. 什么是回表? 回表(Back to Table) 指的是 在使用非聚簇索引(辅助索引)查询时,MySQL 需要 先通过索引找到主键 ID,然后再回到主键索引(聚簇索引)查询完整数据…...
单片机设计暖脚器研究
标题:单片机设计暖脚器研究 内容:1.摘要 本文聚焦于基于单片机设计暖脚器的研究。背景方面,在寒冷季节,暖脚器能有效改善脚部寒冷状况,提升人们的舒适度,但传统暖脚器存在功能单一、温控不准确等问题。目的是设计一款智能、高效且…...
【Linux】在VMWare中安装Ubuntu操作系统(2025最新_Ubuntu 24.04.2)#VMware安装Ubuntu实战分享#
今天田辛老师为大家带来一篇关于在VMWare虚拟机上安装Ubuntu系统的详细教程。无论是学习、开发还是测试,虚拟机都是一个非常实用的工具,它允许我们在同一台物理机上运行多个操作系统。Ubuntu作为一款开源、免费且用户友好的Linux发行版,深受广…...
AutoGen学习笔记系列(十三)Advanced - Logging
这篇文章瞄的是AutoGen官方教学文档 Advanced 章节中的 Logging 篇章,介绍了怎样在使用过程中添加日志信息,其实就是使用了python自带的日志库 logging。 官网链接:https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-g…...
scrcpy pc机远程 无线 控制android app 查看调试log
背景: 公司的安卓机,是那种大屏幕的连接usb外设的。不好挪动,占地方,不能直接连接pc机上的android stduio来调试。 所以从网上找了一个python adb.exe控制器,可以局域网内远程控制开发的app,并在android stduio上看…...
UE5.5 Niagara发射器更新属性
发射器属性 在 Niagara 里,Emitter 负责控制粒子生成的规则和行为。不同的 Emitter 属性决定了如何发射粒子、粒子如何模拟、计算方式等。 发射器 本地空间(Local Space) 控制粒子是否跟随发射器(Emitter)移动。 ✅…...
深度剖析Redis:双写一致性问题及解决方案全景解析
在高并发场景下,缓存与数据库的双写一致性是每个开发者必须直面的核心挑战。本文通过5大解决方案,带你彻底攻克这一技术难关! 一、问题全景图:当缓存遇到数据库 1.1 典型问题场景 // 典型问题代码示例 public void updateProduc…...
MongoDB备份与还原
备份恢复工具介绍 1)mongoexport/mongoimport 2)mongodump/mongorestore 备份工具区别 mongoexport/mongoimport 导入/导出的是JSON格式或者CSV格式 mongodump/mongorestore 导入/导出的是BSON格式。二进制方式,速度快 1)…...
计算机:基于深度学习的Web应用安全漏洞检测与扫描
目录 前言 课题背景和意义 实现技术思路 一、算法理论基础 1.1 网络爬虫 1.2 漏洞检测 二、 数据集 三、实验及结果分析 3.1 实验环境搭建 3.2 模型训练 最后 前言 📅大四是整个大学期间最忙碌的时光,一边要忙着备考或实习为毕业后面临的就业升学做准备,…...
postgresql14编译安装脚本
#!/bin/bash####################################readme################################### #先上传postgresql源码包,再配置yum源,然后执行脚本 #备份官方yum源配置文件: #cp /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS…...
Java 大视界 -- Java 大数据在智能安防视频摘要与检索技术中的应用(128)
💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…...
部署项目至服务器:响应时间太长,无法访问此页面?
在我们部署项目到服务器上的时候,一顿操作猛如虎,打开页面..... 这里记录一下这种情况是怎么回事。一般就是服务器上的安全组没有放行端口。 因为我是用宝塔进行项目部署的。所以遇到这种情况,要去操作两边(宝塔and服务器所属平台…...
如何搭建一个适配微信小程序,h5,app的uni-app项目
在vscode搭建 uni-app 项目(Vue 3 Vite Pinia uView Plus) 一、环境准备 1. 安装 Node.js 确保已安装 Node.js(需≥14版本),可通过以下命令检查版本: node -v2. 安装 VSCode 从 VSCode 官网 下载并…...
【数据结构】List介绍
目录 1. 什么是List 2. 常见接口介绍 3. List的使用 1. 什么是List 在集合框架中,List是一个接口,继承自Collection。此时extends意为拓展 Collection也是一个接口,该接口中规范了后序容器中常用的一些方法,具体如下所示&…...
vs2022用git插件重置--删除更改(--hard)后恢复删除的内容
1、先到项目工程中打开需要恢复的分支。 2、进入代码管理根目录文件夹。 3、在根目录文件夹点右键,点git bash here 正常情况下如果git目录权限足够,是可以如上图所示显示当前分支和当前目录的。 在git权限不足的情况下会出现如下提示: …...
【C++】【数据结构】链表与线性表
线性表和链表优缺点及适用场景 线性表(以数组为例) 优点:随机访问效率高,可通过下标直接访问元素,时间复杂度为 O (1);存储密度大,内存连续存储,空间利用率高。缺点:插入…...
vscode接入DeepSeek 免费送2000 万 Tokens 解决DeepSeek无法充值问题
1. 在vscode中安装插件 Cline 2.打开硅基流动官网 3. 注册并登陆,邀请码 WpcqcXMs 4.登录后新建秘钥 5. 在vscode中配置cline (1) API Provider 选择 OpenAI Compatible ; (2) Base URL设置为 https://api.siliconflow.cn](https://api.siliconfl…...
【MySQL】用户管理和权限
欢迎拜访:雾里看山-CSDN博客 本篇主题:【MySQL】用户管理和权限 发布时间:2025.3.12 隶属专栏:MySQL 目录 引言用户用户信息创建用户语法案例 修改用户密码语法案例 删除用户语法案例 权限权限列表查看和刷新用户的权限给用户授权…...
3ds Max 快捷键分类指南(按功能划分)
以下整理了 3ds Max 常用快捷键,按核心功能模块分类,适用于 建模、动画、渲染 等全流程操作。 一、视图操作 快捷键功能Alt W最大化当前视图G隐藏/显示栅格F3线框/实体显示切换F4显示边面(实体线框)Z聚焦选中对象到视图中心Ctrl…...
npm、pnpm、cnpm、yarn、npx之间的区别
文章目录 区别特点pnpmyarncnpm 关键解读如何选择代码示例安装依赖运行命令 区别 特性npmyarnpnpmcnpmnpx核心定位Node.js 默认包管理增强稳定性与性能高效存储与严格隔离国内镜像加速工具临时执行包命令依赖存储方式扁平化 node_modules扁平化 lock 文件全局硬链接 符号链接…...
指令微调 (Instruction Tuning) 与 Prompt 工程
引言 预训练语言模型 (PLMs) 在通用语言能力方面展现出强大的潜力。然而,如何有效地引导 PLMs 遵循人类指令, 并输出符合人类意图的响应, 成为释放 PLMs 价值的关键挑战。 指令微调 (Instruction Tuning) 和 Prompt 工程 (Prompt Engineerin…...
