使用 PyTorch 实现标准卷积神经网络(CNN)
卷积神经网络(CNN)是深度学习中的重要组成部分,广泛应用于图像处理、语音识别、视频分析等任务。在这篇博客中,我们将使用 PyTorch 实现一个标准的卷积神经网络(CNN),并介绍各个部分的作用。
什么是卷积神经网络(CNN)?
卷积神经网络(CNN)是一种专门用于处理图像数据的深度学习模型,它通过卷积层提取图像的特征。CNN 由多个层次组成,其中包括卷积层(Conv2d)、池化层(MaxPool2d)、全连接层(Linear)、激活函数(ReLU)等。这些层级合作,使得模型能够从原始图像中自动学习到重要特征。
CNN 的核心组成部分
- 卷积层(Conv2d):用于提取输入图像的局部特征,通过多个卷积核对图像进行卷积运算。
- 激活函数(ReLU):增加非线性,使得模型能够学习更复杂的特征。
- 池化层(MaxPool2d):通过对特征图进行下采样来减少空间尺寸,降低计算复杂度,同时保留重要的特征。
- 全连接层(Linear):将卷积和池化后得到的特征图展平,送入全连接层进行分类或回归预测。
PyTorch 实现 CNN
下面是我们实现的标准卷积神经网络模型。它包含三个卷积层和两个全连接层,适用于图像分类任务,如 MNIST 数据集。
代码实现
import torch
import torch.nn as nn
import torch.nn.functional as Fclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层1: 输入1个通道(灰度图像),输出32个通道self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)# 卷积层2: 输入32个通道,输出64个通道self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)# 卷积层3: 输入64个通道,输出128个通道self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)# 全连接层1: 输入128*7*7,输出1024个节点self.fc1 = nn.Linear(128 * 7 * 7, 1024)# 全连接层2: 输入1024个节点,输出10个节点(假设是10分类问题)self.fc2 = nn.Linear(1024, 10)# Dropout层: 避免过拟合self.dropout = nn.Dropout(0.5)def forward(self, x):# 第一层卷积 + ReLU 激活 + 最大池化x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2) # 使用2x2的最大池化# 第二层卷积 + ReLU 激活 + 最大池化x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)# 第三层卷积 + ReLU 激活 + 最大池化x = F.relu(self.conv3(x))x = F.max_pool2d(x, 2, 2)# 展平层(将卷积后的特征图展平成1D向量)x = x.view(-1, 128 * 7 * 7) # -1代表自动推算batch size# 第一个全连接层 + ReLU 激活 + Dropoutx = F.relu(self.fc1(x))x = self.dropout(x)# 第二个全连接层(输出最终分类结果)x = self.fc2(x)return x# 创建CNN模型
model = CNN()# 打印模型架构
print(model)
代码解析
-
卷积层(Conv2d):
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1):该层的输入为 1 个通道(灰度图像),输出 32 个通道,卷积核大小为 3x3,步幅为 1,填充为 1,保持输出特征图的大小与输入相同。- 后续的卷积层类似,只是输出通道数量逐渐增多。
-
激活函数(ReLU):
F.relu(self.conv1(x)):ReLU 激活函数将输入的负值转为 0,并保留正值,增加了模型的非线性。
-
池化层(MaxPool2d):
F.max_pool2d(x, 2, 2):使用 2x2 的池化窗口和步幅为 2 进行池化,将特征图尺寸缩小一半,减少计算复杂度。
-
展平(Flatten):
x = x.view(-1, 128 * 7 * 7):在经过卷积和池化操作后,我们将多维的特征图展平成一维向量,供全连接层输入。
-
Dropout:
self.dropout = nn.Dropout(0.5):Dropout 正则化技术在训练时随机丢弃一些神经元,防止过拟合。
-
全连接层(Linear):
self.fc1 = nn.Linear(128 * 7 * 7, 1024):第一个全连接层的输入是卷积后得到的特征,输出 1024 个节点。self.fc2 = nn.Linear(1024, 10):最后的全连接层将 1024 个节点压缩为 10 个输出,代表分类结果。
训练 CNN 模型
要训练该模型,我们需要加载一个数据集、定义损失函数和优化器,然后进行训练。以下是如何使用 MNIST 数据集进行训练的示例。
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 5
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 100 == 99: # 每100个batch输出一次损失print(f'Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {running_loss / 100:.4f}')running_loss = 0.0print("Finished Training")
训练过程说明
- 数据加载器(DataLoader):用于批量加载训练数据,支持数据的随机打乱(shuffle)。
- 损失函数(CrossEntropyLoss):用于多分类问题,计算预测和真实标签之间的交叉熵损失。
- 优化器(Adam):Adam 优化器自适应调整学习率,通常在深度学习中表现良好。
- 训练循环:每个 epoch 处理整个数据集,通过前向传播、计算损失、反向传播和优化步骤,更新网络参数。
总结
在这篇文章中,我们实现了一个标准的卷积神经网络(CNN),并使用 PyTorch 对其进行了定义和训练。通过使用卷积层、池化层和全连接层,模型能够自动学习图像的特征并进行分类。我们还介绍了如何训练模型、加载数据集以及使用常见的优化器和损失函数。希望这篇文章能帮助你理解 CNN 的基本架构及其实现方式!
相关文章:
使用 PyTorch 实现标准卷积神经网络(CNN)
卷积神经网络(CNN)是深度学习中的重要组成部分,广泛应用于图像处理、语音识别、视频分析等任务。在这篇博客中,我们将使用 PyTorch 实现一个标准的卷积神经网络(CNN),并介绍各个部分的作用。 什…...
开题报告——基于Spring Boot的垃圾分类预约回收系统
关于本科毕业设计(论文)开题报告的规定 为切实做好本科毕业设计(论文)的开题报告工作,保证论文质量,特作如下规定: 一、开题报告是本科毕业设计(论文)的必经过程,所有本科生在写作毕业设计(论文)之前都必须作开题报告。 二、开题报告主要检验学生对专业知识的驾驭能…...
YOLOv5 目标检测优化:降低误检与漏检
1. 引言 在目标检测任务中,误检(False Positive, FP)和漏检(False Negative, FN)是影响检测性能的两个主要问题。误检意味着模型检测到了不存在的目标,而漏检则指模型未能检测到真实存在的目标。本文将介绍…...
网络安全治理模型
0x02 知识点 安全的目标是提供 可用性 Avialability机密性 confidentiality完整性 Integrity真实性 Authenticity不可否认性 Nonrepudiation 安全治理是一个提供监督、问责和合规性的框架 信息安全系统 Information Security Management System ISMS 策略,工作程…...
网络原理-
文章目录 协议应用层传输层网络层 数据链路层 协议 在网络通信中,协议是非常重要的概念.协议就是一种约定. 在网络通信过程中,对协议进行了分层 接下来就按照顺序向大家介绍每一种核心的协议. 应用层 应用层是咱们程序员打交道最多的一层协议.应用层里有很多现成的协议,但…...
HTML/CSS中交集选择器
1.作用:选中同时符合多个条件的元素 交集就是或的意思 2.语法:选择器1选择器2选择器3......选择器n{} 3.举例: /* 选中:类名为beauty的p元素,此种写法用的非常的多 */p.beauty{color: red;}/* 选中:类名包含rich和beauty的元素 */.rich.beauty{color: blue;} 4.注意: 1.有标签…...
机器学习(1)安装Pytorch
1.安装命令 pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 2.安装过程Log: Looking in indexes: https://download.pytorch.org/whl/cu118 Co…...
Spring Boot过滤器链:从入门到精通
文章目录 一、过滤器链是什么?二、为什么需要过滤器链?三、Spring Boot中的过滤器链是如何工作的?(一)过滤器的生命周期(二)过滤器链的执行流程 四、如何在Spring Boot中定义自己的过滤器&#…...
vue3之echarts3D圆柱
vue3之echarts3D圆柱 效果: 版本 "echarts": "^5.1.2" 核心代码: <template><div ref"charts" class"charts"></div><svg><linearGradient id"labColor" x1"0&q…...
Redux中间件redux-thunk和redux-saga的具体区别是什么?
Redux 中间件是增强 Redux 功能的重要工具,redux-thunk 和 redux-saga 是两个常用的中间件,它们在处理异步操作和副作用时提供了不同的方式和理念。以下是两者的具体区别: 1. 概念与设计理念 redux-thunk 简洁:redux-thunk 是一…...
代码随想录算法训练营第四十三天| 动态规划06
322. 零钱兑换 如果求组合数就是外层for循环遍历物品,内层for遍历背包。 如果求排列数就是外层for遍历背包,内层for循环遍历物品。 这句话结合本题 大家要好好理解。 视频讲解:动态规划之完全背包,装满背包最少的物品件数是多少&…...
UI自动化教程 —— 元素定位技巧:精确找到你需要的页面元素
引言 在UI自动化测试中,准确地定位页面元素是至关重要的。无论是点击按钮、填写表单还是验证页面内容,都需要首先找到相应的页面元素。Playwright 提供了多种方法来实现这一点,包括使用CSS选择器和XPath进行元素定位,以及利用文本…...
MySQL六大日志的功能介绍。
前言 首先,MySQL的日志应该包括二进制日志(Binary Log)、错误日志(Error Log)、查询日志(General Query Log)、慢查询日志(Slow Query Log)、重做日志(Redo …...
二级指针略解【C语言】
以int** a为例 1.二级指针的声明 a 是一个指向 int*(指向整型的指针)的指针,即二级指针。 通俗的讲,a是一个指向指针的指针,对a解引用会是一个指针。 它可以用于操作动态分配的二维数组、指针数组或需要间接修改指针…...
鸿蒙状态管理概述
状态管理 状态管理之v1LocalStorageLocalStorageLink的框架行为LocalStorageProp的框架行为LocalStorage使用场景 AppStorageStorageLink的框架行为StorageProp的框架行为AppStorage的使用场景 PersistentStorageEnvironmentEnvironment内置参数 WatchWatch的使用场景 $$语法$$…...
【核心算法篇十三】《DeepSeek自监督学习:图像补全预训练方案》
引言:为什么自监督学习成为AI新宠? 在传统监督学习需要海量标注数据的困境下,自监督学习(Self-Supervised Learning)凭借无需人工标注的特性异军突起。想象一下,如果AI能像人类一样通过观察世界自我学习——这正是DeepSeek图像补全方案的技术哲学。根据,自监督学习通过…...
由浅入深学习大语言模型RLHF(PPO强化学习- v1浅浅的)
最近,随着DeepSeek的爆火,GRPO也走进了视野中。为了更好的学习GRPO,需要对PPO的强化学习有一个深入的理解,那么写一篇文章加深理解吧。纵观网上的文章,要么说PPO原理,各种复杂的公式看了就晕,要…...
网络安全三件套
一、在线安全的四个误解 Internet实际上是个有来有往的世界,你可以很轻松地连接到你喜爱的站点,而其他人,例如黑客也很方便地连接到你的机器。实际上,很多机器都因为自己很糟糕的在线安全设置无意间在机器和系统中留下了“…...
瑞芯微RV1126部署YOLOv8全流程:环境搭建、pt-onnx-rknn模型转换、C++推理代码、错误解决、优化、交叉编译第三方库
目录 1 环境搭建 2 交叉编译opencv 3 模型训练 4 模型转换 4.1 pt模型转onnx模型 4.2 onnx模型转rknn模型 4.2.1 安装rknn-toolkit 4.2.2 onn转成rknn模型 5 升级npu驱动 6 C++推理源码demo 6.1 原版demo 6.2 增加opencv读取图片的代码 7 交叉编译x264 ffmepg和op…...
【ISO 14229-1:2023 UDS诊断(会话控制0x10服务)测试用例CAPL代码全解析⑤】
ISO 14229-1:2023 UDS诊断【会话控制0x10服务】_TestCase05 作者:车端域控测试工程师 更新日期:2025年02月15日 关键词:UDS诊断、0x10服务、诊断会话控制、ECU测试、ISO 14229-1:2023 TC10-005测试用例 用例ID测试场景验证要点参考条款预期…...
HTML 语义化
目录 HTML 语义化HTML5 新特性HTML 语义化的好处语义化标签的使用场景最佳实践 HTML 语义化 HTML5 新特性 标准答案: 语义化标签: <header>:页头<nav>:导航<main>:主要内容<article>&#x…...
Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误
HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误,它们的含义、原因和解决方法都有显著区别。以下是详细对比: 1. HTTP 406 (Not Acceptable) 含义: 客户端请求的内容类型与服务器支持的内容类型不匹…...
五年级数学知识边界总结思考-下册
目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解:由来、作用与意义**一、知识点核心内容****二、知识点的由来:从生活实践到数学抽象****三、知识的作用:解决实际问题的工具****四、学习的意义:培养核心素养…...
第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明
AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...
OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 在 GPU 上对图像执行 均值漂移滤波(Mean Shift Filtering),用于图像分割或平滑处理。 该函数将输入图像中的…...
【VLNs篇】07:NavRL—在动态环境中学习安全飞行
项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战,克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...
处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的
修改bug思路: 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑:async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...
TSN交换机正在重构工业网络,PROFINET和EtherCAT会被取代吗?
在工业自动化持续演进的今天,通信网络的角色正变得愈发关键。 2025年6月6日,为期三天的华南国际工业博览会在深圳国际会展中心(宝安)圆满落幕。作为国内工业通信领域的技术型企业,光路科技(Fiberroad&…...
给网站添加live2d看板娘
给网站添加live2d看板娘 参考文献: stevenjoezhang/live2d-widget: 把萌萌哒的看板娘抱回家 (ノ≧∇≦)ノ | Live2D widget for web platformEikanya/Live2d-model: Live2d model collectionzenghongtu/live2d-model-assets 前言 网站环境如下,文章也主…...
[USACO23FEB] Bakery S
题目描述 Bessie 开了一家面包店! 在她的面包店里,Bessie 有一个烤箱,可以在 t C t_C tC 的时间内生产一块饼干或在 t M t_M tM 单位时间内生产一块松糕。 ( 1 ≤ t C , t M ≤ 10 9 ) (1 \le t_C,t_M \le 10^9) (1≤tC,tM≤109)。由于空间…...
