《深度学习》——ResNet网络
文章目录
- ResNet网络
- ResNet网络实例
- 导入所需库
- 下载训练数据和测试数据
- 设置每个批次的样本个数
- 判断是否使用GPU
- 定义残差模块
- 定义ResNet网络
- 模型导入GPU
- 定义训练函数
- 定义测试函数
- 创建损失函数和优化器
- 训练测试数据
- 结果
ResNet网络
ResNet(Residual Network,残差网络)是深度学习领域中非常重要且具有影响力的一种卷积神经网络(CNN)架构,由何恺明等人于 2015 年提出,在图像识别、目标检测等诸多计算机视觉任务中取得了巨大成功。
1. 产生背景:在深度学习发展过程中,随着网络深度的增加,会出现梯度消失或梯度爆炸的问题,导 致网络难以训练。即使通过归一化等方法解决了梯度问题,还会面临退化问题,即网络深度增加时,模型的训 练误差和测试误差反而增大。ResNet 的提出就是为了解决深度神经网络中的退化问题。


- ResNet-18:是 ResNet 家族中相对较浅的网络,由 4 个残差块组构成,每个残差块组包含不同数量的残差块。它的结构简单,计算量相对较小,适合计算资源有限或对模型复杂度要求不高的场景,如一些小型图像数据集的分类任务。它在一些对实时性要求较高的应用中,如移动设备上的图像识别,也有一定的应用。
- ResNet-34:同样由 4 个残差块组组成,但相比 ResNet-18,它在某些残差块组中包含更多的残差块,网络深度更深,因此能够学习到更复杂的特征表示。它在中等规模的图像数据集上表现良好,在一些对模型性能有一定要求但又不过分追求极致精度的任务中较为常用。
- ResNet-50:是一个比较常用的 ResNet 模型,在许多计算机视觉任务中都有广泛应用。它使用了瓶颈结构(Bottleneck)的残差块,这种结构通过先降维、再卷积、最后升维的方式,在减少计算量的同时保持了模型的表达能力。该模型在图像分类、目标检测、语义分割等任务中,都能作为性能不错的骨干网络,为后续的任务提供有效的特征提取。
- ResNet-101:比 ResNet-50 的网络层数更多,拥有更强大的特征提取能力。它适用于大规模图像数据集和复杂的计算机视觉任务,如在大型目标检测数据集中,能够更好地捕捉目标的细节特征,提升检测的准确性。由于其深度和复杂度,在处理高分辨率图像或需要精细特征表示的任务时表现出色。
- ResNet-152:是 ResNet 系列中深度较深的网络,具有极高的特征提取能力。但由于其深度很大,计算量和参数量也相应增加,训练和推理所需的时间和资源较多。它通常用于对精度要求极高的场景,如学术研究中的图像识别挑战、大规模图像搜索引擎的图像特征提取等。
18层残差网络:

