【深度学习】(5)--搭建卷积神经网络
文章目录
- 搭建卷积神经网络
- 一、数据预处理
- 1. 下载数据集
- 2. 创建DataLoader(数据加载器)
- 二、搭建神经网络
- 三、训练数据
- 四、优化模型
- 总结
搭建卷积神经网络
一、数据预处理
1. 下载数据集
在PyTorch中,有许多封装了很多与图像相关的模型、数据集,那么如何获取数据集呢?
导入datasets模块:
from torchvision import datasets #封装了很多与图像相关的模型,数据集
以datasets模块中的MNIST数据集为例,包含70000张手写数字图像:60000张用于训练,10000张用于测试。图像是灰度的,28*28像素,并且居中的,以减少预处理和加快运行。
from torch.utils.data import DataLoader #数据包管理工具,打包数据
from torchvision import datasets #封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型数据转换为tensor张量
"""-----下载训练集数据集-----"""
training_data = datasets.MNIST(root="data",train=True,# 取训练集download=True,transform=ToTensor(),# 张量,图片是不能直接传入神经网络模型的
) # 对于pytorch库能够识别的数据,一般是tensor张量"""-----下载测试集数据集-----"""
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# numpy数组只能在CPU上运行,Tensor可以在GPU上运行,这在深度学习中可以显著提高计算速度
下载完成之后可在project栏查看。
2. 创建DataLoader(数据加载器)
在PyTorch中,创建DataLoader的主要作用是将数据集(Dataset)加载到模型中,以便进行训练或推理。DataLoader通过封装数据集,提供了一个高效、灵活的方式来处理数据。
DataLoader通过batch_size参数将数据集自动划分为多个小批次(batch),每一批次的放入模型训练,减少内存的使用,提高训练速度。
import torch
from torch.utils.data import DataLoader
"""
创建数据DataLoader(数据加载器)
batch_size:将数据集分成多份,每一份为batch_size(指定数值)个数据。
优点:减少内存的使用,提高训练速度
"""
train_dataloder = DataLoader(training_data,batch_size=64)# 64张图片为一个包
test_datalodar = DataLoader(test_data,batch_size=64)
# 查看打包好的数据
for x,y in test_datalodar: #x是表示打包好的每一个数据包print(f"Shape of x [N, C, H, W]:{x.shape}")print(f"Shape of y:{y.shape} {y.dtype}")break
-----------------------
Shape of x [N, C, H, W]:torch.Size([64, 1, 28, 28])
Shape of y:torch.Size([64]) torch.int64
二、搭建神经网络
注意:同普通的神经网络不同,卷积神经网络在传入图片时不需要将其展开,因为对图片进行卷积就是在原图上进行内积,不能展开。
卷积神经网络是由输入层、卷积层、激活层、池化层、全连接层、输出层组成。所以在结构上我们也同这样式的来,但是可以搭建多层卷积哦!
"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu""""-----定义神经网络-----"""
class CNN(nn.Module):def __init__(self): # 输入大小(1,28,28)super(CNN,self).__init__()self.conv1 = nn.Sequential( # 将多个层组合在一起nn.Conv2d( # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=1, # 图像通道个数,1表示灰度图(确定卷积核 组中的个数)out_channels=16, # 要得到多少特征图,卷积核的个数kernel_size=5, # 卷积核大小stride=1, # 步长padding=2 # 边界填充大小), # 输出的特征图为(16,28,28)-->16个大小28*28的图像nn.ReLU(), # relu层,不会改变特征图的大小nn.MaxPool2d(kernel_size=2) # 进行池化操作(2*2区域),输出结果为(16,14,14))self.conv2 = nn.Sequential( # 输入(16,14,14)nn.Conv2d(16,32,5,1,2), # 输出(32*14*14)nn.ReLU(),nn.Conv2d(32,32,5,1,2), # 输出(32*14*14)nn.ReLU(),nn.MaxPool2d(2) # 输出(32,7,7))self.conv3 = nn.Sequential( # 输入(32,7,7)nn.Conv2d(32,64,5,1,2), # 输出(64,7,7)nn.ReLU())self.out = nn.Linear(64*7*7,10)def forward(self,x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x) # 输出(64,7,7)x = x.view(x.size(0),-1) # flatten 操作,结果为:(batch_size,64*7*7)output = self.out(x)return outputmodel = CNN().to(device)
三、训练数据
- optimizer优化器:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
- loss_fn损失函数:
在PyTorch中,**nn.CrossEntropyLoss()**是一个常用的损失函数,它结合了 nn.LogSoftmax() 和 nn.NLLLoss()(负对数似然损失)在一个单独的类中。
loss_fn = nn.CrossEntropyLoss()
- 训练集
from torch import nn #导入神经网络模块
def train(dataloader,model,loss_fn,optimizer):model.train()# 设置模型为训练模式batch_size_num =1# 迭代次数 for x,y in dataloader:x,y = x.to(device),y.to(device) # 将数据和标签发送到指定设备 pred = model.forward(x) # 前向传播 loss = loss_fn(pred,y) # 计算损失 optimizer.zero_grad() # 清除之前的梯度 loss.backward() # 反向传播 optimizer.step() # 更新模型参数 loss_value = loss.item() # 获取损失值if batch_size_num %200 == 0: # 每200次迭代打印一次损失 print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1
train(train_dataloder,model,loss_fn,optimizer)
------------------------
loss:0.158841 [number:200]
loss:0.242431 [number:400]
loss:0.173504 [number:600]
loss:0.020542 [number:800]
- 测试集
"""-----测试集-----"""
def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")test(test_datalodar,model,loss_fn)
--------------------
Test result: Accuracy:98.11999999999999%,Avg loss:0.05511626677004996
四、优化模型
通过多次迭代,神经网络不断调整其内部参数(如权重和偏置),以最小化预测值与实际值之间的误差。这种优化过程使得神经网络能够更准确地处理输入数据,提高分类、回归等任务的性能。
epochs = 5
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloder,model,loss_fn,optimizer)
print("Done!")
test(test_datalodar,model,loss_fn)
输出结果:
Epoch 1
-------------------------
loss:0.182339 [number:200]
loss:0.229839 [number:400]
loss:0.210450 [number:600]
loss:0.028532 [number:800]
Epoch 2
-------------------------
loss:0.066216 [number:200]
loss:0.149762 [number:400]
loss:0.084482 [number:600]
loss:0.003749 [number:800]
…………
Done!
Test result: Accuracy:98.99%,Avg loss:0.03138259953491878
总结
本篇介绍了如何搭建卷积神经网络,其主要的构造部分为卷积层、激活层以及池化层,可以搭建多层该部分对数据进行多次卷积、池化。
注意:同普通的神经网络不同,卷积神经网络在传入图片时不需要将其展开,因为对图片进行卷积就是在原图上进行内积,不能展开。
相关文章:

