当前位置: 首页 > article >正文

使用PyTorch实现ResNet:从残差块到完整模型训练

ResNet(残差网络)是深度学习中的经典模型,通过引入残差连接解决了深层网络训练中的梯度消失问题。本文将从残差块的定义开始,逐步实现一个ResNet模型,并在Fashion MNIST数据集上进行训练和测试。


1. 残差块(Residual Block)实现

残差块通过跳跃连接(Shortcut Connection)将输入直接传递到输出,缓解了深层网络的训练难题。以下是残差块的PyTorch实现:

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = Noneself.relu = nn.ReLU(inplace=True)def forward(self, x):y = F.relu(self.bn1(self.conv1(x)))y = self.bn2(self.conv2(y))if self.conv3:x = self.conv3(x)y += xreturn F.relu(y)

代码解析

  • use_1x1conv:当输入和输出通道数不一致时,使用1x1卷积调整通道数。

  • strides:控制特征图下采样的步长。

  • 残差相加后再次使用ReLU激活,增强非线性表达能力。


2. 构建ResNet模型

ResNet由多个残差块堆叠而成,以下代码构建了一个简化版ResNet-18:

# 初始卷积层
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:  # 第一个块需下采样blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk# 堆叠残差块
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))# 完整网络结构
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(),nn.Linear(512, 10)
)

模型结构说明

  • AdaptiveAvgPool2d:自适应平均池化,将特征图尺寸统一为1x1。

  • Flatten:展平特征用于全连接层分类。


3. 数据加载与预处理

使用Fashion MNIST数据集,批量大小为256:

train_data, test_data = d2l.load_data_fashion_mnist(batch_size=256)

4. 模型训练与测试

设置训练参数:10个epoch,学习率0.05,并使用GPU加速:

d2l.train_ch6(net, train_data, test_data, num_epochs=10, lr=0.05, device=d2l.try_gpu())

训练结果

loss 0.124, train acc 0.952, test acc 0.860
4921.4 examples/sec on cuda:0

5. 结果可视化

训练过程中损失和准确率变化如下图所示:

分析

  • 训练准确率(紫色虚线)迅速上升并稳定在95%以上。

  • 测试准确率(绿色点线)达到86%,表明模型具有良好的泛化能力。

  • 损失值(蓝色实线)持续下降,未出现过拟合。


6. 完整代码

整合所有代码片段(需安装d2l库):

# 残差块定义、模型构建、训练代码见上文

7. 总结

本文实现了ResNet的核心组件——残差块,并构建了一个简化版ResNet模型。通过实验验证,模型在Fashion MNIST数据集上表现良好。读者可尝试调整网络深度或超参数以进一步提升性能。

改进方向

  • 增加残差块数量构建更深的ResNet(如ResNet-34/50)。

  • 使用数据增强策略提升泛化能力。

  • 尝试不同的优化器和学习率调度策略。


注意事项

  • 确保已安装PyTorch和d2l库。

  • GPU环境可显著加速训练,若使用CPU需调整批量大小。


希望本文能帮助您理解ResNet的实现细节!如有疑问,欢迎在评论区留言讨论。

相关文章:

使用PyTorch实现ResNet:从残差块到完整模型训练

ResNet(残差网络)是深度学习中的经典模型,通过引入残差连接解决了深层网络训练中的梯度消失问题。本文将从残差块的定义开始,逐步实现一个ResNet模型,并在Fashion MNIST数据集上进行训练和测试。 1. 残差块&#xff08…...

Scala相关知识学习总结5

1、多维数组 定义: val arr Array.ofDim[Double](3,4) 表示二维数组中有三个一维数组,每个一维数组有四个元素。 2、列表 List 不可变 List:默认不可变,可创建有序且可重复的列表,可使用:从右向左增加数据&#xf…...

Day1:前端项目uni-app壁纸实战

uni-app官网下载HBuilder。 uni-app快速上手 | uni-app官网 点击HBuilder 安装 新建项目 工具——插件安装 安装uni-app(vue3) 我们先来准备一下: 先在wallpaper下新建目录 我已经建过了 同样,再在common下建images和style目录&…...

光谱相机的光谱数据采集原理

光谱相机的光谱数据采集原理基于‌分光技术‌和‌光电信号转换‌,通过将入射光按波长分解并记录各波段的强度信息,最终生成包含空间和光谱维度的数据立方体。以下是详细原理分解: ‌1. 分光技术:将复合光分解为单色光‌ 光谱相机…...

《算法笔记》10.3小节——图算法专题->图的遍历 问题 A: 第一题

