当前位置: 首页 > news >正文

Pytorch 第九回:卷积神经网络——ResNet模型

Pytorch 第九回:卷积神经网络——ResNet模型

本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。
本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0+cu118,d2l的版本是1.0.3

文章目录

  • Pytorch 第九回:卷积神经网络——ResNet模型
  • 前言
    • 1、残差块
    • 2、ResNet模型
  • 一、数据准备
  • 二、模型准备
    • 1、残差块定义
    • 2、ResNet模型定义
  • 模型训练
    • 1、实例化ResNet模型
    • 2、迭代训练模型
    • 3、输出展示
  • 总结


前言

讲述模型前,先讲述两个概念,统一下思路:

1、残差块

残差块是ResNet模型的基础架构,该架构允许输入特征跨层作用到块的输出端,从而加强了浅层特征的输出。其结构如下图所示:
在这里插入图片描述
如上图所示,这里的1*1卷积层,可以将输入特征直接作用到输出端,从而避免了浅层的梯度到达输出端时过小的问题。

2、ResNet模型

2015年,ResNet模型由微软研究院提出,并在ImageNet大规模视觉挑战赛中一举夺得了冠军。之前设计的神经网络,随着网络层的增多,并没有达到训练误差不断减少的预期,反而出现训练误差逐渐加大的现象,人们也称之为“网络退化”。ResNet模型通过加入了残差块的框架,使训练的深层神经网络更加有效。

闲言少叙,直接展示逻辑,先上引用:

import numpy as np
import torch
from torch import nn
from torchvision.datasets import CIFAR10
import time
from torch.utils.data import DataLoader
from d2l import torch as d2l
import torch.nn.functional as F

一、数据准备

如前几回一样,本次仍然采用CIFAR10数据集,因此不做重点解释(有兴趣的可以查看第六回内容),本回只展示代码:

def data_treating(x):x = x.resize((96, 96), 2)  #x = np.array(x, dtype='float32') / 255x = (x - 0.5) / 0.5  #x = x.transpose((2, 0, 1))  #x = torch.from_numpy(x)return xtrain_set = CIFAR10('./data', train=True, transform=data_treating)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_treating)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

注:
在本回设计的模型中,数据最小为96 * 96。有兴趣的可以参考第八回小记自己进行计算。

二、模型准备

1、残差块定义

残差块的定义分为两部分,一部分是初始化函数,一部分是前向传播函数。需要注意的是add_conv1_1这个参数,用于控制是否有输入特征直接作用于输出端。代码如下所示:

