【深度学习】(5)--搭建卷积神经网络
文章目录
- 搭建卷积神经网络
- 一、数据预处理
- 1. 下载数据集
- 2. 创建DataLoader(数据加载器)
- 二、搭建神经网络
- 三、训练数据
- 四、优化模型
- 总结
搭建卷积神经网络
一、数据预处理
1. 下载数据集
在PyTorch中,有许多封装了很多与图像相关的模型、数据集,那么如何获取数据集呢?
导入datasets模块:
from torchvision import datasets #封装了很多与图像相关的模型,数据集
以datasets模块中的MNIST数据集为例,包含70000张手写数字图像:60000张用于训练,10000张用于测试。图像是灰度的,28*28像素,并且居中的,以减少预处理和加快运行。

from torch.utils.data import DataLoader #数据包管理工具,打包数据
from torchvision import datasets #封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型数据转换为tensor张量
"""-----下载训练集数据集-----"""
training_data = datasets.MNIST(root="data",train=True,# 取训练集download=True,transform=ToTensor(),# 张量,图片是不能直接传入神经网络模型的
) # 对于pytorch库能够识别的数据,一般是tensor张量"""-----下载测试集数据集-----"""
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# numpy数组只能在CPU上运行,Tensor可以在GPU上运行,这在深度学习中可以显著提高计算速度

下载完成之后可在project栏查看。
2. 创建DataLoader(数据加载器)
在PyTorch中,创建DataLoader的主要作用是将数据集(Dataset)加载到模型中,以便进行训练或推理。DataLoader通过封装数据集,提供了一个高效、灵活的方式来处理数据。
DataLoader通过batch_size参数将数据集自动划分为多个小批次(batch),每一批次的放入模型训练,减少内存的使用,提高训练速度。
import torch
from torch.utils.data import DataLoader
"""
创建数据DataLoader(数据加载器)
batch_size:将数据集分成多份,每一份为batch_size(指定数值)个数据。
优点:减少内存的使用,提高训练速度
"""
train_dataloder = DataLoader(training_data,batch_size=64)# 64张图片为一个包
test_datalodar = DataLoader(test_data,batch_size=64)
# 查看打包好的数据
for x,y in test_datalodar: #x是表示打包好的每一个数据包print(f"Shape of x [N, C, H, W]:{x.shape}")print(f"Shape of y:{y.shape} {y.dtype}")break
-----------------------
Shape of x [N, C, H, W]:torch.Size([64, 1, 28, 28])
Shape of y:torch.Size([64]) torch.int64
二、搭建神经网络