题目描述 该题的目的是要你统计图的连通分支数。 输入 每个输入文件包含若干行,每行两个整数i,j,表示节点i和j之间存在一条边。 输出 输出每个图的联通分支数。 样例输入 1 4 4 3 5 5样例输出 2 分析: 由于题目没给出范围&#xff0…...

python中的{}

注意,如果要创建空集合,只能使用 set() 函数实现。因为直接使用一对 {},Python 解释器会将其视为一个空字典。 Python中集合set和字典dict的用法区别_python创建set变量和dict区别-CSDN博客...

宏碁笔记本电脑擎7PRO搭载的 NVIDIA RTX 5080 显卡安装pytorch

宏碁笔记本电脑擎7PRO搭载的 NVIDIA RTX 5080 显卡是一款高性能移动 GPU,基于 NVIDIA 最新的 Blackwell 架构设计,通过修正架构(Blackwell)、显存类型与带宽(GDDR7、960GB/s)、Tensor Core 与 RT Core 全面…...

html+css+js 实现一个贪吃蛇小游戏

目录 游戏简介 游戏功能与特点 如何玩转贪吃蛇 游戏设计与实现 HTML结构 JavaScript核心实现 代码结构: 效果 关于“其他游戏” 游戏简介 贪吃蛇是一款经典的单人小游戏,玩家通过控制蛇的移动,吃掉食物来增加长度,避免撞…...

淘宝按图搜索商品(拍立淘)API接口解析

以下是关于淘宝按图搜索商品(拍立淘)API的深度解析指南,结合官方文档和开发者经验整理,包含调用方法、参数详解、返回结果解析及常见问题处理: 一、API核心接口说明 1. 接口名称 官方接口:taobao.image.…...

Python爬虫生成CSV文件的完整流程

引言 在当今数据驱动的时代,网络爬虫已成为获取互联网数据的重要工具。Python凭借其丰富的库生态系统和简洁的语法,成为了爬虫开发的首选语言。本文将详细介绍使用Python爬虫从网页抓取数据并生成CSV文件的完整流程,包括环境准备、网页请求、…...

21.OpenCV获取图像轮廓信息

OpenCV获取图像轮廓信息 在计算机视觉领域,识别和分析图像中的对象形状是一项基本任务。OpenCV 库提供了一个强大的工具——轮廓检测(Contour Detection),它能够帮助我们精确地定位对象的边界。这篇博文将带你入门 OpenCV 的轮廓…...

医学图像分割效率大幅提升!U-Net架构升级,助力精度提升5%!

在医学图像分割领域,U-Net模型及其变体的创新应用正在带来显著的性能提升和效率优化。最新研究显示,通过引入结构化状态空间模型(SSM)和轻量级LSTM(xLSTM)等技术,VMAXL-UNet模型在多个医学图像数…...

智能设备运行监控系统

在工业 4.0 与智能制造浪潮下,设备运行效率与稳定性成为企业竞争力的核心要素。然而,传统设备管理模式面临数据采集分散、状态分析滞后、维护成本高昂等痛点。为破解这些难题,设备运行监控系统应运而生,通过融合智能传感、5G 通信…...

详细分析单例模式

目录 1.单例模式的定义 2.单例模式的实现方式 1.饿汉模式 2.懒汉模式 (1)线程不安全的问题怎么解决? (2)直接对整个getInstance方法代码块加锁吗? (3)那对if语句加锁不就行了吗…...

Windwos的DNS解析命令nslookup

nslookup 解析dns的命令 有两种使用方式,交互式&命令行方式。 交互式 C:\Users\Administrator>nslookup 默认服务器: UnKnown Address: fe80::52f7:edff:fe28:35de> www.baidu.com 服务器: UnKnown Address: fe80::52f7:edff:fe28:35de非权威应答:…...