class residual(nn.Module):def __init__(self, channel_in, channel_out, add_conv1_1=False, stride=1):super(residual, self).__init__()self.add_conv1_1 = add_conv1_1self.conv1 = nn.Conv2d(channel_in, channel_out, 3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(channel_out)self.conv2 = nn.Conv2d(channel_out, channel_out, 3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(channel_out)if self.add_conv1_1:self.conv3 = nn.Conv2d(channel_in, channel_out, 1, stride=stride)else:self.conv3 = Nonedef forward(self, x):y = self.conv1(x)y = F.relu(self.bn1(y), True)y = self.conv2(y)y = F.relu(self.bn2(y), True)if self.add_conv1_1:x = self.conv3(x)y = y + xreturn F.relu(y, True)

如代所示,码残差块中定义了两个33的卷积,每个卷积层后面接了一个规范化层和一个Relu激活函数(体现在传播函数中)。self.conv3中定义了一个11的卷积层,当add_conv1_1=True时,输入会经过self.conv3卷积层直接反馈到输出端(y = y + x)。

2、ResNet模型定义

本回中的ResNet模型,定义了五个网络块。第一个网络块单独定义,后四个网络块结构相似(有兴趣的可以建立一个标准模块,方便设计深层网络)。

class resnet(nn.Module):def __init__(self, in_channel, num_classes):super(resnet, self).__init__()self.block1 = nn.Sequential(nn.Conv2d(in_channel, 64, 7, 2),nn.BatchNorm2d(64, eps=1e-3),nn.ReLU(True),nn.MaxPool2d(3, 2, 1))self.block2 = nn.Sequential(residual(64, 64),residual(64, 64))self.block3 = nn.Sequential(residual(64, 128, True, stride=2),residual(128, 128))self.block4 = nn.Sequential(residual(128, 256, True, stride=2),residual(256, 256))self.block5 = nn.Sequential(residual(256, 512, True, stride=2),residual(512, 512),nn.AvgPool2d(3),nn.Flatten())self.classifier = nn.Linear(512, num_classes)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.classifier(x)return x

注:
由于本回采用d2l.train_ch6()进行数据训练,里面集成了损失函数和优化器,因此不需要单独定义(在第八回小记中介绍了如何安装d2l库)。

模型训练

1、实例化ResNet模型

这里输入为3个通道,因为彩色图片有三个数据通道。输出为10,因为数据集有10个类别(数据集的介绍,在第六回中)。

classify_ResNet = resnet(3, 10)

2、迭代训练模型

本次训练采用d2l.train_ch6()函数,其参数有六个:第一个是模型,第二个是训练集,第三个是测试集,第四个是迭代次数(设定为20次),第五个是学习率(设定为0.01),第六个是进行训练的设备(设定为GPU训练)。

d2l.train_ch6(classify_ResNet, train_data, test_data, 20, 0.01, d2l.try_gpu())

3、输出展示

输出含有训练集精度、测试集精度和消耗的时间等(这里我进对源码进行了修改,可以查看第八回小记)。

epoch0, loss 1.472, train acc 0.466, test acc 0.526,consume time 228.0
epoch4, loss 0.424, train acc 0.857, test acc 0.678,consume time 1139.3
epoch8, loss 0.047, train acc 0.989, test acc 0.680,consume time 2050.4
epoch12, loss 0.006, train acc 0.999, test acc 0.741,consume time 2961.4
epoch16, loss 0.002, train acc 1.000, test acc 0.746,consume time 3872.2
epoch19, loss 0.002, train acc 1.000, test acc 0.744,consume time 4555.5

对比前三回,ResNet在训练集的精度上又有所提高。

总结

1、数据准备:准备CIFAR10数据集
2、模型准备:准备残差块,ResNEt模型
3、数据训练:实例化训练模型,采用train_ch6函数进行迭代训练。

相关文章:

Pytorch 第九回:卷积神经网络——ResNet模型

Pytorch 第九回:卷积神经网络——ResNet模型 本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。 本次…...

2025-3-9 一周总结

目前来看本学期上半程汇编语言,编译原理,数字电路和离散数学是相对重点的课程. 在汇编语言和编译原理这块,个人感觉黑书内知识点更多,细节更到位,体系更完整,可以在老师讲解之前进行预习 应当及时复习每天的内容.第一是看书,然后听课,在一天结束后保证自己的知识梳理完整,没有…...

如何在el-input搜索框组件的最后面,添加图标按钮?

1、问题描述 2、解决步骤 在el-input组件标签内,添加一个element-plus的自定义插槽, 在插槽里放一个图标按钮即可。 3、效果展示 结语 以上就是在搜索框组件的末尾添加搜索按钮的过程。 喜欢本篇文章的话,请关注本博主~~...

[项目]基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信

基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信 一.Si24Ri原理图二.Si24R1芯片手册解读三.驱动函数讲解五.移植2.4g通讯(飞控部分)六.移植2.4g通讯(遥控部分) 一.Si24Ri原理图 Si24R1芯片原理图如下: 右侧为晶振。 模块…...

Python爬取咸鱼Goodfish店铺所有商品接口的详细指南

在电商数据分析和市场研究中,爬取咸鱼店铺内的所有商品信息是一项极具价值的任务。通过调用咸鱼的goodfish.item_search_shop接口,可以获取指定店铺内的商品列表,包括商品标题、价格、图片链接、销量等详细信息。本文将详细介绍如何使用Pytho…...

【极光 Orbit•STC8A-8H】03. 小刀初试:点亮你的LED灯

【极光 Orbit•STC8H】03. 小刀初试:点亮你的 LED 灯 七律 点灯初探 单片方寸藏乾坤,LED明灭见真章。 端口配置定方向,寄存器值细推敲。 高低电平随心控,循环闪烁展锋芒。 嵌入式门初开启,从此代码手中扬。 摘要 …...

docker本地部署RagFlow

1.安装 克隆仓库 git clone https://github.com/infiniflow/ragflow.git构建预建的Docker映像并启动服务器 cd ragflow/docker chmod x ./entrypoint.sh docker compose -f docker-compose.yml -p ragflow up -d修改ragflow/docker/.env文件 #RAGFLOW_IMAGEinfiniflow/ragfl…...

STM32F4 UDP组播通信:填一填ST官方HAL库的坑

先说写作本文的原因,由于开项目开发中需要用到UDP组播接收的功能,但是ST官方没有提供合适的参考,使用STM32CubeMX生成的代码也是不能直接使用的,而我在网上找了一大圈,也没有一个能够直接解决的方案,deepse…...

基于python大数据的招聘数据可视化与推荐系统

博主介绍:资深开发工程师,从事互联网行业多年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了多年的设计程序开发,开发过上千套设计程序,没有什么华丽的语言,只有…...

10. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Ocelot 网关--认证

在微服务架构中,通过在网关层实现身份认证、权限校验和数据加密,可以有效防范恶意攻击和非法访问,保障内部服务安全。采用JWT、OAuth等主流认证机制,使每次请求均经过严格验证,降低安全漏洞风险。同时,统一…...

DeepSeek 3FS:端到端无缓存的存储新范式

在 2025 年 2 月 28 日,DeepSeek 正式开源了其高性能分布式文件系统 3FS【1】,作为其开源周的压轴项目,3FS 一经发布便引发了技术圈的热烈讨论。它不仅继承了分布式存储的经典设计,还通过极简却高效的架构,展现了存储技…...

vue3组合式API怎么获取全局变量globalProperties

设置全局变量 main.ts app.config.globalProperties.$category { index: 0 } 获取全局变量 const { appContext } getCurrentInstance() as ComponentInternalInstance console.log(appContext.config.globalProperties.$category) 或是 const { proxy } getCurrentInstance…...

【YOLOv12改进trick】多尺度大核注意力机制MLKA模块引入YOLOv12,实现多尺度目标检测涨点,含创新点Python代码,方便发论文

🍋改进模块🍋:多尺度大核注意力机制(MLKA) 🍋解决问题🍋:MLKA模块结合多尺度、门控机制和空间注意力,显著增强卷积网络的模型表示能力。 🍋改进优势🍋:超分辨的MLKA模块对小目标和模糊目标涨点很明显 🍋适用场景🍋:小目标检测、模糊目标检测等 🍋思路…...

网络安全之端口扫描(一)

前置介绍 什么是DVWA? DVWA(Damn Vulnerable Web Application)是一个专门设计用于测试和提高Web应用程序安全技能的开源PHP/MySQL Web应用程序。它是一个具有多个安全漏洞的故意不安全的应用程序,供安全专业人员、渗透测试人员、…...

HCIE云计算学什么?怎么学?未来职业发展如何?

随着云计算成为IT行业发展的主流方向,HCIE云计算(华为认证云计算专家)作为华为认证体系中的高端认证之一,逐渐成为了许多网络工程师和IT从业者提升职业竞争力的重要途径。 那么,HCIE云计算究竟学什么内容,如…...

upload-labs文件上传

第一关 上传一个1.jpg的文件,在里面写好一句webshell 保留一个数据包,将其中截获的1.jpg改为1.php后重新发送 可以看到,已经成功上传 第二关 写一个webshell如图,为2.php 第二关在过滤tpye的属性,在上传2.php后使用b…...

操作系统控制台-健康守护我们的系统

引言基本准备体验功能健康守护系统诊断 收获提升结语 引言 阿里云操作系统控制平台作为新一代云端服务器中枢平台,通过创新交互模式重构主机管理体验。操作系统控制台提供了一系列管理功能,包括运维监控、智能助手、扩展插件管理以及订阅服务等。用户可以…...

财务会计域——合并报表系统设计

摘要 本文主要介绍了合并报表系统的设计,包括其背景、业务流程和系统架构设计。合并报表系统可自动化生成数据,减少人为错误,确保报表合规。其业务流程涵盖数据收集、标准化、合并调整、报表生成、审核及披露等环节。系统架构设计包括数据接…...

教务考试管理系统-Sprintboot vue

一、前言 1.1 实践目的和要求 本次实践的目的是为了帮助学生强化对实践涉及专业技术知识的理解,掌握专业领域中软件知识的应用方法,并了解软件工程在具体行业领域的发展趋势。通过培养学生利用软件工程方法分析、设计并完成具体行业软件开发的能力&…...

vue实现一个pdf在线预览,pdf选择文本并提取复制文字触发弹窗效果

[TOC] 一、文件预览 1、安装依赖包 这里安装了disjs-dist2.16版本&#xff0c;安装过程中报错缺少worker-loader npm i pdfjs-dist2.16.105 worker-loader3.0.8 2、模板部分 <template><div id"pdf-view"><canvas v-for"page in pdfPages&qu…...

Mplus路径系数差异比较实战:两种方法详解与选择指南

Mplus路径系数差异比较实战&#xff1a;两种方法详解与选择指南 在结构方程模型分析中&#xff0c;研究者常常需要比较不同路径系数或中介效应是否存在显著差异。比如&#xff0c;你可能想知道性别对工作满意度的直接影响是否显著大于其对组织承诺的影响&#xff0c;或者比较两…...

如何在ComfyUI中玩转WanVideo:从零到一的视频生成魔法

如何在ComfyUI中玩转WanVideo&#xff1a;从零到一的视频生成魔法 【免费下载链接】ComfyUI-WanVideoWrapper 项目地址: https://gitcode.com/GitHub_Trending/co/ComfyUI-WanVideoWrapper 你是否曾经想过&#xff0c;如果能像搭积木一样轻松创作视频该有多好&#xff…...

避坑指南:用STM32CubeMX配置SPI驱动MAX7219数码管的几个关键细节

STM32CubeMX实战&#xff1a;避开MAX7219数码管驱动的5个致命配置误区 第一次用STM32CubeMX配置SPI驱动MAX7219数码管时&#xff0c;我盯着屏幕上闪烁不定的数字差点崩溃——明明按照教程一步步操作&#xff0c;为什么显示总是错乱&#xff1f;后来才发现&#xff0c;那些看似简…...

Qwen3-0.6B-FP8辅助Java八股文学习:智能抽题与答案要点生成

Qwen3-0.6B-FP8辅助Java八股文学习&#xff1a;智能抽题与答案要点生成 1. 引言&#xff1a;当面试备考遇上AI 准备Java面试&#xff0c;尤其是那些经典的“八股文”题目&#xff0c;对很多程序员来说是个既熟悉又头疼的过程。你可能也经历过&#xff1a;面对厚厚的面试宝典&…...

通义千问3-Reranker-0.6B入门指南:app.py核心逻辑解析+自定义路由扩展

通义千问3-Reranker-0.6B入门指南&#xff1a;app.py核心逻辑解析自定义路由扩展 1. 引言 如果你正在寻找一个既轻量又强大的中文重排序模型&#xff0c;那么通义千问3-Reranker-0.6B绝对值得你花时间了解一下。这个只有6亿参数的模型&#xff0c;在文本检索和排序任务上的表…...

好用的电脑软件总结

总目录&#xff1a;Software_resource 下面为子目录&#xff1a; Software&#xff1a;软件安装的位置 InstallPackage&#xff1a;安装包 SoftLink&#xff1a;快捷方式 一 科研 1 阅读软件 (1) 科研论文相关 Zotero 个人感觉最好用的文献阅读软件Citavi 文献阅读软件小绿…...

F3D:为什么这款极简3D查看器能让你彻底告别传统软件的臃肿?

F3D&#xff1a;为什么这款极简3D查看器能让你彻底告别传统软件的臃肿&#xff1f; 【免费下载链接】f3d Fast and minimalist 3D viewer. 项目地址: https://gitcode.com/GitHub_Trending/f3/f3d 在3D设计、工程可视化和科研数据分析的日常工作中&#xff0c;你是否曾因…...

ReactPy虚拟DOM终极指南:Python如何高效更新网页内容

ReactPy虚拟DOM终极指南&#xff1a;Python如何高效更新网页内容 【免费下载链接】reactpy Its React, but in Python 项目地址: https://gitcode.com/gh_mirrors/re/reactpy ReactPy作为Python领域的创新框架&#xff0c;让开发者能够使用Python语法构建交互式Web界面&…...

3大核心功能+2套实战流程:零基础掌握FreeCAD开源3D建模

3大核心功能2套实战流程&#xff1a;零基础掌握FreeCAD开源3D建模 【免费下载链接】FreeCAD This is the official source code of FreeCAD, a free and opensource multiplatform 3D parametric modeler. 项目地址: https://gitcode.com/GitHub_Trending/fr/freecad 3D…...

保姆级教程:在PX4 1.13.1固件下,从零开始编写一个自定义控制模块(附完整代码)

PX4 1.13.1固件下自定义控制模块开发全流程指南 当你第一次打开PX4的源码目录&#xff0c;面对层层嵌套的文件夹和复杂的编译系统&#xff0c;是否感到无从下手&#xff1f;作为一款开源的无人机飞控系统&#xff0c;PX4的强大之处在于其高度模块化的设计&#xff0c;允许开发者…...