当前位置: 首页 > news >正文

使用 PyTorch 实现标准卷积神经网络(CNN)

卷积神经网络(CNN)是深度学习中的重要组成部分,广泛应用于图像处理、语音识别、视频分析等任务。在这篇博客中,我们将使用 PyTorch 实现一个标准的卷积神经网络(CNN),并介绍各个部分的作用。

什么是卷积神经网络(CNN)?

卷积神经网络(CNN)是一种专门用于处理图像数据的深度学习模型,它通过卷积层提取图像的特征。CNN 由多个层次组成,其中包括卷积层(Conv2d)、池化层(MaxPool2d)、全连接层(Linear)、激活函数(ReLU)等。这些层级合作,使得模型能够从原始图像中自动学习到重要特征。

CNN 的核心组成部分

  1. 卷积层(Conv2d):用于提取输入图像的局部特征,通过多个卷积核对图像进行卷积运算。
  2. 激活函数(ReLU):增加非线性,使得模型能够学习更复杂的特征。
  3. 池化层(MaxPool2d):通过对特征图进行下采样来减少空间尺寸,降低计算复杂度,同时保留重要的特征。
  4. 全连接层(Linear):将卷积和池化后得到的特征图展平,送入全连接层进行分类或回归预测。

PyTorch 实现 CNN

下面是我们实现的标准卷积神经网络模型。它包含三个卷积层和两个全连接层,适用于图像分类任务,如 MNIST 数据集。

代码实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层1: 输入1个通道(灰度图像),输出32个通道self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)# 卷积层2: 输入32个通道,输出64个通道self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)# 卷积层3: 输入64个通道,输出128个通道self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)# 全连接层1: 输入128*7*7,输出1024个节点self.fc1 = nn.Linear(128 * 7 * 7, 1024)# 全连接层2: 输入1024个节点,输出10个节点(假设是10分类问题)self.fc2 = nn.Linear(1024, 10)# Dropout层: 避免过拟合self.dropout = nn.Dropout(0.5)def forward(self, x):# 第一层卷积 + ReLU 激活 + 最大池化x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)  # 使用2x2的最大池化# 第二层卷积 + ReLU 激活 + 最大池化x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)# 第三层卷积 + ReLU 激活 + 最大池化x = F.relu(self.conv3(x))x = F.max_pool2d(x, 2, 2)# 展平层(将卷积后的特征图展平成1D向量)x = x.view(-1, 128 * 7 * 7)  # -1代表自动推算batch size# 第一个全连接层 + ReLU 激活 + Dropoutx = F.relu(self.fc1(x))x = self.dropout(x)# 第二个全连接层(输出最终分类结果)x = self.fc2(x)return x# 创建CNN模型
model = CNN()# 打印模型架构
print(model)

代码解析

  1. 卷积层(Conv2d)

    • self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1):该层的输入为 1 个通道(灰度图像),输出 32 个通道,卷积核大小为 3x3,步幅为 1,填充为 1,保持输出特征图的大小与输入相同。
    • 后续的卷积层类似,只是输出通道数量逐渐增多。
  2. 激活函数(ReLU)

    • F.relu(self.conv1(x)):ReLU 激活函数将输入的负值转为 0,并保留正值,增加了模型的非线性。
  3. 池化层(MaxPool2d)

    • F.max_pool2d(x, 2, 2):使用 2x2 的池化窗口和步幅为 2 进行池化,将特征图尺寸缩小一半,减少计算复杂度。
  4. 展平(Flatten)

    • x = x.view(-1, 128 * 7 * 7):在经过卷积和池化操作后,我们将多维的特征图展平成一维向量,供全连接层输入。
  5. Dropout

    • self.dropout = nn.Dropout(0.5):Dropout 正则化技术在训练时随机丢弃一些神经元,防止过拟合。
  6. 全连接层(Linear)

    • self.fc1 = nn.Linear(128 * 7 * 7, 1024):第一个全连接层的输入是卷积后得到的特征,输出 1024 个节点。
    • self.fc2 = nn.Linear(1024, 10):最后的全连接层将 1024 个节点压缩为 10 个输出,代表分类结果。

训练 CNN 模型

要训练该模型,我们需要加载一个数据集、定义损失函数和优化器,然后进行训练。以下是如何使用 MNIST 数据集进行训练的示例。

