当前位置: 首页 > news >正文

Pytorch 第九回:卷积神经网络——ResNet模型

Pytorch 第九回:卷积神经网络——ResNet模型

本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。
本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0+cu118,d2l的版本是1.0.3

文章目录

  • Pytorch 第九回:卷积神经网络——ResNet模型
  • 前言
    • 1、残差块
    • 2、ResNet模型
  • 一、数据准备
  • 二、模型准备
    • 1、残差块定义
    • 2、ResNet模型定义
  • 模型训练
    • 1、实例化ResNet模型
    • 2、迭代训练模型
    • 3、输出展示
  • 总结


前言

讲述模型前,先讲述两个概念,统一下思路:

1、残差块

残差块是ResNet模型的基础架构,该架构允许输入特征跨层作用到块的输出端,从而加强了浅层特征的输出。其结构如下图所示:
在这里插入图片描述
如上图所示,这里的1*1卷积层,可以将输入特征直接作用到输出端,从而避免了浅层的梯度到达输出端时过小的问题。

2、ResNet模型

2015年,ResNet模型由微软研究院提出,并在ImageNet大规模视觉挑战赛中一举夺得了冠军。之前设计的神经网络,随着网络层的增多,并没有达到训练误差不断减少的预期,反而出现训练误差逐渐加大的现象,人们也称之为“网络退化”。ResNet模型通过加入了残差块的框架,使训练的深层神经网络更加有效。

闲言少叙,直接展示逻辑,先上引用:

import numpy as np
import torch
from torch import nn
from torchvision.datasets import CIFAR10
import time
from torch.utils.data import DataLoader
from d2l import torch as d2l
import torch.nn.functional as F

一、数据准备

如前几回一样,本次仍然采用CIFAR10数据集,因此不做重点解释(有兴趣的可以查看第六回内容),本回只展示代码:

def data_treating(x):x = x.resize((96, 96), 2)  #x = np.array(x, dtype='float32') / 255x = (x - 0.5) / 0.5  #x = x.transpose((2, 0, 1))  #x = torch.from_numpy(x)return xtrain_set = CIFAR10('./data', train=True, transform=data_treating)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_treating)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

注:
在本回设计的模型中,数据最小为96 * 96。有兴趣的可以参考第八回小记自己进行计算。

二、模型准备

1、残差块定义

残差块的定义分为两部分,一部分是初始化函数,一部分是前向传播函数。需要注意的是add_conv1_1这个参数,用于控制是否有输入特征直接作用于输出端。代码如下所示:

class residual(nn.Module):def __init__(self, channel_in, channel_out, add_conv1_1=False, stride=1):super(residual, self).__init__()self.add_conv1_1 = add_conv1_1self.conv1 = nn.Conv2d(channel_in, channel_out, 3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(channel_out)self.conv2 = nn.Conv2d(channel_out, channel_out, 3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(channel_out)if self.add_conv1_1:self.conv3 = nn.Conv2d(channel_in, channel_out, 1, stride=stride)else:self.conv3 = Nonedef forward(self, x):y = self.conv1(x)y = F.relu(self.bn1(y), True)y = self.conv2(y)y = F.relu(self.bn2(y), True)if self.add_conv1_1:x = self.conv3(x)y = y + xreturn F.relu(y, True)

如代所示,码残差块中定义了两个33的卷积,每个卷积层后面接了一个规范化层和一个Relu激活函数(体现在传播函数中)。self.conv3中定义了一个11的卷积层,当add_conv1_1=True时,输入会经过self.conv3卷积层直接反馈到输出端(y = y + x)。

2、ResNet模型定义

本回中的ResNet模型,定义了五个网络块。第一个网络块单独定义,后四个网络块结构相似(有兴趣的可以建立一个标准模块,方便设计深层网络)。

class resnet(nn.Module):def __init__(self, in_channel, num_classes):super(resnet, self).__init__()self.block1 = nn.Sequential(nn.Conv2d(in_channel, 64, 7, 2),nn.BatchNorm2d(64, eps=1e-3),nn.ReLU(True),nn.MaxPool2d(3, 2, 1))self.block2 = nn.Sequential(residual(64, 64),residual(64, 64))self.block3 = nn.Sequential(residual(64, 128, True, stride=2),residual(128, 128))self.block4 = nn.Sequential(residual(128, 256, True, stride=2),residual(256, 256))self.block5 = nn.Sequential(residual(256, 512, True, stride=2),residual(512, 512),nn.AvgPool2d(3),nn.Flatten())self.classifier = nn.Linear(512, num_classes)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.classifier(x)return x

注:
由于本回采用d2l.train_ch6()进行数据训练,里面集成了损失函数和优化器,因此不需要单独定义(在第八回小记中介绍了如何安装d2l库)。

模型训练

1、实例化ResNet模型

这里输入为3个通道,因为彩色图片有三个数据通道。输出为10,因为数据集有10个类别(数据集的介绍,在第六回中)。

classify_ResNet = resnet(3, 10)

2、迭代训练模型

本次训练采用d2l.train_ch6()函数,其参数有六个:第一个是模型,第二个是训练集,第三个是测试集,第四个是迭代次数(设定为20次),第五个是学习率(设定为0.01),第六个是进行训练的设备(设定为GPU训练)。

d2l.train_ch6(classify_ResNet, train_data, test_data, 20, 0.01, d2l.try_gpu())

3、输出展示

输出含有训练集精度、测试集精度和消耗的时间等(这里我进对源码进行了修改,可以查看第八回小记)。

epoch0, loss 1.472, train acc 0.466, test acc 0.526,consume time 228.0
epoch4, loss 0.424, train acc 0.857, test acc 0.678,consume time 1139.3
epoch8, loss 0.047, train acc 0.989, test acc 0.680,consume time 2050.4
epoch12, loss 0.006, train acc 0.999, test acc 0.741,consume time 2961.4
epoch16, loss 0.002, train acc 1.000, test acc 0.746,consume time 3872.2
epoch19, loss 0.002, train acc 1.000, test acc 0.744,consume time 4555.5

对比前三回,ResNet在训练集的精度上又有所提高。

总结

1、数据准备:准备CIFAR10数据集
2、模型准备:准备残差块,ResNEt模型
3、数据训练:实例化训练模型,采用train_ch6函数进行迭代训练。

相关文章:

Pytorch 第九回:卷积神经网络——ResNet模型

Pytorch 第九回:卷积神经网络——ResNet模型 本次开启深度学习第九回,基于Pytorch的ResNet卷积神经网络模型。这是分享的第四个卷积神经网络模型。该模型是基于解决因网络加深而出现的梯度消失和网络退化而进行设计的。接下来给大家分享具体思路。 本次…...

2025-3-9 一周总结

目前来看本学期上半程汇编语言,编译原理,数字电路和离散数学是相对重点的课程. 在汇编语言和编译原理这块,个人感觉黑书内知识点更多,细节更到位,体系更完整,可以在老师讲解之前进行预习 应当及时复习每天的内容.第一是看书,然后听课,在一天结束后保证自己的知识梳理完整,没有…...

如何在el-input搜索框组件的最后面,添加图标按钮?

1、问题描述 2、解决步骤 在el-input组件标签内,添加一个element-plus的自定义插槽, 在插槽里放一个图标按钮即可。 3、效果展示 结语 以上就是在搜索框组件的末尾添加搜索按钮的过程。 喜欢本篇文章的话,请关注本博主~~...

[项目]基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信

基于FreeRTOS的STM32四轴飞行器: 六.2.4g通信 一.Si24Ri原理图二.Si24R1芯片手册解读三.驱动函数讲解五.移植2.4g通讯(飞控部分)六.移植2.4g通讯(遥控部分) 一.Si24Ri原理图 Si24R1芯片原理图如下: 右侧为晶振。 模块…...

Python爬取咸鱼Goodfish店铺所有商品接口的详细指南

在电商数据分析和市场研究中,爬取咸鱼店铺内的所有商品信息是一项极具价值的任务。通过调用咸鱼的goodfish.item_search_shop接口,可以获取指定店铺内的商品列表,包括商品标题、价格、图片链接、销量等详细信息。本文将详细介绍如何使用Pytho…...

【极光 Orbit•STC8A-8H】03. 小刀初试:点亮你的LED灯

【极光 Orbit•STC8H】03. 小刀初试:点亮你的 LED 灯 七律 点灯初探 单片方寸藏乾坤,LED明灭见真章。 端口配置定方向,寄存器值细推敲。 高低电平随心控,循环闪烁展锋芒。 嵌入式门初开启,从此代码手中扬。 摘要 …...

docker本地部署RagFlow

1.安装 克隆仓库 git clone https://github.com/infiniflow/ragflow.git构建预建的Docker映像并启动服务器 cd ragflow/docker chmod x ./entrypoint.sh docker compose -f docker-compose.yml -p ragflow up -d修改ragflow/docker/.env文件 #RAGFLOW_IMAGEinfiniflow/ragfl…...

STM32F4 UDP组播通信:填一填ST官方HAL库的坑

先说写作本文的原因,由于开项目开发中需要用到UDP组播接收的功能,但是ST官方没有提供合适的参考,使用STM32CubeMX生成的代码也是不能直接使用的,而我在网上找了一大圈,也没有一个能够直接解决的方案,deepse…...

基于python大数据的招聘数据可视化与推荐系统

博主介绍:资深开发工程师,从事互联网行业多年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了多年的设计程序开发,开发过上千套设计程序,没有什么华丽的语言,只有…...

10. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Ocelot 网关--认证

在微服务架构中,通过在网关层实现身份认证、权限校验和数据加密,可以有效防范恶意攻击和非法访问,保障内部服务安全。采用JWT、OAuth等主流认证机制,使每次请求均经过严格验证,降低安全漏洞风险。同时,统一…...

DeepSeek 3FS:端到端无缓存的存储新范式

在 2025 年 2 月 28 日,DeepSeek 正式开源了其高性能分布式文件系统 3FS【1】,作为其开源周的压轴项目,3FS 一经发布便引发了技术圈的热烈讨论。它不仅继承了分布式存储的经典设计,还通过极简却高效的架构,展现了存储技…...

vue3组合式API怎么获取全局变量globalProperties

设置全局变量 main.ts app.config.globalProperties.$category { index: 0 } 获取全局变量 const { appContext } getCurrentInstance() as ComponentInternalInstance console.log(appContext.config.globalProperties.$category) 或是 const { proxy } getCurrentInstance…...

【YOLOv12改进trick】多尺度大核注意力机制MLKA模块引入YOLOv12,实现多尺度目标检测涨点,含创新点Python代码,方便发论文

🍋改进模块🍋:多尺度大核注意力机制(MLKA) 🍋解决问题🍋:MLKA模块结合多尺度、门控机制和空间注意力,显著增强卷积网络的模型表示能力。 🍋改进优势🍋:超分辨的MLKA模块对小目标和模糊目标涨点很明显 🍋适用场景🍋:小目标检测、模糊目标检测等 🍋思路…...

网络安全之端口扫描(一)

前置介绍 什么是DVWA? DVWA(Damn Vulnerable Web Application)是一个专门设计用于测试和提高Web应用程序安全技能的开源PHP/MySQL Web应用程序。它是一个具有多个安全漏洞的故意不安全的应用程序,供安全专业人员、渗透测试人员、…...

HCIE云计算学什么?怎么学?未来职业发展如何?

随着云计算成为IT行业发展的主流方向,HCIE云计算(华为认证云计算专家)作为华为认证体系中的高端认证之一,逐渐成为了许多网络工程师和IT从业者提升职业竞争力的重要途径。 那么,HCIE云计算究竟学什么内容,如…...

upload-labs文件上传

第一关 上传一个1.jpg的文件,在里面写好一句webshell 保留一个数据包,将其中截获的1.jpg改为1.php后重新发送 可以看到,已经成功上传 第二关 写一个webshell如图,为2.php 第二关在过滤tpye的属性,在上传2.php后使用b…...

操作系统控制台-健康守护我们的系统

引言基本准备体验功能健康守护系统诊断 收获提升结语 引言 阿里云操作系统控制平台作为新一代云端服务器中枢平台,通过创新交互模式重构主机管理体验。操作系统控制台提供了一系列管理功能,包括运维监控、智能助手、扩展插件管理以及订阅服务等。用户可以…...

财务会计域——合并报表系统设计

摘要 本文主要介绍了合并报表系统的设计,包括其背景、业务流程和系统架构设计。合并报表系统可自动化生成数据,减少人为错误,确保报表合规。其业务流程涵盖数据收集、标准化、合并调整、报表生成、审核及披露等环节。系统架构设计包括数据接…...

教务考试管理系统-Sprintboot vue

一、前言 1.1 实践目的和要求 本次实践的目的是为了帮助学生强化对实践涉及专业技术知识的理解,掌握专业领域中软件知识的应用方法,并了解软件工程在具体行业领域的发展趋势。通过培养学生利用软件工程方法分析、设计并完成具体行业软件开发的能力&…...

vue实现一个pdf在线预览,pdf选择文本并提取复制文字触发弹窗效果

[TOC] 一、文件预览 1、安装依赖包 这里安装了disjs-dist2.16版本&#xff0c;安装过程中报错缺少worker-loader npm i pdfjs-dist2.16.105 worker-loader3.0.8 2、模板部分 <template><div id"pdf-view"><canvas v-for"page in pdfPages&qu…...

基于大模型的 UI 自动化系统

基于大模型的 UI 自动化系统 下面是一个完整的 Python 系统,利用大模型实现智能 UI 自动化,结合计算机视觉和自然语言处理技术,实现"看屏操作"的能力。 系统架构设计 #mermaid-svg-2gn2GRvh5WCP2ktF {font-family:"trebuchet ms",verdana,arial,sans-…...

学校招生小程序源码介绍

基于ThinkPHPFastAdminUniApp开发的学校招生小程序源码&#xff0c;专为学校招生场景量身打造&#xff0c;功能实用且操作便捷。 从技术架构来看&#xff0c;ThinkPHP提供稳定可靠的后台服务&#xff0c;FastAdmin加速开发流程&#xff0c;UniApp则保障小程序在多端有良好的兼…...

linux arm系统烧录

1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 &#xff08;忘了有没有这步了 估计有&#xff09; 刷机程序 和 镜像 就不提供了。要刷的时…...

【单片机期末】单片机系统设计

主要内容&#xff1a;系统状态机&#xff0c;系统时基&#xff0c;系统需求分析&#xff0c;系统构建&#xff0c;系统状态流图 一、题目要求 二、绘制系统状态流图 题目&#xff1a;根据上述描述绘制系统状态流图&#xff0c;注明状态转移条件及方向。 三、利用定时器产生时…...

WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)

一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解&#xff0c;适合用作学习或写简历项目背景说明。 &#x1f9e0; 一、概念简介&#xff1a;Solidity 合约开发 Solidity 是一种专门为 以太坊&#xff08;Ethereum&#xff09;平台编写智能合约的高级编…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现&#xff08;两者等价&#xff09;&#xff0c;用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例&#xff1a; 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)

本文把滑坡位移序列拆开、筛优质因子&#xff0c;再用 CNN-BiLSTM-Attention 来动态预测每个子序列&#xff0c;最后重构出总位移&#xff0c;预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵&#xff08;S…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...

探索Selenium:自动化测试的神奇钥匙

目录 一、Selenium 是什么1.1 定义与概念1.2 发展历程1.3 功能概述 二、Selenium 工作原理剖析2.1 架构组成2.2 工作流程2.3 通信机制 三、Selenium 的优势3.1 跨浏览器与平台支持3.2 丰富的语言支持3.3 强大的社区支持 四、Selenium 的应用场景4.1 Web 应用自动化测试4.2 数据…...

API网关Kong的鉴权与限流:高并发场景下的核心实践

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 引言 在微服务架构中&#xff0c;API网关承担着流量调度、安全防护和协议转换的核心职责。作为云原生时代的代表性网关&#xff0c;Kong凭借其插件化架构…...