服务器报错:xxx/libc.so.6: version `GLIBC_2.32‘ not found

/lib/x86_64-linux-gnu/libc.so.6: version GLIBC_2.32 not found (required by ./aima-sim-app-main) 解决思路 根据错误信息,您的应用程序 aima-sim-app-main 和 libmujoco.so.3.1.6 库依赖于较新的 GNU C Library (glibc) 版本(如 GLIBC_2.32, GLIBC…...

Flutter之页面布局一

目录: 1、页面布局一2、无状态组件StatelessWidget和有状态组件StatefulWidget2.1、无状态组件示例2.2、有状态组件示例2.3、在 widget 之间共享状态1、使用 widget 构造函数2、使用 InheritedWidget3、使用回调 3、布局小组件3.1、布置单个 Widget3.2、容器3.3、垂…...

架构思维: 数据一致性的两种场景深度解读

文章目录 Pre案例数据一致性问题的两种场景第一种场景:实时数据不一致不要紧,保证数据最终一致性就行第二种场景:必须保证实时一致性 最终一致性方案实时一致性方案TCC 模式Seata 中 AT 模式的自动回滚一阶段二阶段-回滚二阶段-提交 Pre 架构…...

大数据knox网关API

我们过去访问大数据组件,如sparkui,hdfs的页面,以及yarn上面看信息是很麻烦的一件事。要记每个端口号,比如50070,8090,8088,4007,如果换到另一个集群,不同版本&#xff0…...

UI测试(2)

1、HTML 是用来描述网页的一种语言。 指的是超文本标记语言 (Hyper Text Markup Language) &#xff0c;HTML 不是一种编程语言&#xff0c;而是一种标记语言 (markup language) 负责定义页面呈现的内容&#xff1a;标签语言&#xff1a;<标签名>标签值<标签名>&am…...

【Tauri2】015——前端的事件、方法和invoke函数

目录 前言 正文 准备 关键url 获取所有命令 切换主题set_theme 设置大小 获得版本version 名字name 监听窗口移动 前言 【Tauri2】005——tauri::command属性与invoke函数-CSDN博客https://blog.csdn.net/qq_63401240/article/details/146581991?spm1001.2014.3001.…...

密码学基础——分组密码的运行模式

前面的文章中文我们已经知道了分组密码是一种对称密钥密码体制&#xff0c;其工作原理可以概括为将明文消息分割成固定长度的分组&#xff0c;然后对每个分组分别进行加密处理。 下面介绍分组密码的运行模式 1.电码本模式&#xff08;ECB&#xff09; 2.密码分组链接模式&…...

Android SELinux权限使用

Android SELinux权限使用 一、SELinux开关 adb在线修改seLinux(也可以改配置文件彻底关闭) $ getenforce; //获取当前seLinux状态,Enforcing(表示已打开),Permissive(表示已关闭) $ setenforce 1; //打开seLinux $ setenforce 0; //关闭seLinux二、命令查看sel…...

Python----计算机视觉处理(Opencv:道路检测完整版:透视变换,提取车道线,车道线拟合,车道线显示,)

Python----计算机视觉处理&#xff08;Opencv:道路检测之道路透视变换) Python----计算机视觉处理&#xff08;Opencv:道路检测之提取车道线&#xff09; Python----计算机视觉处理&#xff08;Opencv:道路检测之车道线拟合&#xff09; Python----计算机视觉处理&#xff0…...

基于飞桨框架3.0本地DeepSeek-R1蒸馏版部署实战

深度学习框架与大模型技术的融合正推动人工智能应用的新一轮变革。百度飞桨&#xff08;PaddlePaddle&#xff09;作为国内首个自主研发、开源开放的深度学习平台&#xff0c;近期推出的3.0版本针对大模型时代的开发痛点进行了系统性革新。其核心创新包括“动静统一自动并行”&…...

docker初始环境搭建(docker、Docker Compose、portainer)

docker、Docker Compose和portainer的安装部署、使用 docker、Docker Compose和portainer的安装部署、使用一.安装docker1.失败的做法2.首先卸载旧版本&#xff08;没安装则下一步&#xff09;3.配置下载的yum来源&#xff0c;不然yum search搜不到4.安装启动docker5.替换国内源…...

开源RuoYi AI助手平台的未来趋势

近年来&#xff0c;人工智能技术的迅猛发展已经深刻地改变了我们的生活和工作方式。 无论是海外的GPT、Claude等国际知名AI助手&#xff0c;还是国内的DeepSeek、Kimi、Qwen等本土化解决方案&#xff0c;都为用户提供了前所未有的便利。然而&#xff0c;对于那些希望构建属于自…...

element-ui自制树形穿梭框

1、需求 由于业务特殊需求&#xff0c;想要element穿梭框功能&#xff0c;数据是二级树形结构&#xff0c;选中左边数据穿梭到右边后&#xff0c;左边数据不变。多次选中左边相同数据进行穿梭操作&#xff0c;右边数据会多次增加相同的数据。右边数据穿梭回左边时&#xff0c;…...

Linux系统学习Day04 阻塞特性,文件状态及文件夹查询

知识点4【文件的阻塞特性】 文件描述符 默认为 阻塞 的 比如&#xff1a;我们读取文件数据的时候&#xff0c;如果文件缓冲区没有数据&#xff0c;就需要等待数据的到来&#xff0c;这就是阻塞 当然写入的时候&#xff0c;如果发现缓冲区是满的&#xff0c;也需要等待刷新缓…...

Module模块化

导出&#xff1a;export关键字 export var color "red"; 重命名导出 在模块中使用as用导出名称表示本地名称。 import { add } from "./05-module-out.js"; 导入&#xff1a; import关键字 导入单个绑定 import { sum } from "./05-module-out.js&…...