import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 5
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 100 == 99:  # 每100个batch输出一次损失print(f'Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {running_loss / 100:.4f}')running_loss = 0.0print("Finished Training")

训练过程说明

  • 数据加载器(DataLoader):用于批量加载训练数据,支持数据的随机打乱(shuffle)。
  • 损失函数(CrossEntropyLoss):用于多分类问题,计算预测和真实标签之间的交叉熵损失。
  • 优化器(Adam):Adam 优化器自适应调整学习率,通常在深度学习中表现良好。
  • 训练循环:每个 epoch 处理整个数据集,通过前向传播、计算损失、反向传播和优化步骤,更新网络参数。

总结

在这篇文章中,我们实现了一个标准的卷积神经网络(CNN),并使用 PyTorch 对其进行了定义和训练。通过使用卷积层、池化层和全连接层,模型能够自动学习图像的特征并进行分类。我们还介绍了如何训练模型、加载数据集以及使用常见的优化器和损失函数。希望这篇文章能帮助你理解 CNN 的基本架构及其实现方式!

相关文章:

使用 PyTorch 实现标准卷积神经网络(CNN)

卷积神经网络(CNN)是深度学习中的重要组成部分,广泛应用于图像处理、语音识别、视频分析等任务。在这篇博客中,我们将使用 PyTorch 实现一个标准的卷积神经网络(CNN),并介绍各个部分的作用。 什…...

开题报告——基于Spring Boot的垃圾分类预约回收系统

关于本科毕业设计(论文)开题报告的规定 为切实做好本科毕业设计(论文)的开题报告工作,保证论文质量,特作如下规定: 一、开题报告是本科毕业设计(论文)的必经过程,所有本科生在写作毕业设计(论文)之前都必须作开题报告。 二、开题报告主要检验学生对专业知识的驾驭能…...

YOLOv5 目标检测优化:降低误检与漏检

1. 引言 在目标检测任务中,误检(False Positive, FP)和漏检(False Negative, FN)是影响检测性能的两个主要问题。误检意味着模型检测到了不存在的目标,而漏检则指模型未能检测到真实存在的目标。本文将介绍…...

网络安全治理模型

0x02 知识点 安全的目标是提供 可用性 Avialability机密性 confidentiality完整性 Integrity真实性 Authenticity不可否认性 Nonrepudiation 安全治理是一个提供监督、问责和合规性的框架 信息安全系统 Information Security Management System ISMS 策略,工作程…...

网络原理-

文章目录 协议应用层传输层网络层 数据链路层 协议 在网络通信中,协议是非常重要的概念.协议就是一种约定. 在网络通信过程中,对协议进行了分层 接下来就按照顺序向大家介绍每一种核心的协议. 应用层 应用层是咱们程序员打交道最多的一层协议.应用层里有很多现成的协议,但…...

HTML/CSS中交集选择器

1.作用:选中同时符合多个条件的元素 交集就是或的意思 2.语法:选择器1选择器2选择器3......选择器n{} 3.举例: /* 选中:类名为beauty的p元素,此种写法用的非常的多 */p.beauty{color: red;}/* 选中:类名包含rich和beauty的元素 */.rich.beauty{color: blue;} 4.注意: 1.有标签…...

机器学习(1)安装Pytorch

1.安装命令 pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 2.安装过程Log: Looking in indexes: https://download.pytorch.org/whl/cu118 Co…...

Spring Boot过滤器链:从入门到精通

文章目录 一、过滤器链是什么?二、为什么需要过滤器链?三、Spring Boot中的过滤器链是如何工作的?(一)过滤器的生命周期(二)过滤器链的执行流程 四、如何在Spring Boot中定义自己的过滤器&#…...

vue3之echarts3D圆柱

vue3之echarts3D圆柱 效果&#xff1a; 版本 "echarts": "^5.1.2" 核心代码&#xff1a; <template><div ref"charts" class"charts"></div><svg><linearGradient id"labColor" x1"0&q…...

Redux中间件redux-thunk和redux-saga的具体区别是什么?

Redux 中间件是增强 Redux 功能的重要工具&#xff0c;redux-thunk 和 redux-saga 是两个常用的中间件&#xff0c;它们在处理异步操作和副作用时提供了不同的方式和理念。以下是两者的具体区别&#xff1a; 1. 概念与设计理念 redux-thunk 简洁&#xff1a;redux-thunk 是一…...