ResNet网络实例
项目需求:对手写数字进行识别。
数据集:此项目数据集来自MNIST 数据集由美国国家标准与技术研究所(NIST)整理而成,包含手写数字的图像,主要用于数字识别的训练和测试。该数据集被分为两部分:训练集和测试集。训练集包含 60,000 张图像,用于模型的学习和训练;测试集包含 10,000 张图像,用于评估训练好的模型在未见过的数据上的性能。
图像格式:数据集中的图像是灰度图像,即每个像素只有一个值表示其亮度,取值范围通常为 0(黑色)到 255(白色)。
图像尺寸:每张图像的尺寸为 28x28 像素,总共有 784 个像素点。
标签信息:每个图像都有一个对应的标签,标签是 0 到 9 之间的整数,表示图像中手写数字的值。
导入所需库
import torch
from torch import nn # 导入神经网络模块
from torch.utils.data import DataLoader # 数据包管理工具,打包数据
from torchvision import datasets # 封装了很对与图像相关的模型,数据集
from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型的数据转换成tensor张量
import torch.nn.functional as F # 用于应用 ReLU 激活函数
下载训练数据和测试数据
'''下载训练数据集(包含训练集图片+标签)'''
training_data = datasets.MNIST( # 跳转到函数的内部源代码,pycharm 按下ctrl+鼠标点击root='data', # 表示下载的手写数字 到哪个路径。60000train=True, # 读取下载后的数据中的数据集download=True, # 如果你之前已经下载过了,就不用再下载了transform=ToTensor(), # 张量,图片是不能直接传入神经网络模型# 对于pytorch库能够识别的数据一般是tensor张量
)'''下载测试数据集(包含训练图片+标签)'''
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(), # Tensor是在深度学习中提出并广泛应用的数据类型,它与深度学习框架(如pytorch,TensorFlow)
) # numpy数组只能在cpu上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。
print(len(training_data))
print(len(test_data))
设置每个批次的样本个数
train_dataloader = DataLoader(training_data, batch_size=64) # 建议用2的指数当作一个包的数量
test_dataloader = DataLoader(test_data, batch_size=64)
判断是否使用GPU
'''判断是否支持GPU'''
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')
定义残差模块
# 定义残差块类,继承自 nn.Module
class ResBlock(nn.Module):def __init__(self, channels_in):# 调用父类的构造函数super().__init__()# 定义第一个卷积层,输入通道数为 channels_in,输出通道数为 30,卷积核大小为 5,填充为 2self.conv1 = torch.nn.Conv2d(channels_in, 30, 5, padding=2)# 定义第二个卷积层,输入通道数为 30,输出通道数为 channels_in,卷积核大小为 3,填充为 1self.conv2 = torch.nn.Conv2d(30, channels_in, 3, padding=1)def forward(self, x):# 输入数据通过第一个卷积层out = self.conv1(x)# 经过第一个卷积层的输出再通过第二个卷积层out = self.conv2(out)# 将输入 x 与卷积输出 out 相加,并通过 ReLU 激活函数return F.relu(out + x)
定义ResNet网络
# 定义 ResNet 网络类,继承自 nn.Module
class ResNet(nn.Module):def __init__(self):# 调用父类的构造函数super().__init__()# 定义第一个卷积层,输入通道数为 1,输出通道数为 20,卷积核大小为 5self.conv1 = torch.nn.Conv2d(1, 20, 5)# 定义第二个卷积层,输入通道数为 20,输出通道数为 15,卷积核大小为 3self.conv2 = torch.nn.Conv2d(20, 15, 3)# 定义最大池化层,池化核大小为 2self.maxpool = torch.nn.MaxPool2d(2)# 定义第一个残差块,输入通道数为 20self.resblock1 = ResBlock(channels_in=20)# 定义第二个残差块,输入通道数为 15self.resblock2 = ResBlock(channels_in=15)# 定义全连接层,输入特征数为 375,输出特征数为 10self.full_c = torch.nn.Linear(375, 10)def forward(self, x):# 获取输入数据的批次大小size = x.shape[0]# 输入数据通过第一个卷积层,然后进行最大池化,最后通过 ReLU 激活函数x = F.relu(self.maxpool(self.conv1(x)))# 经过第一个卷积和池化的输出通过第一个残差块x = self.resblock1(x)# 经过第一个残差块的输出通过第二个卷积层,然后进行最大池化,最后通过 ReLU 激活函数x = F.relu(self.maxpool(self.conv2(x)))# 经过第二个卷积和池化的输出通过第二个残差块x = self.resblock2(x)# 将输出数据展平为一维向量x = x.view(size, -1)# 展平后的向量通过全连接层x = self.full_c(x)return x
模型导入GPU
model = ResNet().to(device)
定义训练函数
# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):# 将模型设置为训练模式,这会影响一些层(如 Dropout、BatchNorm 等)的行为model.train()# 初始化批次编号batch_size_num = 1# 遍历数据加载器中的每个批次for x, y in dataloader:# 将输入数据和标签移动到指定设备(如 GPU)x, y = x.to(device), y.to(device)# 前向传播,计算模型的预测结果pred = model.forward(x)# 通过交叉熵损失函数计算预测结果与真实标签之间的损失值loss = loss_fn(pred, y)# 反向传播步骤:# 清零优化器中的梯度信息,防止梯度累积optimizer.zero_grad()# 反向传播计算每个参数的梯度loss.backward()# 根据计算得到的梯度更新模型的参数optimizer.step()# 从张量中提取损失值的标量loss_value = loss.item()# 每 100 个批次打印一次损失值if batch_size_num % 100 == 0:print(f'loss:{loss_value:7f} [number:{batch_size_num}]')# 批次编号加 1batch_size_num += 1
定义测试函数
# 定义测试函数
def test(dataloader, model, loss_fn):# 获取数据集的总样本数size = len(dataloader.dataset)# 获取数据加载器中的批次数量num_batches = len(dataloader)# 将模型设置为评估模式,这会影响一些层(如 Dropout、BatchNorm 等)的行为model.eval()# 初始化测试损失和正确预测的样本数test_loss, correct = 0, 0# 上下文管理器,关闭梯度计算,减少内存消耗with torch.no_grad():# 遍历数据加载器中的每个批次for x, y in dataloader:# 将输入数据和标签移动到指定设备(如 GPU)x, y = x.to(device), y.to(device)# 前向传播,计算模型的预测结果pred = model.forward(x)# 累加每个批次的损失值test_loss += loss_fn(pred, y).item()# 计算每个批次中预测正确的样本数并累加correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均测试损失test_loss /= num_batches# 计算平均准确率correct /= size# 打印测试结果print(f'Test result: \n Accuracy:{(100 * correct)}%,Avg loss:{test_loss}')
创建损失函数和优化器
# 创建交叉熵损失函数对象
loss_fn = nn.CrossEntropyLoss()
# 创建 Adam 优化器,用于更新模型的参数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.1)
训练测试数据
# 定义训练的轮数
epochs = 26
# 开始训练循环
for t in range(epochs):print(f'epoch{t + 1}\n--------------------')# 调用训练函数进行一轮训练train(train_dataloader, model, loss_fn, optimizer)
print('Done!')
# 调用测试函数进行测试
test(test_dataloader, model, loss_fn)
结果

