当前位置: 首页 > news >正文

Pytorch 第九回:卷积神经网络——ResNet模型

Pytorch 第九回:卷积神经网络——ResNet模型

本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。
本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0+cu118,d2l的版本是1.0.3

文章目录

  • Pytorch 第九回:卷积神经网络——ResNet模型
  • 前言
    • 1、残差块
    • 2、ResNet模型
  • 一、数据准备
  • 二、模型准备
    • 1、残差块定义
    • 2、ResNet模型定义
  • 模型训练
    • 1、实例化ResNet模型
    • 2、迭代训练模型
    • 3、输出展示
  • 总结


前言

讲述模型前,先讲述两个概念,统一下思路:

1、残差块

残差块是ResNet模型的基础架构,该架构允许输入特征跨层作用到块的输出端,从而加强了浅层特征的输出。其结构如下图所示:
在这里插入图片描述
如上图所示,这里的1*1卷积层,可以将输入特征直接作用到输出端,从而避免了浅层的梯度到达输出端时过小的问题。

2、ResNet模型

2015年,ResNet模型由微软研究院提出,并在ImageNet大规模视觉挑战赛中一举夺得了冠军。之前设计的神经网络,随着网络层的增多,并没有达到训练误差不断减少的预期,反而出现训练误差逐渐加大的现象,人们也称之为“网络退化”。ResNet模型通过加入了残差块的框架,使训练的深层神经网络更加有效。

闲言少叙,直接展示逻辑,先上引用:

import numpy as np
import torch
from torch import nn
from torchvision.datasets import CIFAR10
import time
from torch.utils.data import DataLoader
from d2l import torch as d2l
import torch.nn.functional as F

一、数据准备

如前几回一样,本次仍然采用CIFAR10数据集,因此不做重点解释(有兴趣的可以查看第六回内容),本回只展示代码:

def data_treating(x):x = x.resize((96, 96), 2)  #x = np.array(x, dtype='float32') / 255x = (x - 0.5) / 0.5  #x = x.transpose((2, 0, 1))  #x = torch.from_numpy(x)return xtrain_set = CIFAR10('./data', train=True, transform=data_treating)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_treating)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

注:
在本回设计的模型中,数据最小为96 * 96。有兴趣的可以参考第八回小记自己进行计算。

二、模型准备

1、残差块定义

残差块的定义分为两部分,一部分是初始化函数,一部分是前向传播函数。需要注意的是add_conv1_1这个参数,用于控制是否有输入特征直接作用于输出端。代码如下所示:

class residual(nn.Module):def __init__(self, channel_in, channel_out, add_conv1_1=False, stride=1):super(residual, self).__init__()self.add_conv1_1 = add_conv1_1self.conv1 = nn.Conv2d(channel_in, channel_out, 3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(channel_out)self.conv2 = nn.Conv2d(channel_out, channel_out, 3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(channel_out)if self.add_conv1_1:self.conv3 = nn.Conv2d(channel_in, channel_out, 1, stride=stride)else:self.conv3 = Nonedef forward(self, x):y = self.conv1(x)y = F.relu(self.bn1(y), True)y = self.conv2(y)y = F.relu(self.bn2(y), True)if self.add_conv1_1:x = self.conv3(x)y = y + xreturn F.relu(y, True)

如代所示,码残差块中定义了两个33的卷积,每个卷积层后面接了一个规范化层和一个Relu激活函数(体现在传播函数中)。self.conv3中定义了一个11的卷积层,当add_conv1_1=True时,输入会经过self.conv3卷积层直接反馈到输出端(y = y + x)。

2、ResNet模型定义

本回中的ResNet模型,定义了五个网络块。第一个网络块单独定义,后四个网络块结构相似(有兴趣的可以建立一个标准模块,方便设计深层网络)。

class resnet(nn.Module):def __init__(self, in_channel, num_classes):super(resnet, self).__init__()self.block1 = nn.Sequential(nn.Conv2d(in_channel, 64, 7, 2),nn.BatchNorm2d(64, eps=1e-3),nn.ReLU(True),nn.MaxPool2d(3, 2, 1))self.block2 = nn.Sequential(residual(64, 64),residual(64, 64))self.block3 = nn.Sequential(residual(64, 128, True, stride=2),residual(128, 128))self.block4 = nn.Sequential(residual(128, 256, True, stride=2),residual(256, 256))self.block5 = nn.Sequential(residual(256, 512, True, stride=2),residual(512, 512),nn.AvgPool2d(3),nn.Flatten())self.classifier = nn.Linear(512, num_classes)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.classifier(x)return x

注:
由于本回采用d2l.train_ch6()进行数据训练,里面集成了损失函数和优化器,因此不需要单独定义(在第八回小记中介绍了如何安装d2l库)。

模型训练

1、实例化ResNet模型

这里输入为3个通道,因为彩色图片有三个数据通道。输出为10,因为数据集有10个类别(数据集的介绍,在第六回中)。

classify_ResNet = resnet(3, 10)

2、迭代训练模型

本次训练采用d2l.train_ch6()函数,其参数有六个:第一个是模型,第二个是训练集,第三个是测试集,第四个是迭代次数(设定为20次),第五个是学习率(设定为0.01),第六个是进行训练的设备(设定为GPU训练)。

d2l.train_ch6(classify_ResNet, train_data, test_data, 20, 0.01, d2l.try_gpu())

3、输出展示

输出含有训练集精度、测试集精度和消耗的时间等(这里我进对源码进行了修改,可以查看第八回小记)。

epoch0, loss 1.472, train acc 0.466, test acc 0.526,consume time 228.0
epoch4, loss 0.424, train acc 0.857, test acc 0.678,consume time 1139.3
epoch8, loss 0.047, train acc 0.989, test acc 0.680,consume time 2050.4
epoch12, loss 0.006, train acc 0.999, test acc 0.741,consume time 2961.4
epoch16, loss 0.002, train acc 1.000, test acc 0.746,consume time 3872.2
epoch19, loss 0.002, train acc 1.000, test acc 0.744,consume time 4555.5

对比前三回,ResNet在训练集的精度上又有所提高。

总结

1、数据准备:准备CIFAR10数据集
2、模型准备:准备残差块,ResNEt模型
3、数据训练:实例化训练模型,采用train_ch6函数进行迭代训练。

相关文章:

Pytorch 第九回:卷积神经网络——ResNet模型

Pytorch 第九回:卷积神经网络——ResNet模型 本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。 本次…...

2025-3-9 一周总结

目前来看本学期上半程汇编语言,编译原理,数字电路和离散数学是相对重点的课程. 在汇编语言和编译原理这块,个人感觉黑书内知识点更多,细节更到位,体系更完整,可以在老师讲解之前进行预习 应当及时复习每天的内容.第一是看书,然后听课,在一天结束后保证自己的知识梳理完整,没有…...

如何在el-input搜索框组件的最后面,添加图标按钮?

1、问题描述 2、解决步骤 在el-input组件标签内,添加一个element-plus的自定义插槽, 在插槽里放一个图标按钮即可。 3、效果展示 结语 以上就是在搜索框组件的末尾添加搜索按钮的过程。 喜欢本篇文章的话,请关注本博主~~...

[项目]基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信

基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信 一.Si24Ri原理图二.Si24R1芯片手册解读三.驱动函数讲解五.移植2.4g通讯(飞控部分)六.移植2.4g通讯(遥控部分) 一.Si24Ri原理图 Si24R1芯片原理图如下: 右侧为晶振。 模块…...

Python爬取咸鱼Goodfish店铺所有商品接口的详细指南

在电商数据分析和市场研究中,爬取咸鱼店铺内的所有商品信息是一项极具价值的任务。通过调用咸鱼的goodfish.item_search_shop接口,可以获取指定店铺内的商品列表,包括商品标题、价格、图片链接、销量等详细信息。本文将详细介绍如何使用Pytho…...

【极光 Orbit•STC8A-8H】03. 小刀初试:点亮你的LED灯

【极光 Orbit•STC8H】03. 小刀初试:点亮你的 LED 灯 七律 点灯初探 单片方寸藏乾坤,LED明灭见真章。 端口配置定方向,寄存器值细推敲。 高低电平随心控,循环闪烁展锋芒。 嵌入式门初开启,从此代码手中扬。 摘要 …...

docker本地部署RagFlow

1.安装 克隆仓库 git clone https://github.com/infiniflow/ragflow.git构建预建的Docker映像并启动服务器 cd ragflow/docker chmod x ./entrypoint.sh docker compose -f docker-compose.yml -p ragflow up -d修改ragflow/docker/.env文件 #RAGFLOW_IMAGEinfiniflow/ragfl…...

STM32F4 UDP组播通信:填一填ST官方HAL库的坑

先说写作本文的原因,由于开项目开发中需要用到UDP组播接收的功能,但是ST官方没有提供合适的参考,使用STM32CubeMX生成的代码也是不能直接使用的,而我在网上找了一大圈,也没有一个能够直接解决的方案,deepse…...

基于python大数据的招聘数据可视化与推荐系统

博主介绍:资深开发工程师,从事互联网行业多年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了多年的设计程序开发,开发过上千套设计程序,没有什么华丽的语言,只有…...

10. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Ocelot 网关--认证

在微服务架构中,通过在网关层实现身份认证、权限校验和数据加密,可以有效防范恶意攻击和非法访问,保障内部服务安全。采用JWT、OAuth等主流认证机制,使每次请求均经过严格验证,降低安全漏洞风险。同时,统一…...

DeepSeek 3FS:端到端无缓存的存储新范式

在 2025 年 2 月 28 日,DeepSeek 正式开源了其高性能分布式文件系统 3FS【1】,作为其开源周的压轴项目,3FS 一经发布便引发了技术圈的热烈讨论。它不仅继承了分布式存储的经典设计,还通过极简却高效的架构,展现了存储技…...

vue3组合式API怎么获取全局变量globalProperties

设置全局变量 main.ts app.config.globalProperties.$category { index: 0 } 获取全局变量 const { appContext } getCurrentInstance() as ComponentInternalInstance console.log(appContext.config.globalProperties.$category) 或是 const { proxy } getCurrentInstance…...

【YOLOv12改进trick】多尺度大核注意力机制MLKA模块引入YOLOv12,实现多尺度目标检测涨点,含创新点Python代码,方便发论文

🍋改进模块🍋:多尺度大核注意力机制(MLKA) 🍋解决问题🍋:MLKA模块结合多尺度、门控机制和空间注意力,显著增强卷积网络的模型表示能力。 🍋改进优势🍋:超分辨的MLKA模块对小目标和模糊目标涨点很明显 🍋适用场景🍋:小目标检测、模糊目标检测等 🍋思路…...

网络安全之端口扫描(一)

前置介绍 什么是DVWA? DVWA(Damn Vulnerable Web Application)是一个专门设计用于测试和提高Web应用程序安全技能的开源PHP/MySQL Web应用程序。它是一个具有多个安全漏洞的故意不安全的应用程序,供安全专业人员、渗透测试人员、…...

HCIE云计算学什么?怎么学?未来职业发展如何?

随着云计算成为IT行业发展的主流方向,HCIE云计算(华为认证云计算专家)作为华为认证体系中的高端认证之一,逐渐成为了许多网络工程师和IT从业者提升职业竞争力的重要途径。 那么,HCIE云计算究竟学什么内容,如…...

upload-labs文件上传

第一关 上传一个1.jpg的文件,在里面写好一句webshell 保留一个数据包,将其中截获的1.jpg改为1.php后重新发送 可以看到,已经成功上传 第二关 写一个webshell如图,为2.php 第二关在过滤tpye的属性,在上传2.php后使用b…...

操作系统控制台-健康守护我们的系统

引言基本准备体验功能健康守护系统诊断 收获提升结语 引言 阿里云操作系统控制平台作为新一代云端服务器中枢平台,通过创新交互模式重构主机管理体验。操作系统控制台提供了一系列管理功能,包括运维监控、智能助手、扩展插件管理以及订阅服务等。用户可以…...

财务会计域——合并报表系统设计

摘要 本文主要介绍了合并报表系统的设计,包括其背景、业务流程和系统架构设计。合并报表系统可自动化生成数据,减少人为错误,确保报表合规。其业务流程涵盖数据收集、标准化、合并调整、报表生成、审核及披露等环节。系统架构设计包括数据接…...

教务考试管理系统-Sprintboot vue

一、前言 1.1 实践目的和要求 本次实践的目的是为了帮助学生强化对实践涉及专业技术知识的理解,掌握专业领域中软件知识的应用方法,并了解软件工程在具体行业领域的发展趋势。通过培养学生利用软件工程方法分析、设计并完成具体行业软件开发的能力&…...

vue实现一个pdf在线预览,pdf选择文本并提取复制文字触发弹窗效果

[TOC] 一、文件预览 1、安装依赖包 这里安装了disjs-dist2.16版本&#xff0c;安装过程中报错缺少worker-loader npm i pdfjs-dist2.16.105 worker-loader3.0.8 2、模板部分 <template><div id"pdf-view"><canvas v-for"page in pdfPages&qu…...

Python|GIF 解析与构建(5):手搓截屏和帧率控制

目录 Python&#xff5c;GIF 解析与构建&#xff08;5&#xff09;&#xff1a;手搓截屏和帧率控制 一、引言 二、技术实现&#xff1a;手搓截屏模块 2.1 核心原理 2.2 代码解析&#xff1a;ScreenshotData类 2.2.1 截图函数&#xff1a;capture_screen 三、技术实现&…...

【Linux】shell脚本忽略错误继续执行

在 shell 脚本中&#xff0c;可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行&#xff0c;可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令&#xff0c;并忽略错误 rm somefile…...

【JVM】- 内存结构

引言 JVM&#xff1a;Java Virtual Machine 定义&#xff1a;Java虚拟机&#xff0c;Java二进制字节码的运行环境好处&#xff1a; 一次编写&#xff0c;到处运行自动内存管理&#xff0c;垃圾回收的功能数组下标越界检查&#xff08;会抛异常&#xff0c;不会覆盖到其他代码…...

【第二十一章 SDIO接口(SDIO)】

第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

电脑插入多块移动硬盘后经常出现卡顿和蓝屏

当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时&#xff0c;可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案&#xff1a; 1. 检查电源供电问题 问题原因&#xff1a;多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

根据万维钢·精英日课6的内容,使用AI(2025)可以参考以下方法:

根据万维钢精英日课6的内容&#xff0c;使用AI&#xff08;2025&#xff09;可以参考以下方法&#xff1a; 四个洞见 模型已经比人聪明&#xff1a;以ChatGPT o3为代表的AI非常强大&#xff0c;能运用高级理论解释道理、引用最新学术论文&#xff0c;生成对顶尖科学家都有用的…...

站群服务器的应用场景都有哪些?

站群服务器主要是为了多个网站的托管和管理所设计的&#xff0c;可以通过集中管理和高效资源的分配&#xff0c;来支持多个独立的网站同时运行&#xff0c;让每一个网站都可以分配到独立的IP地址&#xff0c;避免出现IP关联的风险&#xff0c;用户还可以通过控制面板进行管理功…...

Git常用命令完全指南:从入门到精通

Git常用命令完全指南&#xff1a;从入门到精通 一、基础配置命令 1. 用户信息配置 # 设置全局用户名 git config --global user.name "你的名字"# 设置全局邮箱 git config --global user.email "你的邮箱example.com"# 查看所有配置 git config --list…...

如何应对敏捷转型中的团队阻力

应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中&#xff0c;明确沟通敏捷转型目的尤为关键&#xff0c;团队成员只有清晰理解转型背后的原因和利益&#xff0c;才能降低对变化的…...

Oracle11g安装包

Oracle 11g安装包 适用于windows系统&#xff0c;64位 下载路径 oracle 11g 安装包...