代码随想录算法训练营第四十三天| 动态规划06

322. 零钱兑换 如果求组合数就是外层for循环遍历物品&#xff0c;内层for遍历背包。 如果求排列数就是外层for遍历背包&#xff0c;内层for循环遍历物品。 这句话结合本题 大家要好好理解。 视频讲解&#xff1a;动态规划之完全背包&#xff0c;装满背包最少的物品件数是多少&…...

UI自动化教程 —— 元素定位技巧:精确找到你需要的页面元素

引言 在UI自动化测试中&#xff0c;准确地定位页面元素是至关重要的。无论是点击按钮、填写表单还是验证页面内容&#xff0c;都需要首先找到相应的页面元素。Playwright 提供了多种方法来实现这一点&#xff0c;包括使用CSS选择器和XPath进行元素定位&#xff0c;以及利用文本…...

MySQL六大日志的功能介绍。

前言 首先&#xff0c;MySQL的日志应该包括二进制日志&#xff08;Binary Log&#xff09;、错误日志&#xff08;Error Log&#xff09;、查询日志&#xff08;General Query Log&#xff09;、慢查询日志&#xff08;Slow Query Log&#xff09;、重做日志&#xff08;Redo …...

二级指针略解【C语言】

以int** a为例 1.二级指针的声明 a 是一个指向 int*&#xff08;指向整型的指针&#xff09;的指针&#xff0c;即二级指针。 通俗的讲&#xff0c;a是一个指向指针的指针&#xff0c;对a解引用会是一个指针。 它可以用于操作动态分配的二维数组、指针数组或需要间接修改指针…...

鸿蒙状态管理概述

状态管理 状态管理之v1LocalStorageLocalStorageLink的框架行为LocalStorageProp的框架行为LocalStorage使用场景 AppStorageStorageLink的框架行为StorageProp的框架行为AppStorage的使用场景 PersistentStorageEnvironmentEnvironment内置参数 WatchWatch的使用场景 $$语法$$…...

【核心算法篇十三】《DeepSeek自监督学习:图像补全预训练方案》

引言:为什么自监督学习成为AI新宠? 在传统监督学习需要海量标注数据的困境下,自监督学习(Self-Supervised Learning)凭借无需人工标注的特性异军突起。想象一下,如果AI能像人类一样通过观察世界自我学习——这正是DeepSeek图像补全方案的技术哲学。根据,自监督学习通过…...

由浅入深学习大语言模型RLHF(PPO强化学习- v1浅浅的)

最近&#xff0c;随着DeepSeek的爆火&#xff0c;GRPO也走进了视野中。为了更好的学习GRPO&#xff0c;需要对PPO的强化学习有一个深入的理解&#xff0c;那么写一篇文章加深理解吧。纵观网上的文章&#xff0c;要么说PPO原理&#xff0c;各种复杂的公式看了就晕&#xff0c;要…...

网络安全三件套

一、在线安全的四个误解     Internet实际上是个有来有往的世界&#xff0c;你可以很轻松地连接到你喜爱的站点&#xff0c;而其他人&#xff0c;例如黑客也很方便地连接到你的机器。实际上&#xff0c;很多机器都因为自己很糟糕的在线安全设置无意间在机器和系统中留下了“…...

瑞芯微RV1126部署YOLOv8全流程:环境搭建、pt-onnx-rknn模型转换、C++推理代码、错误解决、优化、交叉编译第三方库

目录 1 环境搭建 2 交叉编译opencv 3 模型训练 4 模型转换 4.1 pt模型转onnx模型 4.2 onnx模型转rknn模型 4.2.1 安装rknn-toolkit 4.2.2 onn转成rknn模型 5 升级npu驱动 6 C++推理源码demo 6.1 原版demo 6.2 增加opencv读取图片的代码 7 交叉编译x264 ffmepg和op…...

【ISO 14229-1:2023 UDS诊断(会话控制0x10服务)测试用例CAPL代码全解析⑤】

ISO 14229-1:2023 UDS诊断【会话控制0x10服务】_TestCase05 作者&#xff1a;车端域控测试工程师 更新日期&#xff1a;2025年02月15日 关键词&#xff1a;UDS诊断、0x10服务、诊断会话控制、ECU测试、ISO 14229-1:2023 TC10-005测试用例 用例ID测试场景验证要点参考条款预期…...