【深度学习】(5)--搭建卷积神经网络
文章目录 搭建卷积神经网络一、数据预处理1. 下载数据集2. 创建DataLoader(数据加载器) 二、搭建神经网络三、训练数据四、优化模型 总结 搭建卷积神经网络 一、数据预处理 1. 下载数据集 在PyTorch中,有许多封装了很多与图像相关的模型、…...
边学英语边学 Java|Synchronization in java
Why use Java Synchronization? Java Synchronization is used to make sure by some synchronization method that only one thread can access the resource at a given point in time. Java 同步用于确保通过某种同步方法,在给定的时间点只有一个线程可以访问资…...

k8s StorageClass 存储类
文章目录 一、概述1、StorageClass 对象定义2、StorageClass YAML 示例 二、StorageClass 字段1、provisioner(存储制备器)1.1、内置制备器1.2、第三方制备器 2、reclaimPolicy(回收策略)3、allowVolumeExpansion(允许…...

3D Slicer医学图像全自动AI分割组合拳-MONAIAuto3DSeg扩展
3D Slicer医学图像全自动AI分割组合拳-MONAIAuto3DSeg扩展 1 官网下载最新3D Slicer image computing platform | 3D Slicer 版本5.7 2 安装torch依赖包: 2.1 进入安装目录C:\Users\wangzhenlin\AppData\Local\slicer.org\Slicer 5.7.0-2024-09-21\bin࿰…...

分布式光伏的发电监控
国拥有丰富的清洁可再生能源资源储量,积极开发利用可再生能源,为解决当前化石能源短缺与环境污染严重的燃眉之急提供了有效途径[1]。但是可再生能源的利用和开发,可再生能源技术的发展和推广以及可再生能源资源对环境保护的正向影响ÿ…...

微信小程序----日期时间选择器(自定义时间精确到分秒)
目录 页面效果 代码实现 注意事项 页面效果 代码实现 js Component({/*** 组件的属性列表*/properties: {pickerShow: {type: Boolean,},config: Object,},/*** 组件的初始数据*/data: {pickerReady: false,// pickerShow:true// limitStartTime: new Date().getTime()-…...

3D生成技术再创新高:VAST发布Tripo 2.0,提升AI 3D生成新高度
随着《黑神话悟空》的爆火,3D游戏背后的AI 3D生成技术也逐渐受到更多的关注。虽然3D大模型的热度相较于语言模型和视频生成技术稍逊一筹,但全球的3D大模型玩家们却从未放慢脚步。无论是a16z支持的Yellow,还是李飞飞创立的World Labsÿ…...
ONNX Runtime学习之InferenceSession模块
ONNXRuntime库学习之InferenceSession(模块) 一、简介 onnxruntime.InferenceSession 是 ONNX Runtime 中用于加载和运行 ONNX 模型的核心模块。它提供了一种灵活的方式来在多种硬件设备(如 CPU、GPU)上执行 ONNX 模型推理。通过 InferenceSession&…...
【TS】TypeScript内置条件类型-ReturnType
ReturnType 在TypeScript中,ReturnType 是一个内置的条件类型(Conditional Type),它用于获取一个函数返回值的类型。这个工具类型非常有用,特别是当你需要引用某个函数的返回类型,但又不想直接写出那个具体…...

【c语言数据结构】超详细!模拟实现双向链表(初始化、销毁、头删、尾删、头插、尾插、指定位置插入与删除、查找数据、判断链表是否为空)
特点: 结构:指向前一结点指针数据指向后一结点指针由于循环,尾结点的下一结点next指向头结点(哨兵结点)空的双向链表只有自循环的哨兵结点(头结点) 模拟实现双向链表 LIST.h #define _CRT_…...

第十四届蓝桥杯嵌入式国赛
一. 前言 本篇博客主要讲述十四届蓝桥杯嵌入式的国赛题目,包括STM32CubeMx的相关配置以及相关功能实现代码以及我在做题过程中所遇到的一些问题和总结收获。如果有兴趣的伙伴还可以去做做其它届的真题,可去 蓝桥云课 上搜索历届真题即可。 二. 题目概述 …...
(k8s)kubernetes集群基于Containerd部署
资源列表 基础环境 一、基础环境准备 1.1、关闭Swap分区 1.2、添加hosts解析 1.3、桥接的IPv4流量传递给iptables的链 二、准备Containerd容器运行时 2.1、安装Containerd 2.2、配置Containerd 2.3、启动Containerd 三、部署Kubernetes集群 3.1、安装Kubeadm工具 3.2、…...

python内置模块pathlib.Path类操作目录和文件
python自带的pathlib模块提供了很多路径相关的功能,而pathlib.Path 是pathlib 模块中的一个核心类,它代表了文件系统中的一个路径,实现功能比如创建、删除、移动文件,读取和写入文件内容,遍历目录等。 Path 类跟os.pa…...

react开发环境搭建
文章目录 准备工作创建 React 项目使用 create-react-app 创建 React 项目使用 Vite 创建 React 项目启动项目效果安装出现的情况 react项目文件讲解1. 项目根目录2. 其他可能的目录和文件3. 配置文件 准备工作 Node.js 安装方法: 方式一:使用 NVM 安装…...
python 逻辑语句简记
什么语言都少不了逻辑处理语句的使用,python的逻辑处理语句有自身的使用特点,稍稍总结记录一下 一、断言 assert 条件 条件触发,程序执行中断 二、条件语句 if 条件: 执行内容 三、循环语句 while 条件: 循环体…...

8.进销存系统(基于springboot的进销存系统)
目录 1.系统的受众说明 2.开发技术与环境配置 2.1 SpringBoot框架 2.2 Java语言简介 2.3 MySQL环境配置 2.4 idea介绍 2.5 mysql数据库介绍 2.6 B/S架构 3.系统分析与设计 3.1 可行性分析 3.1.1 技术可行性 3.1.2 操作可行性 3.1.3经济可行性 3.4.1 数据库…...
深入理解主键回显:提升数据操作效率与准确性
在软件开发的世界中,主键回显是一个常常被提及但又容易被忽视其重要性的概念。今天,我们就来深入探讨一下主键回显的奥秘。 一、什么是主键回显? 在数据库设计中,主键是用于唯一标识表中每一行记录的字段。而主键回显࿰…...

springboot+阿里云物联网教程
需求背景 最近有一个项目,需要用到阿里云物联网,不是MQ。发现使用原来EMQX的代码去连接阿里云MQTT直接报错,试了很多种方案都不行。最终还是把错误分析和教程都整理一下。 需要注意的是,阿里云物联网平台和MQ不一样。方向别走偏了。 概念描述 EMQX和阿里云MQTT有什么区别…...

QT Creator cmake 自定义项目结构, 编译输出目录指定
1. 目的 将不同的源文件放到不同的目录下进行管理, 如下: build: 编译输出目录 include: 头文件目录 rsources: 资源文件目录 src: cpp文件目录 2. 创建完cmake工程后修改CMakeLists.txt 配置 注 : 这里头文件目录是include, 所以在includ…...
lunar无第三方依赖的公历、农历、法定节假日...日历工具库
文章目录 介绍maven示例示例(前后端)网址文档 介绍 lunar是一款无第三方依赖的公历(阳历)、农历(阴历、老黄历)、道历、佛历工具,支持星座、儒略日、干支、生肖、节气、节日、彭祖百忌、吉神(喜神/福神/财神/阳贵神/阴贵神)方位、胎神方位、…...
conda相比python好处
Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理:…...
React hook之useRef
React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

MongoDB学习和应用(高效的非关系型数据库)
一丶 MongoDB简介 对于社交类软件的功能,我们需要对它的功能特点进行分析: 数据量会随着用户数增大而增大读多写少价值较低非好友看不到其动态信息地理位置的查询… 针对以上特点进行分析各大存储工具: mysql:关系型数据库&am…...

汽车生产虚拟实训中的技能提升与生产优化
在制造业蓬勃发展的大背景下,虚拟教学实训宛如一颗璀璨的新星,正发挥着不可或缺且日益凸显的关键作用,源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例,汽车生产线上各类…...
Axios请求超时重发机制
Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...
uniapp中使用aixos 报错
问题: 在uniapp中使用aixos,运行后报如下错误: AxiosError: There is no suitable adapter to dispatch the request since : - adapter xhr is not supported by the environment - adapter http is not available in the build 解决方案&…...
06 Deep learning神经网络编程基础 激活函数 --吴恩达
深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

OPENCV形态学基础之二腐蚀
一.腐蚀的原理 (图1) 数学表达式:dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一,腐蚀跟膨胀属于反向操作,膨胀是把图像图像变大,而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...

CVE-2020-17519源码分析与漏洞复现(Flink 任意文件读取)
漏洞概览 漏洞名称:Apache Flink REST API 任意文件读取漏洞CVE编号:CVE-2020-17519CVSS评分:7.5影响版本:Apache Flink 1.11.0、1.11.1、1.11.2修复版本:≥ 1.11.3 或 ≥ 1.12.0漏洞类型:路径遍历&#x…...
SpringAI实战:ChatModel智能对话全解
一、引言:Spring AI 与 Chat Model 的核心价值 🚀 在 Java 生态中集成大模型能力,Spring AI 提供了高效的解决方案 🤖。其中 Chat Model 作为核心交互组件,通过标准化接口简化了与大语言模型(LLM࿰…...