【深度学习】(5)--搭建卷积神经网络
文章目录
- 搭建卷积神经网络
- 一、数据预处理
- 1. 下载数据集
- 2. 创建DataLoader(数据加载器)
- 二、搭建神经网络
- 三、训练数据
- 四、优化模型
- 总结
搭建卷积神经网络
一、数据预处理
1. 下载数据集
在PyTorch中,有许多封装了很多与图像相关的模型、数据集,那么如何获取数据集呢?
导入datasets模块:
from torchvision import datasets #封装了很多与图像相关的模型,数据集
以datasets模块中的MNIST数据集为例,包含70000张手写数字图像:60000张用于训练,10000张用于测试。图像是灰度的,28*28像素,并且居中的,以减少预处理和加快运行。

from torch.utils.data import DataLoader #数据包管理工具,打包数据
from torchvision import datasets #封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型数据转换为tensor张量
"""-----下载训练集数据集-----"""
training_data = datasets.MNIST(root="data",train=True,# 取训练集download=True,transform=ToTensor(),# 张量,图片是不能直接传入神经网络模型的
) # 对于pytorch库能够识别的数据,一般是tensor张量"""-----下载测试集数据集-----"""
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# numpy数组只能在CPU上运行,Tensor可以在GPU上运行,这在深度学习中可以显著提高计算速度

下载完成之后可在project栏查看。
2. 创建DataLoader(数据加载器)
在PyTorch中,创建DataLoader的主要作用是将数据集(Dataset)加载到模型中,以便进行训练或推理。DataLoader通过封装数据集,提供了一个高效、灵活的方式来处理数据。
DataLoader通过batch_size参数将数据集自动划分为多个小批次(batch),每一批次的放入模型训练,减少内存的使用,提高训练速度。
import torch
from torch.utils.data import DataLoader
"""
创建数据DataLoader(数据加载器)
batch_size:将数据集分成多份,每一份为batch_size(指定数值)个数据。
优点:减少内存的使用,提高训练速度
"""
train_dataloder = DataLoader(training_data,batch_size=64)# 64张图片为一个包
test_datalodar = DataLoader(test_data,batch_size=64)
# 查看打包好的数据
for x,y in test_datalodar: #x是表示打包好的每一个数据包print(f"Shape of x [N, C, H, W]:{x.shape}")print(f"Shape of y:{y.shape} {y.dtype}")break
-----------------------
Shape of x [N, C, H, W]:torch.Size([64, 1, 28, 28])
Shape of y:torch.Size([64]) torch.int64
二、搭建神经网络

