Pytorch 第九回:卷积神经网络——ResNet模型
Pytorch 第九回:卷积神经网络——ResNet模型
本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。
本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0+cu118,d2l的版本是1.0.3
文章目录
- Pytorch 第九回:卷积神经网络——ResNet模型
- 前言
- 1、残差块
- 2、ResNet模型
- 一、数据准备
- 二、模型准备
- 1、残差块定义
- 2、ResNet模型定义
- 模型训练
- 1、实例化ResNet模型
- 2、迭代训练模型
- 3、输出展示
- 总结
前言
讲述模型前,先讲述两个概念,统一下思路:
1、残差块
残差块是ResNet模型的基础架构,该架构允许输入特征跨层作用到块的输出端,从而加强了浅层特征的输出。其结构如下图所示:
如上图所示,这里的1*1卷积层,可以将输入特征直接作用到输出端,从而避免了浅层的梯度到达输出端时过小的问题。
2、ResNet模型
2015年,ResNet模型由微软研究院提出,并在ImageNet大规模视觉挑战赛中一举夺得了冠军。之前设计的神经网络,随着网络层的增多,并没有达到训练误差不断减少的预期,反而出现训练误差逐渐加大的现象,人们也称之为“网络退化”。ResNet模型通过加入了残差块的框架,使训练的深层神经网络更加有效。
闲言少叙,直接展示逻辑,先上引用:
import numpy as np
import torch
from torch import nn
from torchvision.datasets import CIFAR10
import time
from torch.utils.data import DataLoader
from d2l import torch as d2l
import torch.nn.functional as F
一、数据准备
如前几回一样,本次仍然采用CIFAR10数据集,因此不做重点解释(有兴趣的可以查看第六回内容),本回只展示代码:
def data_treating(x):x = x.resize((96, 96), 2) #x = np.array(x, dtype='float32') / 255x = (x - 0.5) / 0.5 #x = x.transpose((2, 0, 1)) #x = torch.from_numpy(x)return xtrain_set = CIFAR10('./data', train=True, transform=data_treating)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_treating)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
注:
在本回设计的模型中,数据最小为96 * 96。有兴趣的可以参考第八回小记自己进行计算。
二、模型准备
1、残差块定义
残差块的定义分为两部分,一部分是初始化函数,一部分是前向传播函数。需要注意的是add_conv1_1这个参数,用于控制是否有输入特征直接作用于输出端。代码如下所示:
class residual(nn.Module):def __init__(self, channel_in, channel_out, add_conv1_1=False, stride=1):super(residual, self).__init__()self.add_conv1_1 = add_conv1_1self.conv1 = nn.Conv2d(channel_in, channel_out, 3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(channel_out)self.conv2 = nn.Conv2d(channel_out, channel_out, 3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(channel_out)if self.add_conv1_1:self.conv3 = nn.Conv2d(channel_in, channel_out, 1, stride=stride)else:self.conv3 = Nonedef forward(self, x):y = self.conv1(x)y = F.relu(self.bn1(y), True)y = self.conv2(y)y = F.relu(self.bn2(y), True)if self.add_conv1_1:x = self.conv3(x)y = y + xreturn F.relu(y, True)
如代所示,码残差块中定义了两个33的卷积,每个卷积层后面接了一个规范化层和一个Relu激活函数(体现在传播函数中)。self.conv3中定义了一个11的卷积层,当add_conv1_1=True时,输入会经过self.conv3卷积层直接反馈到输出端(y = y + x)。
2、ResNet模型定义
本回中的ResNet模型,定义了五个网络块。第一个网络块单独定义,后四个网络块结构相似(有兴趣的可以建立一个标准模块,方便设计深层网络)。
class resnet(nn.Module):def __init__(self, in_channel, num_classes):super(resnet, self).__init__()self.block1 = nn.Sequential(nn.Conv2d(in_channel, 64, 7, 2),nn.BatchNorm2d(64, eps=1e-3),nn.ReLU(True),nn.MaxPool2d(3, 2, 1))self.block2 = nn.Sequential(residual(64, 64),residual(64, 64))self.block3 = nn.Sequential(residual(64, 128, True, stride=2),residual(128, 128))self.block4 = nn.Sequential(residual(128, 256, True, stride=2),residual(256, 256))self.block5 = nn.Sequential(residual(256, 512, True, stride=2),residual(512, 512),nn.AvgPool2d(3),nn.Flatten())self.classifier = nn.Linear(512, num_classes)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.classifier(x)return x
注:
由于本回采用d2l.train_ch6()进行数据训练,里面集成了损失函数和优化器,因此不需要单独定义(在第八回小记中介绍了如何安装d2l库)。
模型训练
1、实例化ResNet模型
这里输入为3个通道,因为彩色图片有三个数据通道。输出为10,因为数据集有10个类别(数据集的介绍,在第六回中)。
classify_ResNet = resnet(3, 10)
2、迭代训练模型
本次训练采用d2l.train_ch6()函数,其参数有六个:第一个是模型,第二个是训练集,第三个是测试集,第四个是迭代次数(设定为20次),第五个是学习率(设定为0.01),第六个是进行训练的设备(设定为GPU训练)。
d2l.train_ch6(classify_ResNet, train_data, test_data, 20, 0.01, d2l.try_gpu())
3、输出展示
输出含有训练集精度、测试集精度和消耗的时间等(这里我进对源码进行了修改,可以查看第八回小记)。
epoch0, loss 1.472, train acc 0.466, test acc 0.526,consume time 228.0
epoch4, loss 0.424, train acc 0.857, test acc 0.678,consume time 1139.3
epoch8, loss 0.047, train acc 0.989, test acc 0.680,consume time 2050.4
epoch12, loss 0.006, train acc 0.999, test acc 0.741,consume time 2961.4
epoch16, loss 0.002, train acc 1.000, test acc 0.746,consume time 3872.2
epoch19, loss 0.002, train acc 1.000, test acc 0.744,consume time 4555.5
对比前三回,ResNet在训练集的精度上又有所提高。
总结
1、数据准备:准备CIFAR10数据集
2、模型准备:准备残差块,ResNEt模型
3、数据训练:实例化训练模型,采用train_ch6函数进行迭代训练。
相关文章:

Pytorch 第九回:卷积神经网络——ResNet模型
Pytorch 第九回:卷积神经网络——ResNet模型 本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。 本次…...

2025-3-9 一周总结
目前来看本学期上半程汇编语言,编译原理,数字电路和离散数学是相对重点的课程. 在汇编语言和编译原理这块,个人感觉黑书内知识点更多,细节更到位,体系更完整,可以在老师讲解之前进行预习 应当及时复习每天的内容.第一是看书,然后听课,在一天结束后保证自己的知识梳理完整,没有…...

如何在el-input搜索框组件的最后面,添加图标按钮?
1、问题描述 2、解决步骤 在el-input组件标签内,添加一个element-plus的自定义插槽, 在插槽里放一个图标按钮即可。 3、效果展示 结语 以上就是在搜索框组件的末尾添加搜索按钮的过程。 喜欢本篇文章的话,请关注本博主~~...

[项目]基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信
基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信 一.Si24Ri原理图二.Si24R1芯片手册解读三.驱动函数讲解五.移植2.4g通讯(飞控部分)六.移植2.4g通讯(遥控部分) 一.Si24Ri原理图 Si24R1芯片原理图如下: 右侧为晶振。 模块…...

Python爬取咸鱼Goodfish店铺所有商品接口的详细指南
在电商数据分析和市场研究中,爬取咸鱼店铺内的所有商品信息是一项极具价值的任务。通过调用咸鱼的goodfish.item_search_shop接口,可以获取指定店铺内的商品列表,包括商品标题、价格、图片链接、销量等详细信息。本文将详细介绍如何使用Pytho…...

【极光 Orbit•STC8A-8H】03. 小刀初试:点亮你的LED灯
【极光 Orbit•STC8H】03. 小刀初试:点亮你的 LED 灯 七律 点灯初探 单片方寸藏乾坤,LED明灭见真章。 端口配置定方向,寄存器值细推敲。 高低电平随心控,循环闪烁展锋芒。 嵌入式门初开启,从此代码手中扬。 摘要 …...

docker本地部署RagFlow
1.安装 克隆仓库 git clone https://github.com/infiniflow/ragflow.git构建预建的Docker映像并启动服务器 cd ragflow/docker chmod x ./entrypoint.sh docker compose -f docker-compose.yml -p ragflow up -d修改ragflow/docker/.env文件 #RAGFLOW_IMAGEinfiniflow/ragfl…...

STM32F4 UDP组播通信:填一填ST官方HAL库的坑
先说写作本文的原因,由于开项目开发中需要用到UDP组播接收的功能,但是ST官方没有提供合适的参考,使用STM32CubeMX生成的代码也是不能直接使用的,而我在网上找了一大圈,也没有一个能够直接解决的方案,deepse…...

基于python大数据的招聘数据可视化与推荐系统
博主介绍:资深开发工程师,从事互联网行业多年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了多年的设计程序开发,开发过上千套设计程序,没有什么华丽的语言,只有…...

10. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Ocelot 网关--认证
在微服务架构中,通过在网关层实现身份认证、权限校验和数据加密,可以有效防范恶意攻击和非法访问,保障内部服务安全。采用JWT、OAuth等主流认证机制,使每次请求均经过严格验证,降低安全漏洞风险。同时,统一…...

DeepSeek 3FS:端到端无缓存的存储新范式
在 2025 年 2 月 28 日,DeepSeek 正式开源了其高性能分布式文件系统 3FS【1】,作为其开源周的压轴项目,3FS 一经发布便引发了技术圈的热烈讨论。它不仅继承了分布式存储的经典设计,还通过极简却高效的架构,展现了存储技…...

vue3组合式API怎么获取全局变量globalProperties
设置全局变量 main.ts app.config.globalProperties.$category { index: 0 } 获取全局变量 const { appContext } getCurrentInstance() as ComponentInternalInstance console.log(appContext.config.globalProperties.$category) 或是 const { proxy } getCurrentInstance…...

【YOLOv12改进trick】多尺度大核注意力机制MLKA模块引入YOLOv12,实现多尺度目标检测涨点,含创新点Python代码,方便发论文
🍋改进模块🍋:多尺度大核注意力机制(MLKA) 🍋解决问题🍋:MLKA模块结合多尺度、门控机制和空间注意力,显著增强卷积网络的模型表示能力。 🍋改进优势🍋:超分辨的MLKA模块对小目标和模糊目标涨点很明显 🍋适用场景🍋:小目标检测、模糊目标检测等 🍋思路…...

网络安全之端口扫描(一)
前置介绍 什么是DVWA? DVWA(Damn Vulnerable Web Application)是一个专门设计用于测试和提高Web应用程序安全技能的开源PHP/MySQL Web应用程序。它是一个具有多个安全漏洞的故意不安全的应用程序,供安全专业人员、渗透测试人员、…...

HCIE云计算学什么?怎么学?未来职业发展如何?
随着云计算成为IT行业发展的主流方向,HCIE云计算(华为认证云计算专家)作为华为认证体系中的高端认证之一,逐渐成为了许多网络工程师和IT从业者提升职业竞争力的重要途径。 那么,HCIE云计算究竟学什么内容,如…...

upload-labs文件上传
第一关 上传一个1.jpg的文件,在里面写好一句webshell 保留一个数据包,将其中截获的1.jpg改为1.php后重新发送 可以看到,已经成功上传 第二关 写一个webshell如图,为2.php 第二关在过滤tpye的属性,在上传2.php后使用b…...

操作系统控制台-健康守护我们的系统
引言基本准备体验功能健康守护系统诊断 收获提升结语 引言 阿里云操作系统控制平台作为新一代云端服务器中枢平台,通过创新交互模式重构主机管理体验。操作系统控制台提供了一系列管理功能,包括运维监控、智能助手、扩展插件管理以及订阅服务等。用户可以…...

财务会计域——合并报表系统设计
摘要 本文主要介绍了合并报表系统的设计,包括其背景、业务流程和系统架构设计。合并报表系统可自动化生成数据,减少人为错误,确保报表合规。其业务流程涵盖数据收集、标准化、合并调整、报表生成、审核及披露等环节。系统架构设计包括数据接…...

教务考试管理系统-Sprintboot vue
一、前言 1.1 实践目的和要求 本次实践的目的是为了帮助学生强化对实践涉及专业技术知识的理解,掌握专业领域中软件知识的应用方法,并了解软件工程在具体行业领域的发展趋势。通过培养学生利用软件工程方法分析、设计并完成具体行业软件开发的能力&…...

vue实现一个pdf在线预览,pdf选择文本并提取复制文字触发弹窗效果
[TOC] 一、文件预览 1、安装依赖包 这里安装了disjs-dist2.16版本,安装过程中报错缺少worker-loader npm i pdfjs-dist2.16.105 worker-loader3.0.8 2、模板部分 <template><div id"pdf-view"><canvas v-for"page in pdfPages&qu…...

【CSS 】Class Variance Authority CSS 类名管理工具库
1.背景、什么是 CVA? Class Variance Authority (CVA) 是一个用于管理 CSS 类名 的工具库,特别适合在 React 或 Vue 等前端框架中使用。它可以帮助你更轻松地处理组件的 样式变体(Variants),比如按钮的不同状态&#…...
自然语言处理:文本分类
介绍 大家好,我这个热衷于分享知识的博主又来啦!之前我们一起深入探讨了自然语言处理领域中非常重要的两个方法:朴素贝叶斯和逻辑斯谛回归。在探索的过程中,我们剖析了朴素贝叶斯如何基于概率原理和特征条件独立假设,…...

刷题记录 HOT100 贪心-2:45. 跳跃游戏 II
题目:45. 跳跃游戏 II 难度:中等 给定一个长度为 n 的 0 索引整数数组 nums。初始位置为 nums[0]。 每个元素 nums[i] 表示从索引 i 向后跳转的最大长度。换句话说,如果你在 nums[i] 处,你可以跳转到任意 nums[i j] 处: 0 &l…...

7.2 奇异值分解的基与矩阵
一、奇异值分解 奇异值分解(SVD)是线性代数的高光时刻。 A A A 是一个 m n m\times n mn 的矩阵,可以是方阵或者长方形矩阵,秩为 r r r。我们要对角化 A A A,但并不是把它化成 X − 1 A X X^{-1}A X X−1AX 的形…...

PDFMathTranslate安装使用
PDF全文翻译!!!! PDFMathTranslate安装使用 它是个啥 PDFMathTranslate 可能是一个用于 PDF 文件的数学公式翻译 工具。它可能包含以下功能: 提取 PDF 内的数学公式 将数学公式转换成 LaTeX 代码 翻译数学公式的内…...

STL之list的使用(超详解)
目录 一、list的介绍及使用 1.1 list的介绍 1.2 list的使用 1.2.1 list的构造 1.2.2 iterator的使用 1.2.3capacity(容量相关) 1.2.4 element access(元素访问) 1.2.5 modifiers(链表修改)…...

动态 SQL 的使用
目录 1、< if> 标签2、< trim> 标签3、< where> 标签4、< set> 标签5、< foreach> 标签 1、< if> 标签 < if test“条件语句”> xxxx < /if> 只有当条件语句满足条件,才会拼接 < if> 标签内容,因…...

【如何删除在 Linux 系统中的删除乱码文件】
如何删除在 Linux 系统中的删除乱码文件 1. 列出文件并找到乱码文件:2. 使用通配符(谨慎使用):3. 转义特殊字符:4. 使用 find 命令:5. 使用 inode 号删除文件:6. 图形界面文件管理器:…...

防火墙IPSec (无固定IP地址---一对多)
目录 前言 一、场景: 二、实现 1.拓扑图 2.配置思路 ①基础通信配置 ②PPPoE配置 ③总部的模版IPSec配置 ④分部的IPSec配置 ⑤NAT配置 3.具体配置 ①基础配置 ②详细配置和顺序 效果测试: ③PPPOE ①配置PPPoE ②策略放行 ③IPSec与NA…...

基于SpringBoot的智能问诊系统设计与隐私保护策略
通过SpringBoot框架,我们可以快速搭建一个智能问诊系统,为用户提供便捷的线上医疗服务。然而,在系统设计和实现过程中,如何保障用户的隐私和数据安全,始终是一个亟需关注的问题。本文将探讨基于SpringBoot的智能问诊系…...