相关文章:
《深度学习》——ResNet网络
文章目录 ResNet网络ResNet网络实例导入所需库下载训练数据和测试数据设置每个批次的样本个数判断是否使用GPU定义残差模块定义ResNet网络模型导入GPU定义训练函数定义测试函数创建损失函数和优化器训练测试数据结果 ResNet网络 ResNet(Residual Network࿰…...
【Windows软件 - HeidiSQL】导出数据库
HeidSQL导出数据库 软件信息 具体操作 示例文件 选项分析 选项(1) 结果(1) -- -------------------------------------------------------- -- 主机: 127.0.0.1 -- 服务器版本: …...
FFmpeg 全面知识大纲梳理
1. FFmpeg 简介 FFmpeg 是什么: 一个开源的多媒体处理框架,用于处理音频、视频和流媒体。支持多种格式和编解码器。提供命令行工具和库(如 libavcodec, libavformat, libavfilter 等)。主要功能: 格式转换编解码流媒体处理音视频剪辑、合并、分离添加滤镜、特效压缩与优化…...
【达梦数据库】dblink连接[SqlServer/Mysql]报错处理
目录 背景问题1:无法测试以ODBC数据源方式访问的外部链接!问题分析&原因解决方法 问题2:DBLINK连接丢失问题分析&原因解决方法 问题3:DBIINK远程服务器获取对象[xxx]失败,错误洋情[[FreeTDS][SQL Server]Could not find stored proce…...
基于 Spring Boot 的社区居民健康管理系统部署说明书
目录 1 系统概述 2 准备资料 3 系统安装与部署 3.1 数据库部署 3.1.1 MySQL 的部署 3.1.2 Navicat 的部署 3.2 服务器部署 3.3 客户端部署 4 系统配置与优化 5 其他 基于 Spring Boot 的社区居民健康管理系统部署说明书 1 系统概述 本系统主要运用了 Spri…...
量化噪声介绍
量化噪声是在将模拟信号转换为数字信号的量化过程中产生的噪声。以下为你详细介绍: 1. 量化的基本概念 在模拟信号数字化过程中,采样是对模拟信号在时间上进行离散化,而量化则是对采样值在幅度上进行离散化。由于模拟信号的取值是连续的&am…...
java断点调试(debug)
在开发中,新手程序员在查找错误时, 这时老程序员就会温馨提示,可以用断点调试,一步一步的看源码执行的过程,从而发现错误所在。 重要提示: 断点调试过程是运行状态,是以对象的运行类型来执行的 断点调试介绍 断点调试是…...
最新智能优化算法:牛优化( Ox Optimizer,OX)算法求解经典23个函数测试集,MATLAB代码
一、牛优化算法 牛优化( OX Optimizer,OX)算法由 AhmadK.AlHwaitat 与 andHussamN.Fakhouri于2024年提出,该算法的设计灵感来源于公牛的行为特性。公牛以其巨大的力量而闻名,能够承载沉重的负担并进行远距离运输。这种…...
Redis7——基础篇(四)
前言:此篇文章系本人学习过程中记录下来的笔记,里面难免会有不少欠缺的地方,诚心期待大家多多给予指教。 基础篇: Redis(一)Redis(二)Redis(三) 接上期内容&…...
Git备忘录(三)
设置用户信息: git config --global user.name “itcast” git config --global user.email “ helloitcast.cn” 查看配置信息 git config --global user.name git config --global user.email $ git init $ git remote add origin gitgitee.com:XXX/avas.git $ git pull or…...
MySQL 之INDEX 索引(Index Index of MySQL)
MySQL 之INDEX 索引 1.4 INDEX 索引 1.4.1 索引介绍 索引:是排序的快速查找的特殊数据结构,定义作为查找条件的字段上,又称为键 key,索引通过存储引擎实现。 优点 大大加快数据的检索速度; 创建唯一性索引,保证数…...
Linux基础24-C语言之分支结构Ⅰ【入门级】
分支结构 问题抛出 我们在程序设计中往往会遇到如下问题,比如下面的函数计算: 也就是我们必须要通过一个条件的结果来选择下一步的操作,算法上属于一个分支结构,处于严重实现分支结构主要使用if语句。 条件判断 根据某个条件成…...
LeetCode47
LeetCode47 目录 题目描述示例思路分析代码段代码逐行讲解复杂度分析总结的知识点整合总结 题目描述 给定一个可包含重复数字的整数数组 nums,按任意顺序返回所有不重复的全排列。 示例 示例 1 输入: nums [1, 1, 2]输出: [[1, 1, 2],[1, 2, 1],[2, 1, 1] ]…...
C++中std::condition_variable_any、std::lock_guard 和 std::unique_
1、背景 在 C 多线程编程中,同步 和 互斥 是至关重要的概念。C 标准库提供了多种同步机制,其中 std::condition_variable_any、std::lock_guard 和 std::unique_lock 是经常被用到的工具。本文将详细介绍这三者的用途、区别、适用场景,并通过…...
详解AbstractQueuedSynchronizer(AQS)源码
引言 上篇文章讲解了CountDownLatch源码,底层是继承了AQS基类调用父类和重写父类方法实现的,本文将简介AQS源码和架构设计,帮助我们更深入理解多线程实战。 源码架构 1. 状态变量 state AQS 使用一个 int 类型的变量 state 来表示同步状态…...
【Unity动画】导入动画资源到项目中,Animator播放角色动画片段,角色会跟随着动画播放移动。
导入动画资源到项目中,Animator播放角色动画片段,角色会跟随着动画播放移动,但我只想要角色在原地播放动画。比如:播放一个角色Run动画,希望角色在原地奔跑,而不是产生了移动距离。 问题排查: 1.是否勾选…...
图解循环神经网络(RNN)
目录 1.循环神经网络介绍 2.网络结构 3.结构分类 4.模型工作原理 5.模型工作示例 6.总结 1.循环神经网络介绍 RNN(Recurrent Neural Network,循环神经网络)是一种专门用于处理序列数据的神经网络结构。与传统的神经网络不同,…...
【数据结构】(9) 优先级队列(堆)
一、优先级队列 优先级队列不同于队列,队列是先进先出,优先级队列是优先级最高的先出。一般有两种操作:返回最高优先级对象,添加一个新对象。 二、堆 2.1、什么是堆 堆也是一种数据结构,是一棵完全二叉树,…...
4、IP查找工具-Angry IP Scanner
在前序文章中,提到了多种IP查找方法,可能回存在不同场景需要使用不同的查找命令,有些不容易记忆,本文将介绍一个比较优秀的IP查找工具,可以应用在连接树莓派或查找IP的其他场景中。供大家参考。 Angry IP Scanner下载…...
【Linux】命令操作、打jar包、项目部署
阿华代码,不是逆风,就是我疯 你们的点赞收藏是我前进最大的动力!! 希望本文内容能够帮助到你!! 目录 一:Xshell下载 1:镜像设置 二:阿里云设置镜像Ubuntu 三…...
瑞萨RA-T系列芯片ADCGPT功能模块的配合使用
在马达或电源工程中,往往需要采集多路AD信号,且这些信号的优先级和采样时机不相同。本篇介绍在使用RA-T系列芯片建立马达或电源工程时,如何根据需求来设置主要功能模块ADC&GPT,包括采样通道打包和分组,GPT触发启动…...
python爬虫系列课程1:初识爬虫
python爬虫系列课程1:初识爬虫 一、爬虫的概念二、通用爬虫和自定义爬虫的区别三、开发语言四、爬虫流程一、爬虫的概念 网络爬虫(又被称为网页蜘蛛、网络机器人)就是模拟浏览器发送网络请求,接收请求响应,一种按照一定的规则,自动抓取互联网信息的程序。原则上,只要是…...
【笔记】Huggingface Transformers 库加载预训练模型的 4 种方式
Transformers 库加载预训练模型的 4 种方式 Hugging Face Transformers 库提供了 4 种核心代码范式用于加载预训练大语言模型(LLM),具体分类如下: 通用模型加载(无任务头) 使用 AutoModel 加载基础架构&a…...
Unity Shader学习6:多盏平行光+点光源 ( 逐像素 ) 前向渲染 (Built-In)
0 、分析 在前向渲染中,对于逐像素光源来说,①ForwardBase中只计算一个平行光,其他的光都是在FowardAdd中计算的,所以为了能够渲染出其他的光照,需要在第二个Pass中再来一遍光照计算。 而有所区别的操作是࿰…...
tailwindcss学习01
系列教程 01 入门 02 vue中接入 入门 # 注意使用cmd不要powershell npm init -y # 如果没有npx则安装 npm install -g npx npm install -D tailwindcss3.4.17 --registry http://registry.npm.taobao.org npx tailwindcss init修改tailwind.config.js /** type {import(tai…...
DIN:引入注意力机制的深度学习推荐系统,
实验和完整代码 完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main 引言 在电商与广告推荐场景中,用户兴趣的多样性和动态变化是核心挑战。传统推荐模型(如Embedding &…...
【前端】如何安装配置WebStorm软件?
文章目录 前言一、前端开发工具WebStorm和VS Code对比二、官网下载三、安装1、开始安装2、选择安装路径3、安装选项4、选择开始菜单文件夹5、安装成功 四、启动WebStorm五、登录授权六、开始使用 前言 WebStorm 是一款由 JetBrains 公司开发的专业集成开发环境(IDE…...
【Golang学习之旅】Go 语言微服务架构实践(gRPC、Kafka、Docker、K8s)
文章目录 1. 前言:为什么选择Go语言构建微服务架构1.1 微服务架构的兴趣与挑战1.2 为什么选择Go语言构建微服务架构 2. Go语言简介2.1 Go 语言的特点与应用2.2 Go 语言的生态系统 3. 微服务架构中的 gRPC 实践3.1 什么是 gRPC?3.2 gRPC 在 Go 语言中的实…...
Spring核心思想之—AOP(面向切面编程)
目录 一 .AOP概述 二. Spring AOP 使用 2.1 引入AOP依赖 2.2 编写AOP程序 三. Spring AOP详情 3.1 切点(Pointcut) 3.2 连接点(Join Point) 3.3通知(Advice) 3.4切面(Aspect) 3.5通知 3.6 PointCut (公共切点)…...
使用 Openpyxl 操作 Excel 文件详解
文章目录 安装安装Python3安装 openpyxl 基础操作1. 引入2. 创建工作簿和工作表3. 写入数据4. 保存工作簿5. 加载已存在的Excel6. 读取单元格的值7. 选择工作表 样式和格式化1. 引入2. 设置字体3. 设置边框4. 填充5. 设置数字格式6. 数据验证7. 公式操作 性能优化1. read_only/…...
