Pytorch 第九回:卷积神经网络——ResNet模型
Pytorch 第九回:卷积神经网络——ResNet模型
本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。
本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0+cu118,d2l的版本是1.0.3
文章目录
- Pytorch 第九回:卷积神经网络——ResNet模型
- 前言
- 1、残差块
- 2、ResNet模型
- 一、数据准备
- 二、模型准备
- 1、残差块定义
- 2、ResNet模型定义
- 模型训练
- 1、实例化ResNet模型
- 2、迭代训练模型
- 3、输出展示
- 总结
前言
讲述模型前,先讲述两个概念,统一下思路:
1、残差块
残差块是ResNet模型的基础架构,该架构允许输入特征跨层作用到块的输出端,从而加强了浅层特征的输出。其结构如下图所示:

如上图所示,这里的1*1卷积层,可以将输入特征直接作用到输出端,从而避免了浅层的梯度到达输出端时过小的问题。
2、ResNet模型
2015年,ResNet模型由微软研究院提出,并在ImageNet大规模视觉挑战赛中一举夺得了冠军。之前设计的神经网络,随着网络层的增多,并没有达到训练误差不断减少的预期,反而出现训练误差逐渐加大的现象,人们也称之为“网络退化”。ResNet模型通过加入了残差块的框架,使训练的深层神经网络更加有效。
闲言少叙,直接展示逻辑,先上引用:
import numpy as np
import torch
from torch import nn
from torchvision.datasets import CIFAR10
import time
from torch.utils.data import DataLoader
from d2l import torch as d2l
import torch.nn.functional as F
一、数据准备
如前几回一样,本次仍然采用CIFAR10数据集,因此不做重点解释(有兴趣的可以查看第六回内容),本回只展示代码:
def data_treating(x):x = x.resize((96, 96), 2) #x = np.array(x, dtype='float32') / 255x = (x - 0.5) / 0.5 #x = x.transpose((2, 0, 1)) #x = torch.from_numpy(x)return xtrain_set = CIFAR10('./data', train=True, transform=data_treating)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_treating)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
注:
在本回设计的模型中,数据最小为96 * 96。有兴趣的可以参考第八回小记自己进行计算。
二、模型准备
1、残差块定义
残差块的定义分为两部分,一部分是初始化函数,一部分是前向传播函数。需要注意的是add_conv1_1这个参数,用于控制是否有输入特征直接作用于输出端。代码如下所示:
class residual(nn.Module):def __init__(self, channel_in, channel_out, add_conv1_1=False, stride=1):super(residual, self).__init__()self.add_conv1_1 = add_conv1_1self.conv1 = nn.Conv2d(channel_in, channel_out, 3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(channel_out)self.conv2 = nn.Conv2d(channel_out, channel_out, 3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(channel_out)if self.add_conv1_1:self.conv3 = nn.Conv2d(channel_in, channel_out, 1, stride=stride)else:self.conv3 = Nonedef forward(self, x):y = self.conv1(x)y = F.relu(self.bn1(y), True)y = self.conv2(y)y = F.relu(self.bn2(y), True)if self.add_conv1_1:x = self.conv3(x)y = y + xreturn F.relu(y, True)
如代所示,码残差块中定义了两个33的卷积,每个卷积层后面接了一个规范化层和一个Relu激活函数(体现在传播函数中)。self.conv3中定义了一个11的卷积层,当add_conv1_1=True时,输入会经过self.conv3卷积层直接反馈到输出端(y = y + x)。
2、ResNet模型定义
本回中的ResNet模型,定义了五个网络块。第一个网络块单独定义,后四个网络块结构相似(有兴趣的可以建立一个标准模块,方便设计深层网络)。
class resnet(nn.Module):def __init__(self, in_channel, num_classes):super(resnet, self).__init__()self.block1 = nn.Sequential(nn.Conv2d(in_channel, 64, 7, 2),nn.BatchNorm2d(64, eps=1e-3),nn.ReLU(True),nn.MaxPool2d(3, 2, 1))self.block2 = nn.Sequential(residual(64, 64),residual(64, 64))self.block3 = nn.Sequential(residual(64, 128, True, stride=2),residual(128, 128))self.block4 = nn.Sequential(residual(128, 256, True, stride=2),residual(256, 256))self.block5 = nn.Sequential(residual(256, 512, True, stride=2),residual(512, 512),nn.AvgPool2d(3),nn.Flatten())self.classifier = nn.Linear(512, num_classes)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.classifier(x)return x
注:
由于本回采用d2l.train_ch6()进行数据训练,里面集成了损失函数和优化器,因此不需要单独定义(在第八回小记中介绍了如何安装d2l库)。
模型训练
1、实例化ResNet模型
这里输入为3个通道,因为彩色图片有三个数据通道。输出为10,因为数据集有10个类别(数据集的介绍,在第六回中)。
classify_ResNet = resnet(3, 10)
2、迭代训练模型
本次训练采用d2l.train_ch6()函数,其参数有六个:第一个是模型,第二个是训练集,第三个是测试集,第四个是迭代次数(设定为20次),第五个是学习率(设定为0.01),第六个是进行训练的设备(设定为GPU训练)。
d2l.train_ch6(classify_ResNet, train_data, test_data, 20, 0.01, d2l.try_gpu())
3、输出展示
输出含有训练集精度、测试集精度和消耗的时间等(这里我进对源码进行了修改,可以查看第八回小记)。
epoch0, loss 1.472, train acc 0.466, test acc 0.526,consume time 228.0
epoch4, loss 0.424, train acc 0.857, test acc 0.678,consume time 1139.3
epoch8, loss 0.047, train acc 0.989, test acc 0.680,consume time 2050.4
epoch12, loss 0.006, train acc 0.999, test acc 0.741,consume time 2961.4
epoch16, loss 0.002, train acc 1.000, test acc 0.746,consume time 3872.2
epoch19, loss 0.002, train acc 1.000, test acc 0.744,consume time 4555.5
对比前三回,ResNet在训练集的精度上又有所提高。
总结
1、数据准备:准备CIFAR10数据集
2、模型准备:准备残差块,ResNEt模型
3、数据训练:实例化训练模型,采用train_ch6函数进行迭代训练。
相关文章:
Pytorch 第九回:卷积神经网络——ResNet模型
Pytorch 第九回:卷积神经网络——ResNet模型 本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。 本次…...
2025-3-9 一周总结
目前来看本学期上半程汇编语言,编译原理,数字电路和离散数学是相对重点的课程. 在汇编语言和编译原理这块,个人感觉黑书内知识点更多,细节更到位,体系更完整,可以在老师讲解之前进行预习 应当及时复习每天的内容.第一是看书,然后听课,在一天结束后保证自己的知识梳理完整,没有…...
如何在el-input搜索框组件的最后面,添加图标按钮?
1、问题描述 2、解决步骤 在el-input组件标签内,添加一个element-plus的自定义插槽, 在插槽里放一个图标按钮即可。 3、效果展示 结语 以上就是在搜索框组件的末尾添加搜索按钮的过程。 喜欢本篇文章的话,请关注本博主~~...
[项目]基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信
基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信 一.Si24Ri原理图二.Si24R1芯片手册解读三.驱动函数讲解五.移植2.4g通讯(飞控部分)六.移植2.4g通讯(遥控部分) 一.Si24Ri原理图 Si24R1芯片原理图如下: 右侧为晶振。 模块…...
Python爬取咸鱼Goodfish店铺所有商品接口的详细指南
在电商数据分析和市场研究中,爬取咸鱼店铺内的所有商品信息是一项极具价值的任务。通过调用咸鱼的goodfish.item_search_shop接口,可以获取指定店铺内的商品列表,包括商品标题、价格、图片链接、销量等详细信息。本文将详细介绍如何使用Pytho…...
【极光 Orbit•STC8A-8H】03. 小刀初试:点亮你的LED灯
【极光 Orbit•STC8H】03. 小刀初试:点亮你的 LED 灯 七律 点灯初探 单片方寸藏乾坤,LED明灭见真章。 端口配置定方向,寄存器值细推敲。 高低电平随心控,循环闪烁展锋芒。 嵌入式门初开启,从此代码手中扬。 摘要 …...
docker本地部署RagFlow
1.安装 克隆仓库 git clone https://github.com/infiniflow/ragflow.git构建预建的Docker映像并启动服务器 cd ragflow/docker chmod x ./entrypoint.sh docker compose -f docker-compose.yml -p ragflow up -d修改ragflow/docker/.env文件 #RAGFLOW_IMAGEinfiniflow/ragfl…...
STM32F4 UDP组播通信:填一填ST官方HAL库的坑
先说写作本文的原因,由于开项目开发中需要用到UDP组播接收的功能,但是ST官方没有提供合适的参考,使用STM32CubeMX生成的代码也是不能直接使用的,而我在网上找了一大圈,也没有一个能够直接解决的方案,deepse…...
基于python大数据的招聘数据可视化与推荐系统
博主介绍:资深开发工程师,从事互联网行业多年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了多年的设计程序开发,开发过上千套设计程序,没有什么华丽的语言,只有…...
10. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Ocelot 网关--认证
在微服务架构中,通过在网关层实现身份认证、权限校验和数据加密,可以有效防范恶意攻击和非法访问,保障内部服务安全。采用JWT、OAuth等主流认证机制,使每次请求均经过严格验证,降低安全漏洞风险。同时,统一…...
DeepSeek 3FS:端到端无缓存的存储新范式
在 2025 年 2 月 28 日,DeepSeek 正式开源了其高性能分布式文件系统 3FS【1】,作为其开源周的压轴项目,3FS 一经发布便引发了技术圈的热烈讨论。它不仅继承了分布式存储的经典设计,还通过极简却高效的架构,展现了存储技…...
vue3组合式API怎么获取全局变量globalProperties
设置全局变量 main.ts app.config.globalProperties.$category { index: 0 } 获取全局变量 const { appContext } getCurrentInstance() as ComponentInternalInstance console.log(appContext.config.globalProperties.$category) 或是 const { proxy } getCurrentInstance…...
【YOLOv12改进trick】多尺度大核注意力机制MLKA模块引入YOLOv12,实现多尺度目标检测涨点,含创新点Python代码,方便发论文
🍋改进模块🍋:多尺度大核注意力机制(MLKA) 🍋解决问题🍋:MLKA模块结合多尺度、门控机制和空间注意力,显著增强卷积网络的模型表示能力。 🍋改进优势🍋:超分辨的MLKA模块对小目标和模糊目标涨点很明显 🍋适用场景🍋:小目标检测、模糊目标检测等 🍋思路…...
网络安全之端口扫描(一)
前置介绍 什么是DVWA? DVWA(Damn Vulnerable Web Application)是一个专门设计用于测试和提高Web应用程序安全技能的开源PHP/MySQL Web应用程序。它是一个具有多个安全漏洞的故意不安全的应用程序,供安全专业人员、渗透测试人员、…...
HCIE云计算学什么?怎么学?未来职业发展如何?
随着云计算成为IT行业发展的主流方向,HCIE云计算(华为认证云计算专家)作为华为认证体系中的高端认证之一,逐渐成为了许多网络工程师和IT从业者提升职业竞争力的重要途径。 那么,HCIE云计算究竟学什么内容,如…...
upload-labs文件上传
第一关 上传一个1.jpg的文件,在里面写好一句webshell 保留一个数据包,将其中截获的1.jpg改为1.php后重新发送 可以看到,已经成功上传 第二关 写一个webshell如图,为2.php 第二关在过滤tpye的属性,在上传2.php后使用b…...
操作系统控制台-健康守护我们的系统
引言基本准备体验功能健康守护系统诊断 收获提升结语 引言 阿里云操作系统控制平台作为新一代云端服务器中枢平台,通过创新交互模式重构主机管理体验。操作系统控制台提供了一系列管理功能,包括运维监控、智能助手、扩展插件管理以及订阅服务等。用户可以…...
财务会计域——合并报表系统设计
摘要 本文主要介绍了合并报表系统的设计,包括其背景、业务流程和系统架构设计。合并报表系统可自动化生成数据,减少人为错误,确保报表合规。其业务流程涵盖数据收集、标准化、合并调整、报表生成、审核及披露等环节。系统架构设计包括数据接…...
教务考试管理系统-Sprintboot vue
一、前言 1.1 实践目的和要求 本次实践的目的是为了帮助学生强化对实践涉及专业技术知识的理解,掌握专业领域中软件知识的应用方法,并了解软件工程在具体行业领域的发展趋势。通过培养学生利用软件工程方法分析、设计并完成具体行业软件开发的能力&…...
vue实现一个pdf在线预览,pdf选择文本并提取复制文字触发弹窗效果
[TOC] 一、文件预览 1、安装依赖包 这里安装了disjs-dist2.16版本,安装过程中报错缺少worker-loader npm i pdfjs-dist2.16.105 worker-loader3.0.8 2、模板部分 <template><div id"pdf-view"><canvas v-for"page in pdfPages&qu…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)
HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...
Opencv中的addweighted函数
一.addweighted函数作用 addweighted()是OpenCV库中用于图像处理的函数,主要功能是将两个输入图像(尺寸和类型相同)按照指定的权重进行加权叠加(图像融合),并添加一个标量值&#x…...
React19源码系列之 事件插件系统
事件类别 事件类型 定义 文档 Event Event 接口表示在 EventTarget 上出现的事件。 Event - Web API | MDN UIEvent UIEvent 接口表示简单的用户界面事件。 UIEvent - Web API | MDN KeyboardEvent KeyboardEvent 对象描述了用户与键盘的交互。 KeyboardEvent - Web…...
跨链模式:多链互操作架构与性能扩展方案
跨链模式:多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈:模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展(H2Cross架构): 适配层…...
2025盘古石杯决赛【手机取证】
前言 第三届盘古石杯国际电子数据取证大赛决赛 最后一题没有解出来,实在找不到,希望有大佬教一下我。 还有就会议时间,我感觉不是图片时间,因为在电脑看到是其他时间用老会议系统开的会。 手机取证 1、分析鸿蒙手机检材&#x…...
SpringCloudGateway 自定义局部过滤器
场景: 将所有请求转化为同一路径请求(方便穿网配置)在请求头内标识原来路径,然后在将请求分发给不同服务 AllToOneGatewayFilterFactory import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; impor…...
基于matlab策略迭代和值迭代法的动态规划
经典的基于策略迭代和值迭代法的动态规划matlab代码,实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...
ABAP设计模式之---“简单设计原则(Simple Design)”
“Simple Design”(简单设计)是软件开发中的一个重要理念,倡导以最简单的方式实现软件功能,以确保代码清晰易懂、易维护,并在项目需求变化时能够快速适应。 其核心目标是避免复杂和过度设计,遵循“让事情保…...
代码随想录刷题day30
1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币,另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额,返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...
基于SpringBoot在线拍卖系统的设计和实现
摘 要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 在线拍卖系统,主要的模块包括管理员;首页、个人中心、用户管理、商品类型管理、拍卖商品管理、历史竞拍管理、竞拍订单…...