注意:同普通的神经网络不同,卷积神经网络在传入图片时不需要将其展开,因为对图片进行卷积就是在原图上进行内积,不能展开。
卷积神经网络是由输入层、卷积层、激活层、池化层、全连接层、输出层组成。所以在结构上我们也同这样式的来,但是可以搭建多层卷积哦!
"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu""""-----定义神经网络-----"""
class CNN(nn.Module):def __init__(self): # 输入大小(1,28,28)super(CNN,self).__init__()self.conv1 = nn.Sequential( # 将多个层组合在一起nn.Conv2d( # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=1, # 图像通道个数,1表示灰度图(确定卷积核 组中的个数)out_channels=16, # 要得到多少特征图,卷积核的个数kernel_size=5, # 卷积核大小stride=1, # 步长padding=2 # 边界填充大小), # 输出的特征图为(16,28,28)-->16个大小28*28的图像nn.ReLU(), # relu层,不会改变特征图的大小nn.MaxPool2d(kernel_size=2) # 进行池化操作(2*2区域),输出结果为(16,14,14))self.conv2 = nn.Sequential( # 输入(16,14,14)nn.Conv2d(16,32,5,1,2), # 输出(32*14*14)nn.ReLU(),nn.Conv2d(32,32,5,1,2), # 输出(32*14*14)nn.ReLU(),nn.MaxPool2d(2) # 输出(32,7,7))self.conv3 = nn.Sequential( # 输入(32,7,7)nn.Conv2d(32,64,5,1,2), # 输出(64,7,7)nn.ReLU())self.out = nn.Linear(64*7*7,10)def forward(self,x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x) # 输出(64,7,7)x = x.view(x.size(0),-1) # flatten 操作,结果为:(batch_size,64*7*7)output = self.out(x)return outputmodel = CNN().to(device)
三、训练数据
- optimizer优化器:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
- loss_fn损失函数:
在PyTorch中,**nn.CrossEntropyLoss()**是一个常用的损失函数,它结合了 nn.LogSoftmax() 和 nn.NLLLoss()(负对数似然损失)在一个单独的类中。
loss_fn = nn.CrossEntropyLoss()
- 训练集
from torch import nn #导入神经网络模块
def train(dataloader,model,loss_fn,optimizer):model.train()# 设置模型为训练模式batch_size_num =1# 迭代次数 for x,y in dataloader:x,y = x.to(device),y.to(device) # 将数据和标签发送到指定设备 pred = model.forward(x) # 前向传播 loss = loss_fn(pred,y) # 计算损失 optimizer.zero_grad() # 清除之前的梯度 loss.backward() # 反向传播 optimizer.step() # 更新模型参数 loss_value = loss.item() # 获取损失值if batch_size_num %200 == 0: # 每200次迭代打印一次损失 print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1
train(train_dataloder,model,loss_fn,optimizer)
------------------------
loss:0.158841 [number:200]
loss:0.242431 [number:400]
loss:0.173504 [number:600]
loss:0.020542 [number:800]
- 测试集
"""-----测试集-----"""
def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")test(test_datalodar,model,loss_fn)
--------------------
Test result: Accuracy:98.11999999999999%,Avg loss:0.05511626677004996
四、优化模型
通过多次迭代,神经网络不断调整其内部参数(如权重和偏置),以最小化预测值与实际值之间的误差。这种优化过程使得神经网络能够更准确地处理输入数据,提高分类、回归等任务的性能。
epochs = 5
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloder,model,loss_fn,optimizer)
print("Done!")
test(test_datalodar,model,loss_fn)
输出结果:
Epoch 1
-------------------------
loss:0.182339 [number:200]
loss:0.229839 [number:400]
loss:0.210450 [number:600]
loss:0.028532 [number:800]
Epoch 2
-------------------------
loss:0.066216 [number:200]
loss:0.149762 [number:400]
loss:0.084482 [number:600]
loss:0.003749 [number:800]
…………
Done!
Test result: Accuracy:98.99%,Avg loss:0.03138259953491878
总结
本篇介绍了如何搭建卷积神经网络,其主要的构造部分为卷积层、激活层以及池化层,可以搭建多层该部分对数据进行多次卷积、池化。
注意:同普通的神经网络不同,卷积神经网络在传入图片时不需要将其展开,因为对图片进行卷积就是在原图上进行内积,不能展开。
相关文章:
【深度学习】(5)--搭建卷积神经网络
文章目录 搭建卷积神经网络一、数据预处理1. 下载数据集2. 创建DataLoader(数据加载器) 二、搭建神经网络三、训练数据四、优化模型 总结 搭建卷积神经网络 一、数据预处理 1. 下载数据集 在PyTorch中,有许多封装了很多与图像相关的模型、…...
边学英语边学 Java|Synchronization in java
Why use Java Synchronization? Java Synchronization is used to make sure by some synchronization method that only one thread can access the resource at a given point in time. Java 同步用于确保通过某种同步方法,在给定的时间点只有一个线程可以访问资…...
k8s StorageClass 存储类
文章目录 一、概述1、StorageClass 对象定义2、StorageClass YAML 示例 二、StorageClass 字段1、provisioner(存储制备器)1.1、内置制备器1.2、第三方制备器 2、reclaimPolicy(回收策略)3、allowVolumeExpansion(允许…...
3D Slicer医学图像全自动AI分割组合拳-MONAIAuto3DSeg扩展
3D Slicer医学图像全自动AI分割组合拳-MONAIAuto3DSeg扩展 1 官网下载最新3D Slicer image computing platform | 3D Slicer 版本5.7 2 安装torch依赖包: 2.1 进入安装目录C:\Users\wangzhenlin\AppData\Local\slicer.org\Slicer 5.7.0-2024-09-21\bin࿰…...
分布式光伏的发电监控
国拥有丰富的清洁可再生能源资源储量,积极开发利用可再生能源,为解决当前化石能源短缺与环境污染严重的燃眉之急提供了有效途径[1]。但是可再生能源的利用和开发,可再生能源技术的发展和推广以及可再生能源资源对环境保护的正向影响ÿ…...
微信小程序----日期时间选择器(自定义时间精确到分秒)
目录 页面效果 代码实现 注意事项 页面效果 代码实现 js Component({/*** 组件的属性列表*/properties: {pickerShow: {type: Boolean,},config: Object,},/*** 组件的初始数据*/data: {pickerReady: false,// pickerShow:true// limitStartTime: new Date().getTime()-…...
3D生成技术再创新高:VAST发布Tripo 2.0,提升AI 3D生成新高度
随着《黑神话悟空》的爆火,3D游戏背后的AI 3D生成技术也逐渐受到更多的关注。虽然3D大模型的热度相较于语言模型和视频生成技术稍逊一筹,但全球的3D大模型玩家们却从未放慢脚步。无论是a16z支持的Yellow,还是李飞飞创立的World Labsÿ…...
ONNX Runtime学习之InferenceSession模块
ONNXRuntime库学习之InferenceSession(模块) 一、简介 onnxruntime.InferenceSession 是 ONNX Runtime 中用于加载和运行 ONNX 模型的核心模块。它提供了一种灵活的方式来在多种硬件设备(如 CPU、GPU)上执行 ONNX 模型推理。通过 InferenceSession&…...
【TS】TypeScript内置条件类型-ReturnType
ReturnType 在TypeScript中,ReturnType 是一个内置的条件类型(Conditional Type),它用于获取一个函数返回值的类型。这个工具类型非常有用,特别是当你需要引用某个函数的返回类型,但又不想直接写出那个具体…...
【c语言数据结构】超详细!模拟实现双向链表(初始化、销毁、头删、尾删、头插、尾插、指定位置插入与删除、查找数据、判断链表是否为空)
特点: 结构:指向前一结点指针数据指向后一结点指针由于循环,尾结点的下一结点next指向头结点(哨兵结点)空的双向链表只有自循环的哨兵结点(头结点) 模拟实现双向链表 LIST.h #define _CRT_…...
第十四届蓝桥杯嵌入式国赛
一. 前言 本篇博客主要讲述十四届蓝桥杯嵌入式的国赛题目,包括STM32CubeMx的相关配置以及相关功能实现代码以及我在做题过程中所遇到的一些问题和总结收获。如果有兴趣的伙伴还可以去做做其它届的真题,可去 蓝桥云课 上搜索历届真题即可。 二. 题目概述 …...
(k8s)kubernetes集群基于Containerd部署
资源列表 基础环境 一、基础环境准备 1.1、关闭Swap分区 1.2、添加hosts解析 1.3、桥接的IPv4流量传递给iptables的链 二、准备Containerd容器运行时 2.1、安装Containerd 2.2、配置Containerd 2.3、启动Containerd 三、部署Kubernetes集群 3.1、安装Kubeadm工具 3.2、…...
python内置模块pathlib.Path类操作目录和文件
python自带的pathlib模块提供了很多路径相关的功能,而pathlib.Path 是pathlib 模块中的一个核心类,它代表了文件系统中的一个路径,实现功能比如创建、删除、移动文件,读取和写入文件内容,遍历目录等。 Path 类跟os.pa…...
react开发环境搭建
文章目录 准备工作创建 React 项目使用 create-react-app 创建 React 项目使用 Vite 创建 React 项目启动项目效果安装出现的情况 react项目文件讲解1. 项目根目录2. 其他可能的目录和文件3. 配置文件 准备工作 Node.js 安装方法: 方式一:使用 NVM 安装…...
python 逻辑语句简记
什么语言都少不了逻辑处理语句的使用,python的逻辑处理语句有自身的使用特点,稍稍总结记录一下 一、断言 assert 条件 条件触发,程序执行中断 二、条件语句 if 条件: 执行内容 三、循环语句 while 条件: 循环体…...
8.进销存系统(基于springboot的进销存系统)
目录 1.系统的受众说明 2.开发技术与环境配置 2.1 SpringBoot框架 2.2 Java语言简介 2.3 MySQL环境配置 2.4 idea介绍 2.5 mysql数据库介绍 2.6 B/S架构 3.系统分析与设计 3.1 可行性分析 3.1.1 技术可行性 3.1.2 操作可行性 3.1.3经济可行性 3.4.1 数据库…...
深入理解主键回显:提升数据操作效率与准确性
在软件开发的世界中,主键回显是一个常常被提及但又容易被忽视其重要性的概念。今天,我们就来深入探讨一下主键回显的奥秘。 一、什么是主键回显? 在数据库设计中,主键是用于唯一标识表中每一行记录的字段。而主键回显࿰…...
springboot+阿里云物联网教程
需求背景 最近有一个项目,需要用到阿里云物联网,不是MQ。发现使用原来EMQX的代码去连接阿里云MQTT直接报错,试了很多种方案都不行。最终还是把错误分析和教程都整理一下。 需要注意的是,阿里云物联网平台和MQ不一样。方向别走偏了。 概念描述 EMQX和阿里云MQTT有什么区别…...
QT Creator cmake 自定义项目结构, 编译输出目录指定
1. 目的 将不同的源文件放到不同的目录下进行管理, 如下: build: 编译输出目录 include: 头文件目录 rsources: 资源文件目录 src: cpp文件目录 2. 创建完cmake工程后修改CMakeLists.txt 配置 注 : 这里头文件目录是include, 所以在includ…...
lunar无第三方依赖的公历、农历、法定节假日...日历工具库
文章目录 介绍maven示例示例(前后端)网址文档 介绍 lunar是一款无第三方依赖的公历(阳历)、农历(阴历、老黄历)、道历、佛历工具,支持星座、儒略日、干支、生肖、节气、节日、彭祖百忌、吉神(喜神/福神/财神/阳贵神/阴贵神)方位、胎神方位、…...
别再死记硬背了!用这5个真实运维脚本,搞定90%的Shell面试题
5个实战Shell脚本:从面试题到真实运维场景的蜕变 在技术面试中,Shell脚本能力往往是区分普通候选人和优秀候选人的关键指标。但死记硬背面试题答案的时代已经过去,现代企业更看重候选人解决实际问题的能力。本文将带你通过5个真实运维场景中的…...
从混淆矩阵到Kappa系数:实战解析土地利用分类精度评估全流程
1. 土地利用分类精度评估入门指南 当你完成了一张精美的土地利用分类图,最常被问到的问题往往是:"这个结果到底有多准?"作为从业多年的GIS分析师,我见过太多人只关注分类过程却忽视精度验证,最后在项目汇报时…...
嵌入式开发代码版本比较工具与技巧
1. 嵌入式开发中的代码版本差异查看方法在嵌入式开发过程中,代码版本管理是每个工程师必须掌握的核心技能。随着项目迭代和功能更新,我们经常需要比较不同版本代码之间的差异,无论是为了代码审查、问题排查还是版本合并。作为一名嵌入式开发者…...
利用华为云MaaS与OpenTiny NEXT构建智能电商后台:从传统操作到AI驱动的自动化升级
1. 传统电商后台的痛点与AI转型机遇 电商后台管理系统一直是运营人员的"战场",每天面对商品上下架、库存调整、数据统计等重复性工作。记得三年前我参与过一个母婴电商项目,运营团队每天要手动处理上百个商品信息更新,高峰期经常加…...
MoveIt2的KDL插件不好用?手把手教你自定义关节权重,优化机械臂运动优先级
MoveIt2关节权重调优实战:如何让冗余机械臂按你的想法运动 当机械臂的第七个关节开始不受控制地乱转,而前三个关节却几乎不动时,大多数工程师的第一反应是"这IK算法有问题"。但真相往往是:算法没问题,只是它…...
避坑指南:从零搭建Anaconda+CUDA+PyTorch+Pycharm深度学习环境
1. 深度学习环境配置全景图 刚接触深度学习的新手往往会在环境配置这一步卡住好几天。我见过太多人在Anaconda、CUDA、PyTorch的版本兼容性问题上来回折腾,最后连代码都没开始写就放弃了。其实只要理解这四个核心组件的关系,配置过程就会变得清晰很多。 …...
GLM-4-9B-Chat-1M惊艳效果:复杂SQL代码库跨文件依赖关系可视化
GLM-4-9B-Chat-1M惊艳效果:复杂SQL代码库跨文件依赖关系可视化 1. 项目背景与核心价值 当你面对一个包含数百个SQL文件的大型数据仓库项目时,最头疼的问题是什么?我相信很多开发者和数据工程师都会说:理不清的表依赖关系。 传统…...
为什么conda装不上opencv-python?深入解析conda与pip的包管理差异
为什么conda装不上opencv-python?深入解析conda与pip的包管理差异 在Python生态系统中,conda和pip是最常用的两种包管理工具。许多开发者习惯使用conda创建和管理虚拟环境,但在安装某些特定包如opencv-python时,却常常遇到"P…...
Bypass Paywalls Clean 3大突破策略:2024浏览器扩展技术指南
Bypass Paywalls Clean 3大突破策略:2024浏览器扩展技术指南 【免费下载链接】bypass-paywalls-chrome-clean 项目地址: https://gitcode.com/GitHub_Trending/by/bypass-paywalls-chrome-clean 当你在撰写行业分析报告时,是否曾因关键数据被付费…...
3大核心功能揭秘:CELLxGENE如何让单细胞数据分析变得如此简单
3大核心功能揭秘:CELLxGENE如何让单细胞数据分析变得如此简单 【免费下载链接】cellxgene An interactive explorer for single-cell transcriptomics data 项目地址: https://gitcode.com/gh_mirrors/ce/cellxgene 在单细胞转录组学研究中,数据分…...