基于VirtualLab Fusion的光学检测与精密成像(光学检测、精密成像、显微镜系统)课程

基于VirtualLab Fusion的光学检测与精密成像&#xff08;光学检测、精密成像、显微镜系统&#xff09;课程时长&#xff1a;2天/城市授课地点&#xff1a;上海本课程聚焦于利用VirtualLab Fusion先进的光之数字模型平台&#xff0c;解决光学检测与精密成像系统的核心设计挑战。…...

LeetCode 二分图判定题解

LeetCode 二分图判定题解 题目描述 二分图是一种特殊的图&#xff0c;它的顶点可以被分为两个不相交的集合&#xff0c;使得图中的每条边都连接不同集合中的顶点。 示例&#xff1a; 对于以下图&#xff1a;A -- B| |C -- D这是一个二分图&#xff0c;因为可以将顶点分为两个…...

2026国产大模型API价格战再升级:DeepSeek V4把行业打进“厘时代”,谁还扛得住?

2026年的国产大模型市场&#xff0c;正在发生一件足够改变行业格局的大事&#xff1a; 不是谁参数最大。 不是谁榜单第一。 而是——DeepSeek V4用极致低价&#xff0c;把整个行业的商业逻辑重新改写了。 当主流厂商还在讨论模型性能、上下文长度、多模态能力时&#xff0c;Dee…...

【机械制图及CAD实战(一)】专栏简介

《机械制图》是为工科学生提供的技术基础课&#xff0c;旨在培养他们绘制和阅读机械图样的能力&#xff0c;为后续专业学习奠定基础。 它以几何学和投影理论为基础&#xff0c;教授学生掌握国家标准、图样绘制与读图方法、标准件知识以及零件图和装配图的绘制。课程目标是培养学…...

超越G代码:深入LinuxCNC的HAL层,像搭积木一样自定义你的数控逻辑(附Python联动案例)

超越G代码&#xff1a;深入LinuxCNC的HAL层&#xff0c;像搭积木一样自定义你的数控逻辑&#xff08;附Python联动案例&#xff09; 当大多数CNC开发者还在G代码的海洋中挣扎时&#xff0c;少数先行者已经发现了LinuxCNC中隐藏的"魔法工具箱"——硬件抽象层(HAL)。这…...

AI代理与Jina工具实现智能网页抓取方案

1. 项目概述这个标题描述了一个相当有趣的AI应用场景&#xff1a;AI代理如何利用Jina的URL转Markdown工具&#xff0c;在KaibanJS框架中实现更智能化的网页抓取方案。作为一名长期从事自动化工具开发的工程师&#xff0c;我最近在实际项目中深度应用了这套技术栈&#xff0c;发…...

告别手动分页!用z-paging在uni-app里5分钟搞定列表加载(附完整配置流程)

告别手动分页&#xff01;用z-paging在uni-app里5分钟搞定列表加载&#xff08;附完整配置流程&#xff09; 每次开发uni-app的列表页&#xff0c;最头疼的就是处理分页逻辑。下拉刷新要重置数据、上拉加载要拼接数组、空状态要手动判断...这些重复劳动不仅浪费时间&#xff0c…...

LinkSwift:八大网盘直链解析工具,重塑你的下载体验

LinkSwift&#xff1a;八大网盘直链解析工具&#xff0c;重塑你的下载体验 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 &#xff0c;支持 百度网盘 / 阿里云盘 / 中国移动云盘…...

开源配置管理库opencode-config:轻量级、强类型、动态刷新的Java配置解决方案

1. 项目概述&#xff1a;一个开源配置管理库的诞生与价值在软件开发中&#xff0c;配置管理是个老生常谈却又常谈常新的问题。从单体应用时代写在application.properties里的几行键值对&#xff0c;到微服务架构下动辄上百个服务的环境变量、数据库连接串、第三方API密钥&#…...

跨模态注意力机制在视频理解中的应用与优化

1. 跨模态注意力机制的技术解析跨模态注意力机制&#xff08;Cross-Attention&#xff09;作为连接视觉与语言模态的核心技术&#xff0c;其工作原理类似于人类大脑处理多感官信息的方式。当我们在观看视频时&#xff0c;视觉皮层和语言中枢会协同工作——这正是跨模态注意力在…...