注意:同普通的神经网络不同,卷积神经网络在传入图片时不需要将其展开,因为对图片进行卷积就是在原图上进行内积,不能展开。
卷积神经网络是由输入层、卷积层、激活层、池化层、全连接层、输出层组成。所以在结构上我们也同这样式的来,但是可以搭建多层卷积哦!
"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu""""-----定义神经网络-----"""
class CNN(nn.Module):def __init__(self): # 输入大小(1,28,28)super(CNN,self).__init__()self.conv1 = nn.Sequential( # 将多个层组合在一起nn.Conv2d( # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=1, # 图像通道个数,1表示灰度图(确定卷积核 组中的个数)out_channels=16, # 要得到多少特征图,卷积核的个数kernel_size=5, # 卷积核大小stride=1, # 步长padding=2 # 边界填充大小), # 输出的特征图为(16,28,28)-->16个大小28*28的图像nn.ReLU(), # relu层,不会改变特征图的大小nn.MaxPool2d(kernel_size=2) # 进行池化操作(2*2区域),输出结果为(16,14,14))self.conv2 = nn.Sequential( # 输入(16,14,14)nn.Conv2d(16,32,5,1,2), # 输出(32*14*14)nn.ReLU(),nn.Conv2d(32,32,5,1,2), # 输出(32*14*14)nn.ReLU(),nn.MaxPool2d(2) # 输出(32,7,7))self.conv3 = nn.Sequential( # 输入(32,7,7)nn.Conv2d(32,64,5,1,2), # 输出(64,7,7)nn.ReLU())self.out = nn.Linear(64*7*7,10)def forward(self,x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x) # 输出(64,7,7)x = x.view(x.size(0),-1) # flatten 操作,结果为:(batch_size,64*7*7)output = self.out(x)return outputmodel = CNN().to(device)
三、训练数据
- optimizer优化器:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
- loss_fn损失函数:
在PyTorch中,**nn.CrossEntropyLoss()**是一个常用的损失函数,它结合了 nn.LogSoftmax() 和 nn.NLLLoss()(负对数似然损失)在一个单独的类中。
loss_fn = nn.CrossEntropyLoss()
- 训练集
from torch import nn #导入神经网络模块
def train(dataloader,model,loss_fn,optimizer):model.train()# 设置模型为训练模式batch_size_num =1# 迭代次数 for x,y in dataloader:x,y = x.to(device),y.to(device) # 将数据和标签发送到指定设备 pred = model.forward(x) # 前向传播 loss = loss_fn(pred,y) # 计算损失 optimizer.zero_grad() # 清除之前的梯度 loss.backward() # 反向传播 optimizer.step() # 更新模型参数 loss_value = loss.item() # 获取损失值if batch_size_num %200 == 0: # 每200次迭代打印一次损失 print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1
train(train_dataloder,model,loss_fn,optimizer)
------------------------
loss:0.158841 [number:200]
loss:0.242431 [number:400]
loss:0.173504 [number:600]
loss:0.020542 [number:800]
- 测试集
"""-----测试集-----"""
def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")test(test_datalodar,model,loss_fn)
--------------------
Test result: Accuracy:98.11999999999999%,Avg loss:0.05511626677004996
四、优化模型
通过多次迭代,神经网络不断调整其内部参数(如权重和偏置),以最小化预测值与实际值之间的误差。这种优化过程使得神经网络能够更准确地处理输入数据,提高分类、回归等任务的性能。
epochs = 5
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloder,model,loss_fn,optimizer)
print("Done!")
test(test_datalodar,model,loss_fn)
输出结果:
Epoch 1
-------------------------
loss:0.182339 [number:200]
loss:0.229839 [number:400]
loss:0.210450 [number:600]
loss:0.028532 [number:800]
Epoch 2
-------------------------
loss:0.066216 [number:200]
loss:0.149762 [number:400]
loss:0.084482 [number:600]
loss:0.003749 [number:800]
…………
Done!
Test result: Accuracy:98.99%,Avg loss:0.03138259953491878
总结
本篇介绍了如何搭建卷积神经网络,其主要的构造部分为卷积层、激活层以及池化层,可以搭建多层该部分对数据进行多次卷积、池化。
注意:同普通的神经网络不同,卷积神经网络在传入图片时不需要将其展开,因为对图片进行卷积就是在原图上进行内积,不能展开。
相关文章:
【深度学习】(5)--搭建卷积神经网络
文章目录 搭建卷积神经网络一、数据预处理1. 下载数据集2. 创建DataLoader(数据加载器) 二、搭建神经网络三、训练数据四、优化模型 总结 搭建卷积神经网络 一、数据预处理 1. 下载数据集 在PyTorch中,有许多封装了很多与图像相关的模型、…...
边学英语边学 Java|Synchronization in java
Why use Java Synchronization? Java Synchronization is used to make sure by some synchronization method that only one thread can access the resource at a given point in time. Java 同步用于确保通过某种同步方法,在给定的时间点只有一个线程可以访问资…...
k8s StorageClass 存储类
文章目录 一、概述1、StorageClass 对象定义2、StorageClass YAML 示例 二、StorageClass 字段1、provisioner(存储制备器)1.1、内置制备器1.2、第三方制备器 2、reclaimPolicy(回收策略)3、allowVolumeExpansion(允许…...
3D Slicer医学图像全自动AI分割组合拳-MONAIAuto3DSeg扩展
3D Slicer医学图像全自动AI分割组合拳-MONAIAuto3DSeg扩展 1 官网下载最新3D Slicer image computing platform | 3D Slicer 版本5.7 2 安装torch依赖包: 2.1 进入安装目录C:\Users\wangzhenlin\AppData\Local\slicer.org\Slicer 5.7.0-2024-09-21\bin࿰…...
分布式光伏的发电监控
国拥有丰富的清洁可再生能源资源储量,积极开发利用可再生能源,为解决当前化石能源短缺与环境污染严重的燃眉之急提供了有效途径[1]。但是可再生能源的利用和开发,可再生能源技术的发展和推广以及可再生能源资源对环境保护的正向影响ÿ…...
微信小程序----日期时间选择器(自定义时间精确到分秒)
目录 页面效果 代码实现 注意事项 页面效果 代码实现 js Component({/*** 组件的属性列表*/properties: {pickerShow: {type: Boolean,},config: Object,},/*** 组件的初始数据*/data: {pickerReady: false,// pickerShow:true// limitStartTime: new Date().getTime()-…...
3D生成技术再创新高:VAST发布Tripo 2.0,提升AI 3D生成新高度
随着《黑神话悟空》的爆火,3D游戏背后的AI 3D生成技术也逐渐受到更多的关注。虽然3D大模型的热度相较于语言模型和视频生成技术稍逊一筹,但全球的3D大模型玩家们却从未放慢脚步。无论是a16z支持的Yellow,还是李飞飞创立的World Labsÿ…...
ONNX Runtime学习之InferenceSession模块
ONNXRuntime库学习之InferenceSession(模块) 一、简介 onnxruntime.InferenceSession 是 ONNX Runtime 中用于加载和运行 ONNX 模型的核心模块。它提供了一种灵活的方式来在多种硬件设备(如 CPU、GPU)上执行 ONNX 模型推理。通过 InferenceSession&…...
【TS】TypeScript内置条件类型-ReturnType
ReturnType 在TypeScript中,ReturnType 是一个内置的条件类型(Conditional Type),它用于获取一个函数返回值的类型。这个工具类型非常有用,特别是当你需要引用某个函数的返回类型,但又不想直接写出那个具体…...
【c语言数据结构】超详细!模拟实现双向链表(初始化、销毁、头删、尾删、头插、尾插、指定位置插入与删除、查找数据、判断链表是否为空)
特点: 结构:指向前一结点指针数据指向后一结点指针由于循环,尾结点的下一结点next指向头结点(哨兵结点)空的双向链表只有自循环的哨兵结点(头结点) 模拟实现双向链表 LIST.h #define _CRT_…...
第十四届蓝桥杯嵌入式国赛
一. 前言 本篇博客主要讲述十四届蓝桥杯嵌入式的国赛题目,包括STM32CubeMx的相关配置以及相关功能实现代码以及我在做题过程中所遇到的一些问题和总结收获。如果有兴趣的伙伴还可以去做做其它届的真题,可去 蓝桥云课 上搜索历届真题即可。 二. 题目概述 …...
(k8s)kubernetes集群基于Containerd部署
资源列表 基础环境 一、基础环境准备 1.1、关闭Swap分区 1.2、添加hosts解析 1.3、桥接的IPv4流量传递给iptables的链 二、准备Containerd容器运行时 2.1、安装Containerd 2.2、配置Containerd 2.3、启动Containerd 三、部署Kubernetes集群 3.1、安装Kubeadm工具 3.2、…...
python内置模块pathlib.Path类操作目录和文件
python自带的pathlib模块提供了很多路径相关的功能,而pathlib.Path 是pathlib 模块中的一个核心类,它代表了文件系统中的一个路径,实现功能比如创建、删除、移动文件,读取和写入文件内容,遍历目录等。 Path 类跟os.pa…...
react开发环境搭建
文章目录 准备工作创建 React 项目使用 create-react-app 创建 React 项目使用 Vite 创建 React 项目启动项目效果安装出现的情况 react项目文件讲解1. 项目根目录2. 其他可能的目录和文件3. 配置文件 准备工作 Node.js 安装方法: 方式一:使用 NVM 安装…...
python 逻辑语句简记
什么语言都少不了逻辑处理语句的使用,python的逻辑处理语句有自身的使用特点,稍稍总结记录一下 一、断言 assert 条件 条件触发,程序执行中断 二、条件语句 if 条件: 执行内容 三、循环语句 while 条件: 循环体…...
8.进销存系统(基于springboot的进销存系统)
目录 1.系统的受众说明 2.开发技术与环境配置 2.1 SpringBoot框架 2.2 Java语言简介 2.3 MySQL环境配置 2.4 idea介绍 2.5 mysql数据库介绍 2.6 B/S架构 3.系统分析与设计 3.1 可行性分析 3.1.1 技术可行性 3.1.2 操作可行性 3.1.3经济可行性 3.4.1 数据库…...
深入理解主键回显:提升数据操作效率与准确性
在软件开发的世界中,主键回显是一个常常被提及但又容易被忽视其重要性的概念。今天,我们就来深入探讨一下主键回显的奥秘。 一、什么是主键回显? 在数据库设计中,主键是用于唯一标识表中每一行记录的字段。而主键回显࿰…...
springboot+阿里云物联网教程
需求背景 最近有一个项目,需要用到阿里云物联网,不是MQ。发现使用原来EMQX的代码去连接阿里云MQTT直接报错,试了很多种方案都不行。最终还是把错误分析和教程都整理一下。 需要注意的是,阿里云物联网平台和MQ不一样。方向别走偏了。 概念描述 EMQX和阿里云MQTT有什么区别…...
QT Creator cmake 自定义项目结构, 编译输出目录指定
1. 目的 将不同的源文件放到不同的目录下进行管理, 如下: build: 编译输出目录 include: 头文件目录 rsources: 资源文件目录 src: cpp文件目录 2. 创建完cmake工程后修改CMakeLists.txt 配置 注 : 这里头文件目录是include, 所以在includ…...
lunar无第三方依赖的公历、农历、法定节假日...日历工具库
文章目录 介绍maven示例示例(前后端)网址文档 介绍 lunar是一款无第三方依赖的公历(阳历)、农历(阴历、老黄历)、道历、佛历工具,支持星座、儒略日、干支、生肖、节气、节日、彭祖百忌、吉神(喜神/福神/财神/阳贵神/阴贵神)方位、胎神方位、…...
内存分配函数malloc kmalloc vmalloc
内存分配函数malloc kmalloc vmalloc malloc实现步骤: 1)请求大小调整:首先,malloc 需要调整用户请求的大小,以适应内部数据结构(例如,可能需要存储额外的元数据)。通常,这包括对齐调整,确保分配的内存地址满足特定硬件要求(如对齐到8字节或16字节边界)。 2)空闲…...
golang循环变量捕获问题
在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - 循环变量捕获问题。让我详细解释一下: 问题背景 看这个代码片段: fo…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...
大语言模型如何处理长文本?常用文本分割技术详解
为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...
Psychopy音频的使用
Psychopy音频的使用 本文主要解决以下问题: 指定音频引擎与设备;播放音频文件 本文所使用的环境: Python3.10 numpy2.2.6 psychopy2025.1.1 psychtoolbox3.0.19.14 一、音频配置 Psychopy文档链接为Sound - for audio playback — Psy…...
dify打造数据可视化图表
一、概述 在日常工作和学习中,我们经常需要和数据打交道。无论是分析报告、项目展示,还是简单的数据洞察,一个清晰直观的图表,往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server,由蚂蚁集团 AntV 团队…...
安卓基础(aar)
重新设置java21的环境,临时设置 $env:JAVA_HOME "D:\Android Studio\jbr" 查看当前环境变量 JAVA_HOME 的值 echo $env:JAVA_HOME 构建ARR文件 ./gradlew :private-lib:assembleRelease 目录是这样的: MyApp/ ├── app/ …...
保姆级教程:在无网络无显卡的Windows电脑的vscode本地部署deepseek
文章目录 1 前言2 部署流程2.1 准备工作2.2 Ollama2.2.1 使用有网络的电脑下载Ollama2.2.2 安装Ollama(有网络的电脑)2.2.3 安装Ollama(无网络的电脑)2.2.4 安装验证2.2.5 修改大模型安装位置2.2.6 下载Deepseek模型 2.3 将deepse…...
使用Spring AI和MCP协议构建图片搜索服务
目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式(本地调用) SSE模式(远程调用) 4. 注册工具提…...
django blank 与 null的区别
1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是,要注意以下几点: Django的表单验证与null无关:null参数控制的是数据库层面字段是否可以为NULL,而blank参数控制的是Django表单验证时字…...
